Compare commits
15 Commits
a051f83050
...
0c0d0c1443
| Author | SHA1 | Date | |
|---|---|---|---|
| 0c0d0c1443 | |||
| 21b4e60c44 | |||
| a4399037ac | |||
| 075b92f231 | |||
| 02a7fa68c6 | |||
| 98bc9739c6 | |||
| b13bb8c556 | |||
| 8936e70a12 | |||
| 1581732ef9 | |||
| 68e3663c2f | |||
| f3187ceddd | |||
| 16b052bd21 | |||
| a9e7aabed4 | |||
| d5b6cd24fc | |||
| 2dada36bc6 |
@ -4,7 +4,7 @@ version = "0.1.0"
|
|||||||
edition = "2024"
|
edition = "2024"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
reqwest = { version = "0.13.2", default-features = false, features = ["json", "rustls"] }
|
reqwest = { version = "0.13.2", default-features = false, features = ["json", "rustls", "multipart"] }
|
||||||
dotenv = "0.15"
|
dotenv = "0.15"
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
regex = "1.0"
|
regex = "1.0"
|
||||||
@ -23,3 +23,6 @@ tracing = "0.1"
|
|||||||
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
|
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
|
||||||
tracing-appender = "0.2"
|
tracing-appender = "0.2"
|
||||||
anyhow = "1.0"
|
anyhow = "1.0"
|
||||||
|
mime_guess = "2.0"
|
||||||
|
base64 = "0.22"
|
||||||
|
tempfile = "3"
|
||||||
|
|||||||
346
IMPLEMENTATION_LOG.md
Normal file
346
IMPLEMENTATION_LOG.md
Normal file
@ -0,0 +1,346 @@
|
|||||||
|
# Picobot 工具机制增强实现日志
|
||||||
|
|
||||||
|
## 实现记录
|
||||||
|
|
||||||
|
### 1. SchemaCleanr - 跨 Provider Schema 归一化
|
||||||
|
|
||||||
|
**日期**: 2026-04-07
|
||||||
|
**Commit**: `d5b6cd2`
|
||||||
|
|
||||||
|
#### 背景
|
||||||
|
不同 LLM provider 对 JSON Schema 支持差异很大:
|
||||||
|
- **Gemini**: 最严格,不支持 `minLength`, `maxLength`, `pattern`, `minimum`, `maximum` 等
|
||||||
|
- **Anthropic**: 中等,只要求解决 `$ref`
|
||||||
|
- **OpenAI**: 最宽松,支持大部分关键词
|
||||||
|
|
||||||
|
#### 实现方案
|
||||||
|
|
||||||
|
创建 `src/tools/schema.rs`,提供:
|
||||||
|
|
||||||
|
1. **`CleaningStrategy` enum**
|
||||||
|
```rust
|
||||||
|
pub enum CleaningStrategy {
|
||||||
|
Gemini, // 最严格
|
||||||
|
Anthropic, // 中等
|
||||||
|
OpenAI, // 最宽松
|
||||||
|
Conservative,
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **`SchemaCleanr::clean()`** - 核心清洗函数
|
||||||
|
- 移除 provider 不支持的关键词
|
||||||
|
- 解析 `$ref` 到 `$defs`/`definitions`
|
||||||
|
- 将 `anyOf`/`oneOf` 合并为 `enum`
|
||||||
|
- 将 `const` 转换为 `enum`
|
||||||
|
- 移除 `type` 数组中的 `null`
|
||||||
|
|
||||||
|
3. **`SchemaCleanr::validate()`** - Schema 验证
|
||||||
|
|
||||||
|
#### 使用方法
|
||||||
|
```rust
|
||||||
|
use picobot::tools::{SchemaCleanr, CleaningStrategy};
|
||||||
|
|
||||||
|
// Gemini 兼容清洗(最严格)
|
||||||
|
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||||
|
|
||||||
|
// Anthropic 兼容清洗
|
||||||
|
let cleaned = SchemaCleanr::clean_for_anthropic(schema);
|
||||||
|
|
||||||
|
// OpenAI 兼容清洗(最宽松)
|
||||||
|
let cleaned = SchemaCleanr::clean_for_openai(schema);
|
||||||
|
|
||||||
|
// 自定义策略
|
||||||
|
let cleaned = SchemaCleanr::clean(schema, CleaningStrategy::Conservative);
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 工具 Trait 增强
|
||||||
|
|
||||||
|
在 `src/tools/traits.rs` 的 `Tool` trait 中新增:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub trait Tool: Send + Sync + 'static {
|
||||||
|
// ... 原有方法 ...
|
||||||
|
|
||||||
|
/// 是否只读(无副作用)
|
||||||
|
fn read_only(&self) -> bool { false }
|
||||||
|
|
||||||
|
/// 是否可以与其他工具并行执行
|
||||||
|
fn concurrency_safe(&self) -> bool {
|
||||||
|
self.read_only() && !self.exclusive()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 是否需要独占执行
|
||||||
|
fn exclusive(&self) -> bool { false }
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
这些属性为后续的并行工具执行提供基础。
|
||||||
|
|
||||||
|
#### 测试
|
||||||
|
- 12 个单元测试覆盖所有清洗逻辑
|
||||||
|
- 运行 `cargo test --lib tools::schema` 验证
|
||||||
|
|
||||||
|
### 2. file_read 工具
|
||||||
|
|
||||||
|
**日期**: 2026-04-07
|
||||||
|
**Commit**: `a9e7aab`
|
||||||
|
|
||||||
|
#### 功能
|
||||||
|
- 读取文件内容(支持 offset/limit 分页)
|
||||||
|
- 返回带行号的内容,便于引用
|
||||||
|
- 自动处理二进制文件(base64 编码)
|
||||||
|
- 可选的目录限制(安全隔离)
|
||||||
|
|
||||||
|
#### Schema
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": { "type": "string", "description": "文件路径" },
|
||||||
|
"offset": { "type": "integer", "description": "起始行号(1-indexed)" },
|
||||||
|
"limit": { "type": "integer", "description": "最大行数" }
|
||||||
|
},
|
||||||
|
"required": ["path"]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 使用方法
|
||||||
|
```rust
|
||||||
|
use picobot::tools::FileReadTool;
|
||||||
|
|
||||||
|
// 基本用法
|
||||||
|
let tool = FileReadTool::new();
|
||||||
|
let result = tool.execute(json!({
|
||||||
|
"path": "/some/file.txt",
|
||||||
|
"offset": 1,
|
||||||
|
"limit": 100
|
||||||
|
})).await;
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 测试
|
||||||
|
- 4 个单元测试
|
||||||
|
- `cargo test --lib tools::file_read`
|
||||||
|
|
||||||
|
### 3. file_write 工具
|
||||||
|
|
||||||
|
**日期**: 2026-04-07
|
||||||
|
**Commit**: `16b052b`
|
||||||
|
|
||||||
|
#### 功能
|
||||||
|
- 写入内容到文件
|
||||||
|
- 自动创建父目录
|
||||||
|
- 覆盖已存在文件
|
||||||
|
|
||||||
|
#### Schema
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": { "type": "string", "description": "文件路径" },
|
||||||
|
"content": { "type": "string", "description": "写入内容" }
|
||||||
|
},
|
||||||
|
"required": ["path", "content"]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 测试
|
||||||
|
- 5 个单元测试
|
||||||
|
- `cargo test --lib tools::file_write`
|
||||||
|
|
||||||
|
### 4. file_edit 工具
|
||||||
|
|
||||||
|
**日期**: 2026-04-07
|
||||||
|
**Commit**: `f3187ce`
|
||||||
|
|
||||||
|
#### 功能
|
||||||
|
- 编辑文件,替换 old_text 为 new_text
|
||||||
|
- 支持多行编辑
|
||||||
|
- 模糊匹配处理微小差异
|
||||||
|
- replace_all 选项批量替换
|
||||||
|
|
||||||
|
#### Schema
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": { "type": "string", "description": "文件路径" },
|
||||||
|
"old_text": { "type": "string", "description": "要替换的文本" },
|
||||||
|
"new_text": { "type": "string", "description": "替换后的文本" },
|
||||||
|
"replace_all": { "type": "boolean", "description": "替换所有匹配", "default": false }
|
||||||
|
},
|
||||||
|
"required": ["path", "old_text", "new_text"]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 测试
|
||||||
|
- 5 个单元测试
|
||||||
|
- `cargo test --lib tools::file_edit`
|
||||||
|
|
||||||
|
### 5. bash 工具
|
||||||
|
|
||||||
|
**日期**: 2026-04-07
|
||||||
|
**Commit**: `68e3663`
|
||||||
|
|
||||||
|
#### 功能
|
||||||
|
- 执行 shell 命令
|
||||||
|
- 超时控制
|
||||||
|
- 危险命令检测(rm -rf, fork bombs)
|
||||||
|
- 输出截断
|
||||||
|
- 工作目录支持
|
||||||
|
|
||||||
|
#### Schema
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"command": { "type": "string", "description": "Shell 命令" },
|
||||||
|
"timeout": { "type": "integer", "description": "超时秒数", "minimum": 1, "maximum": 600 }
|
||||||
|
},
|
||||||
|
"required": ["command"]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 测试
|
||||||
|
- 7 个单元测试
|
||||||
|
- `cargo test --lib tools::bash`
|
||||||
|
|
||||||
|
### 6. http_request 工具
|
||||||
|
|
||||||
|
**日期**: 2026-04-07
|
||||||
|
**Commit**: `1581732`
|
||||||
|
|
||||||
|
#### 功能
|
||||||
|
- HTTP 客户端支持 GET/POST/PUT/DELETE/PATCH
|
||||||
|
- 域名白名单
|
||||||
|
- SSRF 保护(阻止私有IP、localhost)
|
||||||
|
- 响应大小限制和截断
|
||||||
|
- 超时控制
|
||||||
|
|
||||||
|
#### Schema
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"url": { "type": "string", "description": "请求 URL" },
|
||||||
|
"method": { "type": "string", "description": "HTTP 方法", "enum": ["GET", "POST", "PUT", "DELETE", "PATCH"] },
|
||||||
|
"headers": { "type": "object", "description": "请求头" },
|
||||||
|
"body": { "type": "string", "description": "请求体" }
|
||||||
|
},
|
||||||
|
"required": ["url"]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 测试
|
||||||
|
- 8 个单元测试
|
||||||
|
- `cargo test --lib tools::http_request`
|
||||||
|
|
||||||
|
### 7. web_fetch 工具
|
||||||
|
|
||||||
|
**日期**: 2026-04-07
|
||||||
|
**Commit**: `8936e70`
|
||||||
|
|
||||||
|
#### 功能
|
||||||
|
- 获取 URL 并提取可读文本
|
||||||
|
- HTML 转纯文本
|
||||||
|
- 移除 scripts, styles, HTML 标签
|
||||||
|
- 解码 HTML 实体
|
||||||
|
- JSON 格式化输出
|
||||||
|
- SSRF 保护
|
||||||
|
|
||||||
|
#### Schema
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"url": { "type": "string", "description": "要获取的 URL" }
|
||||||
|
},
|
||||||
|
"required": ["url"]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 测试
|
||||||
|
- 6 个单元测试
|
||||||
|
- `cargo test --lib tools::web_fetch`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 工具清单
|
||||||
|
|
||||||
|
| 工具 | 名称 | 文件 | 功能 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| calculator | 计算器 | `src/tools/calculator.rs` | 25+ 数学和统计函数 |
|
||||||
|
| file_read | 文件读取 | `src/tools/file_read.rs` | 带分页的文件读取 |
|
||||||
|
| file_write | 文件写入 | `src/tools/file_write.rs` | 创建/覆盖文件 |
|
||||||
|
| file_edit | 文件编辑 | `src/tools/file_edit.rs` | 文本替换编辑 |
|
||||||
|
| bash | Shell 执行 | `src/tools/bash.rs` | 带安全保护的命令执行 |
|
||||||
|
| http_request | HTTP 请求 | `src/tools/http_request.rs` | API 请求 |
|
||||||
|
| web_fetch | 网页获取 | `src/tools/web_fetch.rs` | HTML 内容提取 |
|
||||||
|
|
||||||
|
## 工具机制增强
|
||||||
|
|
||||||
|
### SchemaCleanr
|
||||||
|
跨 LLM Provider 的 JSON Schema 归一化,支持:
|
||||||
|
- Gemini (最严格)
|
||||||
|
- Anthropic (中等)
|
||||||
|
- OpenAI (最宽松)
|
||||||
|
- Conservative (保守)
|
||||||
|
|
||||||
|
### 工具属性
|
||||||
|
```rust
|
||||||
|
fn read_only(&self) -> bool { false } // 是否只读
|
||||||
|
fn concurrency_safe(&self) -> bool { true } // 是否可并行
|
||||||
|
fn exclusive(&self) -> bool { false } // 是否独占
|
||||||
|
```
|
||||||
|
|
||||||
|
## 运行测试
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo test --lib # 所有测试
|
||||||
|
cargo test --lib tools::schema # SchemaCleanr
|
||||||
|
```
|
||||||
|
|
||||||
|
## 工具注册
|
||||||
|
|
||||||
|
工具在 `src/gateway/session.rs` 的 `default_tools()` 函数中注册:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
fn default_tools() -> ToolRegistry {
|
||||||
|
let mut registry = ToolRegistry::new();
|
||||||
|
registry.register(CalculatorTool::new());
|
||||||
|
registry.register(FileReadTool::new());
|
||||||
|
registry.register(FileWriteTool::new());
|
||||||
|
registry.register(FileEditTool::new());
|
||||||
|
registry.register(BashTool::new());
|
||||||
|
registry.register(HttpRequestTool::new(
|
||||||
|
vec!["*".to_string()], // 允许所有域名
|
||||||
|
1_000_000, // max_response_size
|
||||||
|
30, // timeout_secs
|
||||||
|
false, // allow_private_hosts
|
||||||
|
));
|
||||||
|
registry.register(WebFetchTool::new(50_000, 30));
|
||||||
|
registry
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
SessionManager 使用这些工具创建 AgentLoop 实例,所有工具自动对 LLM 可用。
|
||||||
|
|
||||||
|
## 更新日志
|
||||||
|
|
||||||
|
| 日期 | Commit | 变更 |
|
||||||
|
|------|--------|------|
|
||||||
|
| 2026-04-07 | `d5b6cd2` | feat: add SchemaCleanr |
|
||||||
|
| 2026-04-07 | `a9e7aab` | feat: add file_read tool |
|
||||||
|
| 2026-04-07 | `16b052b` | feat: add file_write tool |
|
||||||
|
| 2026-04-07 | `f3187ce` | feat: add file_edit tool |
|
||||||
|
| 2026-04-07 | `68e3663` | feat: add bash tool |
|
||||||
|
| 2026-04-07 | `1581732` | feat: add http_request tool |
|
||||||
|
| 2026-04-07 | `8936e70` | feat: add web_fetch tool |
|
||||||
|
| 2026-04-07 | `b13bb8c` | docs: add implementation log |
|
||||||
|
| 2026-04-08 | `98bc973` | feat: register all tools in SessionManager |
|
||||||
|
cargo test --lib tools::file_read # file_read
|
||||||
|
cargo test --lib tools::file_write # file_write
|
||||||
|
cargo test --lib tools::file_edit # file_edit
|
||||||
|
cargo test --lib tools::bash # bash
|
||||||
|
cargo test --lib tools::http_request # http_request
|
||||||
|
cargo test --lib tools::web_fetch # web_fetch
|
||||||
|
```
|
||||||
@ -1,33 +1,72 @@
|
|||||||
|
use crate::bus::message::ContentBlock;
|
||||||
use crate::bus::ChatMessage;
|
use crate::bus::ChatMessage;
|
||||||
use crate::config::LLMProviderConfig;
|
use crate::config::LLMProviderConfig;
|
||||||
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall};
|
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall};
|
||||||
use crate::tools::ToolRegistry;
|
use crate::tools::ToolRegistry;
|
||||||
|
use std::io::Read;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// Build content blocks from text and media paths
|
||||||
|
fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock> {
|
||||||
|
let mut blocks = Vec::new();
|
||||||
|
|
||||||
|
// Add text block if there's text
|
||||||
|
if !text.is_empty() {
|
||||||
|
blocks.push(ContentBlock::text(text));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add image blocks for media paths
|
||||||
|
for path in media_paths {
|
||||||
|
if let Ok((mime_type, base64_data)) = encode_image_to_base64(path) {
|
||||||
|
let url = format!("data:{};base64,{}", mime_type, base64_data);
|
||||||
|
blocks.push(ContentBlock::image_url(url));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If nothing, add empty text block
|
||||||
|
if blocks.is_empty() {
|
||||||
|
blocks.push(ContentBlock::text(""));
|
||||||
|
}
|
||||||
|
|
||||||
|
blocks
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Encode an image file to base64 data URL
|
||||||
|
fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> {
|
||||||
|
use base64::{Engine as _, engine::general_purpose::STANDARD};
|
||||||
|
|
||||||
|
let mut file = std::fs::File::open(path)?;
|
||||||
|
let mut buffer = Vec::new();
|
||||||
|
file.read_to_end(&mut buffer)?;
|
||||||
|
|
||||||
|
let mime = mime_guess::from_path(path)
|
||||||
|
.first_or_octet_stream()
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
let encoded = STANDARD.encode(&buffer);
|
||||||
|
Ok((mime, encoded))
|
||||||
|
}
|
||||||
|
|
||||||
/// Stateless AgentLoop - history is managed externally by SessionManager
|
/// Stateless AgentLoop - history is managed externally by SessionManager
|
||||||
pub struct AgentLoop {
|
pub struct AgentLoop {
|
||||||
provider: Box<dyn LLMProvider>,
|
provider: Box<dyn LLMProvider>,
|
||||||
tools: Arc<ToolRegistry>,
|
tools: Arc<ToolRegistry>,
|
||||||
|
max_iterations: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AgentLoop {
|
impl AgentLoop {
|
||||||
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
|
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
|
||||||
let provider = create_provider(provider_config)
|
Self::with_tools(provider_config, Arc::new(ToolRegistry::new()))
|
||||||
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
provider,
|
|
||||||
tools: Arc::new(ToolRegistry::new()),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>) -> Result<Self, AgentError> {
|
pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>) -> Result<Self, AgentError> {
|
||||||
let provider = create_provider(provider_config)
|
let provider = create_provider(provider_config.clone())
|
||||||
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
provider,
|
provider,
|
||||||
tools,
|
tools,
|
||||||
|
max_iterations: provider_config.max_iterations,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -37,18 +76,16 @@ impl AgentLoop {
|
|||||||
|
|
||||||
/// Process a message using the provided conversation history.
|
/// Process a message using the provided conversation history.
|
||||||
/// History management is handled externally by SessionManager.
|
/// History management is handled externally by SessionManager.
|
||||||
pub async fn process(&self, messages: Vec<ChatMessage>) -> Result<ChatMessage, AgentError> {
|
/// Returns (final_response, complete_message_history) where the history includes
|
||||||
let messages_for_llm: Vec<Message> = messages
|
/// all tool calls and results for proper session continuity.
|
||||||
.iter()
|
pub async fn process(&self, messages: Vec<ChatMessage>) -> Result<(ChatMessage, Vec<ChatMessage>), AgentError> {
|
||||||
.map(|m| Message {
|
let mut messages = messages;
|
||||||
role: m.role.clone(),
|
let mut final_content: String = String::new();
|
||||||
content: m.content.clone(),
|
|
||||||
tool_call_id: m.tool_call_id.clone(),
|
|
||||||
name: m.tool_name.clone(),
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
tracing::debug!(history_len = messages.len(), "Sending request to LLM");
|
for iteration in 0..self.max_iterations {
|
||||||
|
tracing::debug!(iteration, history_len = messages.len(), "Starting iteration");
|
||||||
|
|
||||||
|
let messages_for_llm = self.build_messages_for_llm(&messages);
|
||||||
|
|
||||||
let tools = if self.tools.has_tools() {
|
let tools = if self.tools.has_tools() {
|
||||||
Some(self.tools.get_definitions())
|
Some(self.tools.get_definitions())
|
||||||
@ -76,11 +113,10 @@ impl AgentLoop {
|
|||||||
);
|
);
|
||||||
|
|
||||||
if !response.tool_calls.is_empty() {
|
if !response.tool_calls.is_empty() {
|
||||||
tracing::info!(count = response.tool_calls.len(), "Tool calls detected, executing tools");
|
tracing::info!(count = response.tool_calls.len(), iteration, tools = ?response.tool_calls.iter().map(|tc| &tc.name).collect::<Vec<_>>(), "Tool calls detected, executing tools");
|
||||||
|
|
||||||
let mut updated_messages = messages.clone();
|
|
||||||
let assistant_message = ChatMessage::assistant(response.content.clone());
|
let assistant_message = ChatMessage::assistant(response.content.clone());
|
||||||
updated_messages.push(assistant_message.clone());
|
messages.push(assistant_message);
|
||||||
|
|
||||||
let tool_results = self.execute_tools(&response.tool_calls).await;
|
let tool_results = self.execute_tools(&response.tool_calls).await;
|
||||||
|
|
||||||
@ -90,61 +126,104 @@ impl AgentLoop {
|
|||||||
tool_call.name.clone(),
|
tool_call.name.clone(),
|
||||||
result.clone(),
|
result.clone(),
|
||||||
);
|
);
|
||||||
updated_messages.push(tool_message);
|
messages.push(tool_message);
|
||||||
}
|
}
|
||||||
|
|
||||||
return self.continue_with_tool_results(updated_messages).await;
|
tracing::debug!(iteration, "Tool execution completed, continuing to next iteration");
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let assistant_message = ChatMessage::assistant(response.content);
|
tracing::debug!(iteration, "No tool calls in response, agent loop ending");
|
||||||
Ok(assistant_message)
|
final_content = response.content;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn continue_with_tool_results(&self, messages: Vec<ChatMessage>) -> Result<ChatMessage, AgentError> {
|
if final_content.is_empty() {
|
||||||
let messages_for_llm: Vec<Message> = messages
|
tracing::warn!(iterations = self.max_iterations, "Max iterations reached without final response");
|
||||||
|
final_content = format!("Error: Max iterations ({}) reached without final response", self.max_iterations);
|
||||||
|
}
|
||||||
|
|
||||||
|
let final_message = ChatMessage::assistant(final_content);
|
||||||
|
// Return both the final message and the complete history for session persistence
|
||||||
|
Ok((final_message, messages))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_messages_for_llm(&self, messages: &[ChatMessage]) -> Vec<Message> {
|
||||||
|
messages
|
||||||
.iter()
|
.iter()
|
||||||
.map(|m| Message {
|
.map(|m| {
|
||||||
|
let content = if m.media_refs.is_empty() {
|
||||||
|
vec![ContentBlock::text(&m.content)]
|
||||||
|
} else {
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
tracing::debug!(media_refs = ?m.media_refs, "Building content blocks with media");
|
||||||
|
build_content_blocks(&m.content, &m.media_refs)
|
||||||
|
};
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
tracing::debug!(role = %m.role, content_len = %m.content.len(), media_refs_len = %m.media_refs.len(), "ChatMessage converted to LLM Message");
|
||||||
|
Message {
|
||||||
role: m.role.clone(),
|
role: m.role.clone(),
|
||||||
content: m.content.clone(),
|
content,
|
||||||
tool_call_id: m.tool_call_id.clone(),
|
tool_call_id: m.tool_call_id.clone(),
|
||||||
name: m.tool_name.clone(),
|
name: m.tool_name.clone(),
|
||||||
|
}
|
||||||
})
|
})
|
||||||
.collect();
|
.collect()
|
||||||
|
|
||||||
let tools = if self.tools.has_tools() {
|
|
||||||
Some(self.tools.get_definitions())
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
let request = ChatCompletionRequest {
|
|
||||||
messages: messages_for_llm,
|
|
||||||
temperature: None,
|
|
||||||
max_tokens: None,
|
|
||||||
tools,
|
|
||||||
};
|
|
||||||
|
|
||||||
let response = (*self.provider).chat(request).await
|
|
||||||
.map_err(|e| {
|
|
||||||
tracing::error!(error = %e, "LLM continuation request failed");
|
|
||||||
AgentError::LlmError(e.to_string())
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let assistant_message = ChatMessage::assistant(response.content);
|
|
||||||
Ok(assistant_message)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn execute_tools(&self, tool_calls: &[ToolCall]) -> Vec<String> {
|
async fn execute_tools(&self, tool_calls: &[ToolCall]) -> Vec<String> {
|
||||||
|
let batches = self.partition_tool_batches(tool_calls);
|
||||||
let mut results = Vec::with_capacity(tool_calls.len());
|
let mut results = Vec::with_capacity(tool_calls.len());
|
||||||
|
|
||||||
for tool_call in tool_calls {
|
for batch in batches {
|
||||||
let result = self.execute_tool(tool_call).await;
|
if batch.len() == 1 {
|
||||||
results.push(result);
|
// Single tool — run directly (no spawn overhead)
|
||||||
|
results.push(self.execute_tool(&batch[0]).await);
|
||||||
|
} else {
|
||||||
|
// Multiple tools — run in parallel via join_all
|
||||||
|
use futures_util::future::join_all;
|
||||||
|
let futures = batch.iter().map(|tc| self.execute_tool(tc));
|
||||||
|
let batch_results = join_all(futures).await;
|
||||||
|
results.extend(batch_results);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
results
|
results
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Partition tool calls into batches based on concurrency safety.
|
||||||
|
///
|
||||||
|
/// `concurrency_safe` tools are grouped together; each `exclusive` tool
|
||||||
|
/// runs in its own batch. This matches the approach used in Nanobot's
|
||||||
|
/// `_partition_tool_batches` and Zeroclaw's `parallel_tools` config.
|
||||||
|
fn partition_tool_batches(&self, tool_calls: &[ToolCall]) -> Vec<Vec<ToolCall>> {
|
||||||
|
let mut batches: Vec<Vec<ToolCall>> = Vec::new();
|
||||||
|
let mut current: Vec<ToolCall> = Vec::new();
|
||||||
|
|
||||||
|
for tc in tool_calls {
|
||||||
|
let concurrency_safe = self
|
||||||
|
.tools
|
||||||
|
.get(&tc.name)
|
||||||
|
.map(|t| t.concurrency_safe())
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
if concurrency_safe {
|
||||||
|
current.push(tc.clone());
|
||||||
|
} else {
|
||||||
|
if !current.is_empty() {
|
||||||
|
batches.push(std::mem::take(&mut current));
|
||||||
|
}
|
||||||
|
batches.push(vec![tc.clone()]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !current.is_empty() {
|
||||||
|
batches.push(current);
|
||||||
|
}
|
||||||
|
|
||||||
|
batches
|
||||||
|
}
|
||||||
|
|
||||||
async fn execute_tool(&self, tool_call: &ToolCall) -> String {
|
async fn execute_tool(&self, tool_call: &ToolCall) -> String {
|
||||||
let tool = match self.tools.get(&tool_call.name) {
|
let tool = match self.tools.get(&tool_call.name) {
|
||||||
Some(t) => t,
|
Some(t) => t,
|
||||||
@ -188,3 +267,140 @@ impl std::fmt::Display for AgentError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl std::error::Error for AgentError {}
|
impl std::error::Error for AgentError {}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::providers::ToolCall;
|
||||||
|
use crate::tools::ToolRegistry;
|
||||||
|
use crate::tools::CalculatorTool;
|
||||||
|
use crate::tools::BashTool;
|
||||||
|
use crate::tools::FileReadTool;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
fn make_tc(name: &str, args: serde_json::Value) -> ToolCall {
|
||||||
|
ToolCall {
|
||||||
|
id: format!("tc_{}", name),
|
||||||
|
name: name.to_string(),
|
||||||
|
arguments: args,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Verify that partition_tool_batches groups concurrency-safe tools together
|
||||||
|
/// and isolates exclusive tools, matching the nanobot/zeroclaw approach.
|
||||||
|
#[test]
|
||||||
|
fn test_partition_batches_mixes_safe_and_exclusive() {
|
||||||
|
let registry = Arc::new({
|
||||||
|
let mut r = ToolRegistry::new();
|
||||||
|
r.register(CalculatorTool::new()); // concurrency_safe = true
|
||||||
|
r.register(BashTool::new()); // concurrency_safe = false (exclusive)
|
||||||
|
r.register(FileReadTool::new()); // concurrency_safe = true
|
||||||
|
r
|
||||||
|
});
|
||||||
|
|
||||||
|
// agent_loop needs a provider to construct; test the partitioning logic directly
|
||||||
|
let tcs = vec![
|
||||||
|
make_tc("calculator", json!({})),
|
||||||
|
make_tc("bash", json!({"command": "ls"})),
|
||||||
|
make_tc("file_read", json!({"path": "/tmp/foo"})),
|
||||||
|
make_tc("calculator", json!({})),
|
||||||
|
];
|
||||||
|
|
||||||
|
// Expected:
|
||||||
|
// batch 1: calculator (safe, first run)
|
||||||
|
// batch 2: bash (exclusive, alone)
|
||||||
|
// batch 3: file_read, calculator (both safe, run together)
|
||||||
|
let batches = partition_for_test(®istry, &tcs);
|
||||||
|
assert_eq!(batches.len(), 3);
|
||||||
|
assert_eq!(batches[0].len(), 1);
|
||||||
|
assert_eq!(batches[0][0].name, "calculator");
|
||||||
|
assert_eq!(batches[1].len(), 1);
|
||||||
|
assert_eq!(batches[1][0].name, "bash");
|
||||||
|
assert_eq!(batches[2].len(), 2);
|
||||||
|
assert_eq!(batches[2][0].name, "file_read");
|
||||||
|
assert_eq!(batches[2][1].name, "calculator");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// All-safe tool calls should produce a single batch (parallel execution).
|
||||||
|
#[test]
|
||||||
|
fn test_partition_batches_all_safe_single_batch() {
|
||||||
|
let registry = Arc::new({
|
||||||
|
let mut r = ToolRegistry::new();
|
||||||
|
r.register(CalculatorTool::new());
|
||||||
|
r.register(FileReadTool::new());
|
||||||
|
r
|
||||||
|
});
|
||||||
|
|
||||||
|
let tcs = vec![
|
||||||
|
make_tc("calculator", json!({})),
|
||||||
|
make_tc("file_read", json!({"path": "/tmp/foo"})),
|
||||||
|
];
|
||||||
|
|
||||||
|
let batches = partition_for_test(®istry, &tcs);
|
||||||
|
assert_eq!(batches.len(), 1);
|
||||||
|
assert_eq!(batches[0].len(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// All-exclusive tool calls should each get their own batch (sequential execution).
|
||||||
|
#[test]
|
||||||
|
fn test_partition_batches_all_exclusive_separate_batches() {
|
||||||
|
let registry = Arc::new({
|
||||||
|
let mut r = ToolRegistry::new();
|
||||||
|
r.register(BashTool::new());
|
||||||
|
r
|
||||||
|
});
|
||||||
|
|
||||||
|
let tcs = vec![
|
||||||
|
make_tc("bash", json!({"command": "ls"})),
|
||||||
|
make_tc("bash", json!({"command": "pwd"})),
|
||||||
|
];
|
||||||
|
|
||||||
|
let batches = partition_for_test(®istry, &tcs);
|
||||||
|
assert_eq!(batches.len(), 2);
|
||||||
|
assert_eq!(batches[0].len(), 1);
|
||||||
|
assert_eq!(batches[1].len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Unknown tools (not in registry) default to non-concurrency-safe (single batch).
|
||||||
|
#[test]
|
||||||
|
fn test_partition_batches_unknown_tool_gets_own_batch() {
|
||||||
|
let registry = Arc::new(ToolRegistry::new());
|
||||||
|
|
||||||
|
let tcs = vec![
|
||||||
|
make_tc("calculator", json!({})),
|
||||||
|
make_tc("unknown_tool", json!({})),
|
||||||
|
];
|
||||||
|
|
||||||
|
let batches = partition_for_test(®istry, &tcs);
|
||||||
|
assert_eq!(batches.len(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Expose partition logic for testing without needing a full AgentLoop.
|
||||||
|
fn partition_for_test(registry: &Arc<ToolRegistry>, tool_calls: &[ToolCall]) -> Vec<Vec<ToolCall>> {
|
||||||
|
let mut batches: Vec<Vec<ToolCall>> = Vec::new();
|
||||||
|
let mut current: Vec<ToolCall> = Vec::new();
|
||||||
|
|
||||||
|
for tc in tool_calls {
|
||||||
|
let concurrency_safe = registry
|
||||||
|
.get(&tc.name)
|
||||||
|
.map(|t| t.concurrency_safe())
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
if concurrency_safe {
|
||||||
|
current.push(tc.clone());
|
||||||
|
} else {
|
||||||
|
if !current.is_empty() {
|
||||||
|
batches.push(std::mem::take(&mut current));
|
||||||
|
}
|
||||||
|
batches.push(vec![tc.clone()]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !current.is_empty() {
|
||||||
|
batches.push(current);
|
||||||
|
}
|
||||||
|
|
||||||
|
batches
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -31,6 +31,7 @@ impl OutboundDispatcher {
|
|||||||
|
|
||||||
loop {
|
loop {
|
||||||
let msg = self.bus.consume_outbound().await;
|
let msg = self.bus.consume_outbound().await;
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
channel = %msg.channel,
|
channel = %msg.channel,
|
||||||
chat_id = %msg.chat_id,
|
chat_id = %msg.chat_id,
|
||||||
|
|||||||
@ -2,7 +2,60 @@ use std::collections::HashMap;
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// ChatMessage - Legacy type used by AgentLoop for LLM conversation history
|
// ContentBlock - Multimodal content representation (OpenAI-style)
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
pub enum ContentBlock {
|
||||||
|
#[serde(rename = "text")]
|
||||||
|
Text { text: String },
|
||||||
|
#[serde(rename = "image_url")]
|
||||||
|
ImageUrl { image_url: ImageUrlBlock },
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ImageUrlBlock {
|
||||||
|
pub url: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ContentBlock {
|
||||||
|
pub fn text(content: impl Into<String>) -> Self {
|
||||||
|
Self::Text { text: content.into() }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn image_url(url: impl Into<String>) -> Self {
|
||||||
|
Self::ImageUrl {
|
||||||
|
image_url: ImageUrlBlock { url: url.into() },
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// MediaItem - Media metadata for messages
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct MediaItem {
|
||||||
|
pub path: String, // Local file path
|
||||||
|
pub media_type: String, // "image", "audio", "file", "video"
|
||||||
|
pub mime_type: Option<String>,
|
||||||
|
pub original_key: Option<String>, // Feishu file_key for download
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MediaItem {
|
||||||
|
pub fn new(path: impl Into<String>, media_type: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
path: path.into(),
|
||||||
|
media_type: media_type.into(),
|
||||||
|
mime_type: None,
|
||||||
|
original_key: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// ChatMessage - Used by AgentLoop for LLM conversation history
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@ -10,6 +63,7 @@ pub struct ChatMessage {
|
|||||||
pub id: String,
|
pub id: String,
|
||||||
pub role: String,
|
pub role: String,
|
||||||
pub content: String,
|
pub content: String,
|
||||||
|
pub media_refs: Vec<String>, // Paths to media files for context
|
||||||
pub timestamp: i64,
|
pub timestamp: i64,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub tool_call_id: Option<String>,
|
pub tool_call_id: Option<String>,
|
||||||
@ -23,6 +77,19 @@ impl ChatMessage {
|
|||||||
id: uuid::Uuid::new_v4().to_string(),
|
id: uuid::Uuid::new_v4().to_string(),
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: content.into(),
|
content: content.into(),
|
||||||
|
media_refs: Vec::new(),
|
||||||
|
timestamp: current_timestamp(),
|
||||||
|
tool_call_id: None,
|
||||||
|
tool_name: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn user_with_media(content: impl Into<String>, media_refs: Vec<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
id: uuid::Uuid::new_v4().to_string(),
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: content.into(),
|
||||||
|
media_refs,
|
||||||
timestamp: current_timestamp(),
|
timestamp: current_timestamp(),
|
||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
tool_name: None,
|
tool_name: None,
|
||||||
@ -34,6 +101,7 @@ impl ChatMessage {
|
|||||||
id: uuid::Uuid::new_v4().to_string(),
|
id: uuid::Uuid::new_v4().to_string(),
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: content.into(),
|
content: content.into(),
|
||||||
|
media_refs: Vec::new(),
|
||||||
timestamp: current_timestamp(),
|
timestamp: current_timestamp(),
|
||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
tool_name: None,
|
tool_name: None,
|
||||||
@ -45,6 +113,7 @@ impl ChatMessage {
|
|||||||
id: uuid::Uuid::new_v4().to_string(),
|
id: uuid::Uuid::new_v4().to_string(),
|
||||||
role: "system".to_string(),
|
role: "system".to_string(),
|
||||||
content: content.into(),
|
content: content.into(),
|
||||||
|
media_refs: Vec::new(),
|
||||||
timestamp: current_timestamp(),
|
timestamp: current_timestamp(),
|
||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
tool_name: None,
|
tool_name: None,
|
||||||
@ -56,6 +125,7 @@ impl ChatMessage {
|
|||||||
id: uuid::Uuid::new_v4().to_string(),
|
id: uuid::Uuid::new_v4().to_string(),
|
||||||
role: "tool".to_string(),
|
role: "tool".to_string(),
|
||||||
content: content.into(),
|
content: content.into(),
|
||||||
|
media_refs: Vec::new(),
|
||||||
timestamp: current_timestamp(),
|
timestamp: current_timestamp(),
|
||||||
tool_call_id: Some(tool_call_id.into()),
|
tool_call_id: Some(tool_call_id.into()),
|
||||||
tool_name: Some(tool_name.into()),
|
tool_name: Some(tool_name.into()),
|
||||||
@ -74,8 +144,11 @@ pub struct InboundMessage {
|
|||||||
pub chat_id: String,
|
pub chat_id: String,
|
||||||
pub content: String,
|
pub content: String,
|
||||||
pub timestamp: i64,
|
pub timestamp: i64,
|
||||||
pub media: Vec<String>,
|
pub media: Vec<MediaItem>,
|
||||||
|
/// Channel-specific data used internally by the channel (not forwarded).
|
||||||
pub metadata: HashMap<String, String>,
|
pub metadata: HashMap<String, String>,
|
||||||
|
/// Data forwarded from inbound to outbound (copied to OutboundMessage.metadata by gateway).
|
||||||
|
pub forwarded_metadata: HashMap<String, String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl InboundMessage {
|
impl InboundMessage {
|
||||||
@ -94,7 +167,7 @@ pub struct OutboundMessage {
|
|||||||
pub chat_id: String,
|
pub chat_id: String,
|
||||||
pub content: String,
|
pub content: String,
|
||||||
pub reply_to: Option<String>,
|
pub reply_to: Option<String>,
|
||||||
pub media: Vec<String>,
|
pub media: Vec<MediaItem>,
|
||||||
pub metadata: HashMap<String, String>,
|
pub metadata: HashMap<String, String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,7 @@ pub mod dispatcher;
|
|||||||
pub mod message;
|
pub mod message;
|
||||||
|
|
||||||
pub use dispatcher::OutboundDispatcher;
|
pub use dispatcher::OutboundDispatcher;
|
||||||
pub use message::{ChatMessage, InboundMessage, OutboundMessage};
|
pub use message::{ChatMessage, ContentBlock, InboundMessage, MediaItem, OutboundMessage};
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::{mpsc, Mutex};
|
use tokio::sync::{mpsc, Mutex};
|
||||||
@ -33,6 +33,8 @@ impl MessageBus {
|
|||||||
|
|
||||||
/// Publish an inbound message (Channel -> Bus)
|
/// Publish an inbound message (Channel -> Bus)
|
||||||
pub async fn publish_inbound(&self, msg: InboundMessage) -> Result<(), BusError> {
|
pub async fn publish_inbound(&self, msg: InboundMessage) -> Result<(), BusError> {
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
tracing::debug!(channel = %msg.channel, sender = %msg.sender_id, chat = %msg.chat_id, content_len = %msg.content.len(), media_count = %msg.media.len(), "Bus: publishing inbound message");
|
||||||
self.inbound_tx
|
self.inbound_tx
|
||||||
.send(msg)
|
.send(msg)
|
||||||
.await
|
.await
|
||||||
@ -41,16 +43,21 @@ impl MessageBus {
|
|||||||
|
|
||||||
/// Consume an inbound message (Agent -> Bus)
|
/// Consume an inbound message (Agent -> Bus)
|
||||||
pub async fn consume_inbound(&self) -> InboundMessage {
|
pub async fn consume_inbound(&self) -> InboundMessage {
|
||||||
self.inbound_rx
|
let msg = self.inbound_rx
|
||||||
.lock()
|
.lock()
|
||||||
.await
|
.await
|
||||||
.recv()
|
.recv()
|
||||||
.await
|
.await
|
||||||
.expect("bus inbound closed")
|
.expect("bus inbound closed");
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
tracing::debug!(channel = %msg.channel, sender = %msg.sender_id, chat = %msg.chat_id, "Bus: consuming inbound message");
|
||||||
|
msg
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Publish an outbound message (Agent -> Bus)
|
/// Publish an outbound message (Agent -> Bus)
|
||||||
pub async fn publish_outbound(&self, msg: OutboundMessage) -> Result<(), BusError> {
|
pub async fn publish_outbound(&self, msg: OutboundMessage) -> Result<(), BusError> {
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
tracing::debug!(channel = %msg.channel, chat_id = %msg.chat_id, content_len = %msg.content.len(), "Bus: publishing outbound message");
|
||||||
self.outbound_tx
|
self.outbound_tx
|
||||||
.send(msg)
|
.send(msg)
|
||||||
.await
|
.await
|
||||||
|
|||||||
@ -62,37 +62,18 @@ pub trait Channel: Send + Sync + 'static {
|
|||||||
async fn handle_and_publish(
|
async fn handle_and_publish(
|
||||||
&self,
|
&self,
|
||||||
bus: &Arc<MessageBus>,
|
bus: &Arc<MessageBus>,
|
||||||
sender_id: &str,
|
msg: &InboundMessage,
|
||||||
chat_id: &str,
|
|
||||||
content: &str,
|
|
||||||
) -> Result<(), ChannelError> {
|
) -> Result<(), ChannelError> {
|
||||||
if !self.is_allowed(sender_id) {
|
if !self.is_allowed(&msg.sender_id) {
|
||||||
tracing::warn!(
|
tracing::warn!(
|
||||||
channel = %self.name(),
|
channel = %self.name(),
|
||||||
sender = %sender_id,
|
sender = %msg.sender_id,
|
||||||
"Access denied"
|
"Access denied"
|
||||||
);
|
);
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
let msg = InboundMessage {
|
bus.publish_inbound(msg.clone()).await?;
|
||||||
channel: self.name().to_string(),
|
|
||||||
sender_id: sender_id.to_string(),
|
|
||||||
chat_id: chat_id.to_string(),
|
|
||||||
content: content.to_string(),
|
|
||||||
timestamp: current_timestamp(),
|
|
||||||
media: vec![],
|
|
||||||
metadata: std::collections::HashMap::new(),
|
|
||||||
};
|
|
||||||
|
|
||||||
bus.publish_inbound(msg).await?;
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn current_timestamp() -> i64 {
|
|
||||||
std::time::SystemTime::now()
|
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
|
||||||
.unwrap()
|
|
||||||
.as_millis() as i64
|
|
||||||
}
|
|
||||||
|
|||||||
@ -1,17 +1,33 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::path::Path;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use tokio::sync::{broadcast, RwLock};
|
|
||||||
use serde::Deserialize;
|
|
||||||
use futures_util::{SinkExt, StreamExt};
|
use futures_util::{SinkExt, StreamExt};
|
||||||
use prost::{Message as ProstMessage, bytes::Bytes};
|
use prost::{Message as ProstMessage, bytes::Bytes};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use tokio::sync::{broadcast, RwLock};
|
||||||
|
|
||||||
use crate::bus::{MessageBus, OutboundMessage};
|
use crate::bus::{MessageBus, MediaItem, OutboundMessage};
|
||||||
use crate::channels::base::{Channel, ChannelError};
|
use crate::channels::base::{Channel, ChannelError};
|
||||||
use crate::config::{FeishuChannelConfig, LLMProviderConfig};
|
use crate::config::{FeishuChannelConfig, LLMProviderConfig};
|
||||||
|
|
||||||
const FEISHU_API_BASE: &str = "https://open.feishu.cn/open-apis";
|
const FEISHU_API_BASE: &str = "https://open.feishu.cn/open-apis";
|
||||||
const FEISHU_WS_BASE: &str = "https://open.feishu.cn";
|
const FEISHU_WS_BASE: &str = "https://open.feishu.cn";
|
||||||
|
|
||||||
|
/// Heartbeat timeout for WS connection — must be larger than ping_interval (default 120 s).
|
||||||
|
/// If no binary frame (pong or event) is received within this window, reconnect.
|
||||||
|
const WS_HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(300);
|
||||||
|
/// Refresh tenant token this many seconds before the announced expiry.
|
||||||
|
const TOKEN_REFRESH_SKEW: Duration = Duration::from_secs(120);
|
||||||
|
/// Default tenant token TTL when `expire`/`expires_in` is absent.
|
||||||
|
const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(7200);
|
||||||
|
/// Feishu API business code for expired/invalid tenant access token.
|
||||||
|
const INVALID_ACCESS_TOKEN_CODE: i32 = 99991663;
|
||||||
|
/// Dedup cache TTL (30 minutes).
|
||||||
|
const DEDUP_CACHE_TTL: Duration = Duration::from_secs(30 * 60);
|
||||||
|
|
||||||
// ─────────────────────────────────────────────────────────────────────────────
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
// Protobuf types for Feishu WebSocket protocol (pbbp2.proto)
|
// Protobuf types for Feishu WebSocket protocol (pbbp2.proto)
|
||||||
// ─────────────────────────────────────────────────────────────────────────────
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
@ -124,6 +140,13 @@ struct LarkMessage {
|
|||||||
|
|
||||||
// ─────────────────────────────────────────────────────────────────────────────
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Cached tenant token with proactive refresh metadata.
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct CachedTenantToken {
|
||||||
|
value: String,
|
||||||
|
refresh_after: Instant,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct FeishuChannel {
|
pub struct FeishuChannel {
|
||||||
config: FeishuChannelConfig,
|
config: FeishuChannelConfig,
|
||||||
@ -131,6 +154,10 @@ pub struct FeishuChannel {
|
|||||||
running: Arc<RwLock<bool>>,
|
running: Arc<RwLock<bool>>,
|
||||||
shutdown_tx: Arc<RwLock<Option<broadcast::Sender<()>>>>,
|
shutdown_tx: Arc<RwLock<Option<broadcast::Sender<()>>>>,
|
||||||
connected: Arc<RwLock<bool>>,
|
connected: Arc<RwLock<bool>>,
|
||||||
|
/// Cached tenant access token with proactive refresh.
|
||||||
|
tenant_token: Arc<RwLock<Option<CachedTenantToken>>>,
|
||||||
|
/// Dedup cache: WS message_ids seen in the last ~30 min.
|
||||||
|
seen_message_ids: Arc<RwLock<HashMap<String, Instant>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Parsed message data from a Feishu frame
|
/// Parsed message data from a Feishu frame
|
||||||
@ -139,6 +166,7 @@ struct ParsedMessage {
|
|||||||
open_id: String,
|
open_id: String,
|
||||||
chat_id: String,
|
chat_id: String,
|
||||||
content: String,
|
content: String,
|
||||||
|
media: Option<MediaItem>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FeishuChannel {
|
impl FeishuChannel {
|
||||||
@ -152,6 +180,8 @@ impl FeishuChannel {
|
|||||||
running: Arc::new(RwLock::new(false)),
|
running: Arc::new(RwLock::new(false)),
|
||||||
shutdown_tx: Arc::new(RwLock::new(None)),
|
shutdown_tx: Arc::new(RwLock::new(None)),
|
||||||
connected: Arc::new(RwLock::new(false)),
|
connected: Arc::new(RwLock::new(false)),
|
||||||
|
tenant_token: Arc::new(RwLock::new(None)),
|
||||||
|
seen_message_ids: Arc::new(RwLock::new(HashMap::new())),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -188,8 +218,42 @@ impl FeishuChannel {
|
|||||||
Ok((ep.url, client_config))
|
Ok((ep.url, client_config))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get tenant access token
|
/// Get tenant access token (cached with proactive refresh).
|
||||||
async fn get_tenant_token(&self) -> Result<String, ChannelError> {
|
async fn get_tenant_access_token(&self) -> Result<String, ChannelError> {
|
||||||
|
// 1. Check cache
|
||||||
|
{
|
||||||
|
let cached = self.tenant_token.read().await;
|
||||||
|
if let Some(ref token) = *cached {
|
||||||
|
if Instant::now() < token.refresh_after {
|
||||||
|
return Ok(token.value.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Fetch new token
|
||||||
|
let (token, ttl) = self.fetch_new_token().await?;
|
||||||
|
|
||||||
|
// 3. Cache with proactive refresh time (提前 120 秒)
|
||||||
|
let refresh_after = Instant::now() + ttl.saturating_sub(TOKEN_REFRESH_SKEW);
|
||||||
|
{
|
||||||
|
let mut cached = self.tenant_token.write().await;
|
||||||
|
*cached = Some(CachedTenantToken {
|
||||||
|
value: token.clone(),
|
||||||
|
refresh_after,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(token)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Invalidate cached token (called when API reports expired tenant token).
|
||||||
|
async fn invalidate_token(&self) {
|
||||||
|
let mut cached = self.tenant_token.write().await;
|
||||||
|
*cached = None;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fetch a new tenant access token from Feishu.
|
||||||
|
async fn fetch_new_token(&self) -> Result<(String, Duration), ChannelError> {
|
||||||
let resp = self.http_client
|
let resp = self.http_client
|
||||||
.post(format!("{}/auth/v3/tenant_access_token/internal", FEISHU_API_BASE))
|
.post(format!("{}/auth/v3/tenant_access_token/internal", FEISHU_API_BASE))
|
||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json")
|
||||||
@ -205,6 +269,7 @@ impl FeishuChannel {
|
|||||||
struct TokenResponse {
|
struct TokenResponse {
|
||||||
code: i32,
|
code: i32,
|
||||||
tenant_access_token: Option<String>,
|
tenant_access_token: Option<String>,
|
||||||
|
expire: Option<i64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
let token_resp: TokenResponse = resp
|
let token_resp: TokenResponse = resp
|
||||||
@ -216,15 +281,401 @@ impl FeishuChannel {
|
|||||||
return Err(ChannelError::Other("Auth failed".to_string()));
|
return Err(ChannelError::Other("Auth failed".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
token_resp.tenant_access_token
|
let token = token_resp.tenant_access_token
|
||||||
.ok_or_else(|| ChannelError::Other("No token in response".to_string()))
|
.ok_or_else(|| ChannelError::Other("No token in response".to_string()))?;
|
||||||
|
|
||||||
|
let ttl = token_resp.expire
|
||||||
|
.and_then(|v| u64::try_from(v).ok())
|
||||||
|
.map(Duration::from_secs)
|
||||||
|
.unwrap_or(DEFAULT_TOKEN_TTL);
|
||||||
|
|
||||||
|
Ok((token, ttl))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if message_id has been seen (dedup), and mark it as seen if not.
|
||||||
|
/// Returns true if the message was already processed.
|
||||||
|
/// Note: GC of stale entries is handled in the heartbeat timeout_check loop.
|
||||||
|
async fn is_message_seen(&self, message_id: &str) -> bool {
|
||||||
|
let mut seen = self.seen_message_ids.write().await;
|
||||||
|
let now = Instant::now();
|
||||||
|
|
||||||
|
if seen.contains_key(message_id) {
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
seen.insert(message_id.to_string(), now);
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Download media and save locally, return (description, media_item)
|
||||||
|
async fn download_media(
|
||||||
|
&self,
|
||||||
|
msg_type: &str,
|
||||||
|
content_json: &serde_json::Value,
|
||||||
|
message_id: &str,
|
||||||
|
) -> Result<(String, Option<MediaItem>), ChannelError> {
|
||||||
|
let media_dir = Path::new(&self.config.media_dir);
|
||||||
|
tokio::fs::create_dir_all(media_dir).await
|
||||||
|
.map_err(|e| ChannelError::Other(format!("Failed to create media dir: {}", e)))?;
|
||||||
|
|
||||||
|
match msg_type {
|
||||||
|
"image" => self.download_image(content_json, message_id, media_dir).await,
|
||||||
|
"audio" | "file" | "media" => self.download_file(content_json, message_id, media_dir, msg_type).await,
|
||||||
|
_ => Ok((format!("[unsupported media type: {}]", msg_type), None)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Download image from Feishu
|
||||||
|
async fn download_image(
|
||||||
|
&self,
|
||||||
|
content_json: &serde_json::Value,
|
||||||
|
message_id: &str,
|
||||||
|
media_dir: &Path,
|
||||||
|
) -> Result<(String, Option<MediaItem>), ChannelError> {
|
||||||
|
let image_key = content_json.get("image_key")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| ChannelError::Other("No image_key in message".to_string()))?;
|
||||||
|
|
||||||
|
let token = self.get_tenant_access_token().await?;
|
||||||
|
|
||||||
|
// Use message resource API for downloading message images
|
||||||
|
let url = format!("{}/im/v1/messages/{}/resources/{}?type=image", FEISHU_API_BASE, message_id, image_key);
|
||||||
|
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
tracing::debug!(url = %url, image_key = %image_key, message_id = %message_id, "Downloading image from Feishu via message resource API");
|
||||||
|
|
||||||
|
let resp = self.http_client
|
||||||
|
.get(&url)
|
||||||
|
.header("Authorization", format!("Bearer {}", token))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| ChannelError::ConnectionError(format!("Download image HTTP error: {}", e)))?;
|
||||||
|
|
||||||
|
let status = resp.status();
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
tracing::debug!(status = %status, "Image download response status");
|
||||||
|
|
||||||
|
if !status.is_success() {
|
||||||
|
let error_text = resp.text().await.unwrap_or_default();
|
||||||
|
return Err(ChannelError::Other(format!("Image download failed {}: {}", status, error_text)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let data = resp.bytes().await
|
||||||
|
.map_err(|e| ChannelError::Other(format!("Failed to read image data: {}", e)))?
|
||||||
|
.to_vec();
|
||||||
|
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
tracing::debug!(data_len = %data.len(), "Downloaded image data");
|
||||||
|
|
||||||
|
let filename = format!("{}_{}.jpg", message_id, &image_key[..8.min(image_key.len())]);
|
||||||
|
let file_path = media_dir.join(&filename);
|
||||||
|
|
||||||
|
tokio::fs::write(&file_path, &data).await
|
||||||
|
.map_err(|e| ChannelError::Other(format!("Failed to write image: {}", e)))?;
|
||||||
|
|
||||||
|
let media_item = MediaItem::new(
|
||||||
|
file_path.to_string_lossy().to_string(),
|
||||||
|
"image",
|
||||||
|
);
|
||||||
|
|
||||||
|
tracing::info!(message_id = %message_id, filename = %filename, "Downloaded image");
|
||||||
|
|
||||||
|
Ok((format!("[image: {}]", filename), Some(media_item)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Download file/audio from Feishu
|
||||||
|
async fn download_file(
|
||||||
|
&self,
|
||||||
|
content_json: &serde_json::Value,
|
||||||
|
message_id: &str,
|
||||||
|
media_dir: &Path,
|
||||||
|
file_type: &str,
|
||||||
|
) -> Result<(String, Option<MediaItem>), ChannelError> {
|
||||||
|
let file_key = content_json.get("file_key")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| ChannelError::Other("No file_key in message".to_string()))?;
|
||||||
|
|
||||||
|
let token = self.get_tenant_access_token().await?;
|
||||||
|
|
||||||
|
// Use message resource API for downloading message files
|
||||||
|
let url = format!("{}/im/v1/messages/{}/resources/{}?type=file", FEISHU_API_BASE, message_id, file_key);
|
||||||
|
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
tracing::debug!(url = %url, file_key = %file_key, message_id = %message_id, "Downloading file from Feishu via message resource API");
|
||||||
|
|
||||||
|
let resp = self.http_client
|
||||||
|
.get(&url)
|
||||||
|
.header("Authorization", format!("Bearer {}", token))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| ChannelError::ConnectionError(format!("Download file HTTP error: {}", e)))?;
|
||||||
|
|
||||||
|
let status = resp.status();
|
||||||
|
if !status.is_success() {
|
||||||
|
let error_text = resp.text().await.unwrap_or_default();
|
||||||
|
return Err(ChannelError::Other(format!("File download failed {}: {}", status, error_text)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let data = resp.bytes().await
|
||||||
|
.map_err(|e| ChannelError::Other(format!("Failed to read file data: {}", e)))?
|
||||||
|
.to_vec();
|
||||||
|
|
||||||
|
let extension = match file_type {
|
||||||
|
"audio" => "mp3",
|
||||||
|
"video" => "mp4",
|
||||||
|
_ => "bin",
|
||||||
|
};
|
||||||
|
let filename = format!("{}_{}.{}", message_id, &file_key[..8.min(file_key.len())], extension);
|
||||||
|
let file_path = media_dir.join(&filename);
|
||||||
|
|
||||||
|
tokio::fs::write(&file_path, &data).await
|
||||||
|
.map_err(|e| ChannelError::Other(format!("Failed to write file: {}", e)))?;
|
||||||
|
|
||||||
|
let media_item = MediaItem::new(
|
||||||
|
file_path.to_string_lossy().to_string(),
|
||||||
|
file_type,
|
||||||
|
);
|
||||||
|
|
||||||
|
tracing::info!(message_id = %message_id, filename = %filename, file_type = %file_type, "Downloaded file");
|
||||||
|
|
||||||
|
Ok((format!("[{}: {}]", file_type, filename), Some(media_item)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Upload image to Feishu and return the image_key
|
||||||
|
async fn upload_image(&self, file_path: &str) -> Result<String, ChannelError> {
|
||||||
|
let token = self.get_tenant_access_token().await?;
|
||||||
|
|
||||||
|
let mime = mime_guess::from_path(file_path)
|
||||||
|
.first_or_octet_stream()
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
let file_name = std::path::Path::new(file_path)
|
||||||
|
.file_name()
|
||||||
|
.and_then(|n| n.to_str())
|
||||||
|
.unwrap_or("image.jpg");
|
||||||
|
|
||||||
|
let file_data = tokio::fs::read(file_path).await
|
||||||
|
.map_err(|e| ChannelError::Other(format!("Failed to read file: {}", e)))?;
|
||||||
|
|
||||||
|
let part = reqwest::multipart::Part::bytes(file_data)
|
||||||
|
.file_name(file_name.to_string())
|
||||||
|
.mime_str(&mime)
|
||||||
|
.map_err(|e| ChannelError::Other(format!("Invalid mime type: {}", e)))?;
|
||||||
|
|
||||||
|
let form = reqwest::multipart::Form::new()
|
||||||
|
.text("image_type", "message".to_string())
|
||||||
|
.part("image", part);
|
||||||
|
|
||||||
|
let resp = self.http_client
|
||||||
|
.post(format!("{}/im/v1/images/upload", FEISHU_API_BASE))
|
||||||
|
.header("Authorization", format!("Bearer {}", token))
|
||||||
|
.multipart(form)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| ChannelError::ConnectionError(format!("Upload image HTTP error: {}", e)))?;
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct UploadResp {
|
||||||
|
code: i32,
|
||||||
|
msg: Option<String>,
|
||||||
|
data: Option<UploadData>,
|
||||||
|
}
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct UploadData {
|
||||||
|
image_key: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
let result: UploadResp = resp.json().await
|
||||||
|
.map_err(|e| ChannelError::Other(format!("Parse upload response error: {}", e)))?;
|
||||||
|
|
||||||
|
if result.code != 0 {
|
||||||
|
return Err(ChannelError::Other(format!(
|
||||||
|
"Upload image failed: code={} msg={}",
|
||||||
|
result.code,
|
||||||
|
result.msg.as_deref().unwrap_or("unknown")
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
result.data
|
||||||
|
.map(|d| d.image_key)
|
||||||
|
.ok_or_else(|| ChannelError::Other("No image_key in response".to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Upload file to Feishu and return the file_key
|
||||||
|
async fn upload_file(&self, file_path: &str) -> Result<String, ChannelError> {
|
||||||
|
let token = self.get_tenant_access_token().await?;
|
||||||
|
|
||||||
|
let file_name = std::path::Path::new(file_path)
|
||||||
|
.file_name()
|
||||||
|
.and_then(|n| n.to_str())
|
||||||
|
.unwrap_or("file.bin");
|
||||||
|
|
||||||
|
let extension = std::path::Path::new(file_path)
|
||||||
|
.extension()
|
||||||
|
.and_then(|e| e.to_str())
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_lowercase();
|
||||||
|
|
||||||
|
let file_type = match extension.as_str() {
|
||||||
|
"mp3" | "m4a" | "wav" | "ogg" => "audio",
|
||||||
|
"mp4" | "mov" | "avi" | "mkv" => "video",
|
||||||
|
"pdf" | "doc" | "docx" | "xls" | "xlsx" => "doc",
|
||||||
|
_ => "file",
|
||||||
|
};
|
||||||
|
|
||||||
|
let file_data = tokio::fs::read(file_path).await
|
||||||
|
.map_err(|e| ChannelError::Other(format!("Failed to read file: {}", e)))?;
|
||||||
|
|
||||||
|
let part = reqwest::multipart::Part::bytes(file_data)
|
||||||
|
.file_name(file_name.to_string())
|
||||||
|
.mime_str("application/octet-stream")
|
||||||
|
.map_err(|e| ChannelError::Other(format!("Invalid mime type: {}", e)))?;
|
||||||
|
|
||||||
|
let form = reqwest::multipart::Form::new()
|
||||||
|
.text("file_type", file_type.to_string())
|
||||||
|
.text("file_name", file_name.to_string())
|
||||||
|
.part("file", part);
|
||||||
|
|
||||||
|
let resp = self.http_client
|
||||||
|
.post(format!("{}/im/v1/files", FEISHU_API_BASE))
|
||||||
|
.header("Authorization", format!("Bearer {}", token))
|
||||||
|
.multipart(form)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| ChannelError::ConnectionError(format!("Upload file HTTP error: {}", e)))?;
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct UploadResp {
|
||||||
|
code: i32,
|
||||||
|
msg: Option<String>,
|
||||||
|
data: Option<UploadData>,
|
||||||
|
}
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct UploadData {
|
||||||
|
file_key: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
let result: UploadResp = resp.json().await
|
||||||
|
.map_err(|e| ChannelError::Other(format!("Parse upload response error: {}", e)))?;
|
||||||
|
|
||||||
|
if result.code != 0 {
|
||||||
|
return Err(ChannelError::Other(format!(
|
||||||
|
"Upload file failed: code={} msg={}",
|
||||||
|
result.code,
|
||||||
|
result.msg.as_deref().unwrap_or("unknown")
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
result.data
|
||||||
|
.map(|d| d.file_key)
|
||||||
|
.ok_or_else(|| ChannelError::Other("No file_key in response".to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add a reaction emoji to a message and store the reaction_id for later removal.
|
||||||
|
/// Returns the reaction_id if successful, None otherwise.
|
||||||
|
async fn add_reaction(&self, message_id: &str) -> Result<Option<String>, ChannelError> {
|
||||||
|
let emoji = self.config.reaction_emoji.as_str();
|
||||||
|
let token = self.get_tenant_access_token().await?;
|
||||||
|
|
||||||
|
let resp = self.http_client
|
||||||
|
.post(format!("{}/im/v1/messages/{}/reactions", FEISHU_API_BASE, message_id))
|
||||||
|
.header("Authorization", format!("Bearer {}", token))
|
||||||
|
.json(&serde_json::json!({
|
||||||
|
"reaction_type": { "emoji_type": emoji }
|
||||||
|
}))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| ChannelError::ConnectionError(format!("Add reaction HTTP error: {}", e)))?;
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct ReactionResp {
|
||||||
|
code: i32,
|
||||||
|
msg: Option<String>,
|
||||||
|
data: Option<ReactionData>,
|
||||||
|
}
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct ReactionData {
|
||||||
|
reaction_id: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
let result: ReactionResp = resp.json().await
|
||||||
|
.map_err(|e| ChannelError::Other(format!("Parse reaction response error: {}", e)))?;
|
||||||
|
|
||||||
|
if result.code != 0 {
|
||||||
|
tracing::warn!(
|
||||||
|
"Failed to add reaction to message {}: code={} msg={}",
|
||||||
|
message_id,
|
||||||
|
result.code,
|
||||||
|
result.msg.as_deref().unwrap_or("unknown")
|
||||||
|
);
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
let reaction_id = result.data.and_then(|d| d.reaction_id);
|
||||||
|
Ok(reaction_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Remove reaction using feishu metadata propagated through OutboundMessage.
|
||||||
|
/// Reads feishu.message_id and feishu.reaction_id from metadata.
|
||||||
|
async fn remove_reaction_from_metadata(&self, metadata: &std::collections::HashMap<String, String>) {
|
||||||
|
let (message_id, reaction_id) = match (
|
||||||
|
metadata.get("feishu.message_id"),
|
||||||
|
metadata.get("feishu.reaction_id"),
|
||||||
|
) {
|
||||||
|
(Some(msg_id), Some(rid)) => (msg_id.clone(), rid.clone()),
|
||||||
|
_ => return,
|
||||||
|
};
|
||||||
|
if let Err(e) = self.remove_reaction(&message_id, &reaction_id).await {
|
||||||
|
tracing::debug!(error = %e, message_id = %message_id, "Failed to remove reaction");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Remove a reaction emoji from a message.
|
||||||
|
async fn remove_reaction(&self, message_id: &str, reaction_id: &str) -> Result<(), ChannelError> {
|
||||||
|
let token = self.get_tenant_access_token().await?;
|
||||||
|
|
||||||
|
let resp = self.http_client
|
||||||
|
.delete(format!("{}/im/v1/messages/{}/reactions/{}", FEISHU_API_BASE, message_id, reaction_id))
|
||||||
|
.header("Authorization", format!("Bearer {}", token))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| ChannelError::ConnectionError(format!("Remove reaction HTTP error: {}", e)))?;
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct ReactionResp {
|
||||||
|
code: i32,
|
||||||
|
msg: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
let result: ReactionResp = resp.json().await
|
||||||
|
.map_err(|e| ChannelError::Other(format!("Parse remove reaction response error: {}", e)))?;
|
||||||
|
|
||||||
|
if result.code != 0 {
|
||||||
|
tracing::debug!(
|
||||||
|
"Failed to remove reaction {} from message {}: code={} msg={}",
|
||||||
|
reaction_id,
|
||||||
|
message_id,
|
||||||
|
result.code,
|
||||||
|
result.msg.as_deref().unwrap_or("unknown")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Send a text message to Feishu chat (implements Channel trait)
|
/// Send a text message to Feishu chat (implements Channel trait)
|
||||||
async fn send_message_to_feishu(&self, receive_id: &str, receive_id_type: &str, content: &str) -> Result<(), ChannelError> {
|
async fn send_message_to_feishu(&self, receive_id: &str, receive_id_type: &str, content: &str) -> Result<(), ChannelError> {
|
||||||
let token = self.get_tenant_token().await?;
|
let token = self.get_tenant_access_token().await?;
|
||||||
|
|
||||||
let text_content = serde_json::json!({ "text": content }).to_string();
|
// Feishu text messages have content limits (~64KB).
|
||||||
|
// Truncate if content is too long to avoid API error 230001.
|
||||||
|
const MAX_TEXT_LENGTH: usize = 60_000;
|
||||||
|
let truncated = if content.len() > MAX_TEXT_LENGTH {
|
||||||
|
format!("{}...\n\n[Content truncated due to length limit]", &content[..MAX_TEXT_LENGTH])
|
||||||
|
} else {
|
||||||
|
content.to_string()
|
||||||
|
};
|
||||||
|
|
||||||
|
let text_content = serde_json::json!({ "text": truncated }).to_string();
|
||||||
|
|
||||||
let resp = self.http_client
|
let resp = self.http_client
|
||||||
.post(format!("{}/im/v1/messages?receive_id_type={}", FEISHU_API_BASE, receive_id_type))
|
.post(format!("{}/im/v1/messages?receive_id_type={}", FEISHU_API_BASE, receive_id_type))
|
||||||
@ -285,10 +736,15 @@ impl FeishuChannel {
|
|||||||
let payload = frame.payload.as_deref()
|
let payload = frame.payload.as_deref()
|
||||||
.ok_or_else(|| ChannelError::Other("No payload in frame".to_string()))?;
|
.ok_or_else(|| ChannelError::Other("No payload in frame".to_string()))?;
|
||||||
|
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
tracing::debug!(payload_len = %payload.len(), "Received frame payload");
|
||||||
|
|
||||||
let event: LarkEvent = serde_json::from_slice(payload)
|
let event: LarkEvent = serde_json::from_slice(payload)
|
||||||
.map_err(|e| ChannelError::Other(format!("Parse event error: {}", e)))?;
|
.map_err(|e| ChannelError::Other(format!("Parse event error: {}", e)))?;
|
||||||
|
|
||||||
let event_type = event.header.event_type.as_str();
|
let event_type = event.header.event_type.as_str();
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
tracing::debug!(event_type = %event_type, "Received event type");
|
||||||
if event_type != "im.message.receive_v1" {
|
if event_type != "im.message.receive_v1" {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
@ -303,22 +759,74 @@ impl FeishuChannel {
|
|||||||
|
|
||||||
let message_id = payload_data.message.message_id.clone();
|
let message_id = payload_data.message.message_id.clone();
|
||||||
|
|
||||||
|
// Deduplication check
|
||||||
|
if self.is_message_seen(&message_id).await {
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
tracing::debug!(message_id = %message_id, "Duplicate message, skipping");
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
tracing::debug!(message_id = %message_id, "Received Feishu message");
|
||||||
|
|
||||||
let open_id = payload_data.sender.sender_id.open_id
|
let open_id = payload_data.sender.sender_id.open_id
|
||||||
.ok_or_else(|| ChannelError::Other("No open_id".to_string()))?;
|
.ok_or_else(|| ChannelError::Other("No open_id".to_string()))?;
|
||||||
|
|
||||||
let msg = payload_data.message;
|
let msg = payload_data.message;
|
||||||
let chat_id = msg.chat_id.clone();
|
let chat_id = msg.chat_id.clone();
|
||||||
let msg_type = msg.message_type.as_str();
|
let msg_type = msg.message_type.as_str();
|
||||||
let content = parse_message_content(msg_type, &msg.content);
|
let raw_content = msg.content.clone();
|
||||||
|
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
tracing::debug!(msg_type = %msg_type, chat_id = %chat_id, open_id = %open_id, "Parsing message content");
|
||||||
|
|
||||||
|
let (content, media) = self.parse_and_download_message(msg_type, &raw_content, &message_id).await?;
|
||||||
|
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
if let Some(ref m) = media {
|
||||||
|
tracing::debug!(media_type = %m.media_type, media_path = %m.path, "Media downloaded successfully");
|
||||||
|
}
|
||||||
|
|
||||||
Ok(Some(ParsedMessage {
|
Ok(Some(ParsedMessage {
|
||||||
message_id,
|
message_id,
|
||||||
open_id,
|
open_id,
|
||||||
chat_id,
|
chat_id,
|
||||||
content,
|
content,
|
||||||
|
media,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Parse message content and download media if needed
|
||||||
|
async fn parse_and_download_message(
|
||||||
|
&self,
|
||||||
|
msg_type: &str,
|
||||||
|
content: &str,
|
||||||
|
message_id: &str,
|
||||||
|
) -> Result<(String, Option<MediaItem>), ChannelError> {
|
||||||
|
match msg_type {
|
||||||
|
"text" => {
|
||||||
|
let text = if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) {
|
||||||
|
parsed.get("text").and_then(|v| v.as_str()).unwrap_or(content).to_string()
|
||||||
|
} else {
|
||||||
|
content.to_string()
|
||||||
|
};
|
||||||
|
Ok((text, None))
|
||||||
|
}
|
||||||
|
"post" => {
|
||||||
|
let text = parse_post_content(content);
|
||||||
|
Ok((text, None))
|
||||||
|
}
|
||||||
|
"image" | "audio" | "file" | "media" => {
|
||||||
|
if let Ok(content_json) = serde_json::from_str::<serde_json::Value>(content) {
|
||||||
|
self.download_media(msg_type, &content_json, message_id).await
|
||||||
|
} else {
|
||||||
|
Ok((format!("[{}: content unavailable]", msg_type), None))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => Ok((content.to_string(), None)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Send acknowledgment for a message
|
/// Send acknowledgment for a message
|
||||||
async fn send_ack(frame: &PbFrame, write: &mut futures_util::stream::SplitSink<tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>, tokio_tungstenite::tungstenite::Message>) -> Result<(), ChannelError> {
|
async fn send_ack(frame: &PbFrame, write: &mut futures_util::stream::SplitSink<tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>, tokio_tungstenite::tungstenite::Message>) -> Result<(), ChannelError> {
|
||||||
let mut ack = frame.clone();
|
let mut ack = frame.clone();
|
||||||
@ -366,16 +874,20 @@ impl FeishuChannel {
|
|||||||
|
|
||||||
let ping_interval = client_config.ping_interval.unwrap_or(120).max(10);
|
let ping_interval = client_config.ping_interval.unwrap_or(120).max(10);
|
||||||
let mut ping_interval_tok = tokio::time::interval(tokio::time::Duration::from_secs(ping_interval));
|
let mut ping_interval_tok = tokio::time::interval(tokio::time::Duration::from_secs(ping_interval));
|
||||||
|
let mut timeout_check = tokio::time::interval(tokio::time::Duration::from_secs(10));
|
||||||
let mut seq: u64 = 1;
|
let mut seq: u64 = 1;
|
||||||
|
let mut last_recv = Instant::now();
|
||||||
|
|
||||||
// Consume the immediate tick
|
// Consume the immediate tick
|
||||||
ping_interval_tok.tick().await;
|
ping_interval_tok.tick().await;
|
||||||
|
timeout_check.tick().await;
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
msg = read.next() => {
|
msg = read.next() => {
|
||||||
match msg {
|
match msg {
|
||||||
Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(data))) => {
|
Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(data))) => {
|
||||||
|
last_recv = Instant::now();
|
||||||
let bytes: Bytes = data;
|
let bytes: Bytes = data;
|
||||||
if let Ok(frame) = PbFrame::decode(bytes.as_ref()) {
|
if let Ok(frame) = PbFrame::decode(bytes.as_ref()) {
|
||||||
match self.handle_frame(&frame).await {
|
match self.handle_frame(&frame).await {
|
||||||
@ -385,12 +897,49 @@ impl FeishuChannel {
|
|||||||
tracing::error!(error = %e, "Failed to send ACK to Feishu");
|
tracing::error!(error = %e, "Failed to send ACK to Feishu");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add reaction emoji (await so we get the reaction_id for later removal)
|
||||||
|
let message_id = parsed.message_id.clone();
|
||||||
|
let reaction_id = match self.add_reaction(&message_id).await {
|
||||||
|
Ok(Some(rid)) => Some(rid),
|
||||||
|
Ok(None) => None,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::debug!(error = %e, message_id = %message_id, "Failed to add reaction");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// forwarded_metadata is copied to OutboundMessage.metadata by the gateway.
|
||||||
|
let mut forwarded_metadata = std::collections::HashMap::new();
|
||||||
|
forwarded_metadata.insert("feishu.message_id".to_string(), message_id.clone());
|
||||||
|
if let Some(ref rid) = reaction_id {
|
||||||
|
forwarded_metadata.insert("feishu.reaction_id".to_string(), rid.clone());
|
||||||
|
}
|
||||||
|
|
||||||
// Publish to bus asynchronously
|
// Publish to bus asynchronously
|
||||||
let channel = self.clone();
|
let channel = self.clone();
|
||||||
let bus = bus.clone();
|
let bus = bus.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(e) = channel.handle_and_publish(&bus, &parsed.open_id, &parsed.chat_id, &parsed.content).await {
|
let media_count = if parsed.media.is_some() { 1 } else { 0 };
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
tracing::debug!(open_id = %parsed.open_id, chat_id = %parsed.chat_id, content_len = %parsed.content.len(), media_count = %media_count, "Publishing message to bus");
|
||||||
|
let msg = crate::bus::InboundMessage {
|
||||||
|
channel: "feishu".to_string(),
|
||||||
|
sender_id: parsed.open_id.clone(),
|
||||||
|
chat_id: parsed.chat_id.clone(),
|
||||||
|
content: parsed.content.clone(),
|
||||||
|
timestamp: std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_millis() as i64,
|
||||||
|
media: parsed.media.map(|m| vec![m]).unwrap_or_default(),
|
||||||
|
metadata: std::collections::HashMap::new(),
|
||||||
|
forwarded_metadata,
|
||||||
|
};
|
||||||
|
if let Err(e) = channel.handle_and_publish(&bus, &msg).await {
|
||||||
tracing::error!(error = %e, open_id = %parsed.open_id, chat_id = %parsed.chat_id, "Failed to publish Feishu message to bus");
|
tracing::error!(error = %e, open_id = %parsed.open_id, chat_id = %parsed.chat_id, "Failed to publish Feishu message to bus");
|
||||||
|
} else {
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
tracing::debug!(open_id = %parsed.open_id, chat_id = %parsed.chat_id, "Message published to bus successfully");
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -402,6 +951,7 @@ impl FeishuChannel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Some(Ok(tokio_tungstenite::tungstenite::Message::Ping(data))) => {
|
Some(Ok(tokio_tungstenite::tungstenite::Message::Ping(data))) => {
|
||||||
|
last_recv = Instant::now();
|
||||||
let pong = PbFrame {
|
let pong = PbFrame {
|
||||||
seq_id: seq.wrapping_add(1),
|
seq_id: seq.wrapping_add(1),
|
||||||
log_id: 0,
|
log_id: 0,
|
||||||
@ -415,7 +965,11 @@ impl FeishuChannel {
|
|||||||
};
|
};
|
||||||
let _ = write.send(tokio_tungstenite::tungstenite::Message::Binary(pong.encode_to_vec().into())).await;
|
let _ = write.send(tokio_tungstenite::tungstenite::Message::Binary(pong.encode_to_vec().into())).await;
|
||||||
}
|
}
|
||||||
|
Some(Ok(tokio_tungstenite::tungstenite::Message::Pong(_))) => {
|
||||||
|
last_recv = Instant::now();
|
||||||
|
}
|
||||||
Some(Ok(tokio_tungstenite::tungstenite::Message::Close(_))) | None => {
|
Some(Ok(tokio_tungstenite::tungstenite::Message::Close(_))) | None => {
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!("Feishu WebSocket closed");
|
tracing::debug!("Feishu WebSocket closed");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -444,6 +998,16 @@ impl FeishuChannel {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
_ = timeout_check.tick() => {
|
||||||
|
if last_recv.elapsed() > WS_HEARTBEAT_TIMEOUT {
|
||||||
|
tracing::warn!("Feishu WebSocket heartbeat timeout, reconnecting");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// GC dedup cache: remove entries older than TTL (matches zeroclaw pattern)
|
||||||
|
let now = Instant::now();
|
||||||
|
let mut seen = self.seen_message_ids.write().await;
|
||||||
|
seen.retain(|_, ts| now.duration_since(*ts) < DEDUP_CACHE_TTL);
|
||||||
|
}
|
||||||
_ = shutdown_rx.recv() => {
|
_ = shutdown_rx.recv() => {
|
||||||
tracing::info!("Feishu channel shutdown signal received");
|
tracing::info!("Feishu channel shutdown signal received");
|
||||||
break;
|
break;
|
||||||
@ -456,16 +1020,7 @@ impl FeishuChannel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_message_content(msg_type: &str, content: &str) -> String {
|
fn parse_post_content(content: &str) -> String {
|
||||||
match msg_type {
|
|
||||||
"text" => {
|
|
||||||
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) {
|
|
||||||
parsed.get("text").and_then(|v| v.as_str()).unwrap_or(content).to_string()
|
|
||||||
} else {
|
|
||||||
content.to_string()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"post" => {
|
|
||||||
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) {
|
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) {
|
||||||
let mut texts = vec![];
|
let mut texts = vec![];
|
||||||
if let Some(post) = parsed.get("post") {
|
if let Some(post) = parsed.get("post") {
|
||||||
@ -492,9 +1047,6 @@ fn parse_message_content(msg_type: &str, content: &str) -> String {
|
|||||||
content.to_string()
|
content.to_string()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => content.to_string(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl Channel for FeishuChannel {
|
impl Channel for FeishuChannel {
|
||||||
@ -574,6 +1126,119 @@ impl Channel for FeishuChannel {
|
|||||||
let receive_id = if msg.chat_id.starts_with("oc_") { &msg.chat_id } else { &msg.reply_to.as_ref().unwrap_or(&msg.chat_id) };
|
let receive_id = if msg.chat_id.starts_with("oc_") { &msg.chat_id } else { &msg.reply_to.as_ref().unwrap_or(&msg.chat_id) };
|
||||||
let receive_id_type = if msg.chat_id.starts_with("oc_") { "chat_id" } else { "open_id" };
|
let receive_id_type = if msg.chat_id.starts_with("oc_") { "chat_id" } else { "open_id" };
|
||||||
|
|
||||||
self.send_message_to_feishu(receive_id, receive_id_type, &msg.content).await
|
// If no media, send text only
|
||||||
|
if msg.media.is_empty() {
|
||||||
|
let result = self.send_message_to_feishu(receive_id, receive_id_type, &msg.content).await;
|
||||||
|
// Remove pending reaction after sending (using metadata propagated from inbound)
|
||||||
|
self.remove_reaction_from_metadata(&msg.metadata).await;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle multimodal message - send with media
|
||||||
|
let token = self.get_tenant_access_token().await?;
|
||||||
|
|
||||||
|
// Build content with media references
|
||||||
|
let mut content_parts = Vec::new();
|
||||||
|
|
||||||
|
// Add text content if present (truncate if too long for Feishu)
|
||||||
|
if !msg.content.is_empty() {
|
||||||
|
const MAX_TEXT_LENGTH: usize = 60_000;
|
||||||
|
let truncated_text = if msg.content.len() > MAX_TEXT_LENGTH {
|
||||||
|
format!("{}...\n\n[Content truncated due to length limit]", &msg.content[..MAX_TEXT_LENGTH])
|
||||||
|
} else {
|
||||||
|
msg.content.clone()
|
||||||
|
};
|
||||||
|
content_parts.push(serde_json::json!({
|
||||||
|
"tag": "text",
|
||||||
|
"text": truncated_text
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upload and add media
|
||||||
|
for media_item in &msg.media {
|
||||||
|
let path = &media_item.path;
|
||||||
|
match media_item.media_type.as_str() {
|
||||||
|
"image" => {
|
||||||
|
match self.upload_image(path).await {
|
||||||
|
Ok(image_key) => {
|
||||||
|
content_parts.push(serde_json::json!({
|
||||||
|
"tag": "image",
|
||||||
|
"image_key": image_key
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(error = %e, path = %path, "Failed to upload image");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"audio" | "file" | "video" => {
|
||||||
|
match self.upload_file(path).await {
|
||||||
|
Ok(file_key) => {
|
||||||
|
content_parts.push(serde_json::json!({
|
||||||
|
"tag": "file",
|
||||||
|
"file_key": file_key
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(error = %e, path = %path, "Failed to upload file");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
tracing::warn!(media_type = %media_item.media_type, "Unsupported media type for sending");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no content parts after processing, just send empty text
|
||||||
|
if content_parts.is_empty() {
|
||||||
|
let result = self.send_message_to_feishu(receive_id, receive_id_type, "").await;
|
||||||
|
// Remove pending reaction after sending (using metadata propagated from inbound)
|
||||||
|
self.remove_reaction_from_metadata(&msg.metadata).await;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine message type
|
||||||
|
let has_image = msg.media.iter().any(|m| m.media_type == "image");
|
||||||
|
let msg_type = if has_image && msg.content.is_empty() {
|
||||||
|
"image"
|
||||||
|
} else {
|
||||||
|
"post"
|
||||||
|
};
|
||||||
|
|
||||||
|
let content = serde_json::json!({
|
||||||
|
"content": content_parts
|
||||||
|
}).to_string();
|
||||||
|
|
||||||
|
let resp = self.http_client
|
||||||
|
.post(format!("{}/im/v1/messages?receive_id_type={}", FEISHU_API_BASE, receive_id_type))
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header("Authorization", format!("Bearer {}", token))
|
||||||
|
.json(&serde_json::json!({
|
||||||
|
"receive_id": receive_id,
|
||||||
|
"msg_type": msg_type,
|
||||||
|
"content": content
|
||||||
|
}))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| ChannelError::ConnectionError(format!("Send multimodal message HTTP error: {}", e)))?;
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct SendResp {
|
||||||
|
code: i32,
|
||||||
|
msg: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
let send_resp: SendResp = resp.json().await
|
||||||
|
.map_err(|e| ChannelError::Other(format!("Parse send response error: {}", e)))?;
|
||||||
|
|
||||||
|
if send_resp.code != 0 {
|
||||||
|
return Err(ChannelError::Other(format!("Send multimodal message failed: code={} msg={}", send_resp.code, send_resp.msg)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove pending reaction after successfully sending
|
||||||
|
self.remove_reaction_from_metadata(&msg.metadata).await;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -35,6 +35,7 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
input.write_output(&format!("Error: {}", message)).await?;
|
input.write_output(&format!("Error: {}", message)).await?;
|
||||||
}
|
}
|
||||||
WsOutbound::SessionEstablished { session_id } => {
|
WsOutbound::SessionEstablished { session_id } => {
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(session_id = %session_id, "Session established");
|
tracing::debug!(session_id = %session_id, "Session established");
|
||||||
input.write_output(&format!("Session: {}\n", session_id)).await?;
|
input.write_output(&format!("Session: {}\n", session_id)).await?;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -28,12 +28,26 @@ pub struct FeishuChannelConfig {
|
|||||||
pub allow_from: Vec<String>,
|
pub allow_from: Vec<String>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub agent: String,
|
pub agent: String,
|
||||||
|
#[serde(default = "default_media_dir")]
|
||||||
|
pub media_dir: String,
|
||||||
|
/// Emoji type for message reactions (e.g. "THUMBSUP", "OK", "EYES").
|
||||||
|
#[serde(default = "default_reaction_emoji")]
|
||||||
|
pub reaction_emoji: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_allow_from() -> Vec<String> {
|
fn default_allow_from() -> Vec<String> {
|
||||||
vec!["*".to_string()]
|
vec!["*".to_string()]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn default_media_dir() -> String {
|
||||||
|
let home = dirs::home_dir().unwrap_or_else(|| std::path::PathBuf::from("."));
|
||||||
|
home.join(".picobot/media/feishu").to_string_lossy().to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_reaction_emoji() -> String {
|
||||||
|
"Typing".to_string()
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct ProviderConfig {
|
pub struct ProviderConfig {
|
||||||
#[serde(rename = "type")]
|
#[serde(rename = "type")]
|
||||||
@ -59,6 +73,12 @@ pub struct ModelConfig {
|
|||||||
pub struct AgentConfig {
|
pub struct AgentConfig {
|
||||||
pub provider: String,
|
pub provider: String,
|
||||||
pub model: String,
|
pub model: String,
|
||||||
|
#[serde(default = "default_max_iterations")]
|
||||||
|
pub max_iterations: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_max_iterations() -> u32 {
|
||||||
|
15
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
@ -118,6 +138,7 @@ pub struct LLMProviderConfig {
|
|||||||
pub temperature: Option<f32>,
|
pub temperature: Option<f32>,
|
||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
pub model_extra: HashMap<String, serde_json::Value>,
|
pub model_extra: HashMap<String, serde_json::Value>,
|
||||||
|
pub max_iterations: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_default_config_path() -> PathBuf {
|
fn get_default_config_path() -> PathBuf {
|
||||||
@ -177,6 +198,7 @@ impl Config {
|
|||||||
temperature: model.temperature,
|
temperature: model.temperature,
|
||||||
max_tokens: model.max_tokens,
|
max_tokens: model.max_tokens,
|
||||||
model_extra: model.extra.clone(),
|
model_extra: model.extra.clone(),
|
||||||
|
max_iterations: agent.max_iterations,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -53,11 +53,22 @@ impl GatewayState {
|
|||||||
tracing::info!("Inbound processor started");
|
tracing::info!("Inbound processor started");
|
||||||
loop {
|
loop {
|
||||||
let inbound = bus_for_inbound.consume_inbound().await;
|
let inbound = bus_for_inbound.consume_inbound().await;
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
channel = %inbound.channel,
|
channel = %inbound.channel,
|
||||||
chat_id = %inbound.chat_id,
|
chat_id = %inbound.chat_id,
|
||||||
|
sender = %inbound.sender_id,
|
||||||
|
content = %inbound.content,
|
||||||
|
media_count = %inbound.media.len(),
|
||||||
"Processing inbound message"
|
"Processing inbound message"
|
||||||
);
|
);
|
||||||
|
if !inbound.media.is_empty() {
|
||||||
|
for (i, m) in inbound.media.iter().enumerate() {
|
||||||
|
tracing::debug!(media_index = i, media_type = %m.media_type, path = %m.path, "Media item");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Process via session manager
|
// Process via session manager
|
||||||
match session_manager.handle_message(
|
match session_manager.handle_message(
|
||||||
@ -65,15 +76,19 @@ impl GatewayState {
|
|||||||
&inbound.sender_id,
|
&inbound.sender_id,
|
||||||
&inbound.chat_id,
|
&inbound.chat_id,
|
||||||
&inbound.content,
|
&inbound.content,
|
||||||
|
inbound.media,
|
||||||
).await {
|
).await {
|
||||||
Ok(response_content) => {
|
Ok(response_content) => {
|
||||||
|
// Forward channel-specific metadata from inbound to outbound.
|
||||||
|
// This allows channels to propagate context (e.g. feishu message_id for reaction cleanup)
|
||||||
|
// without gateway needing channel-specific code.
|
||||||
let outbound = crate::bus::OutboundMessage {
|
let outbound = crate::bus::OutboundMessage {
|
||||||
channel: inbound.channel,
|
channel: inbound.channel.clone(),
|
||||||
chat_id: inbound.chat_id,
|
chat_id: inbound.chat_id.clone(),
|
||||||
content: response_content,
|
content: response_content,
|
||||||
reply_to: None,
|
reply_to: None,
|
||||||
media: vec![],
|
media: vec![],
|
||||||
metadata: std::collections::HashMap::new(),
|
metadata: inbound.forwarded_metadata,
|
||||||
};
|
};
|
||||||
if let Err(e) = bus_for_inbound.publish_outbound(outbound).await {
|
if let Err(e) = bus_for_inbound.publish_outbound(outbound).await {
|
||||||
tracing::error!(error = %e, "Failed to publish outbound");
|
tracing::error!(error = %e, "Failed to publish outbound");
|
||||||
|
|||||||
@ -7,7 +7,10 @@ use crate::bus::ChatMessage;
|
|||||||
use crate::config::LLMProviderConfig;
|
use crate::config::LLMProviderConfig;
|
||||||
use crate::agent::{AgentLoop, AgentError};
|
use crate::agent::{AgentLoop, AgentError};
|
||||||
use crate::protocol::WsOutbound;
|
use crate::protocol::WsOutbound;
|
||||||
use crate::tools::{CalculatorTool, ToolRegistry};
|
use crate::tools::{
|
||||||
|
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
|
||||||
|
HttpRequestTool, ToolRegistry, WebFetchTool,
|
||||||
|
};
|
||||||
|
|
||||||
/// Session 按 channel 隔离,每个 channel 一个 Session
|
/// Session 按 channel 隔离,每个 channel 一个 Session
|
||||||
/// History 按 chat_id 隔离,由 Session 统一管理
|
/// History 按 chat_id 隔离,由 Session 统一管理
|
||||||
@ -56,6 +59,12 @@ impl Session {
|
|||||||
history.push(ChatMessage::user(content));
|
history.push(ChatMessage::user(content));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 添加带媒体的用户消息到指定 chat_id 的历史
|
||||||
|
pub fn add_user_message_with_media(&mut self, chat_id: &str, content: &str, media_refs: Vec<String>) {
|
||||||
|
let history = self.get_or_create_history(chat_id);
|
||||||
|
history.push(ChatMessage::user_with_media(content, media_refs));
|
||||||
|
}
|
||||||
|
|
||||||
/// 添加助手响应到指定 chat_id 的历史
|
/// 添加助手响应到指定 chat_id 的历史
|
||||||
pub fn add_assistant_message(&mut self, chat_id: &str, message: ChatMessage) {
|
pub fn add_assistant_message(&mut self, chat_id: &str, message: ChatMessage) {
|
||||||
if let Some(history) = self.chat_histories.get_mut(chat_id) {
|
if let Some(history) = self.chat_histories.get_mut(chat_id) {
|
||||||
@ -68,6 +77,7 @@ impl Session {
|
|||||||
if let Some(history) = self.chat_histories.get_mut(chat_id) {
|
if let Some(history) = self.chat_histories.get_mut(chat_id) {
|
||||||
let len = history.len();
|
let len = history.len();
|
||||||
history.clear();
|
history.clear();
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history cleared");
|
tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history cleared");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -76,6 +86,7 @@ impl Session {
|
|||||||
pub fn clear_all_history(&mut self) {
|
pub fn clear_all_history(&mut self) {
|
||||||
let total: usize = self.chat_histories.values().map(|h| h.len()).sum();
|
let total: usize = self.chat_histories.values().map(|h| h.len()).sum();
|
||||||
self.chat_histories.clear();
|
self.chat_histories.clear();
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(previous_total = total, "All chat histories cleared");
|
tracing::debug!(previous_total = total, "All chat histories cleared");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -106,6 +117,17 @@ struct SessionManagerInner {
|
|||||||
fn default_tools() -> ToolRegistry {
|
fn default_tools() -> ToolRegistry {
|
||||||
let mut registry = ToolRegistry::new();
|
let mut registry = ToolRegistry::new();
|
||||||
registry.register(CalculatorTool::new());
|
registry.register(CalculatorTool::new());
|
||||||
|
registry.register(FileReadTool::new());
|
||||||
|
registry.register(FileWriteTool::new());
|
||||||
|
registry.register(FileEditTool::new());
|
||||||
|
registry.register(BashTool::new());
|
||||||
|
registry.register(HttpRequestTool::new(
|
||||||
|
vec!["*".to_string()], // 允许所有域名,实际使用时建议限制
|
||||||
|
1_000_000, // max_response_size
|
||||||
|
30, // timeout_secs
|
||||||
|
false, // allow_private_hosts
|
||||||
|
));
|
||||||
|
registry.register(WebFetchTool::new(50_000, 30)); // max_chars, timeout_secs
|
||||||
registry
|
registry
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -139,6 +161,7 @@ impl SessionManager {
|
|||||||
false
|
false
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(channel = %channel_name, "Creating new session");
|
tracing::debug!(channel = %channel_name, "Creating new session");
|
||||||
true
|
true
|
||||||
};
|
};
|
||||||
@ -184,13 +207,21 @@ impl SessionManager {
|
|||||||
_sender_id: &str,
|
_sender_id: &str,
|
||||||
chat_id: &str,
|
chat_id: &str,
|
||||||
content: &str,
|
content: &str,
|
||||||
|
media: Vec<crate::bus::MediaItem>,
|
||||||
) -> Result<String, AgentError> {
|
) -> Result<String, AgentError> {
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
channel = %channel_name,
|
channel = %channel_name,
|
||||||
chat_id = %chat_id,
|
chat_id = %chat_id,
|
||||||
content_len = content.len(),
|
content_len = content.len(),
|
||||||
|
media_count = %media.len(),
|
||||||
"Routing message to agent"
|
"Routing message to agent"
|
||||||
);
|
);
|
||||||
|
for (i, m) in media.iter().enumerate() {
|
||||||
|
tracing::debug!(media_index = i, media_type = %m.media_type, path = %m.path, "Media in handle_message");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 确保 session 存在(可能需要重建)
|
// 确保 session 存在(可能需要重建)
|
||||||
self.ensure_session(channel_name).await?;
|
self.ensure_session(channel_name).await?;
|
||||||
@ -209,7 +240,14 @@ impl SessionManager {
|
|||||||
let mut session_guard = session.lock().await;
|
let mut session_guard = session.lock().await;
|
||||||
|
|
||||||
// 添加用户消息到历史
|
// 添加用户消息到历史
|
||||||
|
if media.is_empty() {
|
||||||
session_guard.add_user_message(chat_id, content);
|
session_guard.add_user_message(chat_id, content);
|
||||||
|
} else {
|
||||||
|
let media_refs: Vec<String> = media.iter().map(|m| m.path.clone()).collect();
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
tracing::debug!(media_count = %media.len(), media_refs = ?media_refs, "Adding user message with media");
|
||||||
|
session_guard.add_user_message_with_media(chat_id, content, media_refs);
|
||||||
|
}
|
||||||
|
|
||||||
// 获取完整历史
|
// 获取完整历史
|
||||||
let history = session_guard.get_or_create_history(chat_id).clone();
|
let history = session_guard.get_or_create_history(chat_id).clone();
|
||||||
@ -224,6 +262,7 @@ impl SessionManager {
|
|||||||
response
|
response
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
channel = %channel_name,
|
channel = %channel_name,
|
||||||
chat_id = %chat_id,
|
chat_id = %chat_id,
|
||||||
|
|||||||
@ -62,6 +62,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
|||||||
while let Some(msg) = receiver.recv().await {
|
while let Some(msg) = receiver.recv().await {
|
||||||
if let Ok(text) = serialize_outbound(&msg) {
|
if let Ok(text) = serialize_outbound(&msg) {
|
||||||
if ws_sender.send(WsMessage::Text(text.into())).await.is_err() {
|
if ws_sender.send(WsMessage::Text(text.into())).await.is_err() {
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(session_id = %session_id_for_sender, "WebSocket send error");
|
tracing::debug!(session_id = %session_id_for_sender, "WebSocket send error");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -91,6 +92,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(WsMessage::Close(_)) | Err(_) => {
|
Ok(WsMessage::Close(_)) | Err(_) => {
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(session_id = %session_id, "WebSocket closed");
|
tracing::debug!(session_id = %session_id, "WebSocket closed");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -145,6 +147,7 @@ async fn handle_inbound(session: &Arc<Mutex<Session>>, inbound: WsInbound) {
|
|||||||
|
|
||||||
match agent.process(history).await {
|
match agent.process(history).await {
|
||||||
Ok(response) => {
|
Ok(response) => {
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(chat_id = %chat_id, "Agent response sent");
|
tracing::debug!(chat_id = %chat_id, "Agent response sent");
|
||||||
// 添加助手响应到历史
|
// 添加助手响应到历史
|
||||||
session_guard.add_assistant_message(&chat_id, response.clone());
|
session_guard.add_assistant_message(&chat_id, response.clone());
|
||||||
|
|||||||
@ -3,9 +3,55 @@ use reqwest::Client;
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use crate::bus::message::ContentBlock;
|
||||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
||||||
use super::traits::Usage;
|
use super::traits::Usage;
|
||||||
|
|
||||||
|
fn serialize_content_blocks<S>(blocks: &[serde_json::Value], serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: serde::Serializer,
|
||||||
|
{
|
||||||
|
serializer.serialize_str(&serde_json::to_string(blocks).unwrap_or_else(|_| "[]".to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> {
|
||||||
|
blocks.iter().map(|b| match b {
|
||||||
|
ContentBlock::Text { text } => {
|
||||||
|
serde_json::json!({ "type": "text", "text": text })
|
||||||
|
}
|
||||||
|
ContentBlock::ImageUrl { image_url } => {
|
||||||
|
convert_image_url_to_anthropic(&image_url.url)
|
||||||
|
}
|
||||||
|
}).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_image_url_to_anthropic(url: &str) -> serde_json::Value {
|
||||||
|
// data:image/png;base64,... -> Anthropic image block
|
||||||
|
if let Some(caps) = regex::Regex::new(r"data:(image/\w+);base64,(.+)")
|
||||||
|
.ok()
|
||||||
|
.and_then(|re| re.captures(url))
|
||||||
|
{
|
||||||
|
let media_type = caps.get(1).map(|m| m.as_str()).unwrap_or("image/png");
|
||||||
|
let data = caps.get(2).map(|d| d.as_str()).unwrap_or("");
|
||||||
|
return serde_json::json!({
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": media_type,
|
||||||
|
"data": data
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
// Regular URL -> Anthropic image block with url source
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "url",
|
||||||
|
"url": url
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
pub struct AnthropicProvider {
|
pub struct AnthropicProvider {
|
||||||
client: Client,
|
client: Client,
|
||||||
name: String,
|
name: String,
|
||||||
@ -58,7 +104,8 @@ struct AnthropicRequest {
|
|||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
struct AnthropicMessage {
|
struct AnthropicMessage {
|
||||||
role: String,
|
role: String,
|
||||||
content: String,
|
#[serde(serialize_with = "serialize_content_blocks")]
|
||||||
|
content: Vec<serde_json::Value>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
@ -122,7 +169,7 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
.iter()
|
.iter()
|
||||||
.map(|m| AnthropicMessage {
|
.map(|m| AnthropicMessage {
|
||||||
role: m.role.clone(),
|
role: m.role.clone(),
|
||||||
content: m.content.clone(),
|
content: convert_content_blocks(&m.content),
|
||||||
})
|
})
|
||||||
.collect(),
|
.collect(),
|
||||||
max_tokens,
|
max_tokens,
|
||||||
|
|||||||
@ -1,12 +1,27 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde_json::json;
|
use serde_json::{json, Value};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use crate::bus::message::ContentBlock;
|
||||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
||||||
use super::traits::Usage;
|
use super::traits::Usage;
|
||||||
|
|
||||||
|
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
|
||||||
|
if blocks.len() == 1 {
|
||||||
|
if let ContentBlock::Text { text } = &blocks[0] {
|
||||||
|
return Value::String(text.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Value::Array(blocks.iter().map(|b| match b {
|
||||||
|
ContentBlock::Text { text } => json!({ "type": "text", "text": text }),
|
||||||
|
ContentBlock::ImageUrl { image_url } => {
|
||||||
|
json!({ "type": "image_url", "image_url": { "url": image_url.url } })
|
||||||
|
}
|
||||||
|
}).collect())
|
||||||
|
}
|
||||||
|
|
||||||
pub struct OpenAIProvider {
|
pub struct OpenAIProvider {
|
||||||
client: Client,
|
client: Client,
|
||||||
name: String,
|
name: String,
|
||||||
@ -107,14 +122,14 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
if m.role == "tool" {
|
if m.role == "tool" {
|
||||||
json!({
|
json!({
|
||||||
"role": m.role,
|
"role": m.role,
|
||||||
"content": m.content,
|
"content": convert_content_blocks(&m.content),
|
||||||
"tool_call_id": m.tool_call_id,
|
"tool_call_id": m.tool_call_id,
|
||||||
"name": m.name,
|
"name": m.name,
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
json!({
|
json!({
|
||||||
"role": m.role,
|
"role": m.role,
|
||||||
"content": m.content
|
"content": convert_content_blocks(&m.content)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}).collect::<Vec<_>>(),
|
}).collect::<Vec<_>>(),
|
||||||
@ -131,6 +146,30 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
body["tools"] = json!(tools);
|
body["tools"] = json!(tools);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Debug: Log LLM request summary (only in debug builds)
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
// Log messages summary
|
||||||
|
let msg_count = body["messages"].as_array().map(|a| a.len()).unwrap_or(0);
|
||||||
|
tracing::debug!(msg_count = msg_count, "LLM request messages count");
|
||||||
|
|
||||||
|
// Log first 20 bytes of base64 images (don't log full base64)
|
||||||
|
if let Some(msgs) = body["messages"].as_array() {
|
||||||
|
for (i, msg) in msgs.iter().enumerate() {
|
||||||
|
if let Some(content) = msg.get("content").and_then(|c| c.as_array()) {
|
||||||
|
for (j, item) in content.iter().enumerate() {
|
||||||
|
if item.get("type").and_then(|t| t.as_str()) == Some("image_url") {
|
||||||
|
if let Some(url_str) = item.get("image_url").and_then(|u| u.get("url")).and_then(|v| v.as_str()) {
|
||||||
|
let prefix: String = url_str.chars().take(20).collect();
|
||||||
|
tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let mut req_builder = self
|
let mut req_builder = self
|
||||||
.client
|
.client
|
||||||
.post(&url)
|
.post(&url)
|
||||||
@ -146,6 +185,13 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
let status = resp.status();
|
let status = resp.status();
|
||||||
let text = resp.text().await?;
|
let text = resp.text().await?;
|
||||||
|
|
||||||
|
// Debug: Log LLM response (only in debug builds)
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
let resp_preview: String = text.chars().take(100).collect();
|
||||||
|
tracing::debug!(status = %status, response_preview = %resp_preview, response_len = %text.len(), "LLM response (first 100 chars shown)");
|
||||||
|
}
|
||||||
|
|
||||||
if !status.is_success() {
|
if !status.is_success() {
|
||||||
return Err(format!("API error {}: {}", status, text).into());
|
return Err(format!("API error {}: {}", status, text).into());
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,16 +1,64 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use crate::bus::message::ContentBlock;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct Message {
|
pub struct Message {
|
||||||
pub role: String,
|
pub role: String,
|
||||||
pub content: String,
|
pub content: Vec<ContentBlock>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub tool_call_id: Option<String>,
|
pub tool_call_id: Option<String>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub name: Option<String>,
|
pub name: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Message {
|
||||||
|
pub fn user(content: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: vec![ContentBlock::text(content)],
|
||||||
|
tool_call_id: None,
|
||||||
|
name: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn user_with_blocks(content: Vec<ContentBlock>) -> Self {
|
||||||
|
Self {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content,
|
||||||
|
tool_call_id: None,
|
||||||
|
name: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn assistant(content: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: vec![ContentBlock::text(content)],
|
||||||
|
tool_call_id: None,
|
||||||
|
name: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn system(content: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
role: "system".to_string(),
|
||||||
|
content: vec![ContentBlock::text(content)],
|
||||||
|
tool_call_id: None,
|
||||||
|
name: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
role: "tool".to_string(),
|
||||||
|
content: vec![ContentBlock::text(content)],
|
||||||
|
tool_call_id: Some(tool_call_id.into()),
|
||||||
|
name: Some(tool_name.into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct Tool {
|
pub struct Tool {
|
||||||
#[serde(rename = "type")]
|
#[serde(rename = "type")]
|
||||||
|
|||||||
315
src/tools/bash.rs
Normal file
315
src/tools/bash.rs
Normal file
@ -0,0 +1,315 @@
|
|||||||
|
use std::path::Path;
|
||||||
|
use std::process::Stdio;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde_json::json;
|
||||||
|
use tokio::io::AsyncReadExt;
|
||||||
|
use tokio::process::Command;
|
||||||
|
use tokio::time::timeout;
|
||||||
|
|
||||||
|
use crate::tools::traits::{Tool, ToolResult};
|
||||||
|
|
||||||
|
const MAX_TIMEOUT_SECS: u64 = 600;
|
||||||
|
const MAX_OUTPUT_CHARS: usize = 50_000;
|
||||||
|
|
||||||
|
pub struct BashTool {
|
||||||
|
timeout_secs: u64,
|
||||||
|
working_dir: Option<String>,
|
||||||
|
deny_patterns: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BashTool {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
timeout_secs: 60,
|
||||||
|
working_dir: None,
|
||||||
|
deny_patterns: vec![
|
||||||
|
r"\brm\s+-[rf]{1,2}\b".to_string(),
|
||||||
|
r"\bdel\s+/[fq]\b".to_string(),
|
||||||
|
r"\brmdir\s+/s\b".to_string(),
|
||||||
|
r":\(\)\s*\{.*\};\s*:".to_string(),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
|
||||||
|
self.timeout_secs = timeout_secs;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_working_dir(mut self, dir: String) -> Self {
|
||||||
|
self.working_dir = Some(dir);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
fn guard_command(&self, command: &str) -> Option<String> {
|
||||||
|
let lower = command.to_lowercase();
|
||||||
|
for pattern in &self.deny_patterns {
|
||||||
|
if regex::Regex::new(pattern)
|
||||||
|
.ok()
|
||||||
|
.map(|re| re.is_match(&lower))
|
||||||
|
.unwrap_or(false)
|
||||||
|
{
|
||||||
|
return Some(format!(
|
||||||
|
"Command blocked by safety guard (dangerous pattern: {})",
|
||||||
|
pattern
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn truncate_output(&self, output: &str) -> String {
|
||||||
|
if output.len() <= MAX_OUTPUT_CHARS {
|
||||||
|
return output.to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
let half = MAX_OUTPUT_CHARS / 2;
|
||||||
|
format!(
|
||||||
|
"{}...\n\n(... {} chars truncated ...)\n\n{}",
|
||||||
|
&output[..half],
|
||||||
|
output.len() - MAX_OUTPUT_CHARS,
|
||||||
|
&output[output.len() - half..]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for BashTool {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for BashTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"bash"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Execute a bash shell command and return its output. Use with caution."
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"command": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The shell command to execute"
|
||||||
|
},
|
||||||
|
"timeout": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": format!("Timeout in seconds (default {}, max {})", self.timeout_secs, MAX_TIMEOUT_SECS),
|
||||||
|
"minimum": 1,
|
||||||
|
"maximum": MAX_TIMEOUT_SECS
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["command"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn exclusive(&self) -> bool {
|
||||||
|
true // Shell commands should not run concurrently
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||||
|
let command = match args.get("command").and_then(|v| v.as_str()) {
|
||||||
|
Some(c) => c,
|
||||||
|
None => {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some("Missing required parameter: command".to_string()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Safety check
|
||||||
|
if let Some(error) = self.guard_command(command) {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(error),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let timeout_secs = args
|
||||||
|
.get("timeout")
|
||||||
|
.and_then(|v| v.as_u64())
|
||||||
|
.unwrap_or(self.timeout_secs)
|
||||||
|
.min(MAX_TIMEOUT_SECS);
|
||||||
|
|
||||||
|
let cwd = self
|
||||||
|
.working_dir
|
||||||
|
.as_ref()
|
||||||
|
.map(|d| Path::new(d))
|
||||||
|
.unwrap_or_else(|| Path::new("."));
|
||||||
|
|
||||||
|
let result = timeout(
|
||||||
|
Duration::from_secs(timeout_secs),
|
||||||
|
self.run_command(command, cwd),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(Ok(output)) => Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output,
|
||||||
|
error: None,
|
||||||
|
}),
|
||||||
|
Ok(Err(e)) => Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(e),
|
||||||
|
}),
|
||||||
|
Err(_) => Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(format!(
|
||||||
|
"Command timed out after {} seconds",
|
||||||
|
timeout_secs
|
||||||
|
)),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BashTool {
|
||||||
|
async fn run_command(&self, command: &str, cwd: &Path) -> Result<String, String> {
|
||||||
|
let mut cmd = Command::new("bash");
|
||||||
|
cmd.args(["-c", command])
|
||||||
|
.stdout(Stdio::piped())
|
||||||
|
.stderr(Stdio::piped())
|
||||||
|
.current_dir(cwd);
|
||||||
|
|
||||||
|
let mut child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?;
|
||||||
|
|
||||||
|
let mut stdout = Vec::new();
|
||||||
|
let mut stderr = Vec::new();
|
||||||
|
|
||||||
|
if let Some(ref mut out) = child.stdout {
|
||||||
|
out.read_to_end(&mut stdout)
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to read stdout: {}", e))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref mut err) = child.stderr {
|
||||||
|
err.read_to_end(&mut stderr)
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to read stderr: {}", e))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let status = child
|
||||||
|
.wait()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to wait: {}", e))?;
|
||||||
|
|
||||||
|
let mut output = String::new();
|
||||||
|
|
||||||
|
if !stdout.is_empty() {
|
||||||
|
let stdout_str = String::from_utf8_lossy(&stdout);
|
||||||
|
output.push_str(&stdout_str);
|
||||||
|
}
|
||||||
|
|
||||||
|
if !stderr.is_empty() {
|
||||||
|
let stderr_str = String::from_utf8_lossy(&stderr);
|
||||||
|
if !stderr_str.trim().is_empty() {
|
||||||
|
if !output.is_empty() {
|
||||||
|
output.push_str("\n");
|
||||||
|
}
|
||||||
|
output.push_str("STDERR:\n");
|
||||||
|
output.push_str(&stderr_str);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
output.push_str(&format!("\nExit code: {}", status.code().unwrap_or(-1)));
|
||||||
|
|
||||||
|
Ok(self.truncate_output(&output))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_simple_command() {
|
||||||
|
let tool = BashTool::new();
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({ "command": "echo 'Hello World'" }))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(result.success);
|
||||||
|
assert!(result.output.contains("Hello World"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_pwd_command() {
|
||||||
|
let tool = BashTool::new();
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({ "command": "pwd" }))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(result.success);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_ls_command() {
|
||||||
|
let tool = BashTool::new();
|
||||||
|
let result = tool.execute(json!({ "command": "ls -la /tmp" })).await.unwrap();
|
||||||
|
|
||||||
|
assert!(result.success);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_dangerous_rm() {
|
||||||
|
let tool = BashTool::new();
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({ "command": "rm -rf /" }))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(!result.success);
|
||||||
|
assert!(result.error.unwrap().contains("blocked"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_dangerous_fork_bomb() {
|
||||||
|
let tool = BashTool::new();
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({ "command": ":(){ :|:& };:" }))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(!result.success);
|
||||||
|
assert!(result.error.unwrap().contains("blocked"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_missing_command() {
|
||||||
|
let tool = BashTool::new();
|
||||||
|
let result = tool.execute(json!({})).await.unwrap();
|
||||||
|
|
||||||
|
assert!(!result.success);
|
||||||
|
assert!(result.error.unwrap().contains("command"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_timeout() {
|
||||||
|
let tool = BashTool::new();
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({
|
||||||
|
"command": "sleep 10",
|
||||||
|
"timeout": 1
|
||||||
|
}))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(!result.success);
|
||||||
|
assert!(result.error.unwrap().contains("timed out"));
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -92,6 +92,10 @@ impl Tool for CalculatorTool {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn read_only(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||||
let function = match args.get("function").and_then(|v| v.as_str()) {
|
let function = match args.get("function").and_then(|v| v.as_str()) {
|
||||||
Some(f) => f,
|
Some(f) => f,
|
||||||
|
|||||||
381
src/tools/file_edit.rs
Normal file
381
src/tools/file_edit.rs
Normal file
@ -0,0 +1,381 @@
|
|||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
use crate::tools::traits::{Tool, ToolResult};
|
||||||
|
|
||||||
|
pub struct FileEditTool {
|
||||||
|
allowed_dir: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FileEditTool {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self { allowed_dir: None }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_allowed_dir(dir: String) -> Self {
|
||||||
|
Self {
|
||||||
|
allowed_dir: Some(dir),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn resolve_path(&self, path: &str) -> Result<std::path::PathBuf, String> {
|
||||||
|
let p = Path::new(path);
|
||||||
|
let resolved = if p.is_absolute() {
|
||||||
|
p.to_path_buf()
|
||||||
|
} else {
|
||||||
|
std::env::current_dir()
|
||||||
|
.map_err(|e| format!("Failed to get current directory: {}", e))?
|
||||||
|
.join(p)
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check directory restriction
|
||||||
|
if let Some(ref allowed) = self.allowed_dir {
|
||||||
|
let allowed_path = Path::new(allowed);
|
||||||
|
if !resolved.starts_with(allowed_path) {
|
||||||
|
return Err(format!(
|
||||||
|
"Path '{}' is outside allowed directory '{}'",
|
||||||
|
path, allowed
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(resolved)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn find_match(&self, content: &str, old_text: &str) -> Option<(String, usize)> {
|
||||||
|
// Try exact match first
|
||||||
|
if content.contains(old_text) {
|
||||||
|
let count = content.matches(old_text).count();
|
||||||
|
return Some((old_text.to_string(), count));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try line-based matching for minor differences
|
||||||
|
let old_lines: Vec<&str> = old_text.lines().collect();
|
||||||
|
if old_lines.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let content_lines: Vec<&str> = content.lines().collect();
|
||||||
|
|
||||||
|
for i in 0..content_lines.len().saturating_sub(old_lines.len()) {
|
||||||
|
let window = &content_lines[i..i + old_lines.len()];
|
||||||
|
let stripped_old: Vec<&str> = old_lines.iter().map(|l| l.trim()).collect();
|
||||||
|
let stripped_window: Vec<&str> = window.iter().map(|l| l.trim()).collect();
|
||||||
|
|
||||||
|
if stripped_old == stripped_window {
|
||||||
|
let matched_text = window.join("\n");
|
||||||
|
return Some((matched_text, 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for FileEditTool {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for FileEditTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"file_edit"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Edit a file by replacing old_text with new_text. Supports minor whitespace differences. Set replace_all=true to replace every occurrence."
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The file path to edit"
|
||||||
|
},
|
||||||
|
"old_text": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The text to find and replace"
|
||||||
|
},
|
||||||
|
"new_text": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The text to replace with"
|
||||||
|
},
|
||||||
|
"replace_all": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Replace all occurrences (default false)",
|
||||||
|
"default": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["path", "old_text", "new_text"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||||
|
let path = match args.get("path").and_then(|v| v.as_str()) {
|
||||||
|
Some(p) => p,
|
||||||
|
None => {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some("Missing required parameter: path".to_string()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let old_text = match args.get("old_text").and_then(|v| v.as_str()) {
|
||||||
|
Some(t) => t,
|
||||||
|
None => {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some("Missing required parameter: old_text".to_string()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let new_text = match args.get("new_text").and_then(|v| v.as_str()) {
|
||||||
|
Some(t) => t,
|
||||||
|
None => {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some("Missing required parameter: new_text".to_string()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let replace_all = args
|
||||||
|
.get("replace_all")
|
||||||
|
.and_then(|v| v.as_bool())
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
let resolved = match self.resolve_path(path) {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(e) => {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(e),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if !resolved.exists() {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(format!("File not found: {}", path)),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if !resolved.is_file() {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(format!("Not a file: {}", path)),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read file content
|
||||||
|
let content = match std::fs::read_to_string(&resolved) {
|
||||||
|
Ok(c) => c,
|
||||||
|
Err(e) => {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(format!("Failed to read file: {}", e)),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Detect line ending style
|
||||||
|
let uses_crlf = content.contains("\r\n");
|
||||||
|
let norm_content = content.replace("\r\n", "\n");
|
||||||
|
let norm_old = old_text.replace("\r\n", "\n");
|
||||||
|
|
||||||
|
// Find match
|
||||||
|
let (matched_text, count) = match self.find_match(&norm_content, &norm_old) {
|
||||||
|
Some(m) => m,
|
||||||
|
None => {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(format!(
|
||||||
|
"old_text not found in {}. Verify the file content.",
|
||||||
|
path
|
||||||
|
)),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Warn if multiple matches but replace_all is false
|
||||||
|
if count > 1 && !replace_all {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(format!(
|
||||||
|
"old_text appears {} times. Provide more context to make it unique, or set replace_all=true.",
|
||||||
|
count
|
||||||
|
)),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform replacement
|
||||||
|
let norm_new = new_text.replace("\r\n", "\n");
|
||||||
|
let new_content = if replace_all {
|
||||||
|
norm_content.replace(&matched_text, &norm_new)
|
||||||
|
} else {
|
||||||
|
norm_content.replacen(&matched_text, &norm_new, 1)
|
||||||
|
};
|
||||||
|
|
||||||
|
// Restore line endings if needed
|
||||||
|
let final_content = if uses_crlf {
|
||||||
|
new_content.replace("\n", "\r\n")
|
||||||
|
} else {
|
||||||
|
new_content
|
||||||
|
};
|
||||||
|
|
||||||
|
// Write back
|
||||||
|
match std::fs::write(&resolved, &final_content) {
|
||||||
|
Ok(_) => {
|
||||||
|
let replacements = if replace_all { count } else { 1 };
|
||||||
|
Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: format!(
|
||||||
|
"Successfully edited {} ({} replacement{} made)",
|
||||||
|
resolved.display(),
|
||||||
|
replacements,
|
||||||
|
if replacements == 1 { "" } else { "s" }
|
||||||
|
),
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Err(e) => Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(format!("Failed to write file: {}", e)),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use tempfile::NamedTempFile;
|
||||||
|
use std::io::Write;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_edit_simple() {
|
||||||
|
let mut file = NamedTempFile::new().unwrap();
|
||||||
|
writeln!(file, "Hello World").unwrap();
|
||||||
|
writeln!(file, "Test line").unwrap();
|
||||||
|
|
||||||
|
let tool = FileEditTool::new();
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({
|
||||||
|
"path": file.path().to_str().unwrap(),
|
||||||
|
"old_text": "Hello World",
|
||||||
|
"new_text": "Hello Universe"
|
||||||
|
}))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(result.success);
|
||||||
|
|
||||||
|
let content = std::fs::read_to_string(file.path()).unwrap();
|
||||||
|
assert!(content.contains("Hello Universe"));
|
||||||
|
assert!(!content.contains("Hello World"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_edit_replace_all() {
|
||||||
|
let mut file = NamedTempFile::new().unwrap();
|
||||||
|
writeln!(file, "foo bar").unwrap();
|
||||||
|
writeln!(file, "foo baz").unwrap();
|
||||||
|
writeln!(file, "foo qux").unwrap();
|
||||||
|
|
||||||
|
let tool = FileEditTool::new();
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({
|
||||||
|
"path": file.path().to_str().unwrap(),
|
||||||
|
"old_text": "foo",
|
||||||
|
"new_text": "bar",
|
||||||
|
"replace_all": true
|
||||||
|
}))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(result.success);
|
||||||
|
|
||||||
|
let content = std::fs::read_to_string(file.path()).unwrap();
|
||||||
|
assert!(content.contains("bar bar"));
|
||||||
|
assert!(content.contains("bar baz"));
|
||||||
|
assert!(content.contains("bar qux"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_edit_file_not_found() {
|
||||||
|
let tool = FileEditTool::new();
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({
|
||||||
|
"path": "/nonexistent/file.txt",
|
||||||
|
"old_text": "old",
|
||||||
|
"new_text": "new"
|
||||||
|
}))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(!result.success);
|
||||||
|
assert!(result.error.unwrap().contains("not found"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_edit_old_text_not_found() {
|
||||||
|
let mut file = NamedTempFile::new().unwrap();
|
||||||
|
writeln!(file, "Hello World").unwrap();
|
||||||
|
|
||||||
|
let tool = FileEditTool::new();
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({
|
||||||
|
"path": file.path().to_str().unwrap(),
|
||||||
|
"old_text": "NonExistent",
|
||||||
|
"new_text": "New"
|
||||||
|
}))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(!result.success);
|
||||||
|
assert!(result.error.unwrap().contains("not found"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_edit_multiline() {
|
||||||
|
let mut file = NamedTempFile::new().unwrap();
|
||||||
|
writeln!(file, "Line 1").unwrap();
|
||||||
|
writeln!(file, "Line 2").unwrap();
|
||||||
|
writeln!(file, "Line 3").unwrap();
|
||||||
|
|
||||||
|
let tool = FileEditTool::new();
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({
|
||||||
|
"path": file.path().to_str().unwrap(),
|
||||||
|
"old_text": "Line 1\nLine 2",
|
||||||
|
"new_text": "New Line 1\nNew Line 2"
|
||||||
|
}))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(result.success);
|
||||||
|
|
||||||
|
let content = std::fs::read_to_string(file.path()).unwrap();
|
||||||
|
assert!(content.contains("New Line 1"));
|
||||||
|
assert!(content.contains("New Line 2"));
|
||||||
|
}
|
||||||
|
}
|
||||||
321
src/tools/file_read.rs
Normal file
321
src/tools/file_read.rs
Normal file
@ -0,0 +1,321 @@
|
|||||||
|
use std::io::Read;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
use crate::bus::message::ContentBlock;
|
||||||
|
use crate::tools::traits::{Tool, ToolResult};
|
||||||
|
|
||||||
|
const MAX_CHARS: usize = 128_000;
|
||||||
|
const DEFAULT_LIMIT: usize = 2000;
|
||||||
|
|
||||||
|
pub struct FileReadTool {
|
||||||
|
allowed_dir: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FileReadTool {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self { allowed_dir: None }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_allowed_dir(dir: String) -> Self {
|
||||||
|
Self {
|
||||||
|
allowed_dir: Some(dir),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn resolve_path(&self, path: &str) -> Result<std::path::PathBuf, String> {
|
||||||
|
let p = Path::new(path);
|
||||||
|
let resolved = if p.is_absolute() {
|
||||||
|
p.to_path_buf()
|
||||||
|
} else {
|
||||||
|
std::env::current_dir()
|
||||||
|
.map_err(|e| format!("Failed to get current directory: {}", e))?
|
||||||
|
.join(p)
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check directory restriction
|
||||||
|
if let Some(ref allowed) = self.allowed_dir {
|
||||||
|
let allowed_path = Path::new(allowed);
|
||||||
|
if !resolved.starts_with(allowed_path) {
|
||||||
|
return Err(format!(
|
||||||
|
"Path '{}' is outside allowed directory '{}'",
|
||||||
|
path, allowed
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(resolved)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for FileReadTool {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for FileReadTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"file_read"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Read the contents of a file. Returns numbered lines. Use offset and limit to paginate through large files."
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The file path to read"
|
||||||
|
},
|
||||||
|
"offset": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Line number to start reading from (1-indexed, default 1)",
|
||||||
|
"minimum": 1
|
||||||
|
},
|
||||||
|
"limit": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Maximum number of lines to read (default 2000)",
|
||||||
|
"minimum": 1
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["path"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read_only(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||||
|
let path = match args.get("path").and_then(|v| v.as_str()) {
|
||||||
|
Some(p) => p,
|
||||||
|
None => {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some("Missing required parameter: path".to_string()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let offset = args
|
||||||
|
.get("offset")
|
||||||
|
.and_then(|v| v.as_u64())
|
||||||
|
.map(|v| v as usize)
|
||||||
|
.unwrap_or(1);
|
||||||
|
|
||||||
|
let limit = args
|
||||||
|
.get("limit")
|
||||||
|
.and_then(|v| v.as_u64())
|
||||||
|
.map(|v| v as usize)
|
||||||
|
.unwrap_or(DEFAULT_LIMIT);
|
||||||
|
|
||||||
|
let resolved = match self.resolve_path(path) {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(e) => {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(e),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if !resolved.exists() {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(format!("File not found: {}", path)),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if !resolved.is_file() {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(format!("Not a file: {}", path)),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to read as text
|
||||||
|
match std::fs::read_to_string(&resolved) {
|
||||||
|
Ok(content) => {
|
||||||
|
let all_lines: Vec<&str> = content.lines().collect();
|
||||||
|
let total = all_lines.len();
|
||||||
|
|
||||||
|
if offset < 1 {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(format!("offset must be at least 1, got {}", offset)),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if offset > total {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(format!(
|
||||||
|
"offset {} is beyond end of file ({} lines)",
|
||||||
|
offset, total
|
||||||
|
)),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let start = offset - 1;
|
||||||
|
let end = std::cmp::min(start + limit, total);
|
||||||
|
let lines: Vec<String> = all_lines[start..end]
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, line)| format!("{}| {}", start + i + 1, line))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let mut result = lines.join("\n");
|
||||||
|
|
||||||
|
// Truncate if too long
|
||||||
|
if result.len() > MAX_CHARS {
|
||||||
|
let mut truncated_chars = 0;
|
||||||
|
let mut end_idx = 0;
|
||||||
|
for (i, line) in lines.iter().enumerate() {
|
||||||
|
truncated_chars += line.len() + 1;
|
||||||
|
if truncated_chars > MAX_CHARS {
|
||||||
|
end_idx = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
end_idx = i + 1;
|
||||||
|
}
|
||||||
|
result = lines[..end_idx].join("\n");
|
||||||
|
result.push_str(&format!(
|
||||||
|
"\n\n... ({} chars truncated) ...",
|
||||||
|
result.len() - MAX_CHARS
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if end < total {
|
||||||
|
result.push_str(&format!(
|
||||||
|
"\n\n(Showing lines {}-{} of {}. Use offset={} to continue.)",
|
||||||
|
offset,
|
||||||
|
end,
|
||||||
|
total,
|
||||||
|
end + 1
|
||||||
|
));
|
||||||
|
} else {
|
||||||
|
result.push_str(&format!("\n\n(End of file — {} lines total)", total));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: result,
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
// Try to read as binary and encode as base64
|
||||||
|
match std::fs::read(&resolved) {
|
||||||
|
Ok(bytes) => {
|
||||||
|
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||||
|
let encoded = STANDARD.encode(&bytes);
|
||||||
|
let mime = mime_guess::from_path(&resolved)
|
||||||
|
.first_or_octet_stream()
|
||||||
|
.to_string();
|
||||||
|
Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: format!(
|
||||||
|
"(Binary file: {}, {} bytes, base64 encoded)\n{}",
|
||||||
|
mime,
|
||||||
|
bytes.len(),
|
||||||
|
encoded
|
||||||
|
),
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Err(_) => Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(format!("Failed to read file: {}", e)),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use tempfile::NamedTempFile;
|
||||||
|
use std::io::Write;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_read_simple_file() {
|
||||||
|
let mut file = NamedTempFile::new().unwrap();
|
||||||
|
writeln!(file, "Line 1").unwrap();
|
||||||
|
writeln!(file, "Line 2").unwrap();
|
||||||
|
writeln!(file, "Line 3").unwrap();
|
||||||
|
|
||||||
|
let tool = FileReadTool::new();
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({ "path": file.path().to_str().unwrap() }))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(result.success);
|
||||||
|
assert!(result.output.contains("Line 1"));
|
||||||
|
assert!(result.output.contains("Line 2"));
|
||||||
|
assert!(result.output.contains("Line 3"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_read_with_offset_limit() {
|
||||||
|
let mut file = NamedTempFile::new().unwrap();
|
||||||
|
for i in 1..=10 {
|
||||||
|
writeln!(file, "Line {}", i).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let tool = FileReadTool::new();
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({
|
||||||
|
"path": file.path().to_str().unwrap(),
|
||||||
|
"offset": 3,
|
||||||
|
"limit": 2
|
||||||
|
}))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(result.success);
|
||||||
|
assert!(result.output.contains("Line 3"));
|
||||||
|
assert!(result.output.contains("Line 4"));
|
||||||
|
assert!(!result.output.contains("Line 2"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_file_not_found() {
|
||||||
|
let tool = FileReadTool::new();
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({ "path": "/nonexistent/file.txt" }))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(!result.success);
|
||||||
|
assert!(result.error.unwrap().contains("not found"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_is_directory() {
|
||||||
|
let tool = FileReadTool::new();
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({ "path": "." }))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(!result.success);
|
||||||
|
assert!(result.error.unwrap().contains("Not a file"));
|
||||||
|
}
|
||||||
|
}
|
||||||
242
src/tools/file_write.rs
Normal file
242
src/tools/file_write.rs
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
use crate::tools::traits::{Tool, ToolResult};
|
||||||
|
|
||||||
|
pub struct FileWriteTool {
|
||||||
|
allowed_dir: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FileWriteTool {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self { allowed_dir: None }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_allowed_dir(dir: String) -> Self {
|
||||||
|
Self {
|
||||||
|
allowed_dir: Some(dir),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn resolve_path(&self, path: &str) -> Result<std::path::PathBuf, String> {
|
||||||
|
let p = Path::new(path);
|
||||||
|
let resolved = if p.is_absolute() {
|
||||||
|
p.to_path_buf()
|
||||||
|
} else {
|
||||||
|
std::env::current_dir()
|
||||||
|
.map_err(|e| format!("Failed to get current directory: {}", e))?
|
||||||
|
.join(p)
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check directory restriction
|
||||||
|
if let Some(ref allowed) = self.allowed_dir {
|
||||||
|
let allowed_path = Path::new(allowed);
|
||||||
|
if !resolved.starts_with(allowed_path) {
|
||||||
|
return Err(format!(
|
||||||
|
"Path '{}' is outside allowed directory '{}'",
|
||||||
|
path, allowed
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(resolved)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for FileWriteTool {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for FileWriteTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"file_write"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Write content to a file at the given path. Creates parent directories if needed."
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The file path to write to"
|
||||||
|
},
|
||||||
|
"content": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The content to write"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["path", "content"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||||
|
let path = match args.get("path").and_then(|v| v.as_str()) {
|
||||||
|
Some(p) => p,
|
||||||
|
None => {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some("Missing required parameter: path".to_string()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let content = match args.get("content").and_then(|v| v.as_str()) {
|
||||||
|
Some(c) => c,
|
||||||
|
None => {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some("Missing required parameter: content".to_string()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let resolved = match self.resolve_path(path) {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(e) => {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(e),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create parent directories if needed
|
||||||
|
if let Some(parent) = resolved.parent() {
|
||||||
|
if !parent.exists() {
|
||||||
|
if let Err(e) = std::fs::create_dir_all(parent) {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(format!("Failed to create parent directory: {}", e)),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match std::fs::write(&resolved, content) {
|
||||||
|
Ok(_) => Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: format!(
|
||||||
|
"Successfully wrote {} bytes to {}",
|
||||||
|
content.len(),
|
||||||
|
resolved.display()
|
||||||
|
),
|
||||||
|
error: None,
|
||||||
|
}),
|
||||||
|
Err(e) => Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(format!("Failed to write file: {}", e)),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use tempfile::TempDir;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_write_simple_file() {
|
||||||
|
let temp_dir = TempDir::new().unwrap();
|
||||||
|
let file_path = temp_dir.path().join("test.txt");
|
||||||
|
|
||||||
|
let tool = FileWriteTool::new();
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({
|
||||||
|
"path": file_path.to_str().unwrap(),
|
||||||
|
"content": "Hello, World!"
|
||||||
|
}))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(result.success);
|
||||||
|
assert!(result.output.contains("Successfully wrote"));
|
||||||
|
|
||||||
|
// Verify content
|
||||||
|
let read_content = std::fs::read_to_string(&file_path).unwrap();
|
||||||
|
assert_eq!(read_content, "Hello, World!");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_write_creates_parent_dirs() {
|
||||||
|
let temp_dir = TempDir::new().unwrap();
|
||||||
|
let file_path = temp_dir.path().join("subdir1/subdir2/test.txt");
|
||||||
|
|
||||||
|
let tool = FileWriteTool::new();
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({
|
||||||
|
"path": file_path.to_str().unwrap(),
|
||||||
|
"content": "Nested content"
|
||||||
|
}))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(result.success);
|
||||||
|
|
||||||
|
// Verify content
|
||||||
|
let read_content = std::fs::read_to_string(&file_path).unwrap();
|
||||||
|
assert_eq!(read_content, "Nested content");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_write_missing_path() {
|
||||||
|
let tool = FileWriteTool::new();
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({ "content": "Hello" }))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(!result.success);
|
||||||
|
assert!(result.error.unwrap().contains("path"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_write_missing_content() {
|
||||||
|
let tool = FileWriteTool::new();
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({ "path": "/tmp/test.txt" }))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(!result.success);
|
||||||
|
assert!(result.error.unwrap().contains("content"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_overwrite_file() {
|
||||||
|
let temp_dir = TempDir::new().unwrap();
|
||||||
|
let file_path = temp_dir.path().join("test.txt");
|
||||||
|
|
||||||
|
// Write initial content
|
||||||
|
std::fs::write(&file_path, "Initial content").unwrap();
|
||||||
|
|
||||||
|
let tool = FileWriteTool::new();
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({
|
||||||
|
"path": file_path.to_str().unwrap(),
|
||||||
|
"content": "New content"
|
||||||
|
}))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(result.success);
|
||||||
|
|
||||||
|
// Verify overwritten
|
||||||
|
let read_content = std::fs::read_to_string(&file_path).unwrap();
|
||||||
|
assert_eq!(read_content, "New content");
|
||||||
|
}
|
||||||
|
}
|
||||||
444
src/tools/http_request.rs
Normal file
444
src/tools/http_request.rs
Normal file
@ -0,0 +1,444 @@
|
|||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use reqwest::header::HeaderMap;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
use crate::tools::traits::{Tool, ToolResult};
|
||||||
|
|
||||||
|
pub struct HttpRequestTool {
|
||||||
|
allowed_domains: Vec<String>,
|
||||||
|
max_response_size: usize,
|
||||||
|
timeout_secs: u64,
|
||||||
|
allow_private_hosts: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HttpRequestTool {
|
||||||
|
pub fn new(
|
||||||
|
allowed_domains: Vec<String>,
|
||||||
|
max_response_size: usize,
|
||||||
|
timeout_secs: u64,
|
||||||
|
allow_private_hosts: bool,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
allowed_domains: normalize_domains(allowed_domains),
|
||||||
|
max_response_size,
|
||||||
|
timeout_secs,
|
||||||
|
allow_private_hosts,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate_url(&self, url: &str) -> Result<String, String> {
|
||||||
|
let url = url.trim();
|
||||||
|
|
||||||
|
if url.is_empty() {
|
||||||
|
return Err("URL cannot be empty".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
if url.chars().any(char::is_whitespace) {
|
||||||
|
return Err("URL cannot contain whitespace".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
if !url.starts_with("http://") && !url.starts_with("https://") {
|
||||||
|
return Err("Only http:// and https:// URLs are allowed".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
let host = extract_host(url)?;
|
||||||
|
|
||||||
|
if !self.allow_private_hosts && is_private_host(&host) {
|
||||||
|
return Err(format!("Blocked local/private host: {}", host));
|
||||||
|
}
|
||||||
|
|
||||||
|
if !host_matches_allowlist(&host, &self.allowed_domains) {
|
||||||
|
return Err(format!(
|
||||||
|
"Host '{}' is not in allowed_domains",
|
||||||
|
host
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(url.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate_method(&self, method: &str) -> Result<reqwest::Method, String> {
|
||||||
|
match method.to_uppercase().as_str() {
|
||||||
|
"GET" => Ok(reqwest::Method::GET),
|
||||||
|
"POST" => Ok(reqwest::Method::POST),
|
||||||
|
"PUT" => Ok(reqwest::Method::PUT),
|
||||||
|
"DELETE" => Ok(reqwest::Method::DELETE),
|
||||||
|
"PATCH" => Ok(reqwest::Method::PATCH),
|
||||||
|
_ => Err(format!(
|
||||||
|
"Unsupported HTTP method: {}. Supported: GET, POST, PUT, DELETE, PATCH",
|
||||||
|
method
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_headers(&self, headers: &serde_json::Value) -> HeaderMap {
|
||||||
|
let mut header_map = HeaderMap::new();
|
||||||
|
|
||||||
|
if let Some(obj) = headers.as_object() {
|
||||||
|
for (key, value) in obj {
|
||||||
|
if let Some(str_val) = value.as_str() {
|
||||||
|
if let Ok(name) = reqwest::header::HeaderName::from_bytes(key.as_bytes()) {
|
||||||
|
if let Ok(val) =
|
||||||
|
reqwest::header::HeaderValue::from_str(str_val)
|
||||||
|
{
|
||||||
|
header_map.insert(name, val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
header_map
|
||||||
|
}
|
||||||
|
|
||||||
|
fn truncate_response(&self, text: &str) -> String {
|
||||||
|
if self.max_response_size == 0 {
|
||||||
|
return text.to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
if text.len() > self.max_response_size {
|
||||||
|
format!(
|
||||||
|
"{}\n\n... [Response truncated due to size limit] ...",
|
||||||
|
&text[..self.max_response_size]
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
text.to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalize_domains(domains: Vec<String>) -> Vec<String> {
|
||||||
|
let mut normalized: Vec<String> = domains
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|d| normalize_domain(&d))
|
||||||
|
.collect();
|
||||||
|
normalized.sort_unstable();
|
||||||
|
normalized.dedup();
|
||||||
|
normalized
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalize_domain(raw: &str) -> Option<String> {
|
||||||
|
let mut d = raw.trim().to_lowercase();
|
||||||
|
if d.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(stripped) = d.strip_prefix("https://") {
|
||||||
|
d = stripped.to_string();
|
||||||
|
} else if let Some(stripped) = d.strip_prefix("http://") {
|
||||||
|
d = stripped.to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some((host, _)) = d.split_once('/') {
|
||||||
|
d = host.to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
d = d.trim_start_matches('.').trim_end_matches('.').to_string();
|
||||||
|
|
||||||
|
if let Some((host, _)) = d.split_once(':') {
|
||||||
|
d = host.to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
if d.is_empty() || d.chars().any(char::is_whitespace) {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_host(url: &str) -> Result<String, String> {
|
||||||
|
let rest = url
|
||||||
|
.strip_prefix("http://")
|
||||||
|
.or_else(|| url.strip_prefix("https://"))
|
||||||
|
.ok_or_else(|| "Only http:// and https:// URLs are allowed".to_string())?;
|
||||||
|
|
||||||
|
let authority = rest
|
||||||
|
.split(['/', '?', '#'])
|
||||||
|
.next()
|
||||||
|
.ok_or_else(|| "Invalid URL".to_string())?;
|
||||||
|
|
||||||
|
if authority.is_empty() {
|
||||||
|
return Err("URL must include a host".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
if authority.contains('@') {
|
||||||
|
return Err("URL userinfo is not allowed".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
if authority.starts_with('[') {
|
||||||
|
return Err("IPv6 hosts are not supported".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
let host = authority
|
||||||
|
.split(':')
|
||||||
|
.next()
|
||||||
|
.unwrap_or_default()
|
||||||
|
.trim()
|
||||||
|
.trim_end_matches('.')
|
||||||
|
.to_lowercase();
|
||||||
|
|
||||||
|
if host.is_empty() {
|
||||||
|
return Err("URL must include a valid host".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(host)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool {
|
||||||
|
if allowed_domains.iter().any(|domain| domain == "*") {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
allowed_domains.iter().any(|domain| {
|
||||||
|
host == domain
|
||||||
|
|| host.strip_suffix(domain).is_some_and(|prefix| prefix.ends_with('.'))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_private_host(host: &str) -> bool {
|
||||||
|
// Check localhost
|
||||||
|
if host == "localhost" || host.ends_with(".localhost") {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check .local TLD
|
||||||
|
if host.rsplit('.').next().is_some_and(|label| label == "local") {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to parse as IP
|
||||||
|
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
|
||||||
|
return is_private_ip(&ip);
|
||||||
|
}
|
||||||
|
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_private_ip(ip: &std::net::IpAddr) -> bool {
|
||||||
|
match ip {
|
||||||
|
std::net::IpAddr::V4(v4) => {
|
||||||
|
v4.is_loopback()
|
||||||
|
|| v4.is_private()
|
||||||
|
|| v4.is_link_local()
|
||||||
|
|| v4.is_unspecified()
|
||||||
|
|| v4.is_broadcast()
|
||||||
|
|| v4.is_multicast()
|
||||||
|
}
|
||||||
|
std::net::IpAddr::V6(v6) => {
|
||||||
|
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for HttpRequestTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"http_request"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Make HTTP requests to external APIs. Supports GET, POST, PUT, DELETE, PATCH methods. Security: domain allowlist, no local/private hosts."
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"url": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "HTTP or HTTPS URL to request"
|
||||||
|
},
|
||||||
|
"method": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "HTTP method (GET, POST, PUT, DELETE, PATCH)",
|
||||||
|
"default": "GET"
|
||||||
|
},
|
||||||
|
"headers": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Optional HTTP headers as key-value pairs"
|
||||||
|
},
|
||||||
|
"body": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Optional request body"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["url"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||||
|
let url = match args.get("url").and_then(|v| v.as_str()) {
|
||||||
|
Some(u) => u,
|
||||||
|
None => {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some("Missing required parameter: url".to_string()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let method_str = args
|
||||||
|
.get("method")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("GET");
|
||||||
|
|
||||||
|
let headers_val = args.get("headers").cloned().unwrap_or(json!({}));
|
||||||
|
let body = args.get("body").and_then(|v| v.as_str());
|
||||||
|
|
||||||
|
let url = match self.validate_url(url) {
|
||||||
|
Ok(u) => u,
|
||||||
|
Err(e) => {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(e),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let method = match self.validate_method(method_str) {
|
||||||
|
Ok(m) => m,
|
||||||
|
Err(e) => {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(e),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let headers = self.parse_headers(&headers_val);
|
||||||
|
|
||||||
|
let client = match reqwest::Client::builder()
|
||||||
|
.timeout(Duration::from_secs(self.timeout_secs))
|
||||||
|
.build()
|
||||||
|
{
|
||||||
|
Ok(c) => c,
|
||||||
|
Err(e) => {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(format!("Failed to create HTTP client: {}", e)),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut request = client.request(method, &url).headers(headers);
|
||||||
|
|
||||||
|
if let Some(body_str) = body {
|
||||||
|
request = request.body(body_str.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
match request.send().await {
|
||||||
|
Ok(response) => {
|
||||||
|
let status = response.status();
|
||||||
|
let status_code = status.as_u16();
|
||||||
|
|
||||||
|
let response_text = response
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.map(|t| self.truncate_response(&t))
|
||||||
|
.unwrap_or_else(|_| "[Failed to read response body]".to_string());
|
||||||
|
|
||||||
|
let output = format!(
|
||||||
|
"Status: {} {}\n\nResponse Body:\n{}",
|
||||||
|
status_code,
|
||||||
|
status.canonical_reason().unwrap_or("Unknown"),
|
||||||
|
response_text
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(ToolResult {
|
||||||
|
success: status.is_success(),
|
||||||
|
output,
|
||||||
|
error: if status.is_client_error() || status.is_server_error() {
|
||||||
|
Some(format!("HTTP {}", status_code))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Err(e) => Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(format!("HTTP request failed: {}", e)),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn test_tool(domains: Vec<&str>) -> HttpRequestTool {
|
||||||
|
HttpRequestTool::new(
|
||||||
|
domains.into_iter().map(String::from).collect(),
|
||||||
|
1_000_000,
|
||||||
|
30,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_validate_url_success() {
|
||||||
|
let tool = test_tool(vec!["example.com"]);
|
||||||
|
let result = tool.validate_url("https://example.com/docs");
|
||||||
|
assert!(result.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_validate_url_rejects_private() {
|
||||||
|
let tool = test_tool(vec!["example.com"]);
|
||||||
|
let result = tool.validate_url("https://localhost:8080");
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result.unwrap_err().contains("local/private"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_validate_url_rejects_whitespace() {
|
||||||
|
let tool = test_tool(vec!["example.com"]);
|
||||||
|
let result = tool.validate_url("https://example.com/hello world");
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result.unwrap_err().contains("whitespace"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_validate_url_requires_allowlist() {
|
||||||
|
let tool = HttpRequestTool::new(vec![], 1_000_000, 30, false);
|
||||||
|
let result = tool.validate_url("https://example.com");
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result.unwrap_err().contains("allowed_domains"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_validate_method() {
|
||||||
|
let tool = test_tool(vec!["example.com"]);
|
||||||
|
assert!(tool.validate_method("GET").is_ok());
|
||||||
|
assert!(tool.validate_method("POST").is_ok());
|
||||||
|
assert!(tool.validate_method("PUT").is_ok());
|
||||||
|
assert!(tool.validate_method("DELETE").is_ok());
|
||||||
|
assert!(tool.validate_method("PATCH").is_ok());
|
||||||
|
assert!(tool.validate_method("INVALID").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_blocks_loopback() {
|
||||||
|
assert!(is_private_host("127.0.0.1"));
|
||||||
|
assert!(is_private_host("localhost"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_blocks_private_ranges() {
|
||||||
|
assert!(is_private_host("10.0.0.1"));
|
||||||
|
assert!(is_private_host("172.16.0.1"));
|
||||||
|
assert!(is_private_host("192.168.1.1"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_blocks_local_tld() {
|
||||||
|
assert!(is_private_host("service.local"));
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,7 +1,21 @@
|
|||||||
|
pub mod bash;
|
||||||
pub mod calculator;
|
pub mod calculator;
|
||||||
|
pub mod file_edit;
|
||||||
|
pub mod file_read;
|
||||||
|
pub mod file_write;
|
||||||
|
pub mod http_request;
|
||||||
pub mod registry;
|
pub mod registry;
|
||||||
|
pub mod schema;
|
||||||
pub mod traits;
|
pub mod traits;
|
||||||
|
pub mod web_fetch;
|
||||||
|
|
||||||
|
pub use bash::BashTool;
|
||||||
pub use calculator::CalculatorTool;
|
pub use calculator::CalculatorTool;
|
||||||
|
pub use file_edit::FileEditTool;
|
||||||
|
pub use file_read::FileReadTool;
|
||||||
|
pub use file_write::FileWriteTool;
|
||||||
|
pub use http_request::HttpRequestTool;
|
||||||
pub use registry::ToolRegistry;
|
pub use registry::ToolRegistry;
|
||||||
|
pub use schema::{CleaningStrategy, SchemaCleanr};
|
||||||
pub use traits::{Tool, ToolResult};
|
pub use traits::{Tool, ToolResult};
|
||||||
|
pub use web_fetch::WebFetchTool;
|
||||||
|
|||||||
721
src/tools/schema.rs
Normal file
721
src/tools/schema.rs
Normal file
@ -0,0 +1,721 @@
|
|||||||
|
//! JSON Schema cleaning and validation for LLM tool-calling compatibility.
|
||||||
|
//!
|
||||||
|
//! Different providers support different subsets of JSON Schema. This module
|
||||||
|
//! normalizes tool schemas to improve cross-provider compatibility while
|
||||||
|
//! preserving semantic intent.
|
||||||
|
//!
|
||||||
|
//! ## What this module does
|
||||||
|
//!
|
||||||
|
//! 1. Removes unsupported keywords per provider strategy
|
||||||
|
//! 2. Resolves local `$ref` entries from `$defs` and `definitions`
|
||||||
|
//! 3. Flattens literal `anyOf` / `oneOf` unions into `enum`
|
||||||
|
//! 4. Strips nullable variants from unions and `type` arrays
|
||||||
|
//! 5. Converts `const` to single-value `enum`
|
||||||
|
//! 6. Detects circular references and stops recursion safely
|
||||||
|
|
||||||
|
use serde_json::{Map, Value, json};
|
||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
|
|
||||||
|
/// Keywords that Gemini rejects for tool schemas.
|
||||||
|
pub const GEMINI_UNSUPPORTED_KEYWORDS: &[&str] = &[
|
||||||
|
// Schema composition
|
||||||
|
"$ref",
|
||||||
|
"$schema",
|
||||||
|
"$id",
|
||||||
|
"$defs",
|
||||||
|
"definitions",
|
||||||
|
// Property constraints
|
||||||
|
"additionalProperties",
|
||||||
|
"patternProperties",
|
||||||
|
// String constraints
|
||||||
|
"minLength",
|
||||||
|
"maxLength",
|
||||||
|
"pattern",
|
||||||
|
"format",
|
||||||
|
// Number constraints
|
||||||
|
"minimum",
|
||||||
|
"maximum",
|
||||||
|
"multipleOf",
|
||||||
|
// Array constraints
|
||||||
|
"minItems",
|
||||||
|
"maxItems",
|
||||||
|
"uniqueItems",
|
||||||
|
// Object constraints
|
||||||
|
"minProperties",
|
||||||
|
"maxProperties",
|
||||||
|
// Non-standard
|
||||||
|
"examples",
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Keywords that should be preserved during cleaning (metadata).
|
||||||
|
const SCHEMA_META_KEYWORDS: &[&str] = &["description", "title", "default"];
|
||||||
|
|
||||||
|
/// Schema cleaning strategies for different LLM providers.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum CleaningStrategy {
|
||||||
|
/// Gemini (Google AI / Vertex AI) - Most restrictive
|
||||||
|
Gemini,
|
||||||
|
/// Anthropic Claude - Moderately permissive
|
||||||
|
Anthropic,
|
||||||
|
/// OpenAI GPT - Most permissive
|
||||||
|
OpenAI,
|
||||||
|
/// Conservative: Remove only universally unsupported keywords
|
||||||
|
Conservative,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CleaningStrategy {
|
||||||
|
/// Get the list of unsupported keywords for this strategy.
|
||||||
|
pub fn unsupported_keywords(self) -> &'static [&'static str] {
|
||||||
|
match self {
|
||||||
|
Self::Gemini => GEMINI_UNSUPPORTED_KEYWORDS,
|
||||||
|
Self::Anthropic => &["$ref", "$defs", "definitions"],
|
||||||
|
Self::OpenAI => &[],
|
||||||
|
Self::Conservative => &["$ref", "$defs", "definitions", "additionalProperties"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// JSON Schema cleaner optimized for LLM tool calling.
|
||||||
|
pub struct SchemaCleanr;
|
||||||
|
|
||||||
|
impl SchemaCleanr {
|
||||||
|
/// Clean schema for Gemini compatibility (strictest).
|
||||||
|
pub fn clean_for_gemini(schema: Value) -> Value {
|
||||||
|
Self::clean(schema, CleaningStrategy::Gemini)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clean schema for Anthropic compatibility.
|
||||||
|
pub fn clean_for_anthropic(schema: Value) -> Value {
|
||||||
|
Self::clean(schema, CleaningStrategy::Anthropic)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clean schema for OpenAI compatibility (most permissive).
|
||||||
|
pub fn clean_for_openai(schema: Value) -> Value {
|
||||||
|
Self::clean(schema, CleaningStrategy::OpenAI)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clean schema with specified strategy.
|
||||||
|
pub fn clean(schema: Value, strategy: CleaningStrategy) -> Value {
|
||||||
|
let defs = if let Some(obj) = schema.as_object() {
|
||||||
|
Self::extract_defs(obj)
|
||||||
|
} else {
|
||||||
|
HashMap::new()
|
||||||
|
};
|
||||||
|
Self::clean_with_defs(schema, &defs, strategy, &mut HashSet::new())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate that a schema is suitable for LLM tool calling.
|
||||||
|
pub fn validate(schema: &Value) -> anyhow::Result<()> {
|
||||||
|
let obj = schema
|
||||||
|
.as_object()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Schema must be an object"))?;
|
||||||
|
|
||||||
|
if !obj.contains_key("type") {
|
||||||
|
anyhow::bail!("Schema missing required 'type' field");
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(Value::String(t)) = obj.get("type") {
|
||||||
|
if t == "object" && !obj.contains_key("properties") {
|
||||||
|
tracing::warn!("Object schema without 'properties' field may cause issues");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract $defs and definitions into a flat map for reference resolution.
|
||||||
|
fn extract_defs(obj: &Map<String, Value>) -> HashMap<String, Value> {
|
||||||
|
let mut defs = HashMap::new();
|
||||||
|
|
||||||
|
if let Some(Value::Object(defs_obj)) = obj.get("$defs") {
|
||||||
|
for (key, value) in defs_obj {
|
||||||
|
defs.insert(key.clone(), value.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(Value::Object(defs_obj)) = obj.get("definitions") {
|
||||||
|
for (key, value) in defs_obj {
|
||||||
|
defs.insert(key.clone(), value.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
defs
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Recursively clean a schema value.
|
||||||
|
fn clean_with_defs(
|
||||||
|
schema: Value,
|
||||||
|
defs: &HashMap<String, Value>,
|
||||||
|
strategy: CleaningStrategy,
|
||||||
|
ref_stack: &mut HashSet<String>,
|
||||||
|
) -> Value {
|
||||||
|
match schema {
|
||||||
|
Value::Object(obj) => Self::clean_object(obj, defs, strategy, ref_stack),
|
||||||
|
Value::Array(arr) => Value::Array(
|
||||||
|
arr.into_iter()
|
||||||
|
.map(|v| Self::clean_with_defs(v, defs, strategy, ref_stack))
|
||||||
|
.collect(),
|
||||||
|
),
|
||||||
|
other => other,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clean an object schema.
|
||||||
|
fn clean_object(
|
||||||
|
obj: Map<String, Value>,
|
||||||
|
defs: &HashMap<String, Value>,
|
||||||
|
strategy: CleaningStrategy,
|
||||||
|
ref_stack: &mut HashSet<String>,
|
||||||
|
) -> Value {
|
||||||
|
// Handle $ref resolution
|
||||||
|
if let Some(Value::String(ref_value)) = obj.get("$ref") {
|
||||||
|
return Self::resolve_ref(ref_value, &obj, defs, strategy, ref_stack);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle anyOf/oneOf simplification
|
||||||
|
if obj.contains_key("anyOf") || obj.contains_key("oneOf") {
|
||||||
|
if let Some(simplified) = Self::try_simplify_union(&obj, defs, strategy, ref_stack) {
|
||||||
|
return simplified;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build cleaned object
|
||||||
|
let mut cleaned = Map::new();
|
||||||
|
let unsupported: HashSet<&str> = strategy.unsupported_keywords().iter().copied().collect();
|
||||||
|
let has_union = obj.contains_key("anyOf") || obj.contains_key("oneOf");
|
||||||
|
|
||||||
|
for (key, value) in obj {
|
||||||
|
// Skip unsupported keywords
|
||||||
|
if unsupported.contains(key.as_str()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
match key.as_str() {
|
||||||
|
// Convert const to enum
|
||||||
|
"const" => {
|
||||||
|
cleaned.insert("enum".to_string(), json!([value]));
|
||||||
|
}
|
||||||
|
// Skip type if we have anyOf/oneOf
|
||||||
|
"type" if has_union => {}
|
||||||
|
// Handle type arrays (remove null)
|
||||||
|
"type" if matches!(value, Value::Array(_)) => {
|
||||||
|
let cleaned_value = Self::clean_type_array(value);
|
||||||
|
cleaned.insert(key, cleaned_value);
|
||||||
|
}
|
||||||
|
// Recursively clean nested schemas
|
||||||
|
"properties" => {
|
||||||
|
let cleaned_value = Self::clean_properties(value, defs, strategy, ref_stack);
|
||||||
|
cleaned.insert(key, cleaned_value);
|
||||||
|
}
|
||||||
|
"items" => {
|
||||||
|
let cleaned_value = Self::clean_with_defs(value, defs, strategy, ref_stack);
|
||||||
|
cleaned.insert(key, cleaned_value);
|
||||||
|
}
|
||||||
|
"anyOf" | "oneOf" | "allOf" => {
|
||||||
|
let cleaned_value = Self::clean_union(value, defs, strategy, ref_stack);
|
||||||
|
cleaned.insert(key, cleaned_value);
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
let cleaned_value = match value {
|
||||||
|
Value::Object(_) | Value::Array(_) => {
|
||||||
|
Self::clean_with_defs(value, defs, strategy, ref_stack)
|
||||||
|
}
|
||||||
|
other => other,
|
||||||
|
};
|
||||||
|
cleaned.insert(key, cleaned_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Value::Object(cleaned)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resolve a $ref to its definition.
|
||||||
|
fn resolve_ref(
|
||||||
|
ref_value: &str,
|
||||||
|
obj: &Map<String, Value>,
|
||||||
|
defs: &HashMap<String, Value>,
|
||||||
|
strategy: CleaningStrategy,
|
||||||
|
ref_stack: &mut HashSet<String>,
|
||||||
|
) -> Value {
|
||||||
|
// Prevent circular references
|
||||||
|
if ref_stack.contains(ref_value) {
|
||||||
|
tracing::warn!("Circular $ref detected: {}", ref_value);
|
||||||
|
return Self::preserve_meta(obj, Value::Object(Map::new()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(def_name) = Self::parse_local_ref(ref_value) {
|
||||||
|
if let Some(definition) = defs.get(def_name.as_str()) {
|
||||||
|
ref_stack.insert(ref_value.to_string());
|
||||||
|
let cleaned = Self::clean_with_defs(definition.clone(), defs, strategy, ref_stack);
|
||||||
|
ref_stack.remove(ref_value);
|
||||||
|
return Self::preserve_meta(obj, cleaned);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::warn!("Cannot resolve $ref: {}", ref_value);
|
||||||
|
Self::preserve_meta(obj, Value::Object(Map::new()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse a local JSON Pointer ref (#/$defs/Name or #/definitions/Name).
|
||||||
|
fn parse_local_ref(ref_value: &str) -> Option<String> {
|
||||||
|
ref_value
|
||||||
|
.strip_prefix("#/$defs/")
|
||||||
|
.or_else(|| ref_value.strip_prefix("#/definitions/"))
|
||||||
|
.map(Self::decode_json_pointer)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decode JSON Pointer escaping (`~0` = `~`, `~1` = `/`).
|
||||||
|
fn decode_json_pointer(segment: &str) -> String {
|
||||||
|
if !segment.contains('~') {
|
||||||
|
return segment.to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut decoded = String::with_capacity(segment.len());
|
||||||
|
let mut chars = segment.chars().peekable();
|
||||||
|
|
||||||
|
while let Some(ch) = chars.next() {
|
||||||
|
if ch == '~' {
|
||||||
|
match chars.peek().copied() {
|
||||||
|
Some('0') => {
|
||||||
|
chars.next();
|
||||||
|
decoded.push('~');
|
||||||
|
}
|
||||||
|
Some('1') => {
|
||||||
|
chars.next();
|
||||||
|
decoded.push('/');
|
||||||
|
}
|
||||||
|
_ => decoded.push('~'),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
decoded.push(ch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
decoded
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to simplify anyOf/oneOf to a simpler form.
|
||||||
|
fn try_simplify_union(
|
||||||
|
obj: &Map<String, Value>,
|
||||||
|
defs: &HashMap<String, Value>,
|
||||||
|
strategy: CleaningStrategy,
|
||||||
|
ref_stack: &mut HashSet<String>,
|
||||||
|
) -> Option<Value> {
|
||||||
|
let union_key = if obj.contains_key("anyOf") {
|
||||||
|
"anyOf"
|
||||||
|
} else if obj.contains_key("oneOf") {
|
||||||
|
"oneOf"
|
||||||
|
} else {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
|
||||||
|
let variants = obj.get(union_key)?.as_array()?;
|
||||||
|
|
||||||
|
let cleaned_variants: Vec<Value> = variants
|
||||||
|
.iter()
|
||||||
|
.map(|v| Self::clean_with_defs(v.clone(), defs, strategy, ref_stack))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Strip null variants
|
||||||
|
let non_null: Vec<Value> = cleaned_variants
|
||||||
|
.into_iter()
|
||||||
|
.filter(|v| !Self::is_null_schema(v))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// If only one variant remains after stripping nulls, return it
|
||||||
|
if non_null.len() == 1 {
|
||||||
|
return Some(Self::preserve_meta(obj, non_null[0].clone()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to flatten to enum if all variants are literals
|
||||||
|
if let Some(enum_value) = Self::try_flatten_literal_union(&non_null) {
|
||||||
|
return Some(Self::preserve_meta(obj, enum_value));
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if a schema represents null type.
|
||||||
|
fn is_null_schema(value: &Value) -> bool {
|
||||||
|
if let Some(obj) = value.as_object() {
|
||||||
|
if let Some(Value::Null) = obj.get("const") {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if let Some(Value::Array(arr)) = obj.get("enum") {
|
||||||
|
if arr.len() == 1 && matches!(arr[0], Value::Null) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(Value::String(t)) = obj.get("type") {
|
||||||
|
if t == "null" {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to flatten anyOf/oneOf with only literal values to enum.
|
||||||
|
fn try_flatten_literal_union(variants: &[Value]) -> Option<Value> {
|
||||||
|
if variants.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut all_values = Vec::new();
|
||||||
|
let mut common_type: Option<String> = None;
|
||||||
|
|
||||||
|
for variant in variants {
|
||||||
|
let obj = variant.as_object()?;
|
||||||
|
|
||||||
|
let literal_value = if let Some(const_val) = obj.get("const") {
|
||||||
|
const_val.clone()
|
||||||
|
} else if let Some(Value::Array(arr)) = obj.get("enum") {
|
||||||
|
if arr.len() == 1 {
|
||||||
|
arr[0].clone()
|
||||||
|
} else {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
|
||||||
|
let variant_type = obj.get("type")?.as_str()?;
|
||||||
|
match &common_type {
|
||||||
|
None => common_type = Some(variant_type.to_string()),
|
||||||
|
Some(t) if t != variant_type => return None,
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
all_values.push(literal_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
common_type.map(|t| {
|
||||||
|
json!({
|
||||||
|
"type": t,
|
||||||
|
"enum": all_values
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clean type array, removing null.
|
||||||
|
fn clean_type_array(value: Value) -> Value {
|
||||||
|
if let Value::Array(types) = value {
|
||||||
|
let non_null: Vec<Value> = types
|
||||||
|
.into_iter()
|
||||||
|
.filter(|v| v.as_str() != Some("null"))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
match non_null.len() {
|
||||||
|
0 => Value::String("null".to_string()),
|
||||||
|
1 => non_null.into_iter().next().unwrap_or(Value::String("null".to_string())),
|
||||||
|
_ => Value::Array(non_null),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clean properties object.
|
||||||
|
fn clean_properties(
|
||||||
|
value: Value,
|
||||||
|
defs: &HashMap<String, Value>,
|
||||||
|
strategy: CleaningStrategy,
|
||||||
|
ref_stack: &mut HashSet<String>,
|
||||||
|
) -> Value {
|
||||||
|
if let Value::Object(props) = value {
|
||||||
|
let cleaned: Map<String, Value> = props
|
||||||
|
.into_iter()
|
||||||
|
.map(|(k, v)| (k, Self::clean_with_defs(v, defs, strategy, ref_stack)))
|
||||||
|
.collect();
|
||||||
|
Value::Object(cleaned)
|
||||||
|
} else {
|
||||||
|
value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clean union (anyOf/oneOf/allOf).
|
||||||
|
fn clean_union(
|
||||||
|
value: Value,
|
||||||
|
defs: &HashMap<String, Value>,
|
||||||
|
strategy: CleaningStrategy,
|
||||||
|
ref_stack: &mut HashSet<String>,
|
||||||
|
) -> Value {
|
||||||
|
if let Value::Array(variants) = value {
|
||||||
|
let cleaned: Vec<Value> = variants
|
||||||
|
.into_iter()
|
||||||
|
.map(|v| Self::clean_with_defs(v, defs, strategy, ref_stack))
|
||||||
|
.collect();
|
||||||
|
Value::Array(cleaned)
|
||||||
|
} else {
|
||||||
|
value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Preserve metadata (description, title, default) from source to target.
|
||||||
|
fn preserve_meta(source: &Map<String, Value>, mut target: Value) -> Value {
|
||||||
|
if let Value::Object(target_obj) = &mut target {
|
||||||
|
for &key in SCHEMA_META_KEYWORDS {
|
||||||
|
if let Some(value) = source.get(key) {
|
||||||
|
target_obj.insert(key.to_string(), value.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
target
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_remove_unsupported_keywords() {
|
||||||
|
let schema = json!({
|
||||||
|
"type": "string",
|
||||||
|
"minLength": 1,
|
||||||
|
"maxLength": 100,
|
||||||
|
"pattern": "^[a-z]+$",
|
||||||
|
"description": "A lowercase string"
|
||||||
|
});
|
||||||
|
|
||||||
|
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||||
|
|
||||||
|
assert_eq!(cleaned["type"], "string");
|
||||||
|
assert_eq!(cleaned["description"], "A lowercase string");
|
||||||
|
assert!(cleaned.get("minLength").is_none());
|
||||||
|
assert!(cleaned.get("maxLength").is_none());
|
||||||
|
assert!(cleaned.get("pattern").is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_resolve_ref() {
|
||||||
|
let schema = json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"age": {
|
||||||
|
"$ref": "#/$defs/Age"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"$defs": {
|
||||||
|
"Age": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||||
|
|
||||||
|
assert_eq!(cleaned["properties"]["age"]["type"], "integer");
|
||||||
|
assert!(cleaned["properties"]["age"].get("minimum").is_none());
|
||||||
|
assert!(cleaned.get("$defs").is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_flatten_literal_union() {
|
||||||
|
let schema = json!({
|
||||||
|
"anyOf": [
|
||||||
|
{ "const": "admin", "type": "string" },
|
||||||
|
{ "const": "user", "type": "string" },
|
||||||
|
{ "const": "guest", "type": "string" }
|
||||||
|
]
|
||||||
|
});
|
||||||
|
|
||||||
|
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||||
|
|
||||||
|
assert_eq!(cleaned["type"], "string");
|
||||||
|
assert!(cleaned["enum"].is_array());
|
||||||
|
let enum_values = cleaned["enum"].as_array().unwrap();
|
||||||
|
assert_eq!(enum_values.len(), 3);
|
||||||
|
assert!(enum_values.contains(&json!("admin")));
|
||||||
|
assert!(enum_values.contains(&json!("user")));
|
||||||
|
assert!(enum_values.contains(&json!("guest")));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_strip_null_from_union() {
|
||||||
|
let schema = json!({
|
||||||
|
"oneOf": [
|
||||||
|
{ "type": "string" },
|
||||||
|
{ "type": "null" }
|
||||||
|
]
|
||||||
|
});
|
||||||
|
|
||||||
|
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||||
|
|
||||||
|
assert_eq!(cleaned["type"], "string");
|
||||||
|
assert!(cleaned.get("oneOf").is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_const_to_enum() {
|
||||||
|
let schema = json!({
|
||||||
|
"const": "fixed_value",
|
||||||
|
"description": "A constant"
|
||||||
|
});
|
||||||
|
|
||||||
|
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||||
|
|
||||||
|
assert_eq!(cleaned["enum"], json!(["fixed_value"]));
|
||||||
|
assert_eq!(cleaned["description"], "A constant");
|
||||||
|
assert!(cleaned.get("const").is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_preserve_metadata() {
|
||||||
|
let schema = json!({
|
||||||
|
"$ref": "#/$defs/Name",
|
||||||
|
"description": "User's name",
|
||||||
|
"title": "Name Field",
|
||||||
|
"default": "Anonymous",
|
||||||
|
"$defs": {
|
||||||
|
"Name": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||||
|
|
||||||
|
assert_eq!(cleaned["type"], "string");
|
||||||
|
assert_eq!(cleaned["description"], "User's name");
|
||||||
|
assert_eq!(cleaned["title"], "Name Field");
|
||||||
|
assert_eq!(cleaned["default"], "Anonymous");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_circular_ref_prevention() {
|
||||||
|
let schema = json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"parent": {
|
||||||
|
"$ref": "#/$defs/Node"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"$defs": {
|
||||||
|
"Node": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"child": {
|
||||||
|
"$ref": "#/$defs/Node"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||||
|
|
||||||
|
assert_eq!(cleaned["properties"]["parent"]["type"], "object");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_schema() {
|
||||||
|
let valid = json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": { "type": "string" }
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
assert!(SchemaCleanr::validate(&valid).is_ok());
|
||||||
|
|
||||||
|
let invalid = json!({
|
||||||
|
"properties": {
|
||||||
|
"name": { "type": "string" }
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
assert!(SchemaCleanr::validate(&invalid).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_strategy_differences() {
|
||||||
|
let schema = json!({
|
||||||
|
"type": "string",
|
||||||
|
"minLength": 1,
|
||||||
|
"description": "A string field"
|
||||||
|
});
|
||||||
|
|
||||||
|
// Gemini: Most restrictive (removes minLength)
|
||||||
|
let gemini = SchemaCleanr::clean_for_gemini(schema.clone());
|
||||||
|
assert!(gemini.get("minLength").is_none());
|
||||||
|
assert_eq!(gemini["type"], "string");
|
||||||
|
assert_eq!(gemini["description"], "A string field");
|
||||||
|
|
||||||
|
// OpenAI: Most permissive (keeps minLength)
|
||||||
|
let openai = SchemaCleanr::clean_for_openai(schema.clone());
|
||||||
|
assert_eq!(openai["minLength"], 1);
|
||||||
|
assert_eq!(openai["type"], "string");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_nested_properties() {
|
||||||
|
let schema = json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"user": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"type": "string",
|
||||||
|
"minLength": 1
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
cleaned["properties"]["user"]["properties"]["name"]
|
||||||
|
.get("minLength")
|
||||||
|
.is_none()
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
cleaned["properties"]["user"]
|
||||||
|
.get("additionalProperties")
|
||||||
|
.is_none()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_type_array_null_removal() {
|
||||||
|
let schema = json!({
|
||||||
|
"type": ["string", "null"]
|
||||||
|
});
|
||||||
|
|
||||||
|
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||||
|
|
||||||
|
assert_eq!(cleaned["type"], "string");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_skip_type_when_non_simplifiable_union_exists() {
|
||||||
|
let schema = json!({
|
||||||
|
"type": "object",
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"a": { "type": "string" }
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"b": { "type": "number" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
});
|
||||||
|
|
||||||
|
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||||
|
|
||||||
|
assert!(cleaned.get("type").is_none());
|
||||||
|
assert!(cleaned.get("oneOf").is_some());
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -13,4 +13,19 @@ pub trait Tool: Send + Sync + 'static {
|
|||||||
fn description(&self) -> &str;
|
fn description(&self) -> &str;
|
||||||
fn parameters_schema(&self) -> serde_json::Value;
|
fn parameters_schema(&self) -> serde_json::Value;
|
||||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult>;
|
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult>;
|
||||||
|
|
||||||
|
/// Whether this tool is side-effect free and safe to parallelize.
|
||||||
|
fn read_only(&self) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Whether this tool can run alongside other concurrency-safe tools.
|
||||||
|
fn concurrency_safe(&self) -> bool {
|
||||||
|
self.read_only() && !self.exclusive()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Whether this tool should run alone even if concurrency is enabled.
|
||||||
|
fn exclusive(&self) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
411
src/tools/web_fetch.rs
Normal file
411
src/tools/web_fetch.rs
Normal file
@ -0,0 +1,411 @@
|
|||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use reqwest::header::HeaderMap;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
use crate::tools::traits::{Tool, ToolResult};
|
||||||
|
|
||||||
|
pub struct WebFetchTool {
|
||||||
|
max_response_size: usize,
|
||||||
|
timeout_secs: u64,
|
||||||
|
user_agent: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WebFetchTool {
|
||||||
|
pub fn new(max_response_size: usize, timeout_secs: u64) -> Self {
|
||||||
|
Self {
|
||||||
|
max_response_size,
|
||||||
|
timeout_secs,
|
||||||
|
user_agent: "Mozilla/5.0 (compatible; Picobot/1.0)".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate_url(&self, url: &str) -> Result<String, String> {
|
||||||
|
let url = url.trim();
|
||||||
|
|
||||||
|
if url.is_empty() {
|
||||||
|
return Err("URL cannot be empty".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
if url.chars().any(char::is_whitespace) {
|
||||||
|
return Err("URL cannot contain whitespace".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
if !url.starts_with("http://") && !url.starts_with("https://") {
|
||||||
|
return Err("Only http:// and https:// URLs are allowed".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
let host = extract_host(url)?;
|
||||||
|
|
||||||
|
if is_private_host(&host) {
|
||||||
|
return Err(format!("Blocked local/private host: {}", host));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(url.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn truncate_response(&self, text: &str) -> String {
|
||||||
|
if self.max_response_size == 0 {
|
||||||
|
return text.to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
if text.len() > self.max_response_size {
|
||||||
|
format!(
|
||||||
|
"{}\n\n... [Response truncated due to size limit] ...",
|
||||||
|
&text[..self.max_response_size]
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
text.to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn fetch_content(&self, url: &str) -> Result<String, String> {
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.timeout(Duration::from_secs(self.timeout_secs))
|
||||||
|
.build()
|
||||||
|
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||||
|
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert(
|
||||||
|
reqwest::header::USER_AGENT,
|
||||||
|
self.user_agent.parse().unwrap(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let response = client
|
||||||
|
.get(url)
|
||||||
|
.headers(headers)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Request failed: {}", e))?;
|
||||||
|
|
||||||
|
let content_type = response
|
||||||
|
.headers()
|
||||||
|
.get(reqwest::header::CONTENT_TYPE)
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.unwrap_or("");
|
||||||
|
|
||||||
|
// Handle HTML content
|
||||||
|
if content_type.contains("text/html") {
|
||||||
|
let html = response
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to read response: {}", e))?;
|
||||||
|
return Ok(self.extract_text_from_html(&html));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle JSON content
|
||||||
|
if content_type.contains("application/json") {
|
||||||
|
let text = response
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to read response: {}", e))?;
|
||||||
|
// Pretty print JSON
|
||||||
|
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&text) {
|
||||||
|
return Ok(serde_json::to_string_pretty(&parsed).unwrap_or(text));
|
||||||
|
}
|
||||||
|
return Ok(text);
|
||||||
|
}
|
||||||
|
|
||||||
|
// For other content types, return raw text
|
||||||
|
response
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to read response: {}", e))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_text_from_html(&self, html: &str) -> String {
|
||||||
|
let mut text = html.to_string();
|
||||||
|
|
||||||
|
// Remove script and style tags with content using simple replacements
|
||||||
|
text = strip_tag(&text, "script");
|
||||||
|
text = strip_tag(&text, "style");
|
||||||
|
|
||||||
|
// Remove all HTML tags
|
||||||
|
text = strip_all_tags(&text);
|
||||||
|
|
||||||
|
// Decode HTML entities
|
||||||
|
text = self.decode_html_entities(&text);
|
||||||
|
|
||||||
|
// Clean up whitespace
|
||||||
|
let mut cleaned = String::new();
|
||||||
|
let mut last_was_space = true;
|
||||||
|
for c in text.chars() {
|
||||||
|
if c.is_whitespace() {
|
||||||
|
if !last_was_space {
|
||||||
|
cleaned.push(' ');
|
||||||
|
last_was_space = true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cleaned.push(c);
|
||||||
|
last_was_space = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cleaned.trim().to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn decode_html_entities(&self, text: &str) -> String {
|
||||||
|
let entities = [
|
||||||
|
(" ", " "),
|
||||||
|
("<", "<"),
|
||||||
|
(">", ">"),
|
||||||
|
("&", "&"),
|
||||||
|
(""", "\""),
|
||||||
|
("'", "'"),
|
||||||
|
("—", "—"),
|
||||||
|
("–", "–"),
|
||||||
|
("©", "©"),
|
||||||
|
("®", "®"),
|
||||||
|
("™", "™"),
|
||||||
|
];
|
||||||
|
|
||||||
|
let mut result = text.to_string();
|
||||||
|
for (entity, replacement) in entities {
|
||||||
|
result = result.replace(entity, replacement);
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn strip_tag(s: &str, tag_name: &str) -> String {
|
||||||
|
let open = format!("<{}>", tag_name);
|
||||||
|
let close = format!("</{}>", tag_name);
|
||||||
|
let mut result = s.to_string();
|
||||||
|
|
||||||
|
// Keep removing until no more found (simple approach)
|
||||||
|
while let Some(start) = result.to_lowercase().find(&open) {
|
||||||
|
if let Some(end) = result.to_lowercase()[start..].find(&close) {
|
||||||
|
let end_pos = start + end + close.len();
|
||||||
|
result = format!("{}{}", &result[..start], &result[end_pos..]);
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
fn strip_all_tags(s: &str) -> String {
|
||||||
|
let mut result = String::new();
|
||||||
|
let mut in_tag = false;
|
||||||
|
|
||||||
|
for c in s.chars() {
|
||||||
|
if c == '<' {
|
||||||
|
in_tag = true;
|
||||||
|
} else if c == '>' {
|
||||||
|
in_tag = false;
|
||||||
|
result.push(' ');
|
||||||
|
} else if !in_tag {
|
||||||
|
result.push(c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_html_entity(s: &str) -> Option<(char, usize)> {
|
||||||
|
let s_lower = s.to_lowercase();
|
||||||
|
|
||||||
|
let entities = [
|
||||||
|
(" ", ' '),
|
||||||
|
("<", '<'),
|
||||||
|
(">", '>'),
|
||||||
|
("&", '&'),
|
||||||
|
(""", '"'),
|
||||||
|
("'", '\''),
|
||||||
|
("—", '—'),
|
||||||
|
("–", '–'),
|
||||||
|
("©", '©'),
|
||||||
|
("®", '®'),
|
||||||
|
("™", '™'),
|
||||||
|
];
|
||||||
|
|
||||||
|
for (entity, replacement) in entities {
|
||||||
|
if s_lower.starts_with(&entity.to_lowercase()) {
|
||||||
|
return Some((replacement, entity.len()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle numeric entities
|
||||||
|
if s_lower.starts_with("&#x") || s_lower.starts_with("&#") {
|
||||||
|
// Skip for now
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_host(url: &str) -> Result<String, String> {
|
||||||
|
let rest = url
|
||||||
|
.strip_prefix("http://")
|
||||||
|
.or_else(|| url.strip_prefix("https://"))
|
||||||
|
.ok_or_else(|| "Only http:// and https:// URLs are allowed".to_string())?;
|
||||||
|
|
||||||
|
let authority = rest
|
||||||
|
.split(['/', '?', '#'])
|
||||||
|
.next()
|
||||||
|
.ok_or_else(|| "Invalid URL".to_string())?;
|
||||||
|
|
||||||
|
if authority.is_empty() {
|
||||||
|
return Err("URL must include a host".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
let host = authority
|
||||||
|
.split(':')
|
||||||
|
.next()
|
||||||
|
.unwrap_or_default()
|
||||||
|
.trim()
|
||||||
|
.to_lowercase();
|
||||||
|
|
||||||
|
if host.is_empty() {
|
||||||
|
return Err("URL must include a valid host".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(host)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_private_host(host: &str) -> bool {
|
||||||
|
if host == "localhost" || host.ends_with(".localhost") {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if host.rsplit('.').next().is_some_and(|label| label == "local") {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
|
||||||
|
return match ip {
|
||||||
|
std::net::IpAddr::V4(v4) => {
|
||||||
|
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified()
|
||||||
|
}
|
||||||
|
std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for WebFetchTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"web_fetch"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Fetch a URL and extract readable text content. Supports HTML and JSON. Automatically extracts plain text from HTML, removes scripts and styles."
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"url": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "URL to fetch"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["url"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read_only(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||||
|
let url = match args.get("url").and_then(|v| v.as_str()) {
|
||||||
|
Some(u) => u,
|
||||||
|
None => {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some("Missing required parameter: url".to_string()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let url = match self.validate_url(url) {
|
||||||
|
Ok(u) => u,
|
||||||
|
Err(e) => {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(e),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
match self.fetch_content(&url).await {
|
||||||
|
Ok(content) => Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: self.truncate_response(&content),
|
||||||
|
error: None,
|
||||||
|
}),
|
||||||
|
Err(e) => Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(e),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn test_tool() -> WebFetchTool {
|
||||||
|
WebFetchTool::new(50_000, 30)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_validate_url_success() {
|
||||||
|
let tool = test_tool();
|
||||||
|
let result = tool.validate_url("https://example.com");
|
||||||
|
assert!(result.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_validate_url_rejects_private() {
|
||||||
|
let tool = test_tool();
|
||||||
|
let result = tool.validate_url("https://localhost:8080");
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result.unwrap_err().contains("local/private"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_validate_url_rejects_whitespace() {
|
||||||
|
let tool = test_tool();
|
||||||
|
let result = tool.validate_url("https://example.com/hello world");
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result.unwrap_err().contains("whitespace"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_extract_text_simple() {
|
||||||
|
let tool = test_tool();
|
||||||
|
let html = "<html><body><p>Hello World</p></body></html>";
|
||||||
|
let text = tool.extract_text_from_html(html);
|
||||||
|
assert!(text.contains("Hello World"));
|
||||||
|
assert!(!text.contains("<"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_extract_text_removes_scripts() {
|
||||||
|
let tool = test_tool();
|
||||||
|
let html = "<html><body><script>alert('bad');</script><p>Good</p></body></html>";
|
||||||
|
let text = tool.extract_text_from_html(html);
|
||||||
|
assert!(text.contains("Good"));
|
||||||
|
assert!(!text.contains("alert"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_extract_text_removes_styles() {
|
||||||
|
let tool = test_tool();
|
||||||
|
let html = "<html><head><style>.x { color: red; }</style></head><body><p>Content</p></body></html>";
|
||||||
|
let text = tool.extract_text_from_html(html);
|
||||||
|
assert!(text.contains("Content"));
|
||||||
|
assert!(!text.contains("color"));
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user