Compare commits

...

15 Commits

Author SHA1 Message Date
0c0d0c1443 feat(agent): add parallel tool execution with concurrency-safe batching
Implement parallel tool execution in AgentLoop, following the approach
used in Nanobot (_partition_tool_batches) and Zeroclaw (parallel_tools).

Key changes:
- partition_tool_batches(): group tool calls into batches based on
  concurrency_safe flag. Safe tools run in parallel via join_all;
  exclusive tools (e.g. bash) run in their own sequential batch.
- execute_tools(): now uses batching instead of flat sequential loop.
- CalculatorTool: add read_only() -> true so it participates in
  parallel batches (it has no side effects, so concurrency_safe = true).

4 unit tests added covering: mixed safe/exclusive, all-safe single
batch, all-exclusive separate batches, unknown tool defaults.
2026-04-08 12:04:03 +08:00
21b4e60c44 feat(feishu): add reaction handling and metadata forwarding in messages 2026-04-08 10:24:15 +08:00
a4399037ac fix: use char-based slicing instead of byte-based to handle UTF-8
Byte index slicing like `&text[..100.min(text.len())]` panics when the
byte index falls inside a multi-byte UTF-8 character (e.g., Chinese).
Changed to `text.chars().take(100).collect::<String>()` for safe
character-based truncation.
2026-04-08 08:49:52 +08:00
075b92f231 fix: truncate long text content before sending to Feishu
Feishu API rejects messages with content exceeding ~64KB with error
230001 "invalid message content". Added truncation at 60,000 characters
to prevent this, with a notice appended to truncated content.
2026-04-08 08:42:56 +08:00
02a7fa68c6 docs: update implementation log with tools registration
- Add tools registration section in session.rs
- Add update log table with all commits
2026-04-08 08:32:51 +08:00
98bc9739c6 feat(gateway): register all tools in SessionManager
- Register FileReadTool, FileWriteTool, FileEditTool, BashTool
- Register HttpRequestTool with allow-all domains for testing
- Register WebFetchTool
- CalculatorTool already registered
2026-04-08 08:32:06 +08:00
b13bb8c556 docs: add implementation log for tools
Document all implemented tools and mechanisms:
- SchemaCleanr for cross-provider schema normalization
- Tool trait enhancements (read_only, concurrency_safe, exclusive)
- file_read, file_write, file_edit, bash, http_request, web_fetch
2026-04-07 23:52:40 +08:00
8936e70a12 feat(tools): add web_fetch tool for HTML content extraction
- Fetch URL and extract readable text
- HTML to plain text conversion
- Removes scripts, styles, and HTML tags
- Decodes HTML entities
- JSON pretty printing
- SSRF protection
- Includes 6 unit tests
2026-04-07 23:52:06 +08:00
1581732ef9 feat(tools): add http_request tool with security features
- HTTP client with GET/POST/PUT/DELETE/PATCH support
- Domain allowlist for security
- SSRF protection (blocks private IPs, localhost)
- Response size limit and truncation
- Timeout control
- Includes 8 unit tests
2026-04-07 23:49:15 +08:00
68e3663c2f feat(tools): add bash tool with safety guards
- Execute shell commands with timeout
- Safety guards block dangerous commands (rm -rf, fork bombs)
- Output truncation for large outputs
- Working directory support
- Includes 7 unit tests
2026-04-07 23:47:44 +08:00
f3187ceddd feat(tools): add file_edit tool with fuzzy matching
- Edit file by replacing old_text with new_text
- Supports multiline edits
- Fuzzy line-based matching for minor differences
- replace_all option for batch replacement
- Includes 5 unit tests
2026-04-07 23:46:34 +08:00
16b052bd21 feat(tools): add file_write tool with directory creation
- Write content to file, creating parent directories if needed
- Overwrites existing files
- Includes 5 unit tests
2026-04-07 23:44:45 +08:00
a9e7aabed4 feat(tools): add file_read tool with pagination support
- Read file contents with offset/limit pagination
- Returns numbered lines for easy reference
- Handles binary files as base64 encoded
- Supports directory restriction for security
- Includes 4 unit tests
2026-04-07 23:43:47 +08:00
d5b6cd24fc feat(tools): add SchemaCleanr for cross-provider schema normalization
- Add SchemaCleanr with CleaningStrategy enum (Gemini, Anthropic, OpenAI, Conservative)
- Support cleaning JSON schemas for different LLM provider compatibility
- Add $ref resolution, anyOf/oneOf flattening, const-to-enum conversion
- Add read_only, concurrency_safe, exclusive methods to Tool trait
- Add comprehensive unit tests for all schema cleaning features
2026-04-07 23:41:20 +08:00
2dada36bc6 feat: introduce multimodal content handling with media support
- Added ContentBlock enum for multimodal content representation (text, image).
- Enhanced ChatMessage struct to include media references.
- Updated InboundMessage and OutboundMessage to use MediaItem for media handling.
- Implemented media download and upload functionality in FeishuChannel.
- Modified message processing in the gateway to handle media items.
- Improved logging for message processing and media handling in debug mode.
- Refactored message serialization for LLM providers to support content blocks.
2026-04-07 23:09:31 +08:00
26 changed files with 4564 additions and 183 deletions

View File

@ -4,7 +4,7 @@ version = "0.1.0"
edition = "2024"
[dependencies]
reqwest = { version = "0.13.2", default-features = false, features = ["json", "rustls"] }
reqwest = { version = "0.13.2", default-features = false, features = ["json", "rustls", "multipart"] }
dotenv = "0.15"
serde = { version = "1.0", features = ["derive"] }
regex = "1.0"
@ -23,3 +23,6 @@ tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
tracing-appender = "0.2"
anyhow = "1.0"
mime_guess = "2.0"
base64 = "0.22"
tempfile = "3"

346
IMPLEMENTATION_LOG.md Normal file
View File

@ -0,0 +1,346 @@
# Picobot 工具机制增强实现日志
## 实现记录
### 1. SchemaCleanr - 跨 Provider Schema 归一化
**日期**: 2026-04-07
**Commit**: `d5b6cd2`
#### 背景
不同 LLM provider 对 JSON Schema 支持差异很大:
- **Gemini**: 最严格,不支持 `minLength`, `maxLength`, `pattern`, `minimum`, `maximum`
- **Anthropic**: 中等,只要求解决 `$ref`
- **OpenAI**: 最宽松,支持大部分关键词
#### 实现方案
创建 `src/tools/schema.rs`,提供:
1. **`CleaningStrategy` enum**
```rust
pub enum CleaningStrategy {
Gemini, // 最严格
Anthropic, // 中等
OpenAI, // 最宽松
Conservative,
}
```
2. **`SchemaCleanr::clean()`** - 核心清洗函数
- 移除 provider 不支持的关键词
- 解析 `$ref``$defs`/`definitions`
- 将 `anyOf`/`oneOf` 合并为 `enum`
- 将 `const` 转换为 `enum`
- 移除 `type` 数组中的 `null`
3. **`SchemaCleanr::validate()`** - Schema 验证
#### 使用方法
```rust
use picobot::tools::{SchemaCleanr, CleaningStrategy};
// Gemini 兼容清洗(最严格)
let cleaned = SchemaCleanr::clean_for_gemini(schema);
// Anthropic 兼容清洗
let cleaned = SchemaCleanr::clean_for_anthropic(schema);
// OpenAI 兼容清洗(最宽松)
let cleaned = SchemaCleanr::clean_for_openai(schema);
// 自定义策略
let cleaned = SchemaCleanr::clean(schema, CleaningStrategy::Conservative);
```
#### 工具 Trait 增强
`src/tools/traits.rs``Tool` trait 中新增:
```rust
pub trait Tool: Send + Sync + 'static {
// ... 原有方法 ...
/// 是否只读(无副作用)
fn read_only(&self) -> bool { false }
/// 是否可以与其他工具并行执行
fn concurrency_safe(&self) -> bool {
self.read_only() && !self.exclusive()
}
/// 是否需要独占执行
fn exclusive(&self) -> bool { false }
}
```
这些属性为后续的并行工具执行提供基础。
#### 测试
- 12 个单元测试覆盖所有清洗逻辑
- 运行 `cargo test --lib tools::schema` 验证
### 2. file_read 工具
**日期**: 2026-04-07
**Commit**: `a9e7aab`
#### 功能
- 读取文件内容(支持 offset/limit 分页)
- 返回带行号的内容,便于引用
- 自动处理二进制文件base64 编码)
- 可选的目录限制(安全隔离)
#### Schema
```json
{
"type": "object",
"properties": {
"path": { "type": "string", "description": "文件路径" },
"offset": { "type": "integer", "description": "起始行号(1-indexed)" },
"limit": { "type": "integer", "description": "最大行数" }
},
"required": ["path"]
}
```
#### 使用方法
```rust
use picobot::tools::FileReadTool;
// 基本用法
let tool = FileReadTool::new();
let result = tool.execute(json!({
"path": "/some/file.txt",
"offset": 1,
"limit": 100
})).await;
```
#### 测试
- 4 个单元测试
- `cargo test --lib tools::file_read`
### 3. file_write 工具
**日期**: 2026-04-07
**Commit**: `16b052b`
#### 功能
- 写入内容到文件
- 自动创建父目录
- 覆盖已存在文件
#### Schema
```json
{
"type": "object",
"properties": {
"path": { "type": "string", "description": "文件路径" },
"content": { "type": "string", "description": "写入内容" }
},
"required": ["path", "content"]
}
```
#### 测试
- 5 个单元测试
- `cargo test --lib tools::file_write`
### 4. file_edit 工具
**日期**: 2026-04-07
**Commit**: `f3187ce`
#### 功能
- 编辑文件,替换 old_text 为 new_text
- 支持多行编辑
- 模糊匹配处理微小差异
- replace_all 选项批量替换
#### Schema
```json
{
"type": "object",
"properties": {
"path": { "type": "string", "description": "文件路径" },
"old_text": { "type": "string", "description": "要替换的文本" },
"new_text": { "type": "string", "description": "替换后的文本" },
"replace_all": { "type": "boolean", "description": "替换所有匹配", "default": false }
},
"required": ["path", "old_text", "new_text"]
}
```
#### 测试
- 5 个单元测试
- `cargo test --lib tools::file_edit`
### 5. bash 工具
**日期**: 2026-04-07
**Commit**: `68e3663`
#### 功能
- 执行 shell 命令
- 超时控制
- 危险命令检测rm -rf, fork bombs
- 输出截断
- 工作目录支持
#### Schema
```json
{
"type": "object",
"properties": {
"command": { "type": "string", "description": "Shell 命令" },
"timeout": { "type": "integer", "description": "超时秒数", "minimum": 1, "maximum": 600 }
},
"required": ["command"]
}
```
#### 测试
- 7 个单元测试
- `cargo test --lib tools::bash`
### 6. http_request 工具
**日期**: 2026-04-07
**Commit**: `1581732`
#### 功能
- HTTP 客户端支持 GET/POST/PUT/DELETE/PATCH
- 域名白名单
- SSRF 保护阻止私有IP、localhost
- 响应大小限制和截断
- 超时控制
#### Schema
```json
{
"type": "object",
"properties": {
"url": { "type": "string", "description": "请求 URL" },
"method": { "type": "string", "description": "HTTP 方法", "enum": ["GET", "POST", "PUT", "DELETE", "PATCH"] },
"headers": { "type": "object", "description": "请求头" },
"body": { "type": "string", "description": "请求体" }
},
"required": ["url"]
}
```
#### 测试
- 8 个单元测试
- `cargo test --lib tools::http_request`
### 7. web_fetch 工具
**日期**: 2026-04-07
**Commit**: `8936e70`
#### 功能
- 获取 URL 并提取可读文本
- HTML 转纯文本
- 移除 scripts, styles, HTML 标签
- 解码 HTML 实体
- JSON 格式化输出
- SSRF 保护
#### Schema
```json
{
"type": "object",
"properties": {
"url": { "type": "string", "description": "要获取的 URL" }
},
"required": ["url"]
}
```
#### 测试
- 6 个单元测试
- `cargo test --lib tools::web_fetch`
---
## 工具清单
| 工具 | 名称 | 文件 | 功能 |
|------|------|------|------|
| calculator | 计算器 | `src/tools/calculator.rs` | 25+ 数学和统计函数 |
| file_read | 文件读取 | `src/tools/file_read.rs` | 带分页的文件读取 |
| file_write | 文件写入 | `src/tools/file_write.rs` | 创建/覆盖文件 |
| file_edit | 文件编辑 | `src/tools/file_edit.rs` | 文本替换编辑 |
| bash | Shell 执行 | `src/tools/bash.rs` | 带安全保护的命令执行 |
| http_request | HTTP 请求 | `src/tools/http_request.rs` | API 请求 |
| web_fetch | 网页获取 | `src/tools/web_fetch.rs` | HTML 内容提取 |
## 工具机制增强
### SchemaCleanr
跨 LLM Provider 的 JSON Schema 归一化,支持:
- Gemini (最严格)
- Anthropic (中等)
- OpenAI (最宽松)
- Conservative (保守)
### 工具属性
```rust
fn read_only(&self) -> bool { false } // 是否只读
fn concurrency_safe(&self) -> bool { true } // 是否可并行
fn exclusive(&self) -> bool { false } // 是否独占
```
## 运行测试
```bash
cargo test --lib # 所有测试
cargo test --lib tools::schema # SchemaCleanr
```
## 工具注册
工具在 `src/gateway/session.rs``default_tools()` 函数中注册:
```rust
fn default_tools() -> ToolRegistry {
let mut registry = ToolRegistry::new();
registry.register(CalculatorTool::new());
registry.register(FileReadTool::new());
registry.register(FileWriteTool::new());
registry.register(FileEditTool::new());
registry.register(BashTool::new());
registry.register(HttpRequestTool::new(
vec!["*".to_string()], // 允许所有域名
1_000_000, // max_response_size
30, // timeout_secs
false, // allow_private_hosts
));
registry.register(WebFetchTool::new(50_000, 30));
registry
}
```
SessionManager 使用这些工具创建 AgentLoop 实例,所有工具自动对 LLM 可用。
## 更新日志
| 日期 | Commit | 变更 |
|------|--------|------|
| 2026-04-07 | `d5b6cd2` | feat: add SchemaCleanr |
| 2026-04-07 | `a9e7aab` | feat: add file_read tool |
| 2026-04-07 | `16b052b` | feat: add file_write tool |
| 2026-04-07 | `f3187ce` | feat: add file_edit tool |
| 2026-04-07 | `68e3663` | feat: add bash tool |
| 2026-04-07 | `1581732` | feat: add http_request tool |
| 2026-04-07 | `8936e70` | feat: add web_fetch tool |
| 2026-04-07 | `b13bb8c` | docs: add implementation log |
| 2026-04-08 | `98bc973` | feat: register all tools in SessionManager |
cargo test --lib tools::file_read # file_read
cargo test --lib tools::file_write # file_write
cargo test --lib tools::file_edit # file_edit
cargo test --lib tools::bash # bash
cargo test --lib tools::http_request # http_request
cargo test --lib tools::web_fetch # web_fetch
```

View File

@ -1,33 +1,72 @@
use crate::bus::message::ContentBlock;
use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig;
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall};
use crate::tools::ToolRegistry;
use std::io::Read;
use std::sync::Arc;
/// Build content blocks from text and media paths
fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock> {
let mut blocks = Vec::new();
// Add text block if there's text
if !text.is_empty() {
blocks.push(ContentBlock::text(text));
}
// Add image blocks for media paths
for path in media_paths {
if let Ok((mime_type, base64_data)) = encode_image_to_base64(path) {
let url = format!("data:{};base64,{}", mime_type, base64_data);
blocks.push(ContentBlock::image_url(url));
}
}
// If nothing, add empty text block
if blocks.is_empty() {
blocks.push(ContentBlock::text(""));
}
blocks
}
/// Encode an image file to base64 data URL
fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> {
use base64::{Engine as _, engine::general_purpose::STANDARD};
let mut file = std::fs::File::open(path)?;
let mut buffer = Vec::new();
file.read_to_end(&mut buffer)?;
let mime = mime_guess::from_path(path)
.first_or_octet_stream()
.to_string();
let encoded = STANDARD.encode(&buffer);
Ok((mime, encoded))
}
/// Stateless AgentLoop - history is managed externally by SessionManager
pub struct AgentLoop {
provider: Box<dyn LLMProvider>,
tools: Arc<ToolRegistry>,
max_iterations: u32,
}
impl AgentLoop {
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
let provider = create_provider(provider_config)
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
Ok(Self {
provider,
tools: Arc::new(ToolRegistry::new()),
})
Self::with_tools(provider_config, Arc::new(ToolRegistry::new()))
}
pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>) -> Result<Self, AgentError> {
let provider = create_provider(provider_config)
let provider = create_provider(provider_config.clone())
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
Ok(Self {
provider,
tools,
max_iterations: provider_config.max_iterations,
})
}
@ -37,18 +76,16 @@ impl AgentLoop {
/// Process a message using the provided conversation history.
/// History management is handled externally by SessionManager.
pub async fn process(&self, messages: Vec<ChatMessage>) -> Result<ChatMessage, AgentError> {
let messages_for_llm: Vec<Message> = messages
.iter()
.map(|m| Message {
role: m.role.clone(),
content: m.content.clone(),
tool_call_id: m.tool_call_id.clone(),
name: m.tool_name.clone(),
})
.collect();
/// Returns (final_response, complete_message_history) where the history includes
/// all tool calls and results for proper session continuity.
pub async fn process(&self, messages: Vec<ChatMessage>) -> Result<(ChatMessage, Vec<ChatMessage>), AgentError> {
let mut messages = messages;
let mut final_content: String = String::new();
tracing::debug!(history_len = messages.len(), "Sending request to LLM");
for iteration in 0..self.max_iterations {
tracing::debug!(iteration, history_len = messages.len(), "Starting iteration");
let messages_for_llm = self.build_messages_for_llm(&messages);
let tools = if self.tools.has_tools() {
Some(self.tools.get_definitions())
@ -76,11 +113,10 @@ impl AgentLoop {
);
if !response.tool_calls.is_empty() {
tracing::info!(count = response.tool_calls.len(), "Tool calls detected, executing tools");
tracing::info!(count = response.tool_calls.len(), iteration, tools = ?response.tool_calls.iter().map(|tc| &tc.name).collect::<Vec<_>>(), "Tool calls detected, executing tools");
let mut updated_messages = messages.clone();
let assistant_message = ChatMessage::assistant(response.content.clone());
updated_messages.push(assistant_message.clone());
messages.push(assistant_message);
let tool_results = self.execute_tools(&response.tool_calls).await;
@ -90,61 +126,104 @@ impl AgentLoop {
tool_call.name.clone(),
result.clone(),
);
updated_messages.push(tool_message);
messages.push(tool_message);
}
return self.continue_with_tool_results(updated_messages).await;
tracing::debug!(iteration, "Tool execution completed, continuing to next iteration");
continue;
}
let assistant_message = ChatMessage::assistant(response.content);
Ok(assistant_message)
tracing::debug!(iteration, "No tool calls in response, agent loop ending");
final_content = response.content;
break;
}
async fn continue_with_tool_results(&self, messages: Vec<ChatMessage>) -> Result<ChatMessage, AgentError> {
let messages_for_llm: Vec<Message> = messages
if final_content.is_empty() {
tracing::warn!(iterations = self.max_iterations, "Max iterations reached without final response");
final_content = format!("Error: Max iterations ({}) reached without final response", self.max_iterations);
}
let final_message = ChatMessage::assistant(final_content);
// Return both the final message and the complete history for session persistence
Ok((final_message, messages))
}
fn build_messages_for_llm(&self, messages: &[ChatMessage]) -> Vec<Message> {
messages
.iter()
.map(|m| Message {
.map(|m| {
let content = if m.media_refs.is_empty() {
vec![ContentBlock::text(&m.content)]
} else {
#[cfg(debug_assertions)]
tracing::debug!(media_refs = ?m.media_refs, "Building content blocks with media");
build_content_blocks(&m.content, &m.media_refs)
};
#[cfg(debug_assertions)]
tracing::debug!(role = %m.role, content_len = %m.content.len(), media_refs_len = %m.media_refs.len(), "ChatMessage converted to LLM Message");
Message {
role: m.role.clone(),
content: m.content.clone(),
content,
tool_call_id: m.tool_call_id.clone(),
name: m.tool_name.clone(),
}
})
.collect();
let tools = if self.tools.has_tools() {
Some(self.tools.get_definitions())
} else {
None
};
let request = ChatCompletionRequest {
messages: messages_for_llm,
temperature: None,
max_tokens: None,
tools,
};
let response = (*self.provider).chat(request).await
.map_err(|e| {
tracing::error!(error = %e, "LLM continuation request failed");
AgentError::LlmError(e.to_string())
})?;
let assistant_message = ChatMessage::assistant(response.content);
Ok(assistant_message)
.collect()
}
async fn execute_tools(&self, tool_calls: &[ToolCall]) -> Vec<String> {
let batches = self.partition_tool_batches(tool_calls);
let mut results = Vec::with_capacity(tool_calls.len());
for tool_call in tool_calls {
let result = self.execute_tool(tool_call).await;
results.push(result);
for batch in batches {
if batch.len() == 1 {
// Single tool — run directly (no spawn overhead)
results.push(self.execute_tool(&batch[0]).await);
} else {
// Multiple tools — run in parallel via join_all
use futures_util::future::join_all;
let futures = batch.iter().map(|tc| self.execute_tool(tc));
let batch_results = join_all(futures).await;
results.extend(batch_results);
}
}
results
}
/// Partition tool calls into batches based on concurrency safety.
///
/// `concurrency_safe` tools are grouped together; each `exclusive` tool
/// runs in its own batch. This matches the approach used in Nanobot's
/// `_partition_tool_batches` and Zeroclaw's `parallel_tools` config.
fn partition_tool_batches(&self, tool_calls: &[ToolCall]) -> Vec<Vec<ToolCall>> {
let mut batches: Vec<Vec<ToolCall>> = Vec::new();
let mut current: Vec<ToolCall> = Vec::new();
for tc in tool_calls {
let concurrency_safe = self
.tools
.get(&tc.name)
.map(|t| t.concurrency_safe())
.unwrap_or(false);
if concurrency_safe {
current.push(tc.clone());
} else {
if !current.is_empty() {
batches.push(std::mem::take(&mut current));
}
batches.push(vec![tc.clone()]);
}
}
if !current.is_empty() {
batches.push(current);
}
batches
}
async fn execute_tool(&self, tool_call: &ToolCall) -> String {
let tool = match self.tools.get(&tool_call.name) {
Some(t) => t,
@ -188,3 +267,140 @@ impl std::fmt::Display for AgentError {
}
impl std::error::Error for AgentError {}
#[cfg(test)]
mod tests {
use super::*;
use crate::providers::ToolCall;
use crate::tools::ToolRegistry;
use crate::tools::CalculatorTool;
use crate::tools::BashTool;
use crate::tools::FileReadTool;
use std::sync::Arc;
use serde_json::json;
fn make_tc(name: &str, args: serde_json::Value) -> ToolCall {
ToolCall {
id: format!("tc_{}", name),
name: name.to_string(),
arguments: args,
}
}
/// Verify that partition_tool_batches groups concurrency-safe tools together
/// and isolates exclusive tools, matching the nanobot/zeroclaw approach.
#[test]
fn test_partition_batches_mixes_safe_and_exclusive() {
let registry = Arc::new({
let mut r = ToolRegistry::new();
r.register(CalculatorTool::new()); // concurrency_safe = true
r.register(BashTool::new()); // concurrency_safe = false (exclusive)
r.register(FileReadTool::new()); // concurrency_safe = true
r
});
// agent_loop needs a provider to construct; test the partitioning logic directly
let tcs = vec![
make_tc("calculator", json!({})),
make_tc("bash", json!({"command": "ls"})),
make_tc("file_read", json!({"path": "/tmp/foo"})),
make_tc("calculator", json!({})),
];
// Expected:
// batch 1: calculator (safe, first run)
// batch 2: bash (exclusive, alone)
// batch 3: file_read, calculator (both safe, run together)
let batches = partition_for_test(&registry, &tcs);
assert_eq!(batches.len(), 3);
assert_eq!(batches[0].len(), 1);
assert_eq!(batches[0][0].name, "calculator");
assert_eq!(batches[1].len(), 1);
assert_eq!(batches[1][0].name, "bash");
assert_eq!(batches[2].len(), 2);
assert_eq!(batches[2][0].name, "file_read");
assert_eq!(batches[2][1].name, "calculator");
}
/// All-safe tool calls should produce a single batch (parallel execution).
#[test]
fn test_partition_batches_all_safe_single_batch() {
let registry = Arc::new({
let mut r = ToolRegistry::new();
r.register(CalculatorTool::new());
r.register(FileReadTool::new());
r
});
let tcs = vec![
make_tc("calculator", json!({})),
make_tc("file_read", json!({"path": "/tmp/foo"})),
];
let batches = partition_for_test(&registry, &tcs);
assert_eq!(batches.len(), 1);
assert_eq!(batches[0].len(), 2);
}
/// All-exclusive tool calls should each get their own batch (sequential execution).
#[test]
fn test_partition_batches_all_exclusive_separate_batches() {
let registry = Arc::new({
let mut r = ToolRegistry::new();
r.register(BashTool::new());
r
});
let tcs = vec![
make_tc("bash", json!({"command": "ls"})),
make_tc("bash", json!({"command": "pwd"})),
];
let batches = partition_for_test(&registry, &tcs);
assert_eq!(batches.len(), 2);
assert_eq!(batches[0].len(), 1);
assert_eq!(batches[1].len(), 1);
}
/// Unknown tools (not in registry) default to non-concurrency-safe (single batch).
#[test]
fn test_partition_batches_unknown_tool_gets_own_batch() {
let registry = Arc::new(ToolRegistry::new());
let tcs = vec![
make_tc("calculator", json!({})),
make_tc("unknown_tool", json!({})),
];
let batches = partition_for_test(&registry, &tcs);
assert_eq!(batches.len(), 2);
}
/// Expose partition logic for testing without needing a full AgentLoop.
fn partition_for_test(registry: &Arc<ToolRegistry>, tool_calls: &[ToolCall]) -> Vec<Vec<ToolCall>> {
let mut batches: Vec<Vec<ToolCall>> = Vec::new();
let mut current: Vec<ToolCall> = Vec::new();
for tc in tool_calls {
let concurrency_safe = registry
.get(&tc.name)
.map(|t| t.concurrency_safe())
.unwrap_or(false);
if concurrency_safe {
current.push(tc.clone());
} else {
if !current.is_empty() {
batches.push(std::mem::take(&mut current));
}
batches.push(vec![tc.clone()]);
}
}
if !current.is_empty() {
batches.push(current);
}
batches
}
}

View File

@ -31,6 +31,7 @@ impl OutboundDispatcher {
loop {
let msg = self.bus.consume_outbound().await;
#[cfg(debug_assertions)]
tracing::debug!(
channel = %msg.channel,
chat_id = %msg.chat_id,

View File

@ -2,7 +2,60 @@ use std::collections::HashMap;
use serde::{Deserialize, Serialize};
// ============================================================================
// ChatMessage - Legacy type used by AgentLoop for LLM conversation history
// ContentBlock - Multimodal content representation (OpenAI-style)
// ============================================================================
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlock {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrlBlock },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageUrlBlock {
pub url: String,
}
impl ContentBlock {
pub fn text(content: impl Into<String>) -> Self {
Self::Text { text: content.into() }
}
pub fn image_url(url: impl Into<String>) -> Self {
Self::ImageUrl {
image_url: ImageUrlBlock { url: url.into() },
}
}
}
// ============================================================================
// MediaItem - Media metadata for messages
// ============================================================================
#[derive(Debug, Clone)]
pub struct MediaItem {
pub path: String, // Local file path
pub media_type: String, // "image", "audio", "file", "video"
pub mime_type: Option<String>,
pub original_key: Option<String>, // Feishu file_key for download
}
impl MediaItem {
pub fn new(path: impl Into<String>, media_type: impl Into<String>) -> Self {
Self {
path: path.into(),
media_type: media_type.into(),
mime_type: None,
original_key: None,
}
}
}
// ============================================================================
// ChatMessage - Used by AgentLoop for LLM conversation history
// ============================================================================
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -10,6 +63,7 @@ pub struct ChatMessage {
pub id: String,
pub role: String,
pub content: String,
pub media_refs: Vec<String>, // Paths to media files for context
pub timestamp: i64,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
@ -23,6 +77,19 @@ impl ChatMessage {
id: uuid::Uuid::new_v4().to_string(),
role: "user".to_string(),
content: content.into(),
media_refs: Vec::new(),
timestamp: current_timestamp(),
tool_call_id: None,
tool_name: None,
}
}
pub fn user_with_media(content: impl Into<String>, media_refs: Vec<String>) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
role: "user".to_string(),
content: content.into(),
media_refs,
timestamp: current_timestamp(),
tool_call_id: None,
tool_name: None,
@ -34,6 +101,7 @@ impl ChatMessage {
id: uuid::Uuid::new_v4().to_string(),
role: "assistant".to_string(),
content: content.into(),
media_refs: Vec::new(),
timestamp: current_timestamp(),
tool_call_id: None,
tool_name: None,
@ -45,6 +113,7 @@ impl ChatMessage {
id: uuid::Uuid::new_v4().to_string(),
role: "system".to_string(),
content: content.into(),
media_refs: Vec::new(),
timestamp: current_timestamp(),
tool_call_id: None,
tool_name: None,
@ -56,6 +125,7 @@ impl ChatMessage {
id: uuid::Uuid::new_v4().to_string(),
role: "tool".to_string(),
content: content.into(),
media_refs: Vec::new(),
timestamp: current_timestamp(),
tool_call_id: Some(tool_call_id.into()),
tool_name: Some(tool_name.into()),
@ -74,8 +144,11 @@ pub struct InboundMessage {
pub chat_id: String,
pub content: String,
pub timestamp: i64,
pub media: Vec<String>,
pub media: Vec<MediaItem>,
/// Channel-specific data used internally by the channel (not forwarded).
pub metadata: HashMap<String, String>,
/// Data forwarded from inbound to outbound (copied to OutboundMessage.metadata by gateway).
pub forwarded_metadata: HashMap<String, String>,
}
impl InboundMessage {
@ -94,7 +167,7 @@ pub struct OutboundMessage {
pub chat_id: String,
pub content: String,
pub reply_to: Option<String>,
pub media: Vec<String>,
pub media: Vec<MediaItem>,
pub metadata: HashMap<String, String>,
}

View File

@ -2,7 +2,7 @@ pub mod dispatcher;
pub mod message;
pub use dispatcher::OutboundDispatcher;
pub use message::{ChatMessage, InboundMessage, OutboundMessage};
pub use message::{ChatMessage, ContentBlock, InboundMessage, MediaItem, OutboundMessage};
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
@ -33,6 +33,8 @@ impl MessageBus {
/// Publish an inbound message (Channel -> Bus)
pub async fn publish_inbound(&self, msg: InboundMessage) -> Result<(), BusError> {
#[cfg(debug_assertions)]
tracing::debug!(channel = %msg.channel, sender = %msg.sender_id, chat = %msg.chat_id, content_len = %msg.content.len(), media_count = %msg.media.len(), "Bus: publishing inbound message");
self.inbound_tx
.send(msg)
.await
@ -41,16 +43,21 @@ impl MessageBus {
/// Consume an inbound message (Agent -> Bus)
pub async fn consume_inbound(&self) -> InboundMessage {
self.inbound_rx
let msg = self.inbound_rx
.lock()
.await
.recv()
.await
.expect("bus inbound closed")
.expect("bus inbound closed");
#[cfg(debug_assertions)]
tracing::debug!(channel = %msg.channel, sender = %msg.sender_id, chat = %msg.chat_id, "Bus: consuming inbound message");
msg
}
/// Publish an outbound message (Agent -> Bus)
pub async fn publish_outbound(&self, msg: OutboundMessage) -> Result<(), BusError> {
#[cfg(debug_assertions)]
tracing::debug!(channel = %msg.channel, chat_id = %msg.chat_id, content_len = %msg.content.len(), "Bus: publishing outbound message");
self.outbound_tx
.send(msg)
.await

View File

@ -62,37 +62,18 @@ pub trait Channel: Send + Sync + 'static {
async fn handle_and_publish(
&self,
bus: &Arc<MessageBus>,
sender_id: &str,
chat_id: &str,
content: &str,
msg: &InboundMessage,
) -> Result<(), ChannelError> {
if !self.is_allowed(sender_id) {
if !self.is_allowed(&msg.sender_id) {
tracing::warn!(
channel = %self.name(),
sender = %sender_id,
sender = %msg.sender_id,
"Access denied"
);
return Ok(());
}
let msg = InboundMessage {
channel: self.name().to_string(),
sender_id: sender_id.to_string(),
chat_id: chat_id.to_string(),
content: content.to_string(),
timestamp: current_timestamp(),
media: vec![],
metadata: std::collections::HashMap::new(),
};
bus.publish_inbound(msg).await?;
bus.publish_inbound(msg.clone()).await?;
Ok(())
}
}
fn current_timestamp() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as i64
}

View File

@ -1,17 +1,33 @@
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use tokio::sync::{broadcast, RwLock};
use serde::Deserialize;
use futures_util::{SinkExt, StreamExt};
use prost::{Message as ProstMessage, bytes::Bytes};
use serde::Deserialize;
use tokio::sync::{broadcast, RwLock};
use crate::bus::{MessageBus, OutboundMessage};
use crate::bus::{MessageBus, MediaItem, OutboundMessage};
use crate::channels::base::{Channel, ChannelError};
use crate::config::{FeishuChannelConfig, LLMProviderConfig};
const FEISHU_API_BASE: &str = "https://open.feishu.cn/open-apis";
const FEISHU_WS_BASE: &str = "https://open.feishu.cn";
/// Heartbeat timeout for WS connection — must be larger than ping_interval (default 120 s).
/// If no binary frame (pong or event) is received within this window, reconnect.
const WS_HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(300);
/// Refresh tenant token this many seconds before the announced expiry.
const TOKEN_REFRESH_SKEW: Duration = Duration::from_secs(120);
/// Default tenant token TTL when `expire`/`expires_in` is absent.
const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(7200);
/// Feishu API business code for expired/invalid tenant access token.
const INVALID_ACCESS_TOKEN_CODE: i32 = 99991663;
/// Dedup cache TTL (30 minutes).
const DEDUP_CACHE_TTL: Duration = Duration::from_secs(30 * 60);
// ─────────────────────────────────────────────────────────────────────────────
// Protobuf types for Feishu WebSocket protocol (pbbp2.proto)
// ─────────────────────────────────────────────────────────────────────────────
@ -124,6 +140,13 @@ struct LarkMessage {
// ─────────────────────────────────────────────────────────────────────────────
/// Cached tenant token with proactive refresh metadata.
#[derive(Clone)]
struct CachedTenantToken {
value: String,
refresh_after: Instant,
}
#[derive(Clone)]
pub struct FeishuChannel {
config: FeishuChannelConfig,
@ -131,6 +154,10 @@ pub struct FeishuChannel {
running: Arc<RwLock<bool>>,
shutdown_tx: Arc<RwLock<Option<broadcast::Sender<()>>>>,
connected: Arc<RwLock<bool>>,
/// Cached tenant access token with proactive refresh.
tenant_token: Arc<RwLock<Option<CachedTenantToken>>>,
/// Dedup cache: WS message_ids seen in the last ~30 min.
seen_message_ids: Arc<RwLock<HashMap<String, Instant>>>,
}
/// Parsed message data from a Feishu frame
@ -139,6 +166,7 @@ struct ParsedMessage {
open_id: String,
chat_id: String,
content: String,
media: Option<MediaItem>,
}
impl FeishuChannel {
@ -152,6 +180,8 @@ impl FeishuChannel {
running: Arc::new(RwLock::new(false)),
shutdown_tx: Arc::new(RwLock::new(None)),
connected: Arc::new(RwLock::new(false)),
tenant_token: Arc::new(RwLock::new(None)),
seen_message_ids: Arc::new(RwLock::new(HashMap::new())),
})
}
@ -188,8 +218,42 @@ impl FeishuChannel {
Ok((ep.url, client_config))
}
/// Get tenant access token
async fn get_tenant_token(&self) -> Result<String, ChannelError> {
/// Get tenant access token (cached with proactive refresh).
async fn get_tenant_access_token(&self) -> Result<String, ChannelError> {
// 1. Check cache
{
let cached = self.tenant_token.read().await;
if let Some(ref token) = *cached {
if Instant::now() < token.refresh_after {
return Ok(token.value.clone());
}
}
}
// 2. Fetch new token
let (token, ttl) = self.fetch_new_token().await?;
// 3. Cache with proactive refresh time (提前 120 秒)
let refresh_after = Instant::now() + ttl.saturating_sub(TOKEN_REFRESH_SKEW);
{
let mut cached = self.tenant_token.write().await;
*cached = Some(CachedTenantToken {
value: token.clone(),
refresh_after,
});
}
Ok(token)
}
/// Invalidate cached token (called when API reports expired tenant token).
async fn invalidate_token(&self) {
let mut cached = self.tenant_token.write().await;
*cached = None;
}
/// Fetch a new tenant access token from Feishu.
async fn fetch_new_token(&self) -> Result<(String, Duration), ChannelError> {
let resp = self.http_client
.post(format!("{}/auth/v3/tenant_access_token/internal", FEISHU_API_BASE))
.header("Content-Type", "application/json")
@ -205,6 +269,7 @@ impl FeishuChannel {
struct TokenResponse {
code: i32,
tenant_access_token: Option<String>,
expire: Option<i64>,
}
let token_resp: TokenResponse = resp
@ -216,15 +281,401 @@ impl FeishuChannel {
return Err(ChannelError::Other("Auth failed".to_string()));
}
token_resp.tenant_access_token
.ok_or_else(|| ChannelError::Other("No token in response".to_string()))
let token = token_resp.tenant_access_token
.ok_or_else(|| ChannelError::Other("No token in response".to_string()))?;
let ttl = token_resp.expire
.and_then(|v| u64::try_from(v).ok())
.map(Duration::from_secs)
.unwrap_or(DEFAULT_TOKEN_TTL);
Ok((token, ttl))
}
/// Check if message_id has been seen (dedup), and mark it as seen if not.
/// Returns true if the message was already processed.
/// Note: GC of stale entries is handled in the heartbeat timeout_check loop.
async fn is_message_seen(&self, message_id: &str) -> bool {
let mut seen = self.seen_message_ids.write().await;
let now = Instant::now();
if seen.contains_key(message_id) {
true
} else {
seen.insert(message_id.to_string(), now);
false
}
}
/// Download media and save locally, return (description, media_item)
async fn download_media(
&self,
msg_type: &str,
content_json: &serde_json::Value,
message_id: &str,
) -> Result<(String, Option<MediaItem>), ChannelError> {
let media_dir = Path::new(&self.config.media_dir);
tokio::fs::create_dir_all(media_dir).await
.map_err(|e| ChannelError::Other(format!("Failed to create media dir: {}", e)))?;
match msg_type {
"image" => self.download_image(content_json, message_id, media_dir).await,
"audio" | "file" | "media" => self.download_file(content_json, message_id, media_dir, msg_type).await,
_ => Ok((format!("[unsupported media type: {}]", msg_type), None)),
}
}
/// Download image from Feishu
async fn download_image(
&self,
content_json: &serde_json::Value,
message_id: &str,
media_dir: &Path,
) -> Result<(String, Option<MediaItem>), ChannelError> {
let image_key = content_json.get("image_key")
.and_then(|v| v.as_str())
.ok_or_else(|| ChannelError::Other("No image_key in message".to_string()))?;
let token = self.get_tenant_access_token().await?;
// Use message resource API for downloading message images
let url = format!("{}/im/v1/messages/{}/resources/{}?type=image", FEISHU_API_BASE, message_id, image_key);
#[cfg(debug_assertions)]
tracing::debug!(url = %url, image_key = %image_key, message_id = %message_id, "Downloading image from Feishu via message resource API");
let resp = self.http_client
.get(&url)
.header("Authorization", format!("Bearer {}", token))
.send()
.await
.map_err(|e| ChannelError::ConnectionError(format!("Download image HTTP error: {}", e)))?;
let status = resp.status();
#[cfg(debug_assertions)]
tracing::debug!(status = %status, "Image download response status");
if !status.is_success() {
let error_text = resp.text().await.unwrap_or_default();
return Err(ChannelError::Other(format!("Image download failed {}: {}", status, error_text)));
}
let data = resp.bytes().await
.map_err(|e| ChannelError::Other(format!("Failed to read image data: {}", e)))?
.to_vec();
#[cfg(debug_assertions)]
tracing::debug!(data_len = %data.len(), "Downloaded image data");
let filename = format!("{}_{}.jpg", message_id, &image_key[..8.min(image_key.len())]);
let file_path = media_dir.join(&filename);
tokio::fs::write(&file_path, &data).await
.map_err(|e| ChannelError::Other(format!("Failed to write image: {}", e)))?;
let media_item = MediaItem::new(
file_path.to_string_lossy().to_string(),
"image",
);
tracing::info!(message_id = %message_id, filename = %filename, "Downloaded image");
Ok((format!("[image: {}]", filename), Some(media_item)))
}
/// Download file/audio from Feishu
async fn download_file(
&self,
content_json: &serde_json::Value,
message_id: &str,
media_dir: &Path,
file_type: &str,
) -> Result<(String, Option<MediaItem>), ChannelError> {
let file_key = content_json.get("file_key")
.and_then(|v| v.as_str())
.ok_or_else(|| ChannelError::Other("No file_key in message".to_string()))?;
let token = self.get_tenant_access_token().await?;
// Use message resource API for downloading message files
let url = format!("{}/im/v1/messages/{}/resources/{}?type=file", FEISHU_API_BASE, message_id, file_key);
#[cfg(debug_assertions)]
tracing::debug!(url = %url, file_key = %file_key, message_id = %message_id, "Downloading file from Feishu via message resource API");
let resp = self.http_client
.get(&url)
.header("Authorization", format!("Bearer {}", token))
.send()
.await
.map_err(|e| ChannelError::ConnectionError(format!("Download file HTTP error: {}", e)))?;
let status = resp.status();
if !status.is_success() {
let error_text = resp.text().await.unwrap_or_default();
return Err(ChannelError::Other(format!("File download failed {}: {}", status, error_text)));
}
let data = resp.bytes().await
.map_err(|e| ChannelError::Other(format!("Failed to read file data: {}", e)))?
.to_vec();
let extension = match file_type {
"audio" => "mp3",
"video" => "mp4",
_ => "bin",
};
let filename = format!("{}_{}.{}", message_id, &file_key[..8.min(file_key.len())], extension);
let file_path = media_dir.join(&filename);
tokio::fs::write(&file_path, &data).await
.map_err(|e| ChannelError::Other(format!("Failed to write file: {}", e)))?;
let media_item = MediaItem::new(
file_path.to_string_lossy().to_string(),
file_type,
);
tracing::info!(message_id = %message_id, filename = %filename, file_type = %file_type, "Downloaded file");
Ok((format!("[{}: {}]", file_type, filename), Some(media_item)))
}
/// Upload image to Feishu and return the image_key
async fn upload_image(&self, file_path: &str) -> Result<String, ChannelError> {
let token = self.get_tenant_access_token().await?;
let mime = mime_guess::from_path(file_path)
.first_or_octet_stream()
.to_string();
let file_name = std::path::Path::new(file_path)
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("image.jpg");
let file_data = tokio::fs::read(file_path).await
.map_err(|e| ChannelError::Other(format!("Failed to read file: {}", e)))?;
let part = reqwest::multipart::Part::bytes(file_data)
.file_name(file_name.to_string())
.mime_str(&mime)
.map_err(|e| ChannelError::Other(format!("Invalid mime type: {}", e)))?;
let form = reqwest::multipart::Form::new()
.text("image_type", "message".to_string())
.part("image", part);
let resp = self.http_client
.post(format!("{}/im/v1/images/upload", FEISHU_API_BASE))
.header("Authorization", format!("Bearer {}", token))
.multipart(form)
.send()
.await
.map_err(|e| ChannelError::ConnectionError(format!("Upload image HTTP error: {}", e)))?;
#[derive(Deserialize)]
struct UploadResp {
code: i32,
msg: Option<String>,
data: Option<UploadData>,
}
#[derive(Deserialize)]
struct UploadData {
image_key: String,
}
let result: UploadResp = resp.json().await
.map_err(|e| ChannelError::Other(format!("Parse upload response error: {}", e)))?;
if result.code != 0 {
return Err(ChannelError::Other(format!(
"Upload image failed: code={} msg={}",
result.code,
result.msg.as_deref().unwrap_or("unknown")
)));
}
result.data
.map(|d| d.image_key)
.ok_or_else(|| ChannelError::Other("No image_key in response".to_string()))
}
/// Upload file to Feishu and return the file_key
async fn upload_file(&self, file_path: &str) -> Result<String, ChannelError> {
let token = self.get_tenant_access_token().await?;
let file_name = std::path::Path::new(file_path)
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("file.bin");
let extension = std::path::Path::new(file_path)
.extension()
.and_then(|e| e.to_str())
.unwrap_or("")
.to_lowercase();
let file_type = match extension.as_str() {
"mp3" | "m4a" | "wav" | "ogg" => "audio",
"mp4" | "mov" | "avi" | "mkv" => "video",
"pdf" | "doc" | "docx" | "xls" | "xlsx" => "doc",
_ => "file",
};
let file_data = tokio::fs::read(file_path).await
.map_err(|e| ChannelError::Other(format!("Failed to read file: {}", e)))?;
let part = reqwest::multipart::Part::bytes(file_data)
.file_name(file_name.to_string())
.mime_str("application/octet-stream")
.map_err(|e| ChannelError::Other(format!("Invalid mime type: {}", e)))?;
let form = reqwest::multipart::Form::new()
.text("file_type", file_type.to_string())
.text("file_name", file_name.to_string())
.part("file", part);
let resp = self.http_client
.post(format!("{}/im/v1/files", FEISHU_API_BASE))
.header("Authorization", format!("Bearer {}", token))
.multipart(form)
.send()
.await
.map_err(|e| ChannelError::ConnectionError(format!("Upload file HTTP error: {}", e)))?;
#[derive(Deserialize)]
struct UploadResp {
code: i32,
msg: Option<String>,
data: Option<UploadData>,
}
#[derive(Deserialize)]
struct UploadData {
file_key: String,
}
let result: UploadResp = resp.json().await
.map_err(|e| ChannelError::Other(format!("Parse upload response error: {}", e)))?;
if result.code != 0 {
return Err(ChannelError::Other(format!(
"Upload file failed: code={} msg={}",
result.code,
result.msg.as_deref().unwrap_or("unknown")
)));
}
result.data
.map(|d| d.file_key)
.ok_or_else(|| ChannelError::Other("No file_key in response".to_string()))
}
/// Add a reaction emoji to a message and store the reaction_id for later removal.
/// Returns the reaction_id if successful, None otherwise.
async fn add_reaction(&self, message_id: &str) -> Result<Option<String>, ChannelError> {
let emoji = self.config.reaction_emoji.as_str();
let token = self.get_tenant_access_token().await?;
let resp = self.http_client
.post(format!("{}/im/v1/messages/{}/reactions", FEISHU_API_BASE, message_id))
.header("Authorization", format!("Bearer {}", token))
.json(&serde_json::json!({
"reaction_type": { "emoji_type": emoji }
}))
.send()
.await
.map_err(|e| ChannelError::ConnectionError(format!("Add reaction HTTP error: {}", e)))?;
#[derive(Deserialize)]
struct ReactionResp {
code: i32,
msg: Option<String>,
data: Option<ReactionData>,
}
#[derive(Deserialize)]
struct ReactionData {
reaction_id: Option<String>,
}
let result: ReactionResp = resp.json().await
.map_err(|e| ChannelError::Other(format!("Parse reaction response error: {}", e)))?;
if result.code != 0 {
tracing::warn!(
"Failed to add reaction to message {}: code={} msg={}",
message_id,
result.code,
result.msg.as_deref().unwrap_or("unknown")
);
return Ok(None);
}
let reaction_id = result.data.and_then(|d| d.reaction_id);
Ok(reaction_id)
}
/// Remove reaction using feishu metadata propagated through OutboundMessage.
/// Reads feishu.message_id and feishu.reaction_id from metadata.
async fn remove_reaction_from_metadata(&self, metadata: &std::collections::HashMap<String, String>) {
let (message_id, reaction_id) = match (
metadata.get("feishu.message_id"),
metadata.get("feishu.reaction_id"),
) {
(Some(msg_id), Some(rid)) => (msg_id.clone(), rid.clone()),
_ => return,
};
if let Err(e) = self.remove_reaction(&message_id, &reaction_id).await {
tracing::debug!(error = %e, message_id = %message_id, "Failed to remove reaction");
}
}
/// Remove a reaction emoji from a message.
async fn remove_reaction(&self, message_id: &str, reaction_id: &str) -> Result<(), ChannelError> {
let token = self.get_tenant_access_token().await?;
let resp = self.http_client
.delete(format!("{}/im/v1/messages/{}/reactions/{}", FEISHU_API_BASE, message_id, reaction_id))
.header("Authorization", format!("Bearer {}", token))
.send()
.await
.map_err(|e| ChannelError::ConnectionError(format!("Remove reaction HTTP error: {}", e)))?;
#[derive(Deserialize)]
struct ReactionResp {
code: i32,
msg: Option<String>,
}
let result: ReactionResp = resp.json().await
.map_err(|e| ChannelError::Other(format!("Parse remove reaction response error: {}", e)))?;
if result.code != 0 {
tracing::debug!(
"Failed to remove reaction {} from message {}: code={} msg={}",
reaction_id,
message_id,
result.code,
result.msg.as_deref().unwrap_or("unknown")
);
}
Ok(())
}
/// Send a text message to Feishu chat (implements Channel trait)
async fn send_message_to_feishu(&self, receive_id: &str, receive_id_type: &str, content: &str) -> Result<(), ChannelError> {
let token = self.get_tenant_token().await?;
let token = self.get_tenant_access_token().await?;
let text_content = serde_json::json!({ "text": content }).to_string();
// Feishu text messages have content limits (~64KB).
// Truncate if content is too long to avoid API error 230001.
const MAX_TEXT_LENGTH: usize = 60_000;
let truncated = if content.len() > MAX_TEXT_LENGTH {
format!("{}...\n\n[Content truncated due to length limit]", &content[..MAX_TEXT_LENGTH])
} else {
content.to_string()
};
let text_content = serde_json::json!({ "text": truncated }).to_string();
let resp = self.http_client
.post(format!("{}/im/v1/messages?receive_id_type={}", FEISHU_API_BASE, receive_id_type))
@ -285,10 +736,15 @@ impl FeishuChannel {
let payload = frame.payload.as_deref()
.ok_or_else(|| ChannelError::Other("No payload in frame".to_string()))?;
#[cfg(debug_assertions)]
tracing::debug!(payload_len = %payload.len(), "Received frame payload");
let event: LarkEvent = serde_json::from_slice(payload)
.map_err(|e| ChannelError::Other(format!("Parse event error: {}", e)))?;
let event_type = event.header.event_type.as_str();
#[cfg(debug_assertions)]
tracing::debug!(event_type = %event_type, "Received event type");
if event_type != "im.message.receive_v1" {
return Ok(None);
}
@ -303,22 +759,74 @@ impl FeishuChannel {
let message_id = payload_data.message.message_id.clone();
// Deduplication check
if self.is_message_seen(&message_id).await {
#[cfg(debug_assertions)]
tracing::debug!(message_id = %message_id, "Duplicate message, skipping");
return Ok(None);
}
#[cfg(debug_assertions)]
tracing::debug!(message_id = %message_id, "Received Feishu message");
let open_id = payload_data.sender.sender_id.open_id
.ok_or_else(|| ChannelError::Other("No open_id".to_string()))?;
let msg = payload_data.message;
let chat_id = msg.chat_id.clone();
let msg_type = msg.message_type.as_str();
let content = parse_message_content(msg_type, &msg.content);
let raw_content = msg.content.clone();
#[cfg(debug_assertions)]
tracing::debug!(msg_type = %msg_type, chat_id = %chat_id, open_id = %open_id, "Parsing message content");
let (content, media) = self.parse_and_download_message(msg_type, &raw_content, &message_id).await?;
#[cfg(debug_assertions)]
if let Some(ref m) = media {
tracing::debug!(media_type = %m.media_type, media_path = %m.path, "Media downloaded successfully");
}
Ok(Some(ParsedMessage {
message_id,
open_id,
chat_id,
content,
media,
}))
}
/// Parse message content and download media if needed
async fn parse_and_download_message(
&self,
msg_type: &str,
content: &str,
message_id: &str,
) -> Result<(String, Option<MediaItem>), ChannelError> {
match msg_type {
"text" => {
let text = if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) {
parsed.get("text").and_then(|v| v.as_str()).unwrap_or(content).to_string()
} else {
content.to_string()
};
Ok((text, None))
}
"post" => {
let text = parse_post_content(content);
Ok((text, None))
}
"image" | "audio" | "file" | "media" => {
if let Ok(content_json) = serde_json::from_str::<serde_json::Value>(content) {
self.download_media(msg_type, &content_json, message_id).await
} else {
Ok((format!("[{}: content unavailable]", msg_type), None))
}
}
_ => Ok((content.to_string(), None)),
}
}
/// Send acknowledgment for a message
async fn send_ack(frame: &PbFrame, write: &mut futures_util::stream::SplitSink<tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>, tokio_tungstenite::tungstenite::Message>) -> Result<(), ChannelError> {
let mut ack = frame.clone();
@ -366,16 +874,20 @@ impl FeishuChannel {
let ping_interval = client_config.ping_interval.unwrap_or(120).max(10);
let mut ping_interval_tok = tokio::time::interval(tokio::time::Duration::from_secs(ping_interval));
let mut timeout_check = tokio::time::interval(tokio::time::Duration::from_secs(10));
let mut seq: u64 = 1;
let mut last_recv = Instant::now();
// Consume the immediate tick
ping_interval_tok.tick().await;
timeout_check.tick().await;
loop {
tokio::select! {
msg = read.next() => {
match msg {
Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(data))) => {
last_recv = Instant::now();
let bytes: Bytes = data;
if let Ok(frame) = PbFrame::decode(bytes.as_ref()) {
match self.handle_frame(&frame).await {
@ -385,12 +897,49 @@ impl FeishuChannel {
tracing::error!(error = %e, "Failed to send ACK to Feishu");
}
// Add reaction emoji (await so we get the reaction_id for later removal)
let message_id = parsed.message_id.clone();
let reaction_id = match self.add_reaction(&message_id).await {
Ok(Some(rid)) => Some(rid),
Ok(None) => None,
Err(e) => {
tracing::debug!(error = %e, message_id = %message_id, "Failed to add reaction");
None
}
};
// forwarded_metadata is copied to OutboundMessage.metadata by the gateway.
let mut forwarded_metadata = std::collections::HashMap::new();
forwarded_metadata.insert("feishu.message_id".to_string(), message_id.clone());
if let Some(ref rid) = reaction_id {
forwarded_metadata.insert("feishu.reaction_id".to_string(), rid.clone());
}
// Publish to bus asynchronously
let channel = self.clone();
let bus = bus.clone();
tokio::spawn(async move {
if let Err(e) = channel.handle_and_publish(&bus, &parsed.open_id, &parsed.chat_id, &parsed.content).await {
let media_count = if parsed.media.is_some() { 1 } else { 0 };
#[cfg(debug_assertions)]
tracing::debug!(open_id = %parsed.open_id, chat_id = %parsed.chat_id, content_len = %parsed.content.len(), media_count = %media_count, "Publishing message to bus");
let msg = crate::bus::InboundMessage {
channel: "feishu".to_string(),
sender_id: parsed.open_id.clone(),
chat_id: parsed.chat_id.clone(),
content: parsed.content.clone(),
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as i64,
media: parsed.media.map(|m| vec![m]).unwrap_or_default(),
metadata: std::collections::HashMap::new(),
forwarded_metadata,
};
if let Err(e) = channel.handle_and_publish(&bus, &msg).await {
tracing::error!(error = %e, open_id = %parsed.open_id, chat_id = %parsed.chat_id, "Failed to publish Feishu message to bus");
} else {
#[cfg(debug_assertions)]
tracing::debug!(open_id = %parsed.open_id, chat_id = %parsed.chat_id, "Message published to bus successfully");
}
});
}
@ -402,6 +951,7 @@ impl FeishuChannel {
}
}
Some(Ok(tokio_tungstenite::tungstenite::Message::Ping(data))) => {
last_recv = Instant::now();
let pong = PbFrame {
seq_id: seq.wrapping_add(1),
log_id: 0,
@ -415,7 +965,11 @@ impl FeishuChannel {
};
let _ = write.send(tokio_tungstenite::tungstenite::Message::Binary(pong.encode_to_vec().into())).await;
}
Some(Ok(tokio_tungstenite::tungstenite::Message::Pong(_))) => {
last_recv = Instant::now();
}
Some(Ok(tokio_tungstenite::tungstenite::Message::Close(_))) | None => {
#[cfg(debug_assertions)]
tracing::debug!("Feishu WebSocket closed");
break;
}
@ -444,6 +998,16 @@ impl FeishuChannel {
break;
}
}
_ = timeout_check.tick() => {
if last_recv.elapsed() > WS_HEARTBEAT_TIMEOUT {
tracing::warn!("Feishu WebSocket heartbeat timeout, reconnecting");
break;
}
// GC dedup cache: remove entries older than TTL (matches zeroclaw pattern)
let now = Instant::now();
let mut seen = self.seen_message_ids.write().await;
seen.retain(|_, ts| now.duration_since(*ts) < DEDUP_CACHE_TTL);
}
_ = shutdown_rx.recv() => {
tracing::info!("Feishu channel shutdown signal received");
break;
@ -456,16 +1020,7 @@ impl FeishuChannel {
}
}
fn parse_message_content(msg_type: &str, content: &str) -> String {
match msg_type {
"text" => {
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) {
parsed.get("text").and_then(|v| v.as_str()).unwrap_or(content).to_string()
} else {
content.to_string()
}
}
"post" => {
fn parse_post_content(content: &str) -> String {
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(content) {
let mut texts = vec![];
if let Some(post) = parsed.get("post") {
@ -492,9 +1047,6 @@ fn parse_message_content(msg_type: &str, content: &str) -> String {
content.to_string()
}
}
_ => content.to_string(),
}
}
#[async_trait]
impl Channel for FeishuChannel {
@ -574,6 +1126,119 @@ impl Channel for FeishuChannel {
let receive_id = if msg.chat_id.starts_with("oc_") { &msg.chat_id } else { &msg.reply_to.as_ref().unwrap_or(&msg.chat_id) };
let receive_id_type = if msg.chat_id.starts_with("oc_") { "chat_id" } else { "open_id" };
self.send_message_to_feishu(receive_id, receive_id_type, &msg.content).await
// If no media, send text only
if msg.media.is_empty() {
let result = self.send_message_to_feishu(receive_id, receive_id_type, &msg.content).await;
// Remove pending reaction after sending (using metadata propagated from inbound)
self.remove_reaction_from_metadata(&msg.metadata).await;
return result;
}
// Handle multimodal message - send with media
let token = self.get_tenant_access_token().await?;
// Build content with media references
let mut content_parts = Vec::new();
// Add text content if present (truncate if too long for Feishu)
if !msg.content.is_empty() {
const MAX_TEXT_LENGTH: usize = 60_000;
let truncated_text = if msg.content.len() > MAX_TEXT_LENGTH {
format!("{}...\n\n[Content truncated due to length limit]", &msg.content[..MAX_TEXT_LENGTH])
} else {
msg.content.clone()
};
content_parts.push(serde_json::json!({
"tag": "text",
"text": truncated_text
}));
}
// Upload and add media
for media_item in &msg.media {
let path = &media_item.path;
match media_item.media_type.as_str() {
"image" => {
match self.upload_image(path).await {
Ok(image_key) => {
content_parts.push(serde_json::json!({
"tag": "image",
"image_key": image_key
}));
}
Err(e) => {
tracing::warn!(error = %e, path = %path, "Failed to upload image");
}
}
}
"audio" | "file" | "video" => {
match self.upload_file(path).await {
Ok(file_key) => {
content_parts.push(serde_json::json!({
"tag": "file",
"file_key": file_key
}));
}
Err(e) => {
tracing::warn!(error = %e, path = %path, "Failed to upload file");
}
}
}
_ => {
tracing::warn!(media_type = %media_item.media_type, "Unsupported media type for sending");
}
}
}
// If no content parts after processing, just send empty text
if content_parts.is_empty() {
let result = self.send_message_to_feishu(receive_id, receive_id_type, "").await;
// Remove pending reaction after sending (using metadata propagated from inbound)
self.remove_reaction_from_metadata(&msg.metadata).await;
return result;
}
// Determine message type
let has_image = msg.media.iter().any(|m| m.media_type == "image");
let msg_type = if has_image && msg.content.is_empty() {
"image"
} else {
"post"
};
let content = serde_json::json!({
"content": content_parts
}).to_string();
let resp = self.http_client
.post(format!("{}/im/v1/messages?receive_id_type={}", FEISHU_API_BASE, receive_id_type))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", token))
.json(&serde_json::json!({
"receive_id": receive_id,
"msg_type": msg_type,
"content": content
}))
.send()
.await
.map_err(|e| ChannelError::ConnectionError(format!("Send multimodal message HTTP error: {}", e)))?;
#[derive(Deserialize)]
struct SendResp {
code: i32,
msg: String,
}
let send_resp: SendResp = resp.json().await
.map_err(|e| ChannelError::Other(format!("Parse send response error: {}", e)))?;
if send_resp.code != 0 {
return Err(ChannelError::Other(format!("Send multimodal message failed: code={} msg={}", send_resp.code, send_resp.msg)));
}
// Remove pending reaction after successfully sending
self.remove_reaction_from_metadata(&msg.metadata).await;
Ok(())
}
}

View File

@ -35,6 +35,7 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
input.write_output(&format!("Error: {}", message)).await?;
}
WsOutbound::SessionEstablished { session_id } => {
#[cfg(debug_assertions)]
tracing::debug!(session_id = %session_id, "Session established");
input.write_output(&format!("Session: {}\n", session_id)).await?;
}

View File

@ -28,12 +28,26 @@ pub struct FeishuChannelConfig {
pub allow_from: Vec<String>,
#[serde(default)]
pub agent: String,
#[serde(default = "default_media_dir")]
pub media_dir: String,
/// Emoji type for message reactions (e.g. "THUMBSUP", "OK", "EYES").
#[serde(default = "default_reaction_emoji")]
pub reaction_emoji: String,
}
fn default_allow_from() -> Vec<String> {
vec!["*".to_string()]
}
fn default_media_dir() -> String {
let home = dirs::home_dir().unwrap_or_else(|| std::path::PathBuf::from("."));
home.join(".picobot/media/feishu").to_string_lossy().to_string()
}
fn default_reaction_emoji() -> String {
"Typing".to_string()
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ProviderConfig {
#[serde(rename = "type")]
@ -59,6 +73,12 @@ pub struct ModelConfig {
pub struct AgentConfig {
pub provider: String,
pub model: String,
#[serde(default = "default_max_iterations")]
pub max_iterations: u32,
}
fn default_max_iterations() -> u32 {
15
}
#[derive(Debug, Clone, Deserialize, Serialize)]
@ -118,6 +138,7 @@ pub struct LLMProviderConfig {
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub model_extra: HashMap<String, serde_json::Value>,
pub max_iterations: u32,
}
fn get_default_config_path() -> PathBuf {
@ -177,6 +198,7 @@ impl Config {
temperature: model.temperature,
max_tokens: model.max_tokens,
model_extra: model.extra.clone(),
max_iterations: agent.max_iterations,
})
}
}

View File

@ -53,11 +53,22 @@ impl GatewayState {
tracing::info!("Inbound processor started");
loop {
let inbound = bus_for_inbound.consume_inbound().await;
#[cfg(debug_assertions)]
{
tracing::debug!(
channel = %inbound.channel,
chat_id = %inbound.chat_id,
sender = %inbound.sender_id,
content = %inbound.content,
media_count = %inbound.media.len(),
"Processing inbound message"
);
if !inbound.media.is_empty() {
for (i, m) in inbound.media.iter().enumerate() {
tracing::debug!(media_index = i, media_type = %m.media_type, path = %m.path, "Media item");
}
}
}
// Process via session manager
match session_manager.handle_message(
@ -65,15 +76,19 @@ impl GatewayState {
&inbound.sender_id,
&inbound.chat_id,
&inbound.content,
inbound.media,
).await {
Ok(response_content) => {
// Forward channel-specific metadata from inbound to outbound.
// This allows channels to propagate context (e.g. feishu message_id for reaction cleanup)
// without gateway needing channel-specific code.
let outbound = crate::bus::OutboundMessage {
channel: inbound.channel,
chat_id: inbound.chat_id,
channel: inbound.channel.clone(),
chat_id: inbound.chat_id.clone(),
content: response_content,
reply_to: None,
media: vec![],
metadata: std::collections::HashMap::new(),
metadata: inbound.forwarded_metadata,
};
if let Err(e) = bus_for_inbound.publish_outbound(outbound).await {
tracing::error!(error = %e, "Failed to publish outbound");

View File

@ -7,7 +7,10 @@ use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig;
use crate::agent::{AgentLoop, AgentError};
use crate::protocol::WsOutbound;
use crate::tools::{CalculatorTool, ToolRegistry};
use crate::tools::{
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
HttpRequestTool, ToolRegistry, WebFetchTool,
};
/// Session 按 channel 隔离,每个 channel 一个 Session
/// History 按 chat_id 隔离,由 Session 统一管理
@ -56,6 +59,12 @@ impl Session {
history.push(ChatMessage::user(content));
}
/// 添加带媒体的用户消息到指定 chat_id 的历史
pub fn add_user_message_with_media(&mut self, chat_id: &str, content: &str, media_refs: Vec<String>) {
let history = self.get_or_create_history(chat_id);
history.push(ChatMessage::user_with_media(content, media_refs));
}
/// 添加助手响应到指定 chat_id 的历史
pub fn add_assistant_message(&mut self, chat_id: &str, message: ChatMessage) {
if let Some(history) = self.chat_histories.get_mut(chat_id) {
@ -68,6 +77,7 @@ impl Session {
if let Some(history) = self.chat_histories.get_mut(chat_id) {
let len = history.len();
history.clear();
#[cfg(debug_assertions)]
tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history cleared");
}
}
@ -76,6 +86,7 @@ impl Session {
pub fn clear_all_history(&mut self) {
let total: usize = self.chat_histories.values().map(|h| h.len()).sum();
self.chat_histories.clear();
#[cfg(debug_assertions)]
tracing::debug!(previous_total = total, "All chat histories cleared");
}
@ -106,6 +117,17 @@ struct SessionManagerInner {
fn default_tools() -> ToolRegistry {
let mut registry = ToolRegistry::new();
registry.register(CalculatorTool::new());
registry.register(FileReadTool::new());
registry.register(FileWriteTool::new());
registry.register(FileEditTool::new());
registry.register(BashTool::new());
registry.register(HttpRequestTool::new(
vec!["*".to_string()], // 允许所有域名,实际使用时建议限制
1_000_000, // max_response_size
30, // timeout_secs
false, // allow_private_hosts
));
registry.register(WebFetchTool::new(50_000, 30)); // max_chars, timeout_secs
registry
}
@ -139,6 +161,7 @@ impl SessionManager {
false
}
} else {
#[cfg(debug_assertions)]
tracing::debug!(channel = %channel_name, "Creating new session");
true
};
@ -184,13 +207,21 @@ impl SessionManager {
_sender_id: &str,
chat_id: &str,
content: &str,
media: Vec<crate::bus::MediaItem>,
) -> Result<String, AgentError> {
#[cfg(debug_assertions)]
{
tracing::debug!(
channel = %channel_name,
chat_id = %chat_id,
content_len = content.len(),
media_count = %media.len(),
"Routing message to agent"
);
for (i, m) in media.iter().enumerate() {
tracing::debug!(media_index = i, media_type = %m.media_type, path = %m.path, "Media in handle_message");
}
}
// 确保 session 存在(可能需要重建)
self.ensure_session(channel_name).await?;
@ -209,7 +240,14 @@ impl SessionManager {
let mut session_guard = session.lock().await;
// 添加用户消息到历史
if media.is_empty() {
session_guard.add_user_message(chat_id, content);
} else {
let media_refs: Vec<String> = media.iter().map(|m| m.path.clone()).collect();
#[cfg(debug_assertions)]
tracing::debug!(media_count = %media.len(), media_refs = ?media_refs, "Adding user message with media");
session_guard.add_user_message_with_media(chat_id, content, media_refs);
}
// 获取完整历史
let history = session_guard.get_or_create_history(chat_id).clone();
@ -224,6 +262,7 @@ impl SessionManager {
response
};
#[cfg(debug_assertions)]
tracing::debug!(
channel = %channel_name,
chat_id = %chat_id,

View File

@ -62,6 +62,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
while let Some(msg) = receiver.recv().await {
if let Ok(text) = serialize_outbound(&msg) {
if ws_sender.send(WsMessage::Text(text.into())).await.is_err() {
#[cfg(debug_assertions)]
tracing::debug!(session_id = %session_id_for_sender, "WebSocket send error");
break;
}
@ -91,6 +92,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
}
}
Ok(WsMessage::Close(_)) | Err(_) => {
#[cfg(debug_assertions)]
tracing::debug!(session_id = %session_id, "WebSocket closed");
break;
}
@ -145,6 +147,7 @@ async fn handle_inbound(session: &Arc<Mutex<Session>>, inbound: WsInbound) {
match agent.process(history).await {
Ok(response) => {
#[cfg(debug_assertions)]
tracing::debug!(chat_id = %chat_id, "Agent response sent");
// 添加助手响应到历史
session_guard.add_assistant_message(&chat_id, response.clone());

View File

@ -3,9 +3,55 @@ use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::bus::message::ContentBlock;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
use super::traits::Usage;
fn serialize_content_blocks<S>(blocks: &[serde_json::Value], serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&serde_json::to_string(blocks).unwrap_or_else(|_| "[]".to_string()))
}
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> {
blocks.iter().map(|b| match b {
ContentBlock::Text { text } => {
serde_json::json!({ "type": "text", "text": text })
}
ContentBlock::ImageUrl { image_url } => {
convert_image_url_to_anthropic(&image_url.url)
}
}).collect()
}
fn convert_image_url_to_anthropic(url: &str) -> serde_json::Value {
// data:image/png;base64,... -> Anthropic image block
if let Some(caps) = regex::Regex::new(r"data:(image/\w+);base64,(.+)")
.ok()
.and_then(|re| re.captures(url))
{
let media_type = caps.get(1).map(|m| m.as_str()).unwrap_or("image/png");
let data = caps.get(2).map(|d| d.as_str()).unwrap_or("");
return serde_json::json!({
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": data
}
});
}
// Regular URL -> Anthropic image block with url source
serde_json::json!({
"type": "image",
"source": {
"type": "url",
"url": url
}
})
}
pub struct AnthropicProvider {
client: Client,
name: String,
@ -58,7 +104,8 @@ struct AnthropicRequest {
#[derive(Serialize)]
struct AnthropicMessage {
role: String,
content: String,
#[serde(serialize_with = "serialize_content_blocks")]
content: Vec<serde_json::Value>,
}
#[derive(Serialize)]
@ -122,7 +169,7 @@ impl LLMProvider for AnthropicProvider {
.iter()
.map(|m| AnthropicMessage {
role: m.role.clone(),
content: m.content.clone(),
content: convert_content_blocks(&m.content),
})
.collect(),
max_tokens,

View File

@ -1,12 +1,27 @@
use async_trait::async_trait;
use reqwest::Client;
use serde::Deserialize;
use serde_json::json;
use serde_json::{json, Value};
use std::collections::HashMap;
use crate::bus::message::ContentBlock;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
use super::traits::Usage;
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
if blocks.len() == 1 {
if let ContentBlock::Text { text } = &blocks[0] {
return Value::String(text.clone());
}
}
Value::Array(blocks.iter().map(|b| match b {
ContentBlock::Text { text } => json!({ "type": "text", "text": text }),
ContentBlock::ImageUrl { image_url } => {
json!({ "type": "image_url", "image_url": { "url": image_url.url } })
}
}).collect())
}
pub struct OpenAIProvider {
client: Client,
name: String,
@ -107,14 +122,14 @@ impl LLMProvider for OpenAIProvider {
if m.role == "tool" {
json!({
"role": m.role,
"content": m.content,
"content": convert_content_blocks(&m.content),
"tool_call_id": m.tool_call_id,
"name": m.name,
})
} else {
json!({
"role": m.role,
"content": m.content
"content": convert_content_blocks(&m.content)
})
}
}).collect::<Vec<_>>(),
@ -131,6 +146,30 @@ impl LLMProvider for OpenAIProvider {
body["tools"] = json!(tools);
}
// Debug: Log LLM request summary (only in debug builds)
#[cfg(debug_assertions)]
{
// Log messages summary
let msg_count = body["messages"].as_array().map(|a| a.len()).unwrap_or(0);
tracing::debug!(msg_count = msg_count, "LLM request messages count");
// Log first 20 bytes of base64 images (don't log full base64)
if let Some(msgs) = body["messages"].as_array() {
for (i, msg) in msgs.iter().enumerate() {
if let Some(content) = msg.get("content").and_then(|c| c.as_array()) {
for (j, item) in content.iter().enumerate() {
if item.get("type").and_then(|t| t.as_str()) == Some("image_url") {
if let Some(url_str) = item.get("image_url").and_then(|u| u.get("url")).and_then(|v| v.as_str()) {
let prefix: String = url_str.chars().take(20).collect();
tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)");
}
}
}
}
}
}
}
let mut req_builder = self
.client
.post(&url)
@ -146,6 +185,13 @@ impl LLMProvider for OpenAIProvider {
let status = resp.status();
let text = resp.text().await?;
// Debug: Log LLM response (only in debug builds)
#[cfg(debug_assertions)]
{
let resp_preview: String = text.chars().take(100).collect();
tracing::debug!(status = %status, response_preview = %resp_preview, response_len = %text.len(), "LLM response (first 100 chars shown)");
}
if !status.is_success() {
return Err(format!("API error {}: {}", status, text).into());
}

View File

@ -1,16 +1,64 @@
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::bus::message::ContentBlock;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: String,
pub content: Vec<ContentBlock>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}
impl Message {
pub fn user(content: impl Into<String>) -> Self {
Self {
role: "user".to_string(),
content: vec![ContentBlock::text(content)],
tool_call_id: None,
name: None,
}
}
pub fn user_with_blocks(content: Vec<ContentBlock>) -> Self {
Self {
role: "user".to_string(),
content,
tool_call_id: None,
name: None,
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: "assistant".to_string(),
content: vec![ContentBlock::text(content)],
tool_call_id: None,
name: None,
}
}
pub fn system(content: impl Into<String>) -> Self {
Self {
role: "system".to_string(),
content: vec![ContentBlock::text(content)],
tool_call_id: None,
name: None,
}
}
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self {
Self {
role: "tool".to_string(),
content: vec![ContentBlock::text(content)],
tool_call_id: Some(tool_call_id.into()),
name: Some(tool_name.into()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
#[serde(rename = "type")]

315
src/tools/bash.rs Normal file
View File

@ -0,0 +1,315 @@
use std::path::Path;
use std::process::Stdio;
use std::time::Duration;
use async_trait::async_trait;
use serde_json::json;
use tokio::io::AsyncReadExt;
use tokio::process::Command;
use tokio::time::timeout;
use crate::tools::traits::{Tool, ToolResult};
const MAX_TIMEOUT_SECS: u64 = 600;
const MAX_OUTPUT_CHARS: usize = 50_000;
pub struct BashTool {
timeout_secs: u64,
working_dir: Option<String>,
deny_patterns: Vec<String>,
}
impl BashTool {
pub fn new() -> Self {
Self {
timeout_secs: 60,
working_dir: None,
deny_patterns: vec![
r"\brm\s+-[rf]{1,2}\b".to_string(),
r"\bdel\s+/[fq]\b".to_string(),
r"\brmdir\s+/s\b".to_string(),
r":\(\)\s*\{.*\};\s*:".to_string(),
],
}
}
pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
self.timeout_secs = timeout_secs;
self
}
pub fn with_working_dir(mut self, dir: String) -> Self {
self.working_dir = Some(dir);
self
}
fn guard_command(&self, command: &str) -> Option<String> {
let lower = command.to_lowercase();
for pattern in &self.deny_patterns {
if regex::Regex::new(pattern)
.ok()
.map(|re| re.is_match(&lower))
.unwrap_or(false)
{
return Some(format!(
"Command blocked by safety guard (dangerous pattern: {})",
pattern
));
}
}
None
}
fn truncate_output(&self, output: &str) -> String {
if output.len() <= MAX_OUTPUT_CHARS {
return output.to_string();
}
let half = MAX_OUTPUT_CHARS / 2;
format!(
"{}...\n\n(... {} chars truncated ...)\n\n{}",
&output[..half],
output.len() - MAX_OUTPUT_CHARS,
&output[output.len() - half..]
)
}
}
impl Default for BashTool {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Tool for BashTool {
fn name(&self) -> &str {
"bash"
}
fn description(&self) -> &str {
"Execute a bash shell command and return its output. Use with caution."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "The shell command to execute"
},
"timeout": {
"type": "integer",
"description": format!("Timeout in seconds (default {}, max {})", self.timeout_secs, MAX_TIMEOUT_SECS),
"minimum": 1,
"maximum": MAX_TIMEOUT_SECS
}
},
"required": ["command"]
})
}
fn exclusive(&self) -> bool {
true // Shell commands should not run concurrently
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let command = match args.get("command").and_then(|v| v.as_str()) {
Some(c) => c,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: command".to_string()),
});
}
};
// Safety check
if let Some(error) = self.guard_command(command) {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(error),
});
}
let timeout_secs = args
.get("timeout")
.and_then(|v| v.as_u64())
.unwrap_or(self.timeout_secs)
.min(MAX_TIMEOUT_SECS);
let cwd = self
.working_dir
.as_ref()
.map(|d| Path::new(d))
.unwrap_or_else(|| Path::new("."));
let result = timeout(
Duration::from_secs(timeout_secs),
self.run_command(command, cwd),
)
.await;
match result {
Ok(Ok(output)) => Ok(ToolResult {
success: true,
output,
error: None,
}),
Ok(Err(e)) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
}),
Err(_) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Command timed out after {} seconds",
timeout_secs
)),
}),
}
}
}
impl BashTool {
async fn run_command(&self, command: &str, cwd: &Path) -> Result<String, String> {
let mut cmd = Command::new("bash");
cmd.args(["-c", command])
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.current_dir(cwd);
let mut child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?;
let mut stdout = Vec::new();
let mut stderr = Vec::new();
if let Some(ref mut out) = child.stdout {
out.read_to_end(&mut stdout)
.await
.map_err(|e| format!("Failed to read stdout: {}", e))?;
}
if let Some(ref mut err) = child.stderr {
err.read_to_end(&mut stderr)
.await
.map_err(|e| format!("Failed to read stderr: {}", e))?;
}
let status = child
.wait()
.await
.map_err(|e| format!("Failed to wait: {}", e))?;
let mut output = String::new();
if !stdout.is_empty() {
let stdout_str = String::from_utf8_lossy(&stdout);
output.push_str(&stdout_str);
}
if !stderr.is_empty() {
let stderr_str = String::from_utf8_lossy(&stderr);
if !stderr_str.trim().is_empty() {
if !output.is_empty() {
output.push_str("\n");
}
output.push_str("STDERR:\n");
output.push_str(&stderr_str);
}
}
output.push_str(&format!("\nExit code: {}", status.code().unwrap_or(-1)));
Ok(self.truncate_output(&output))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_simple_command() {
let tool = BashTool::new();
let result = tool
.execute(json!({ "command": "echo 'Hello World'" }))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("Hello World"));
}
#[tokio::test]
async fn test_pwd_command() {
let tool = BashTool::new();
let result = tool
.execute(json!({ "command": "pwd" }))
.await
.unwrap();
assert!(result.success);
}
#[tokio::test]
async fn test_ls_command() {
let tool = BashTool::new();
let result = tool.execute(json!({ "command": "ls -la /tmp" })).await.unwrap();
assert!(result.success);
}
#[tokio::test]
async fn test_dangerous_rm() {
let tool = BashTool::new();
let result = tool
.execute(json!({ "command": "rm -rf /" }))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("blocked"));
}
#[tokio::test]
async fn test_dangerous_fork_bomb() {
let tool = BashTool::new();
let result = tool
.execute(json!({ "command": ":(){ :|:& };:" }))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("blocked"));
}
#[tokio::test]
async fn test_missing_command() {
let tool = BashTool::new();
let result = tool.execute(json!({})).await.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("command"));
}
#[tokio::test]
async fn test_timeout() {
let tool = BashTool::new();
let result = tool
.execute(json!({
"command": "sleep 10",
"timeout": 1
}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("timed out"));
}
}

View File

@ -92,6 +92,10 @@ impl Tool for CalculatorTool {
})
}
fn read_only(&self) -> bool {
true
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let function = match args.get("function").and_then(|v| v.as_str()) {
Some(f) => f,

381
src/tools/file_edit.rs Normal file
View File

@ -0,0 +1,381 @@
use std::path::Path;
use async_trait::async_trait;
use serde_json::json;
use crate::tools::traits::{Tool, ToolResult};
pub struct FileEditTool {
allowed_dir: Option<String>,
}
impl FileEditTool {
pub fn new() -> Self {
Self { allowed_dir: None }
}
pub fn with_allowed_dir(dir: String) -> Self {
Self {
allowed_dir: Some(dir),
}
}
fn resolve_path(&self, path: &str) -> Result<std::path::PathBuf, String> {
let p = Path::new(path);
let resolved = if p.is_absolute() {
p.to_path_buf()
} else {
std::env::current_dir()
.map_err(|e| format!("Failed to get current directory: {}", e))?
.join(p)
};
// Check directory restriction
if let Some(ref allowed) = self.allowed_dir {
let allowed_path = Path::new(allowed);
if !resolved.starts_with(allowed_path) {
return Err(format!(
"Path '{}' is outside allowed directory '{}'",
path, allowed
));
}
}
Ok(resolved)
}
fn find_match(&self, content: &str, old_text: &str) -> Option<(String, usize)> {
// Try exact match first
if content.contains(old_text) {
let count = content.matches(old_text).count();
return Some((old_text.to_string(), count));
}
// Try line-based matching for minor differences
let old_lines: Vec<&str> = old_text.lines().collect();
if old_lines.is_empty() {
return None;
}
let content_lines: Vec<&str> = content.lines().collect();
for i in 0..content_lines.len().saturating_sub(old_lines.len()) {
let window = &content_lines[i..i + old_lines.len()];
let stripped_old: Vec<&str> = old_lines.iter().map(|l| l.trim()).collect();
let stripped_window: Vec<&str> = window.iter().map(|l| l.trim()).collect();
if stripped_old == stripped_window {
let matched_text = window.join("\n");
return Some((matched_text, 1));
}
}
None
}
}
impl Default for FileEditTool {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Tool for FileEditTool {
fn name(&self) -> &str {
"file_edit"
}
fn description(&self) -> &str {
"Edit a file by replacing old_text with new_text. Supports minor whitespace differences. Set replace_all=true to replace every occurrence."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "The file path to edit"
},
"old_text": {
"type": "string",
"description": "The text to find and replace"
},
"new_text": {
"type": "string",
"description": "The text to replace with"
},
"replace_all": {
"type": "boolean",
"description": "Replace all occurrences (default false)",
"default": false
}
},
"required": ["path", "old_text", "new_text"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let path = match args.get("path").and_then(|v| v.as_str()) {
Some(p) => p,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: path".to_string()),
});
}
};
let old_text = match args.get("old_text").and_then(|v| v.as_str()) {
Some(t) => t,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: old_text".to_string()),
});
}
};
let new_text = match args.get("new_text").and_then(|v| v.as_str()) {
Some(t) => t,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: new_text".to_string()),
});
}
};
let replace_all = args
.get("replace_all")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let resolved = match self.resolve_path(path) {
Ok(p) => p,
Err(e) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
});
}
};
if !resolved.exists() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("File not found: {}", path)),
});
}
if !resolved.is_file() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Not a file: {}", path)),
});
}
// Read file content
let content = match std::fs::read_to_string(&resolved) {
Ok(c) => c,
Err(e) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Failed to read file: {}", e)),
});
}
};
// Detect line ending style
let uses_crlf = content.contains("\r\n");
let norm_content = content.replace("\r\n", "\n");
let norm_old = old_text.replace("\r\n", "\n");
// Find match
let (matched_text, count) = match self.find_match(&norm_content, &norm_old) {
Some(m) => m,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"old_text not found in {}. Verify the file content.",
path
)),
});
}
};
// Warn if multiple matches but replace_all is false
if count > 1 && !replace_all {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"old_text appears {} times. Provide more context to make it unique, or set replace_all=true.",
count
)),
});
}
// Perform replacement
let norm_new = new_text.replace("\r\n", "\n");
let new_content = if replace_all {
norm_content.replace(&matched_text, &norm_new)
} else {
norm_content.replacen(&matched_text, &norm_new, 1)
};
// Restore line endings if needed
let final_content = if uses_crlf {
new_content.replace("\n", "\r\n")
} else {
new_content
};
// Write back
match std::fs::write(&resolved, &final_content) {
Ok(_) => {
let replacements = if replace_all { count } else { 1 };
Ok(ToolResult {
success: true,
output: format!(
"Successfully edited {} ({} replacement{} made)",
resolved.display(),
replacements,
if replacements == 1 { "" } else { "s" }
),
error: None,
})
}
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Failed to write file: {}", e)),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::NamedTempFile;
use std::io::Write;
#[tokio::test]
async fn test_edit_simple() {
let mut file = NamedTempFile::new().unwrap();
writeln!(file, "Hello World").unwrap();
writeln!(file, "Test line").unwrap();
let tool = FileEditTool::new();
let result = tool
.execute(json!({
"path": file.path().to_str().unwrap(),
"old_text": "Hello World",
"new_text": "Hello Universe"
}))
.await
.unwrap();
assert!(result.success);
let content = std::fs::read_to_string(file.path()).unwrap();
assert!(content.contains("Hello Universe"));
assert!(!content.contains("Hello World"));
}
#[tokio::test]
async fn test_edit_replace_all() {
let mut file = NamedTempFile::new().unwrap();
writeln!(file, "foo bar").unwrap();
writeln!(file, "foo baz").unwrap();
writeln!(file, "foo qux").unwrap();
let tool = FileEditTool::new();
let result = tool
.execute(json!({
"path": file.path().to_str().unwrap(),
"old_text": "foo",
"new_text": "bar",
"replace_all": true
}))
.await
.unwrap();
assert!(result.success);
let content = std::fs::read_to_string(file.path()).unwrap();
assert!(content.contains("bar bar"));
assert!(content.contains("bar baz"));
assert!(content.contains("bar qux"));
}
#[tokio::test]
async fn test_edit_file_not_found() {
let tool = FileEditTool::new();
let result = tool
.execute(json!({
"path": "/nonexistent/file.txt",
"old_text": "old",
"new_text": "new"
}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("not found"));
}
#[tokio::test]
async fn test_edit_old_text_not_found() {
let mut file = NamedTempFile::new().unwrap();
writeln!(file, "Hello World").unwrap();
let tool = FileEditTool::new();
let result = tool
.execute(json!({
"path": file.path().to_str().unwrap(),
"old_text": "NonExistent",
"new_text": "New"
}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("not found"));
}
#[tokio::test]
async fn test_edit_multiline() {
let mut file = NamedTempFile::new().unwrap();
writeln!(file, "Line 1").unwrap();
writeln!(file, "Line 2").unwrap();
writeln!(file, "Line 3").unwrap();
let tool = FileEditTool::new();
let result = tool
.execute(json!({
"path": file.path().to_str().unwrap(),
"old_text": "Line 1\nLine 2",
"new_text": "New Line 1\nNew Line 2"
}))
.await
.unwrap();
assert!(result.success);
let content = std::fs::read_to_string(file.path()).unwrap();
assert!(content.contains("New Line 1"));
assert!(content.contains("New Line 2"));
}
}

321
src/tools/file_read.rs Normal file
View File

@ -0,0 +1,321 @@
use std::io::Read;
use std::path::Path;
use async_trait::async_trait;
use serde_json::json;
use crate::bus::message::ContentBlock;
use crate::tools::traits::{Tool, ToolResult};
const MAX_CHARS: usize = 128_000;
const DEFAULT_LIMIT: usize = 2000;
pub struct FileReadTool {
allowed_dir: Option<String>,
}
impl FileReadTool {
pub fn new() -> Self {
Self { allowed_dir: None }
}
pub fn with_allowed_dir(dir: String) -> Self {
Self {
allowed_dir: Some(dir),
}
}
fn resolve_path(&self, path: &str) -> Result<std::path::PathBuf, String> {
let p = Path::new(path);
let resolved = if p.is_absolute() {
p.to_path_buf()
} else {
std::env::current_dir()
.map_err(|e| format!("Failed to get current directory: {}", e))?
.join(p)
};
// Check directory restriction
if let Some(ref allowed) = self.allowed_dir {
let allowed_path = Path::new(allowed);
if !resolved.starts_with(allowed_path) {
return Err(format!(
"Path '{}' is outside allowed directory '{}'",
path, allowed
));
}
}
Ok(resolved)
}
}
impl Default for FileReadTool {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Tool for FileReadTool {
fn name(&self) -> &str {
"file_read"
}
fn description(&self) -> &str {
"Read the contents of a file. Returns numbered lines. Use offset and limit to paginate through large files."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "The file path to read"
},
"offset": {
"type": "integer",
"description": "Line number to start reading from (1-indexed, default 1)",
"minimum": 1
},
"limit": {
"type": "integer",
"description": "Maximum number of lines to read (default 2000)",
"minimum": 1
}
},
"required": ["path"]
})
}
fn read_only(&self) -> bool {
true
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let path = match args.get("path").and_then(|v| v.as_str()) {
Some(p) => p,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: path".to_string()),
});
}
};
let offset = args
.get("offset")
.and_then(|v| v.as_u64())
.map(|v| v as usize)
.unwrap_or(1);
let limit = args
.get("limit")
.and_then(|v| v.as_u64())
.map(|v| v as usize)
.unwrap_or(DEFAULT_LIMIT);
let resolved = match self.resolve_path(path) {
Ok(p) => p,
Err(e) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
});
}
};
if !resolved.exists() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("File not found: {}", path)),
});
}
if !resolved.is_file() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Not a file: {}", path)),
});
}
// Try to read as text
match std::fs::read_to_string(&resolved) {
Ok(content) => {
let all_lines: Vec<&str> = content.lines().collect();
let total = all_lines.len();
if offset < 1 {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("offset must be at least 1, got {}", offset)),
});
}
if offset > total {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"offset {} is beyond end of file ({} lines)",
offset, total
)),
});
}
let start = offset - 1;
let end = std::cmp::min(start + limit, total);
let lines: Vec<String> = all_lines[start..end]
.iter()
.enumerate()
.map(|(i, line)| format!("{}| {}", start + i + 1, line))
.collect();
let mut result = lines.join("\n");
// Truncate if too long
if result.len() > MAX_CHARS {
let mut truncated_chars = 0;
let mut end_idx = 0;
for (i, line) in lines.iter().enumerate() {
truncated_chars += line.len() + 1;
if truncated_chars > MAX_CHARS {
end_idx = i;
break;
}
end_idx = i + 1;
}
result = lines[..end_idx].join("\n");
result.push_str(&format!(
"\n\n... ({} chars truncated) ...",
result.len() - MAX_CHARS
));
}
if end < total {
result.push_str(&format!(
"\n\n(Showing lines {}-{} of {}. Use offset={} to continue.)",
offset,
end,
total,
end + 1
));
} else {
result.push_str(&format!("\n\n(End of file — {} lines total)", total));
}
Ok(ToolResult {
success: true,
output: result,
error: None,
})
}
Err(e) => {
// Try to read as binary and encode as base64
match std::fs::read(&resolved) {
Ok(bytes) => {
use base64::{engine::general_purpose::STANDARD, Engine};
let encoded = STANDARD.encode(&bytes);
let mime = mime_guess::from_path(&resolved)
.first_or_octet_stream()
.to_string();
Ok(ToolResult {
success: true,
output: format!(
"(Binary file: {}, {} bytes, base64 encoded)\n{}",
mime,
bytes.len(),
encoded
),
error: None,
})
}
Err(_) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Failed to read file: {}", e)),
}),
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::NamedTempFile;
use std::io::Write;
#[tokio::test]
async fn test_read_simple_file() {
let mut file = NamedTempFile::new().unwrap();
writeln!(file, "Line 1").unwrap();
writeln!(file, "Line 2").unwrap();
writeln!(file, "Line 3").unwrap();
let tool = FileReadTool::new();
let result = tool
.execute(json!({ "path": file.path().to_str().unwrap() }))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("Line 1"));
assert!(result.output.contains("Line 2"));
assert!(result.output.contains("Line 3"));
}
#[tokio::test]
async fn test_read_with_offset_limit() {
let mut file = NamedTempFile::new().unwrap();
for i in 1..=10 {
writeln!(file, "Line {}", i).unwrap();
}
let tool = FileReadTool::new();
let result = tool
.execute(json!({
"path": file.path().to_str().unwrap(),
"offset": 3,
"limit": 2
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("Line 3"));
assert!(result.output.contains("Line 4"));
assert!(!result.output.contains("Line 2"));
}
#[tokio::test]
async fn test_file_not_found() {
let tool = FileReadTool::new();
let result = tool
.execute(json!({ "path": "/nonexistent/file.txt" }))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("not found"));
}
#[tokio::test]
async fn test_is_directory() {
let tool = FileReadTool::new();
let result = tool
.execute(json!({ "path": "." }))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("Not a file"));
}
}

242
src/tools/file_write.rs Normal file
View File

@ -0,0 +1,242 @@
use std::path::Path;
use async_trait::async_trait;
use serde_json::json;
use crate::tools::traits::{Tool, ToolResult};
pub struct FileWriteTool {
allowed_dir: Option<String>,
}
impl FileWriteTool {
pub fn new() -> Self {
Self { allowed_dir: None }
}
pub fn with_allowed_dir(dir: String) -> Self {
Self {
allowed_dir: Some(dir),
}
}
fn resolve_path(&self, path: &str) -> Result<std::path::PathBuf, String> {
let p = Path::new(path);
let resolved = if p.is_absolute() {
p.to_path_buf()
} else {
std::env::current_dir()
.map_err(|e| format!("Failed to get current directory: {}", e))?
.join(p)
};
// Check directory restriction
if let Some(ref allowed) = self.allowed_dir {
let allowed_path = Path::new(allowed);
if !resolved.starts_with(allowed_path) {
return Err(format!(
"Path '{}' is outside allowed directory '{}'",
path, allowed
));
}
}
Ok(resolved)
}
}
impl Default for FileWriteTool {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Tool for FileWriteTool {
fn name(&self) -> &str {
"file_write"
}
fn description(&self) -> &str {
"Write content to a file at the given path. Creates parent directories if needed."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "The file path to write to"
},
"content": {
"type": "string",
"description": "The content to write"
}
},
"required": ["path", "content"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let path = match args.get("path").and_then(|v| v.as_str()) {
Some(p) => p,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: path".to_string()),
});
}
};
let content = match args.get("content").and_then(|v| v.as_str()) {
Some(c) => c,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: content".to_string()),
});
}
};
let resolved = match self.resolve_path(path) {
Ok(p) => p,
Err(e) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
});
}
};
// Create parent directories if needed
if let Some(parent) = resolved.parent() {
if !parent.exists() {
if let Err(e) = std::fs::create_dir_all(parent) {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Failed to create parent directory: {}", e)),
});
}
}
}
match std::fs::write(&resolved, content) {
Ok(_) => Ok(ToolResult {
success: true,
output: format!(
"Successfully wrote {} bytes to {}",
content.len(),
resolved.display()
),
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Failed to write file: {}", e)),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[tokio::test]
async fn test_write_simple_file() {
let temp_dir = TempDir::new().unwrap();
let file_path = temp_dir.path().join("test.txt");
let tool = FileWriteTool::new();
let result = tool
.execute(json!({
"path": file_path.to_str().unwrap(),
"content": "Hello, World!"
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("Successfully wrote"));
// Verify content
let read_content = std::fs::read_to_string(&file_path).unwrap();
assert_eq!(read_content, "Hello, World!");
}
#[tokio::test]
async fn test_write_creates_parent_dirs() {
let temp_dir = TempDir::new().unwrap();
let file_path = temp_dir.path().join("subdir1/subdir2/test.txt");
let tool = FileWriteTool::new();
let result = tool
.execute(json!({
"path": file_path.to_str().unwrap(),
"content": "Nested content"
}))
.await
.unwrap();
assert!(result.success);
// Verify content
let read_content = std::fs::read_to_string(&file_path).unwrap();
assert_eq!(read_content, "Nested content");
}
#[tokio::test]
async fn test_write_missing_path() {
let tool = FileWriteTool::new();
let result = tool
.execute(json!({ "content": "Hello" }))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("path"));
}
#[tokio::test]
async fn test_write_missing_content() {
let tool = FileWriteTool::new();
let result = tool
.execute(json!({ "path": "/tmp/test.txt" }))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("content"));
}
#[tokio::test]
async fn test_overwrite_file() {
let temp_dir = TempDir::new().unwrap();
let file_path = temp_dir.path().join("test.txt");
// Write initial content
std::fs::write(&file_path, "Initial content").unwrap();
let tool = FileWriteTool::new();
let result = tool
.execute(json!({
"path": file_path.to_str().unwrap(),
"content": "New content"
}))
.await
.unwrap();
assert!(result.success);
// Verify overwritten
let read_content = std::fs::read_to_string(&file_path).unwrap();
assert_eq!(read_content, "New content");
}
}

444
src/tools/http_request.rs Normal file
View File

@ -0,0 +1,444 @@
use std::time::Duration;
use async_trait::async_trait;
use reqwest::header::HeaderMap;
use serde_json::json;
use crate::tools::traits::{Tool, ToolResult};
pub struct HttpRequestTool {
allowed_domains: Vec<String>,
max_response_size: usize,
timeout_secs: u64,
allow_private_hosts: bool,
}
impl HttpRequestTool {
pub fn new(
allowed_domains: Vec<String>,
max_response_size: usize,
timeout_secs: u64,
allow_private_hosts: bool,
) -> Self {
Self {
allowed_domains: normalize_domains(allowed_domains),
max_response_size,
timeout_secs,
allow_private_hosts,
}
}
fn validate_url(&self, url: &str) -> Result<String, String> {
let url = url.trim();
if url.is_empty() {
return Err("URL cannot be empty".to_string());
}
if url.chars().any(char::is_whitespace) {
return Err("URL cannot contain whitespace".to_string());
}
if !url.starts_with("http://") && !url.starts_with("https://") {
return Err("Only http:// and https:// URLs are allowed".to_string());
}
let host = extract_host(url)?;
if !self.allow_private_hosts && is_private_host(&host) {
return Err(format!("Blocked local/private host: {}", host));
}
if !host_matches_allowlist(&host, &self.allowed_domains) {
return Err(format!(
"Host '{}' is not in allowed_domains",
host
));
}
Ok(url.to_string())
}
fn validate_method(&self, method: &str) -> Result<reqwest::Method, String> {
match method.to_uppercase().as_str() {
"GET" => Ok(reqwest::Method::GET),
"POST" => Ok(reqwest::Method::POST),
"PUT" => Ok(reqwest::Method::PUT),
"DELETE" => Ok(reqwest::Method::DELETE),
"PATCH" => Ok(reqwest::Method::PATCH),
_ => Err(format!(
"Unsupported HTTP method: {}. Supported: GET, POST, PUT, DELETE, PATCH",
method
)),
}
}
fn parse_headers(&self, headers: &serde_json::Value) -> HeaderMap {
let mut header_map = HeaderMap::new();
if let Some(obj) = headers.as_object() {
for (key, value) in obj {
if let Some(str_val) = value.as_str() {
if let Ok(name) = reqwest::header::HeaderName::from_bytes(key.as_bytes()) {
if let Ok(val) =
reqwest::header::HeaderValue::from_str(str_val)
{
header_map.insert(name, val);
}
}
}
}
}
header_map
}
fn truncate_response(&self, text: &str) -> String {
if self.max_response_size == 0 {
return text.to_string();
}
if text.len() > self.max_response_size {
format!(
"{}\n\n... [Response truncated due to size limit] ...",
&text[..self.max_response_size]
)
} else {
text.to_string()
}
}
}
fn normalize_domains(domains: Vec<String>) -> Vec<String> {
let mut normalized: Vec<String> = domains
.into_iter()
.filter_map(|d| normalize_domain(&d))
.collect();
normalized.sort_unstable();
normalized.dedup();
normalized
}
fn normalize_domain(raw: &str) -> Option<String> {
let mut d = raw.trim().to_lowercase();
if d.is_empty() {
return None;
}
if let Some(stripped) = d.strip_prefix("https://") {
d = stripped.to_string();
} else if let Some(stripped) = d.strip_prefix("http://") {
d = stripped.to_string();
}
if let Some((host, _)) = d.split_once('/') {
d = host.to_string();
}
d = d.trim_start_matches('.').trim_end_matches('.').to_string();
if let Some((host, _)) = d.split_once(':') {
d = host.to_string();
}
if d.is_empty() || d.chars().any(char::is_whitespace) {
return None;
}
Some(d)
}
fn extract_host(url: &str) -> Result<String, String> {
let rest = url
.strip_prefix("http://")
.or_else(|| url.strip_prefix("https://"))
.ok_or_else(|| "Only http:// and https:// URLs are allowed".to_string())?;
let authority = rest
.split(['/', '?', '#'])
.next()
.ok_or_else(|| "Invalid URL".to_string())?;
if authority.is_empty() {
return Err("URL must include a host".to_string());
}
if authority.contains('@') {
return Err("URL userinfo is not allowed".to_string());
}
if authority.starts_with('[') {
return Err("IPv6 hosts are not supported".to_string());
}
let host = authority
.split(':')
.next()
.unwrap_or_default()
.trim()
.trim_end_matches('.')
.to_lowercase();
if host.is_empty() {
return Err("URL must include a valid host".to_string());
}
Ok(host)
}
fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool {
if allowed_domains.iter().any(|domain| domain == "*") {
return true;
}
allowed_domains.iter().any(|domain| {
host == domain
|| host.strip_suffix(domain).is_some_and(|prefix| prefix.ends_with('.'))
})
}
fn is_private_host(host: &str) -> bool {
// Check localhost
if host == "localhost" || host.ends_with(".localhost") {
return true;
}
// Check .local TLD
if host.rsplit('.').next().is_some_and(|label| label == "local") {
return true;
}
// Try to parse as IP
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
return is_private_ip(&ip);
}
false
}
fn is_private_ip(ip: &std::net::IpAddr) -> bool {
match ip {
std::net::IpAddr::V4(v4) => {
v4.is_loopback()
|| v4.is_private()
|| v4.is_link_local()
|| v4.is_unspecified()
|| v4.is_broadcast()
|| v4.is_multicast()
}
std::net::IpAddr::V6(v6) => {
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
}
}
}
#[async_trait]
impl Tool for HttpRequestTool {
fn name(&self) -> &str {
"http_request"
}
fn description(&self) -> &str {
"Make HTTP requests to external APIs. Supports GET, POST, PUT, DELETE, PATCH methods. Security: domain allowlist, no local/private hosts."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "HTTP or HTTPS URL to request"
},
"method": {
"type": "string",
"description": "HTTP method (GET, POST, PUT, DELETE, PATCH)",
"default": "GET"
},
"headers": {
"type": "object",
"description": "Optional HTTP headers as key-value pairs"
},
"body": {
"type": "string",
"description": "Optional request body"
}
},
"required": ["url"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let url = match args.get("url").and_then(|v| v.as_str()) {
Some(u) => u,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: url".to_string()),
});
}
};
let method_str = args
.get("method")
.and_then(|v| v.as_str())
.unwrap_or("GET");
let headers_val = args.get("headers").cloned().unwrap_or(json!({}));
let body = args.get("body").and_then(|v| v.as_str());
let url = match self.validate_url(url) {
Ok(u) => u,
Err(e) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
});
}
};
let method = match self.validate_method(method_str) {
Ok(m) => m,
Err(e) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
});
}
};
let headers = self.parse_headers(&headers_val);
let client = match reqwest::Client::builder()
.timeout(Duration::from_secs(self.timeout_secs))
.build()
{
Ok(c) => c,
Err(e) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Failed to create HTTP client: {}", e)),
});
}
};
let mut request = client.request(method, &url).headers(headers);
if let Some(body_str) = body {
request = request.body(body_str.to_string());
}
match request.send().await {
Ok(response) => {
let status = response.status();
let status_code = status.as_u16();
let response_text = response
.text()
.await
.map(|t| self.truncate_response(&t))
.unwrap_or_else(|_| "[Failed to read response body]".to_string());
let output = format!(
"Status: {} {}\n\nResponse Body:\n{}",
status_code,
status.canonical_reason().unwrap_or("Unknown"),
response_text
);
Ok(ToolResult {
success: status.is_success(),
output,
error: if status.is_client_error() || status.is_server_error() {
Some(format!("HTTP {}", status_code))
} else {
None
},
})
}
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("HTTP request failed: {}", e)),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_tool(domains: Vec<&str>) -> HttpRequestTool {
HttpRequestTool::new(
domains.into_iter().map(String::from).collect(),
1_000_000,
30,
false,
)
}
#[tokio::test]
async fn test_validate_url_success() {
let tool = test_tool(vec!["example.com"]);
let result = tool.validate_url("https://example.com/docs");
assert!(result.is_ok());
}
#[tokio::test]
async fn test_validate_url_rejects_private() {
let tool = test_tool(vec!["example.com"]);
let result = tool.validate_url("https://localhost:8080");
assert!(result.is_err());
assert!(result.unwrap_err().contains("local/private"));
}
#[tokio::test]
async fn test_validate_url_rejects_whitespace() {
let tool = test_tool(vec!["example.com"]);
let result = tool.validate_url("https://example.com/hello world");
assert!(result.is_err());
assert!(result.unwrap_err().contains("whitespace"));
}
#[tokio::test]
async fn test_validate_url_requires_allowlist() {
let tool = HttpRequestTool::new(vec![], 1_000_000, 30, false);
let result = tool.validate_url("https://example.com");
assert!(result.is_err());
assert!(result.unwrap_err().contains("allowed_domains"));
}
#[tokio::test]
async fn test_validate_method() {
let tool = test_tool(vec!["example.com"]);
assert!(tool.validate_method("GET").is_ok());
assert!(tool.validate_method("POST").is_ok());
assert!(tool.validate_method("PUT").is_ok());
assert!(tool.validate_method("DELETE").is_ok());
assert!(tool.validate_method("PATCH").is_ok());
assert!(tool.validate_method("INVALID").is_err());
}
#[tokio::test]
async fn test_blocks_loopback() {
assert!(is_private_host("127.0.0.1"));
assert!(is_private_host("localhost"));
}
#[tokio::test]
async fn test_blocks_private_ranges() {
assert!(is_private_host("10.0.0.1"));
assert!(is_private_host("172.16.0.1"));
assert!(is_private_host("192.168.1.1"));
}
#[tokio::test]
async fn test_blocks_local_tld() {
assert!(is_private_host("service.local"));
}
}

View File

@ -1,7 +1,21 @@
pub mod bash;
pub mod calculator;
pub mod file_edit;
pub mod file_read;
pub mod file_write;
pub mod http_request;
pub mod registry;
pub mod schema;
pub mod traits;
pub mod web_fetch;
pub use bash::BashTool;
pub use calculator::CalculatorTool;
pub use file_edit::FileEditTool;
pub use file_read::FileReadTool;
pub use file_write::FileWriteTool;
pub use http_request::HttpRequestTool;
pub use registry::ToolRegistry;
pub use schema::{CleaningStrategy, SchemaCleanr};
pub use traits::{Tool, ToolResult};
pub use web_fetch::WebFetchTool;

721
src/tools/schema.rs Normal file
View File

@ -0,0 +1,721 @@
//! JSON Schema cleaning and validation for LLM tool-calling compatibility.
//!
//! Different providers support different subsets of JSON Schema. This module
//! normalizes tool schemas to improve cross-provider compatibility while
//! preserving semantic intent.
//!
//! ## What this module does
//!
//! 1. Removes unsupported keywords per provider strategy
//! 2. Resolves local `$ref` entries from `$defs` and `definitions`
//! 3. Flattens literal `anyOf` / `oneOf` unions into `enum`
//! 4. Strips nullable variants from unions and `type` arrays
//! 5. Converts `const` to single-value `enum`
//! 6. Detects circular references and stops recursion safely
use serde_json::{Map, Value, json};
use std::collections::{HashMap, HashSet};
/// Keywords that Gemini rejects for tool schemas.
pub const GEMINI_UNSUPPORTED_KEYWORDS: &[&str] = &[
// Schema composition
"$ref",
"$schema",
"$id",
"$defs",
"definitions",
// Property constraints
"additionalProperties",
"patternProperties",
// String constraints
"minLength",
"maxLength",
"pattern",
"format",
// Number constraints
"minimum",
"maximum",
"multipleOf",
// Array constraints
"minItems",
"maxItems",
"uniqueItems",
// Object constraints
"minProperties",
"maxProperties",
// Non-standard
"examples",
];
/// Keywords that should be preserved during cleaning (metadata).
const SCHEMA_META_KEYWORDS: &[&str] = &["description", "title", "default"];
/// Schema cleaning strategies for different LLM providers.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CleaningStrategy {
/// Gemini (Google AI / Vertex AI) - Most restrictive
Gemini,
/// Anthropic Claude - Moderately permissive
Anthropic,
/// OpenAI GPT - Most permissive
OpenAI,
/// Conservative: Remove only universally unsupported keywords
Conservative,
}
impl CleaningStrategy {
/// Get the list of unsupported keywords for this strategy.
pub fn unsupported_keywords(self) -> &'static [&'static str] {
match self {
Self::Gemini => GEMINI_UNSUPPORTED_KEYWORDS,
Self::Anthropic => &["$ref", "$defs", "definitions"],
Self::OpenAI => &[],
Self::Conservative => &["$ref", "$defs", "definitions", "additionalProperties"],
}
}
}
/// JSON Schema cleaner optimized for LLM tool calling.
pub struct SchemaCleanr;
impl SchemaCleanr {
/// Clean schema for Gemini compatibility (strictest).
pub fn clean_for_gemini(schema: Value) -> Value {
Self::clean(schema, CleaningStrategy::Gemini)
}
/// Clean schema for Anthropic compatibility.
pub fn clean_for_anthropic(schema: Value) -> Value {
Self::clean(schema, CleaningStrategy::Anthropic)
}
/// Clean schema for OpenAI compatibility (most permissive).
pub fn clean_for_openai(schema: Value) -> Value {
Self::clean(schema, CleaningStrategy::OpenAI)
}
/// Clean schema with specified strategy.
pub fn clean(schema: Value, strategy: CleaningStrategy) -> Value {
let defs = if let Some(obj) = schema.as_object() {
Self::extract_defs(obj)
} else {
HashMap::new()
};
Self::clean_with_defs(schema, &defs, strategy, &mut HashSet::new())
}
/// Validate that a schema is suitable for LLM tool calling.
pub fn validate(schema: &Value) -> anyhow::Result<()> {
let obj = schema
.as_object()
.ok_or_else(|| anyhow::anyhow!("Schema must be an object"))?;
if !obj.contains_key("type") {
anyhow::bail!("Schema missing required 'type' field");
}
if let Some(Value::String(t)) = obj.get("type") {
if t == "object" && !obj.contains_key("properties") {
tracing::warn!("Object schema without 'properties' field may cause issues");
}
}
Ok(())
}
/// Extract $defs and definitions into a flat map for reference resolution.
fn extract_defs(obj: &Map<String, Value>) -> HashMap<String, Value> {
let mut defs = HashMap::new();
if let Some(Value::Object(defs_obj)) = obj.get("$defs") {
for (key, value) in defs_obj {
defs.insert(key.clone(), value.clone());
}
}
if let Some(Value::Object(defs_obj)) = obj.get("definitions") {
for (key, value) in defs_obj {
defs.insert(key.clone(), value.clone());
}
}
defs
}
/// Recursively clean a schema value.
fn clean_with_defs(
schema: Value,
defs: &HashMap<String, Value>,
strategy: CleaningStrategy,
ref_stack: &mut HashSet<String>,
) -> Value {
match schema {
Value::Object(obj) => Self::clean_object(obj, defs, strategy, ref_stack),
Value::Array(arr) => Value::Array(
arr.into_iter()
.map(|v| Self::clean_with_defs(v, defs, strategy, ref_stack))
.collect(),
),
other => other,
}
}
/// Clean an object schema.
fn clean_object(
obj: Map<String, Value>,
defs: &HashMap<String, Value>,
strategy: CleaningStrategy,
ref_stack: &mut HashSet<String>,
) -> Value {
// Handle $ref resolution
if let Some(Value::String(ref_value)) = obj.get("$ref") {
return Self::resolve_ref(ref_value, &obj, defs, strategy, ref_stack);
}
// Handle anyOf/oneOf simplification
if obj.contains_key("anyOf") || obj.contains_key("oneOf") {
if let Some(simplified) = Self::try_simplify_union(&obj, defs, strategy, ref_stack) {
return simplified;
}
}
// Build cleaned object
let mut cleaned = Map::new();
let unsupported: HashSet<&str> = strategy.unsupported_keywords().iter().copied().collect();
let has_union = obj.contains_key("anyOf") || obj.contains_key("oneOf");
for (key, value) in obj {
// Skip unsupported keywords
if unsupported.contains(key.as_str()) {
continue;
}
match key.as_str() {
// Convert const to enum
"const" => {
cleaned.insert("enum".to_string(), json!([value]));
}
// Skip type if we have anyOf/oneOf
"type" if has_union => {}
// Handle type arrays (remove null)
"type" if matches!(value, Value::Array(_)) => {
let cleaned_value = Self::clean_type_array(value);
cleaned.insert(key, cleaned_value);
}
// Recursively clean nested schemas
"properties" => {
let cleaned_value = Self::clean_properties(value, defs, strategy, ref_stack);
cleaned.insert(key, cleaned_value);
}
"items" => {
let cleaned_value = Self::clean_with_defs(value, defs, strategy, ref_stack);
cleaned.insert(key, cleaned_value);
}
"anyOf" | "oneOf" | "allOf" => {
let cleaned_value = Self::clean_union(value, defs, strategy, ref_stack);
cleaned.insert(key, cleaned_value);
}
_ => {
let cleaned_value = match value {
Value::Object(_) | Value::Array(_) => {
Self::clean_with_defs(value, defs, strategy, ref_stack)
}
other => other,
};
cleaned.insert(key, cleaned_value);
}
}
}
Value::Object(cleaned)
}
/// Resolve a $ref to its definition.
fn resolve_ref(
ref_value: &str,
obj: &Map<String, Value>,
defs: &HashMap<String, Value>,
strategy: CleaningStrategy,
ref_stack: &mut HashSet<String>,
) -> Value {
// Prevent circular references
if ref_stack.contains(ref_value) {
tracing::warn!("Circular $ref detected: {}", ref_value);
return Self::preserve_meta(obj, Value::Object(Map::new()));
}
if let Some(def_name) = Self::parse_local_ref(ref_value) {
if let Some(definition) = defs.get(def_name.as_str()) {
ref_stack.insert(ref_value.to_string());
let cleaned = Self::clean_with_defs(definition.clone(), defs, strategy, ref_stack);
ref_stack.remove(ref_value);
return Self::preserve_meta(obj, cleaned);
}
}
tracing::warn!("Cannot resolve $ref: {}", ref_value);
Self::preserve_meta(obj, Value::Object(Map::new()))
}
/// Parse a local JSON Pointer ref (#/$defs/Name or #/definitions/Name).
fn parse_local_ref(ref_value: &str) -> Option<String> {
ref_value
.strip_prefix("#/$defs/")
.or_else(|| ref_value.strip_prefix("#/definitions/"))
.map(Self::decode_json_pointer)
}
/// Decode JSON Pointer escaping (`~0` = `~`, `~1` = `/`).
fn decode_json_pointer(segment: &str) -> String {
if !segment.contains('~') {
return segment.to_string();
}
let mut decoded = String::with_capacity(segment.len());
let mut chars = segment.chars().peekable();
while let Some(ch) = chars.next() {
if ch == '~' {
match chars.peek().copied() {
Some('0') => {
chars.next();
decoded.push('~');
}
Some('1') => {
chars.next();
decoded.push('/');
}
_ => decoded.push('~'),
}
} else {
decoded.push(ch);
}
}
decoded
}
/// Try to simplify anyOf/oneOf to a simpler form.
fn try_simplify_union(
obj: &Map<String, Value>,
defs: &HashMap<String, Value>,
strategy: CleaningStrategy,
ref_stack: &mut HashSet<String>,
) -> Option<Value> {
let union_key = if obj.contains_key("anyOf") {
"anyOf"
} else if obj.contains_key("oneOf") {
"oneOf"
} else {
return None;
};
let variants = obj.get(union_key)?.as_array()?;
let cleaned_variants: Vec<Value> = variants
.iter()
.map(|v| Self::clean_with_defs(v.clone(), defs, strategy, ref_stack))
.collect();
// Strip null variants
let non_null: Vec<Value> = cleaned_variants
.into_iter()
.filter(|v| !Self::is_null_schema(v))
.collect();
// If only one variant remains after stripping nulls, return it
if non_null.len() == 1 {
return Some(Self::preserve_meta(obj, non_null[0].clone()));
}
// Try to flatten to enum if all variants are literals
if let Some(enum_value) = Self::try_flatten_literal_union(&non_null) {
return Some(Self::preserve_meta(obj, enum_value));
}
None
}
/// Check if a schema represents null type.
fn is_null_schema(value: &Value) -> bool {
if let Some(obj) = value.as_object() {
if let Some(Value::Null) = obj.get("const") {
return true;
}
if let Some(Value::Array(arr)) = obj.get("enum") {
if arr.len() == 1 && matches!(arr[0], Value::Null) {
return true;
}
}
if let Some(Value::String(t)) = obj.get("type") {
if t == "null" {
return true;
}
}
}
false
}
/// Try to flatten anyOf/oneOf with only literal values to enum.
fn try_flatten_literal_union(variants: &[Value]) -> Option<Value> {
if variants.is_empty() {
return None;
}
let mut all_values = Vec::new();
let mut common_type: Option<String> = None;
for variant in variants {
let obj = variant.as_object()?;
let literal_value = if let Some(const_val) = obj.get("const") {
const_val.clone()
} else if let Some(Value::Array(arr)) = obj.get("enum") {
if arr.len() == 1 {
arr[0].clone()
} else {
return None;
}
} else {
return None;
};
let variant_type = obj.get("type")?.as_str()?;
match &common_type {
None => common_type = Some(variant_type.to_string()),
Some(t) if t != variant_type => return None,
_ => {}
}
all_values.push(literal_value);
}
common_type.map(|t| {
json!({
"type": t,
"enum": all_values
})
})
}
/// Clean type array, removing null.
fn clean_type_array(value: Value) -> Value {
if let Value::Array(types) = value {
let non_null: Vec<Value> = types
.into_iter()
.filter(|v| v.as_str() != Some("null"))
.collect();
match non_null.len() {
0 => Value::String("null".to_string()),
1 => non_null.into_iter().next().unwrap_or(Value::String("null".to_string())),
_ => Value::Array(non_null),
}
} else {
value
}
}
/// Clean properties object.
fn clean_properties(
value: Value,
defs: &HashMap<String, Value>,
strategy: CleaningStrategy,
ref_stack: &mut HashSet<String>,
) -> Value {
if let Value::Object(props) = value {
let cleaned: Map<String, Value> = props
.into_iter()
.map(|(k, v)| (k, Self::clean_with_defs(v, defs, strategy, ref_stack)))
.collect();
Value::Object(cleaned)
} else {
value
}
}
/// Clean union (anyOf/oneOf/allOf).
fn clean_union(
value: Value,
defs: &HashMap<String, Value>,
strategy: CleaningStrategy,
ref_stack: &mut HashSet<String>,
) -> Value {
if let Value::Array(variants) = value {
let cleaned: Vec<Value> = variants
.into_iter()
.map(|v| Self::clean_with_defs(v, defs, strategy, ref_stack))
.collect();
Value::Array(cleaned)
} else {
value
}
}
/// Preserve metadata (description, title, default) from source to target.
fn preserve_meta(source: &Map<String, Value>, mut target: Value) -> Value {
if let Value::Object(target_obj) = &mut target {
for &key in SCHEMA_META_KEYWORDS {
if let Some(value) = source.get(key) {
target_obj.insert(key.to_string(), value.clone());
}
}
}
target
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_remove_unsupported_keywords() {
let schema = json!({
"type": "string",
"minLength": 1,
"maxLength": 100,
"pattern": "^[a-z]+$",
"description": "A lowercase string"
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert_eq!(cleaned["type"], "string");
assert_eq!(cleaned["description"], "A lowercase string");
assert!(cleaned.get("minLength").is_none());
assert!(cleaned.get("maxLength").is_none());
assert!(cleaned.get("pattern").is_none());
}
#[test]
fn test_resolve_ref() {
let schema = json!({
"type": "object",
"properties": {
"age": {
"$ref": "#/$defs/Age"
}
},
"$defs": {
"Age": {
"type": "integer",
"minimum": 0
}
}
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert_eq!(cleaned["properties"]["age"]["type"], "integer");
assert!(cleaned["properties"]["age"].get("minimum").is_none());
assert!(cleaned.get("$defs").is_none());
}
#[test]
fn test_flatten_literal_union() {
let schema = json!({
"anyOf": [
{ "const": "admin", "type": "string" },
{ "const": "user", "type": "string" },
{ "const": "guest", "type": "string" }
]
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert_eq!(cleaned["type"], "string");
assert!(cleaned["enum"].is_array());
let enum_values = cleaned["enum"].as_array().unwrap();
assert_eq!(enum_values.len(), 3);
assert!(enum_values.contains(&json!("admin")));
assert!(enum_values.contains(&json!("user")));
assert!(enum_values.contains(&json!("guest")));
}
#[test]
fn test_strip_null_from_union() {
let schema = json!({
"oneOf": [
{ "type": "string" },
{ "type": "null" }
]
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert_eq!(cleaned["type"], "string");
assert!(cleaned.get("oneOf").is_none());
}
#[test]
fn test_const_to_enum() {
let schema = json!({
"const": "fixed_value",
"description": "A constant"
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert_eq!(cleaned["enum"], json!(["fixed_value"]));
assert_eq!(cleaned["description"], "A constant");
assert!(cleaned.get("const").is_none());
}
#[test]
fn test_preserve_metadata() {
let schema = json!({
"$ref": "#/$defs/Name",
"description": "User's name",
"title": "Name Field",
"default": "Anonymous",
"$defs": {
"Name": {
"type": "string"
}
}
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert_eq!(cleaned["type"], "string");
assert_eq!(cleaned["description"], "User's name");
assert_eq!(cleaned["title"], "Name Field");
assert_eq!(cleaned["default"], "Anonymous");
}
#[test]
fn test_circular_ref_prevention() {
let schema = json!({
"type": "object",
"properties": {
"parent": {
"$ref": "#/$defs/Node"
}
},
"$defs": {
"Node": {
"type": "object",
"properties": {
"child": {
"$ref": "#/$defs/Node"
}
}
}
}
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert_eq!(cleaned["properties"]["parent"]["type"], "object");
}
#[test]
fn test_validate_schema() {
let valid = json!({
"type": "object",
"properties": {
"name": { "type": "string" }
}
});
assert!(SchemaCleanr::validate(&valid).is_ok());
let invalid = json!({
"properties": {
"name": { "type": "string" }
}
});
assert!(SchemaCleanr::validate(&invalid).is_err());
}
#[test]
fn test_strategy_differences() {
let schema = json!({
"type": "string",
"minLength": 1,
"description": "A string field"
});
// Gemini: Most restrictive (removes minLength)
let gemini = SchemaCleanr::clean_for_gemini(schema.clone());
assert!(gemini.get("minLength").is_none());
assert_eq!(gemini["type"], "string");
assert_eq!(gemini["description"], "A string field");
// OpenAI: Most permissive (keeps minLength)
let openai = SchemaCleanr::clean_for_openai(schema.clone());
assert_eq!(openai["minLength"], 1);
assert_eq!(openai["type"], "string");
}
#[test]
fn test_nested_properties() {
let schema = json!({
"type": "object",
"properties": {
"user": {
"type": "object",
"properties": {
"name": {
"type": "string",
"minLength": 1
}
},
"additionalProperties": false
}
}
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert!(
cleaned["properties"]["user"]["properties"]["name"]
.get("minLength")
.is_none()
);
assert!(
cleaned["properties"]["user"]
.get("additionalProperties")
.is_none()
);
}
#[test]
fn test_type_array_null_removal() {
let schema = json!({
"type": ["string", "null"]
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert_eq!(cleaned["type"], "string");
}
#[test]
fn test_skip_type_when_non_simplifiable_union_exists() {
let schema = json!({
"type": "object",
"oneOf": [
{
"type": "object",
"properties": {
"a": { "type": "string" }
}
},
{
"type": "object",
"properties": {
"b": { "type": "number" }
}
}
]
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert!(cleaned.get("type").is_none());
assert!(cleaned.get("oneOf").is_some());
}
}

View File

@ -13,4 +13,19 @@ pub trait Tool: Send + Sync + 'static {
fn description(&self) -> &str;
fn parameters_schema(&self) -> serde_json::Value;
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult>;
/// Whether this tool is side-effect free and safe to parallelize.
fn read_only(&self) -> bool {
false
}
/// Whether this tool can run alongside other concurrency-safe tools.
fn concurrency_safe(&self) -> bool {
self.read_only() && !self.exclusive()
}
/// Whether this tool should run alone even if concurrency is enabled.
fn exclusive(&self) -> bool {
false
}
}

411
src/tools/web_fetch.rs Normal file
View File

@ -0,0 +1,411 @@
use std::time::Duration;
use async_trait::async_trait;
use reqwest::header::HeaderMap;
use serde_json::json;
use crate::tools::traits::{Tool, ToolResult};
pub struct WebFetchTool {
max_response_size: usize,
timeout_secs: u64,
user_agent: String,
}
impl WebFetchTool {
pub fn new(max_response_size: usize, timeout_secs: u64) -> Self {
Self {
max_response_size,
timeout_secs,
user_agent: "Mozilla/5.0 (compatible; Picobot/1.0)".to_string(),
}
}
fn validate_url(&self, url: &str) -> Result<String, String> {
let url = url.trim();
if url.is_empty() {
return Err("URL cannot be empty".to_string());
}
if url.chars().any(char::is_whitespace) {
return Err("URL cannot contain whitespace".to_string());
}
if !url.starts_with("http://") && !url.starts_with("https://") {
return Err("Only http:// and https:// URLs are allowed".to_string());
}
let host = extract_host(url)?;
if is_private_host(&host) {
return Err(format!("Blocked local/private host: {}", host));
}
Ok(url.to_string())
}
fn truncate_response(&self, text: &str) -> String {
if self.max_response_size == 0 {
return text.to_string();
}
if text.len() > self.max_response_size {
format!(
"{}\n\n... [Response truncated due to size limit] ...",
&text[..self.max_response_size]
)
} else {
text.to_string()
}
}
async fn fetch_content(&self, url: &str) -> Result<String, String> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(self.timeout_secs))
.build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
let mut headers = HeaderMap::new();
headers.insert(
reqwest::header::USER_AGENT,
self.user_agent.parse().unwrap(),
);
let response = client
.get(url)
.headers(headers)
.send()
.await
.map_err(|e| format!("Request failed: {}", e))?;
let content_type = response
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
// Handle HTML content
if content_type.contains("text/html") {
let html = response
.text()
.await
.map_err(|e| format!("Failed to read response: {}", e))?;
return Ok(self.extract_text_from_html(&html));
}
// Handle JSON content
if content_type.contains("application/json") {
let text = response
.text()
.await
.map_err(|e| format!("Failed to read response: {}", e))?;
// Pretty print JSON
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&text) {
return Ok(serde_json::to_string_pretty(&parsed).unwrap_or(text));
}
return Ok(text);
}
// For other content types, return raw text
response
.text()
.await
.map_err(|e| format!("Failed to read response: {}", e))
}
fn extract_text_from_html(&self, html: &str) -> String {
let mut text = html.to_string();
// Remove script and style tags with content using simple replacements
text = strip_tag(&text, "script");
text = strip_tag(&text, "style");
// Remove all HTML tags
text = strip_all_tags(&text);
// Decode HTML entities
text = self.decode_html_entities(&text);
// Clean up whitespace
let mut cleaned = String::new();
let mut last_was_space = true;
for c in text.chars() {
if c.is_whitespace() {
if !last_was_space {
cleaned.push(' ');
last_was_space = true;
}
} else {
cleaned.push(c);
last_was_space = false;
}
}
cleaned.trim().to_string()
}
fn decode_html_entities(&self, text: &str) -> String {
let entities = [
("&nbsp;", " "),
("&lt;", "<"),
("&gt;", ">"),
("&amp;", "&"),
("&quot;", "\""),
("&apos;", "'"),
("&mdash;", ""),
("&ndash;", ""),
("&copy;", "©"),
("&reg;", "®"),
("&trade;", ""),
];
let mut result = text.to_string();
for (entity, replacement) in entities {
result = result.replace(entity, replacement);
}
result
}
}
fn strip_tag(s: &str, tag_name: &str) -> String {
let open = format!("<{}>", tag_name);
let close = format!("</{}>", tag_name);
let mut result = s.to_string();
// Keep removing until no more found (simple approach)
while let Some(start) = result.to_lowercase().find(&open) {
if let Some(end) = result.to_lowercase()[start..].find(&close) {
let end_pos = start + end + close.len();
result = format!("{}{}", &result[..start], &result[end_pos..]);
} else {
break;
}
}
result
}
fn strip_all_tags(s: &str) -> String {
let mut result = String::new();
let mut in_tag = false;
for c in s.chars() {
if c == '<' {
in_tag = true;
} else if c == '>' {
in_tag = false;
result.push(' ');
} else if !in_tag {
result.push(c);
}
}
result
}
fn extract_html_entity(s: &str) -> Option<(char, usize)> {
let s_lower = s.to_lowercase();
let entities = [
("&nbsp;", ' '),
("&lt;", '<'),
("&gt;", '>'),
("&amp;", '&'),
("&quot;", '"'),
("&apos;", '\''),
("&mdash;", '—'),
("&ndash;", ''),
("&copy;", '©'),
("&reg;", '®'),
("&trade;", '™'),
];
for (entity, replacement) in entities {
if s_lower.starts_with(&entity.to_lowercase()) {
return Some((replacement, entity.len()));
}
}
// Handle numeric entities
if s_lower.starts_with("&#x") || s_lower.starts_with("&#") {
// Skip for now
}
None
}
fn extract_host(url: &str) -> Result<String, String> {
let rest = url
.strip_prefix("http://")
.or_else(|| url.strip_prefix("https://"))
.ok_or_else(|| "Only http:// and https:// URLs are allowed".to_string())?;
let authority = rest
.split(['/', '?', '#'])
.next()
.ok_or_else(|| "Invalid URL".to_string())?;
if authority.is_empty() {
return Err("URL must include a host".to_string());
}
let host = authority
.split(':')
.next()
.unwrap_or_default()
.trim()
.to_lowercase();
if host.is_empty() {
return Err("URL must include a valid host".to_string());
}
Ok(host)
}
fn is_private_host(host: &str) -> bool {
if host == "localhost" || host.ends_with(".localhost") {
return true;
}
if host.rsplit('.').next().is_some_and(|label| label == "local") {
return true;
}
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
return match ip {
std::net::IpAddr::V4(v4) => {
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified()
}
std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(),
};
}
false
}
#[async_trait]
impl Tool for WebFetchTool {
fn name(&self) -> &str {
"web_fetch"
}
fn description(&self) -> &str {
"Fetch a URL and extract readable text content. Supports HTML and JSON. Automatically extracts plain text from HTML, removes scripts and styles."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "URL to fetch"
}
},
"required": ["url"]
})
}
fn read_only(&self) -> bool {
true
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let url = match args.get("url").and_then(|v| v.as_str()) {
Some(u) => u,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: url".to_string()),
});
}
};
let url = match self.validate_url(url) {
Ok(u) => u,
Err(e) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
});
}
};
match self.fetch_content(&url).await {
Ok(content) => Ok(ToolResult {
success: true,
output: self.truncate_response(&content),
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_tool() -> WebFetchTool {
WebFetchTool::new(50_000, 30)
}
#[tokio::test]
async fn test_validate_url_success() {
let tool = test_tool();
let result = tool.validate_url("https://example.com");
assert!(result.is_ok());
}
#[tokio::test]
async fn test_validate_url_rejects_private() {
let tool = test_tool();
let result = tool.validate_url("https://localhost:8080");
assert!(result.is_err());
assert!(result.unwrap_err().contains("local/private"));
}
#[tokio::test]
async fn test_validate_url_rejects_whitespace() {
let tool = test_tool();
let result = tool.validate_url("https://example.com/hello world");
assert!(result.is_err());
assert!(result.unwrap_err().contains("whitespace"));
}
#[tokio::test]
async fn test_extract_text_simple() {
let tool = test_tool();
let html = "<html><body><p>Hello World</p></body></html>";
let text = tool.extract_text_from_html(html);
assert!(text.contains("Hello World"));
assert!(!text.contains("<"));
}
#[tokio::test]
async fn test_extract_text_removes_scripts() {
let tool = test_tool();
let html = "<html><body><script>alert('bad');</script><p>Good</p></body></html>";
let text = tool.extract_text_from_html(html);
assert!(text.contains("Good"));
assert!(!text.contains("alert"));
}
#[tokio::test]
async fn test_extract_text_removes_styles() {
let tool = test_tool();
let html = "<html><head><style>.x { color: red; }</style></head><body><p>Content</p></body></html>";
let text = tool.extract_text_from_html(html);
assert!(text.contains("Content"));
assert!(!text.contains("color"));
}
}