Compare commits

..

10 Commits

97 changed files with 5928 additions and 7939 deletions

28
.dockerignore Normal file
View File

@ -0,0 +1,28 @@
# Git
.git
.gitignore
# Build artifacts
target/
!target/release/picobot
# IDE
.vscode/
.idea/
*.swp
*.swo
# Docs and references
docs/
reference/
# Test files
tests/
# Misc
*.md
*.txt
.opencode/
CLAUDE.md
AGENTS.md
ARCHITECTURE_REVIEW.md

2
.gitignore vendored
View File

@ -1,6 +1,8 @@
/target /target
docker_build/
reference/** reference/**
.env .env
*.env *.env
Cargo.lock Cargo.lock
.worktrees/ .worktrees/
design

View File

@ -1,128 +0,0 @@
# 架构审查报告
> 生成时间: 2026-04-26
> 更新时间: 2026-04-26
## 审查摘要
本报告识别了当前代码库中的架构不合理、冗余和无效代码的问题。
---
## 问题清单
### 已修复
#### ✅ #1 OutboundDispatcher 重复维护 Channel 注册表
**修复方案**: `OutboundDispatcher` 现在从 `ChannelManager` 获取 channels而不是自己维护一份注册表。
**修改文件**:
- `src/bus/dispatcher.rs` - 移除 `channels` 字段,改用 `ChannelManager`
- `src/channels/manager.rs` - 添加 `register_channel` 方法
- `src/gateway/mod.rs` - 简化 dispatcher 初始化
---
#### ✅ #2 CliChatChannel 持有独立的 SessionStore
**修复方案**: `CliChatChannel``SessionStore` 通过依赖注入从 `ChannelManager` 获取,而不是独立持有。
**修改文件**:
- `src/channels/cli_chat.rs` - 添加 `set_store()` 方法
- `src/channels/manager.rs` - 添加 `cli_chat_channel` 字段
- `src/gateway/mod.rs` - 重构 channel 初始化流程
---
#### ✅ #3 MessageBus 被创建两次引用
**修复方案**: 移除 `GatewayState.bus` 字段,直接使用 `channel_manager.bus()`
**修改文件**:
- `src/gateway/mod.rs` - 移除冗余的 `bus` 字段
---
#### ✅ #4 GatewayState 同时持有 channel_manager 和 cli_chat_channel
**修复方案**: `cli_chat_channel` 只通过 `ChannelManager` 管理,`GatewayState` 不再单独持有。
**修改文件**:
- `src/gateway/mod.rs` - 移除 `cli_chat_channel` 字段,添加 `cli_chat_channel()` getter 方法
---
### 高优先级(待修复)
#### ❌ Session 每次重建都创建新的 LLM Provider
**文件**: `src/gateway/session.rs:349-361`
**问题**: 每当 session TTL 过期默认4小时就会销毁并重建 session同时创建新的 LLM provider 连接。
**建议**: Provider 应该池化复用,不随 session 销毁而重建。
---
#### ❌ CliChatChannel::send 广播给所有客户端
**文件**: `src/channels/cli_chat.rs:279-289`
**问题**: `OutboundMessage``chat_id` 字段用于路由,但实现广播给所有客户端,而不是只发给对应 chat_id 的客户端。
**建议**: 根据 `chat_id` 过滤客户端,只发送给对应的客户端。
---
### 中优先级(待修复)
#### ❌ default_tools() 每次调用创建新 ToolRegistry
**文件**: `src/gateway/session.rs:212-227`
**建议**: 如果工具列表是只读的,直接 clone Arc如果需要修改需要澄清设计意图。
---
### 低优先级(待修复)
#### ❌ FeishuChannel::new 接收未使用的 provider_config
**文件**: `src/channels/feishu.rs:175-178`
---
#### ❌ OutboundDispatcher::send_with_retry 永不执行的 unreachable
**文件**: `src/bus/dispatcher.rs:81`
---
#### ❌ Channel trait 的 `is_running` 使用 std::sync::Mutex
**文件**: `src/channels/base.rs:38` vs `src/channels/cli_chat.rs:265-267`
---
#### ❌ LoopDetector 硬编码在 AgentLoop 中
**文件**: `src/agent/agent_loop.rs:88-172`
---
#### ❌ InboundMessage 和 OutboundMessage 结构重复
**文件**: `src/bus/message.rs`
---
## 问题统计
| 状态 | 优先级 | 数量 |
|------|--------|------|
| ✅ 已修复 | - | 4 |
| ❌ 待修复 | 高 | 2 |
| ❌ 待修复 | 中 | 1 |
| ❌ 待修复 | 低 | 5 |
| **总计** | - | **12** |

View File

@ -1,6 +1,6 @@
[package] [package]
name = "picobot" name = "picobot"
version = "0.1.0" version = "1.1.0"
edition = "2024" edition = "2024"
[dependencies] [dependencies]
@ -12,6 +12,8 @@ serde_json = "1.0"
async-trait = "0.1" async-trait = "0.1"
thiserror = "2.0.18" thiserror = "2.0.18"
tokio = { version = "1.52", features = ["full"] } tokio = { version = "1.52", features = ["full"] }
tokio-util = { version = "0.7", features = ["rt"] }
dashmap = "6.1"
uuid = { version = "1.23", features = ["v4"] } uuid = { version = "1.23", features = ["v4"] }
axum = { version = "0.8", features = ["ws"] } axum = { version = "0.8", features = ["ws"] }
tokio-tungstenite = { version = "0.29.0", features = ["rustls-tls-webpki-roots", "rustls"] } tokio-tungstenite = { version = "0.29.0", features = ["rustls-tls-webpki-roots", "rustls"] }
@ -49,6 +51,7 @@ encoding_rs = "0.8"
zstd = "0.13" zstd = "0.13"
tar = "0.4" tar = "0.4"
fantoccini = { version = "0.22", default-features = false, features = ["rustls-tls"] } fantoccini = { version = "0.22", default-features = false, features = ["rustls-tls"] }
portable-pty = "0.9"
[build-dependencies] [build-dependencies]
zstd = "0.13" zstd = "0.13"

110
Dockerfile Normal file
View File

@ -0,0 +1,110 @@
# =============================================================================
# PicoBot Docker Image
# =============================================================================
# Build binary on host:
# cargo build --release
#
# Build image:
# docker build -t picobot .
#
# Run gateway: docker run -d -v ~/.picobot:/app/.picobot -p 19876:19876 picobot gateway
# Run chat: docker run -it -v ~/.picobot:/app/.picobot picobot chat
# =============================================================================
FROM debian:trixie-slim
LABEL org.opencontainers.image.title="PicoBot"
LABEL org.opencontainers.image.description="AI agent gateway and chat client"
LABEL org.opencontainers.image.source="https://github.com/your-repo/picobot"
# Avoid interactive prompts
ENV DEBIAN_FRONTEND=noninteractive
# Configure domestic mirrors for pip, uv, npm (China)
ENV PIP_INDEX_URL=https://mirrors.aliyun.com/pypi/simple/
ENV UV_INDEX_URL=https://mirrors.aliyun.com/pypi/simple/
# Install base tools, Python, and uv in one layer to reduce duplication
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
tini \
curl \
gnupg \
git \
jq \
tree \
zip \
unzip \
sqlite3 \
openssh-client \
sshpass \
dnsutils \
poppler-utils \
fonts-wqy-zenhei \
fonts-wqy-microhei \
python3 \
python3-pip \
python3-venv \
&& rm -rf /var/lib/apt/lists/* \
&& pip3 install --no-cache-dir --break-system-packages uv
# Install Node.js and npx
RUN curl -fsSL https://deb.nodesource.com/setup_22.x | bash - \
&& apt-get install -y --no-install-recommends nodejs \
&& npm config set registry https://registry.npmmirror.com \
&& npm cache clean --force \
&& rm -rf /var/lib/apt/lists/*
# Install himalaya (CLI email client) from local file
COPY docker_build/himalaya.x86_64-linux.tgz /tmp/himalaya.tgz
RUN tar -xzf /tmp/himalaya.tgz -C /usr/local/bin \
&& chmod +x /usr/local/bin/himalaya \
&& rm -f /tmp/himalaya.tgz
# Install fd (alternative to find)
RUN curl -fsSL https://github.com/sharkdp/fd/releases/download/v9.0.0/fd-v9.0.0-x86_64-unknown-linux-gnu.tar.gz | \
tar -xz --strip-components=1 -C /usr/local/bin \
&& chmod +x /usr/local/bin/fd
# Install ripgrep (rg)
RUN curl -fsSL https://github.com/BurntSushi/ripgrep/releases/download/14.1.0/ripgrep-14.1.0-x86_64-unknown-linux-musl.tar.gz | \
tar -xz --strip-components=1 -C /usr/local/bin \
&& chmod +x /usr/local/bin/rg
# Install Chromium and chromedriver for browser automation
# Debian's chromium package is real (not a snap shim like Ubuntu 24.04)
RUN apt-get update && apt-get install -y --no-install-recommends \
chromium \
chromium-driver \
&& ln -sf /usr/bin/chromium /usr/local/bin/chrome \
&& ln -sf /usr/bin/chromedriver /usr/local/bin/chromedriver \
&& rm -rf /var/lib/apt/lists/*
# Create non-root user
RUN useradd -m -s /bin/bash app
WORKDIR /app
# Copy pre-built binary from host
COPY target/release/picobot /app/picobot
# Copy config template
COPY resources/templates/config.example.json /app/config.json.example
# Create required directories
RUN mkdir -p /app/.picobot/workspace /app/.picobot/media /app/.picobot/tmp && \
chown -R app:app /app
USER app
ENV HOME=/app
# Environment variables for Chromium in containers
ENV CHROME_BIN=/usr/bin/chromium
ENV TMPDIR=/app/.picobot/tmp
ENTRYPOINT ["/app/picobot"]
CMD ["gateway"]
EXPOSE 19876
ENV RUST_LOG=info

537
README.md
View File

@ -1,143 +1,102 @@
# PicoBot # PicoBot
A multi-channel AI agent framework with a WebSocket gateway and TUI client, supporting OpenAI-compatible and Anthropic LLM providers, tool calling, session persistence, and cron-based scheduling. PicoBot is a Rust-based personal AI assistant runtime. It runs a local gateway, connects chat channels such as the terminal TUI and Feishu/Lark, persists sessions in SQLite, and gives the agent a tool system for files, shell commands, web access, memory, scheduling, skills, MCP tools, and delegated sub-agents.
## System Architecture ## What It Does
```mermaid - Runs as a gateway server on `127.0.0.1:19876` by default.
graph TB - Provides a Ratatui terminal client over WebSocket.
subgraph Clients - Supports Feishu/Lark messages, reactions, file upload/download, and media references.
TUI["🖥️ CLI Chat (TUI)"] - Calls OpenAI-compatible providers and Anthropic Messages API providers.
FS["📱 Feishu/Lark"] - Persists conversations, messages, memories, scheduled jobs, LLM call metadata, and background sub-agent tasks in SQLite.
end - Loads skills from workspace, user, and shared skill directories, with built-in skills installed on first use.
- Compresses long contexts and stores timeline summaries for later recall.
- Can register tools discovered from configured MCP servers.
subgraph Gateway["Gateway Server (127.0.0.1:19876)"] ## Architecture
HTTP["HTTP Endpoints<br/>GET /health<br/>GET /ws (WebSocket upgrade)"]
WS["WebSocket Handler"]
CD["ChannelManager"]
SP["SessionManager"]
AL["AgentLoop"]
end
subgraph Bus["MessageBus"] ```text
IB["Inbound Channel"] Channel -> MessageBus -> SessionManager -> AgentLoop -> LLM Provider
OB["Outbound Channel"] | |
CC["Control Channel"] | v
end | Tools
v
SQLite
subgraph Storage Control messages -> SessionManager -> MessageBus -> OutboundDispatcher -> Channel
SQLite[("SQLite<br/>picobot.db")]
end
subgraph AI["AI Providers"]
OpenAI["OpenAI / DashScope"]
Anthropic["Anthropic Claude"]
end
TUI <-->|WebSocket| WS
FS <-->|Webhook| HTTP
CD -->|InboundMessage| IB
IB -->|DialogEvent| SP
CC -->|ControlMessage| SP
SP <--> AL
AL -->|API Call| OpenAI
AL -->|API Call| Anthropic
AL -->|Tool Call| Tools
SP -->|OutboundMessage| OB
OB --> CD
SP --> SQLite
Tools --> SQLite
subgraph Tools
Bash["Bash"]
FileIO["File Read/Write/Edit"]
Web["HTTP Request / Web Fetch"]
Calc["Calculator"]
Skill["Get Skill"]
Msg["Send Message"]
Cron["Cron Jobs"]
end
``` ```
### Core Data Flow The main runtime boundary is:
```mermaid - `channels` only receive and send external messages.
sequenceDiagram - `bus` is an async queue, not a router.
participant Channel as Channel<br/>(CLI/Feishu) - `session` owns dialog lifecycle, persistence, memory recall, prompt assembly, compression, and task cancellation.
participant Bus as MessageBus - `agent` runs the stateless LLM/tool loop.
participant SM as SessionManager - `providers` are HTTP clients for model APIs.
participant AL as AgentLoop - `tools` execute agent actions and return string results.
participant LLM as LLM Provider - `storage` owns SQLite schema and CRUD.
participant Tool as Tools - `scheduler` polls due jobs and feeds prompts back into sessions.
Channel->>Bus: InboundMessage (user input)
Bus->>SM: DialogEvent
SM->>SM: Load/Resolve Session
SM->>AL: Process (session state)
AL->>LLM: ChatCompletionRequest
LLM-->>AL: response / tool_calls
alt Tool Calls
AL->>Tool: execute tool
Tool-->>AL: result
AL->>LLM: continue with tool result
end
AL-->>SM: AgentProcessResult (text + token count)
SM->>SM: Persist to SQLite
SM->>Bus: OutboundMessage
Bus->>Channel: response to user
```
## Features ## Features
### Multi-Channel Support ### Channels
- **CLI Chat Client** — Full TUI with session management, Markdown rendering, slash commands
- **Feishu (Lark)** — Webhook-based integration with typing indicators and media support
### Multi-Provider LLM - `cli_chat`: terminal TUI client connected through `/ws`.
- OpenAI-compatible API (GPT-4, DashScope, Volcengine, etc.) - `feishu`: Feishu/Lark channel with configurable allow list, media directory, and reaction emoji.
- Anthropic Messages API (Claude)
- Cross-provider JSON Schema normalization for tool calling compatibility
### Session Management ### LLM Providers
- Multi-session conversations per channel/chat
- Create, switch, rename, archive, delete dialogs via slash commands or WebSocket
- SQLite-persisted session history with automatic TTL-based cleanup
- Context compression for long conversations approaching token limits
### Tool System - OpenAI-compatible chat completions, including DashScope, Volcengine, and similar APIs.
| Tool | Description | - Anthropic Messages API.
|------|-------------| - Model-specific `input_type` metadata for text/image capability checks.
| `bash` | Execute shell commands in workspace | - JSON Schema cleanup for cross-provider tool compatibility.
| `file_read` | Read file contents |
| `file_write` | Create/overwrite files |
| `file_edit` | Precise string substitution in files |
| `http_request` | Make HTTP API requests |
| `web_fetch` | Fetch and parse web pages |
| `calculator` | Evaluate mathematical expressions |
| `get_skill` | Load agent skills from local skill files |
| `send_message` | Send messages to other channels |
| `cron_add/list/remove/enable/disable/update` | Manage scheduled jobs |
### Scheduling ### Sessions And Memory
- Cron-based recurring jobs with optional timezone support
- One-shot (`at`) and interval (`every`) schedules
- Jobs trigger agent processing via specified channel/chat
### Skills System - Session IDs use `<channel>:<chat_id>:<dialog_id>`.
- Load Markdown skill files from `~/.picobot/skills` and `~/.agents/skills` - Each channel/chat can have multiple dialogs.
- Skills inject specialized system prompts for specific tasks - Dialog operations include create, list, switch, rename, delete, compact, dump, info, and stop.
- Automatic hot-reload on file changes - Session history is persisted to SQLite and can be incrementally restored after compression.
- Knowledge memories are recalled into the system prompt each turn.
- Timeline memories are produced by context compression and can be searched later.
### Observability ### Tools
- Observer pattern for agent and tool telemetry
- Events: `AgentStart`, `AgentEnd`, `ToolCallStart`, `ToolCall` Base tools registered for the agent:
- Structured JSON logging with file rotation
| Tool | Purpose |
|------|---------|
| `calculator` | Math expressions and statistics |
| `file_read` / `file_write` / `file_edit` | Workspace file operations |
| `file_search` / `content_search` | File and content search |
| `bash` | Run shell commands in the workspace |
| `http_request` | HTTP API requests |
| `web_fetch` | Fetch and extract web page text |
| `get_skill` | List or load local skills |
| `memory_store` / `memory_recall` / `timeline_recall` / `memory_forget` | Long-term memory operations |
| `delegate` | Run inline, background, or parallel sub-agents |
| `send_message` | Send outbound messages to configured channels |
| `chat_manager` | Inspect sessions, channels, and stored messages |
| `cron_add/list/remove/enable/disable/update` | Manage scheduled jobs when scheduler is enabled |
| `browser` | Optional WebDriver browser automation when enabled |
| MCP tools | Dynamically registered from configured MCP servers |
### Skills
Skills are directories containing `SKILL.md`. Load priority is:
1. `{workspace}/skills`
2. `~/.picobot/skills`
3. `~/.agents/skills`
Same-name skills in higher-priority locations override lower-priority ones. Built-in skills from `resources/skills` are embedded into the binary and installed into `~/.picobot/skills` if missing.
## Quick Start ## Quick Start
### Prerequisites ### Prerequisites
- Rust nightly (edition 2024) — use `rustup` to install
- Rust toolchain with edition 2024 support.
- A configured LLM provider API key.
### Build ### Build
@ -147,276 +106,186 @@ cargo build
### Configure ### Configure
1. Create `config.json` (or `~/.picobot/config.json`): PicoBot loads `~/.picobot/config.json` first, then falls back to `./config.json`. On gateway startup, a template is released to `~/.picobot/config.example.json` if it does not exist. The source template is [resources/templates/config.example.json](/home/xiaoxixi/code/PicoBot/resources/templates/config.example.json).
Minimal example:
```json ```json
{ {
"providers": { "providers": {
"openai": { "openai": {
"type": "openai", "type": "openai",
"base_url": "https://api.openai.com/v1", "base_url": "https://api.openai.com/v1",
"api_key": "<OPENAI_API_KEY>" "api_key": "<OPENAI_API_KEY>",
} "extra_headers": {}
},
"models": {
"gpt-4o": {
"model_id": "gpt-4o",
"temperature": 0.7,
"max_tokens": 4096
}
},
"agents": {
"default": {
"provider": "openai",
"model": "gpt-4o",
"max_tool_iterations": 99,
"token_limit": 128000
}
} }
},
"models": {
"gpt-4o": {
"model_id": "gpt-4o",
"temperature": 0.7,
"max_tokens": 4096,
"input_type": ["text", "image"]
}
},
"agents": {
"default": {
"provider": "openai",
"model": "gpt-4o",
"max_tool_iterations": 99,
"token_limit": 128000
}
},
"workspace_dir": "~/.picobot/workspace"
} }
``` ```
2. Set API keys via `.env` file (one `KEY=VALUE` per line): The `.env` file in the current directory is parsed by PicoBot itself. Values like `<OPENAI_API_KEY>` in JSON are replaced from the process environment after `.env` is loaded.
```env
OPENAI_API_KEY=sk-xxxxx
```
### Run ### Run
**Start gateway server:**
```bash ```bash
cargo run -- gateway cargo run -- gateway
``` ```
Binds `127.0.0.1:19876` by default. Override with `--host` and `--port`. The gateway switches the process working directory to `workspace_dir` and stores `picobot.db` there by default.
**Connect CLI client:** In another terminal:
```bash ```bash
cargo run -- chat cargo run -- chat
``` ```
Connects to `ws://127.0.0.1:19876/ws`. Override with `--gateway-url`. The client connects to `ws://127.0.0.1:19876/ws` by default. Override with `--gateway-url`.
## Configuration Reference ## Configuration
Config load order: `~/.picobot/config.json``./config.json` (fallback). Top-level config fields:
### Full Config Structure | Field | Purpose |
|-------|---------|
| `providers` | Named LLM provider configs |
| `models` | Named model configs |
| `agents` | Agent-to-provider/model binding |
| `gateway` | Bind address, session DB path, cleanup, scheduler, background task limits |
| `client` | Default WebSocket URL for the TUI client |
| `channels` | Channel configs, currently Feishu/Lark |
| `memory` | Recall and consolidation settings |
| `mcp` | MCP server configs |
| `browser` | Optional WebDriver browser tool config |
| `workspace_dir` | Workspace used for file tools, shell commands, DB default, and workspace skills |
```mermaid Important defaults:
graph LR
Config["config.json"]
Config --> Providers["providers<br/>ProviderConfig{}"]
Config --> Models["models<br/>ModelConfig{}"]
Config --> Agents["agents<br/>AgentConfig{}"]
Config --> Gateway["gateway<br/>GatewayConfig"]
Config --> Client["client<br/>ClientConfig"]
Config --> Channels["channels<br/>ChannelConfig{}"]
Config --> Workspace["workspace_dir"]
Providers --> PT["type (openai / anthropic)<br/>base_url<br/>api_key<br/>extra_headers"] | Key | Default |
Models --> MT["model_id<br/>temperature<br/>max_tokens"] |-----|---------|
Agents --> AT["provider (ref)<br/>model (ref)<br/>max_tool_iterations<br/>token_limit"] | `gateway.host` | `127.0.0.1` |
Gateway --> GT["host / port<br/>session_db_path<br/>scheduler"] | `gateway.port` | `19876` |
Channels --> CT["feishu: app_id, app_secret<br/>allow_from, agent, media_dir"] | `gateway.max_concurrent_background_tasks` | `10` |
``` | `gateway.scheduler.enabled` | `true` if `scheduler` is omitted and defaulted |
| `client.gateway_url` | `ws://127.0.0.1:19876/ws` |
| `memory.recall_limit` | `5` |
| `memory.timeline_retention_days` | `90` |
| `mcp.tool_timeout_secs` | `180` |
| `browser.enabled` | `false` |
### Environment Variables MCP servers support `stdio`, `sse`, and `streamable-http` transports. Browser automation requires a compatible Chrome/Chromium and chromedriver/WebDriver endpoint.
The `.env` file in the working directory is loaded manually (not via dotenv crate). Placeholders in `config.json` written as `<VAR_NAME>` are substituted at load time.
### Gateway Config
| Key | Type | Default | Description |
|-----|------|---------|-------------|
| `host` | string | `127.0.0.1` | Bind address |
| `port` | u16 | `19876` | Listen port |
| `session_db_path` | string | workspace `picobot.db` | SQLite database path |
| `scheduler.enabled` | bool | `false` | Enable cron scheduler |
### Agent Config
| Key | Type | Default | Description |
|-----|------|---------|-------------|
| `provider` | string | — | Provider name (key in `providers`) |
| `model` | string | — | Model name (key in `models`) |
| `max_tool_iterations` | number | `99` | Max tool call iterations per turn |
| `token_limit` | number | `128000` | Context window token limit |
## Slash Commands ## Slash Commands
Available in CLI chat and Feishu: Available from CLI chat and channel text messages:
| Command | Alias | Description | | Command | Description |
|---------|-------|-------------| |---------|-------------|
| `/new` | `/刷新` | Create a new dialog | | `/new` | Create a new dialog |
| `/list` | `/对话列表` | List all dialogs | | `/sessions` | List recent dialogs |
| `/switch <id>` | — | Switch to a dialog | | `/switch <dialog_id>` | Switch dialog |
| `/rename <title>` | — | Rename current dialog | | `/rename <title>` | Rename current dialog |
| `/archive` | — | Archive current dialog | | `/delete` | Delete current dialog |
| `/delete` | — | Delete current dialog | | `/compact` | Manually trigger context compression |
| `/clear` | `/清空` | Clear current dialog history | | `/info` | Show current dialog information |
| `/dump` | Save current dialog as Markdown |
| `/?`, `/help` | Show help |
| `/mcp` | Show MCP server and tool status |
| `/stop` | Stop active tasks and clear queued messages |
## WebSocket Protocol ## WebSocket API
The gateway exposes a WebSocket endpoint at `/ws`. Messages use typed JSON with a `type` discriminator field. The gateway exposes:
### Client → Server (WsInbound)
| Type | Fields |
|------|--------|
| `user_input` | `content`, `channel?`, `chat_id?`, `sender_id?` |
| `create_session` | `title?` |
| `list_sessions` | `include_archived` |
| `load_session` | `session_id` |
| `rename_session` | `session_id?`, `title` |
| `archive_session` | `session_id?` |
| `delete_session` | `session_id?` |
| `clear_history` | `chat_id?`, `session_id?` |
| `get_slash_commands` | — |
| `ping` | — |
### Server → Client (WsOutbound)
| Type | Fields |
|------|--------|
| `assistant_response` | `session_id`, `response`, `tokens_used?`, `tool_calls?` |
| `session_list` | `sessions[]` |
| `session_loaded` | `session_id`, `messages[]` |
| `session_created` | `session_id`, `title` |
| `session_renamed` | `session_id`, `title` |
| `session_archived` | `session_id` |
| `session_deleted` | `session_id` |
| `slash_commands` | `commands[]` |
| `error` | `message` |
| `pong` | — |
## HTTP Endpoints
| Method | Path | Description | | Method | Path | Description |
|--------|------|-------------| |--------|------|-------------|
| `GET` | `/health` | Health check — returns `{"status":"ok","version":"x.y.z"}` | | `GET` | `/health` | Returns service health and version |
| `GET` | `/ws` | WebSocket upgrade for chat clients | | `GET` | `/ws` | WebSocket upgrade for chat clients |
Inbound WebSocket message types:
| Type | Main fields |
|------|-------------|
| `user_input` | `content`, optional `channel`, `chat_id`, `sender_id` |
| `clear_history` | optional `chat_id`, `session_id` |
| `create_session` | optional `title` |
| `list_sessions` | `include_archived` |
| `load_session` | `session_id` |
| `rename_session` | optional `session_id`, `title` |
| `archive_session` | optional `session_id` |
| `delete_session` | optional `session_id` |
| `get_slash_commands` | none |
| `ping` | none |
Outbound WebSocket message types include `assistant_response`, `error`, `session_established`, `session_created`, `session_list`, `session_loaded`, `session_renamed`, `session_archived`, `session_deleted`, `history_cleared`, `slash_commands_list`, `pong`, `command_executed`, and `system_notification`.
## Testing ## Testing
```bash ```bash
# Unit tests (no external dependencies) # Unit tests
cargo test --lib cargo test --lib
# Integration tests (require API keys) # Integration tests require real API keys in tests/test.env
cp tests/test.env.example tests/test.env cp tests/test.env.example tests/test.env
# Fill in your API keys in tests/test.env
cargo test --test test_integration -- --ignored cargo test --test test_integration -- --ignored
cargo test --test test_tool_calling -- --ignored cargo test --test test_tool_calling -- --ignored
cargo test --test test_request_format -- --ignored cargo test --test test_request_format -- --ignored
# Run all tests
cargo test -- --ignored
``` ```
Integration tests are `#[ignore]` by default because they make real API calls. Integration tests are ignored by default because they make real provider calls.
## Project Structure ## Project Layout
``` ```text
├── src/ src/
│ ├── main.rs # CLI entrypoint (clap-based subcommands) agent/ LLM loop, context compression, system prompts, media handling, sub-agents
│ ├── lib.rs # Module declarations bus/ Inbound, outbound, and control message queues
│ ├── gateway/ # HTTP/WS server, GatewayState initialization channels/ CLI chat and Feishu/Lark integrations
│ │ ├── mod.rs client/ Ratatui terminal UI
│ │ ├── http.rs # Health endpoint config/ Config loading, env substitution, path expansion
│ │ └── ws.rs # WebSocket handler gateway/ Axum HTTP/WebSocket server and GatewayState wiring
│ ├── client/ # TUI chat client mcp/ MCP client connections and tool wrappers
│ │ ├── mod.rs memory/ Memory manager and memory types
│ │ └── tui/ # Ratatui-based terminal UI observability/ Agent/tool telemetry observer interfaces
│ ├── channels/ # Channel integrations providers/ OpenAI-compatible and Anthropic clients
│ │ ├── base.rs # Channel trait scheduler/ Scheduled job runtime
│ │ ├── cli_chat.rs # CLI WebSocket channel session/ Session lifecycle, dialog commands, persistence integration
│ │ ├── feishu.rs # Feishu/Lark webhook channel skills/ Skill loading and embedded built-in skill installation
│ │ ├── manager.rs # ChannelManager storage/ SQLite schema and CRUD
│ │ └── slash_command.rs # Slash command parser tools/ Agent tool implementations
│ ├── bus/ # Async message bus resources/
│ │ ├── mod.rs # MessageBus (tokio mpsc channels) skills/ Built-in skills embedded at build time
│ │ ├── message.rs # Message types templates/ Config, AGENTS.md, and USER.md templates released on first run
│ │ └── dispatcher.rs # OutboundDispatcher tests/ Unit and ignored integration tests
│ ├── session/ # Session & dialog management reference/ Third-party reference code; do not modify as project source
│ │ ├── mod.rs
│ │ ├── session.rs # Session, SessionManager
│ │ ├── session_id.rs # UnifiedSessionId
│ │ ├── commands.rs # SessionCommand enum
│ │ └── events.rs # SessionEvent, DialogInfo
│ ├── agent/ # LLM interaction loop
│ │ ├── mod.rs
│ │ ├── agent_loop.rs # AgentLoop (stateless)
│ │ ├── context_compressor.rs # Token estimation & summarization
│ │ └── system_prompt.rs # System prompt builder
│ ├── providers/ # LLM API clients
│ │ ├── mod.rs # Factory: create_provider()
│ │ ├── traits.rs # LLMProvider trait
│ │ ├── openai.rs # OpenAI-compatible client
│ │ └── anthropic.rs # Anthropic Messages API client
│ ├── tools/ # Agent tools
│ │ ├── mod.rs # create_default_tools()
│ │ ├── registry.rs # ToolRegistry
│ │ ├── traits.rs # Tool trait, ToolResult
│ │ ├── schema.rs # Cross-provider JSON Schema cleaner
│ │ ├── bash.rs # Shell command execution
│ │ ├── calculator.rs # Math expression evaluator
│ │ ├── chat_manager.rs # Session management tool
│ │ ├── cron.rs # Cron job management tools
│ │ ├── file_read.rs # File reader
│ │ ├── file_write.rs # File writer
│ │ ├── file_edit.rs # File editor (string substitution)
│ │ ├── get_skill.rs # Skill loader tool
│ │ ├── http_request.rs # HTTP request tool
│ │ ├── send_message.rs # Cross-channel messaging
│ │ └── web_fetch.rs # Web page fetcher
│ ├── skills/ # Skills loading from markdown files
│ │ └── mod.rs # SkillsLoader, Skill
│ ├── storage/ # SQLite persistence
│ │ ├── mod.rs # Storage, schema init
│ │ ├── session.rs # Session CRUD operations
│ │ ├── message.rs # Message persistence
│ │ ├── scheduler.rs # ScheduledJob, JobRun storage
│ │ └── error.rs # StorageError
│ ├── scheduler/ # Cron scheduler runtime
│ │ ├── mod.rs # Scheduler, next_run_for_schedule()
│ │ └── types.rs # Schedule enum (At/Every/Cron)
│ ├── observability/ # Telemetry observer pattern
│ │ └── mod.rs # Observer trait, ObserverEvent, MultiObserver
│ ├── protocol.rs # WebSocket message types (WsInbound/WsOutbound)
│ ├── config/ # Config loading & env substitution
│ │ └── mod.rs # Config, LLMProviderConfig, load_env_file()
│ └── logging.rs # Tracing subscriber init with file rotation
├── tests/
│ ├── test_integration.rs # LLM provider integration tests
│ ├── test_tool_calling.rs # Tool calling integration tests
│ ├── test_request_format.rs # Request format tests
│ ├── test_scheduler.rs # Scheduler unit tests
│ ├── test.env.example # Test environment template
│ └── test.env # Actual test keys (gitignored)
├── reference/ # Third-party reference code (do not modify)
├── resources/ # Assets embedded in binary
│ └── templates/ # Templates released to ~/.picobot/ on first run
├── config.example.json # Full config example
└── Cargo.toml
``` ```
## Key Dependencies ## Key Dependencies
| Crate | Purpose | | Crate | Purpose |
|-------|---------| |-------|---------|
| `axum` + `tokio-tungstenite` | HTTP server & WebSocket | | `axum`, `tokio`, `tokio-tungstenite` | Gateway and WebSocket runtime |
| `sqlx` (SQLite) | Session/Message/Job persistence | | `sqlx` | SQLite persistence |
| `reqwest` (rustls) | LLM API & external HTTP calls | | `reqwest` | LLM and HTTP clients |
| `ratatui` + `crossterm` | Terminal UI | | `ratatui`, `crossterm`, `termimad` | Terminal UI |
| `clap` | CLI argument parsing | | `rmcp` | MCP client support |
| `tracing` + `tracing-subscriber` | Structured logging | | `fantoccini` | Optional browser automation |
| `cron` + `chrono-tz` | Cron schedule parsing | | `cron`, `chrono-tz` | Scheduling |
| `meval` | Mathematical expression evaluation | | `jieba-rs` | Chinese tokenization for memory search |
| `uuid` | Session/Dialog ID generation | | `zstd`, `tar` | Embedded built-in skill packaging |
| `dirs` | Platform config directory resolution |

16
docker-compose.yml Normal file
View File

@ -0,0 +1,16 @@
services:
picobot:
image: picobot:latest
container_name: picobot
restart: unless-stopped
ports:
- "19876:19876"
volumes:
- ~/.picobot/config.json:/app/.picobot/config.json:ro
- picobot_data:/app/.picobot
environment:
- RUST_LOG=info
command: gateway
volumes:
picobot_data:

View File

@ -0,0 +1,359 @@
# PicoBot 代码质量分析报告
审查日期2026-06-15
## 结论摘要
PicoBot 的总体架构方向是清晰的Gateway 负责装配Channel 只做收发MessageBus 解耦输入输出SessionManager 管理会话AgentLoop 保持无状态并执行工具Storage 统一持久化。这条主线是成立的,也已经具备较完整的 AI 助手运行时能力。
当前主要质量风险集中在三类:
1. 会话/CLI 路由语义不一致,导致多客户端隔离、加载会话、当前会话追踪不可靠。
2. 若干公开控制接口是空实现或弱实现,协议层暴露的能力和后端实际行为不匹配。
3. 工具和后台任务的资源边界偏弱文件、shell、HTTP、长期任务在异常情况下容易突破预期的安全或稳定性边界。
如果只安排一轮修复,优先处理会话路由和控制接口。这些问题会直接影响用户看到的行为;工具安全和大模块拆分可以作为第二阶段。
## 修复状态
- 已修复CLI 会话路由现在按每个 WebSocket client 的稳定 `chat_id` 隔离,普通输入、创建、列表、加载和 outbound 投递不再混用完整 `session_id``chat_id`
- 已修复Dialog 控制接口已补齐当前会话查询、列表 current 标记、归档、清空历史和 `/delete` 删除当前会话后新建的行为;`include_archived` 现在由 Storage 查询生效。
- 待处理工具文件边界、Session 锁粒度、Bash 超时进程清理等仍是后续质量风险。
## 主要发现
### 已修复CLI 会话路由会破坏会话连续性和多客户端隔离
位置:
- `src/channels/cli_chat.rs:113-126`
- `src/channels/cli_chat.rs:160-164`
- `src/channels/cli_chat.rs:225-249`
- `src/channels/cli_chat.rs:479-494`
- `src/session/session.rs:1305-1310`
问题:
`Client.current_session_id` 存的是完整 session id但 CLI channel 在多个地方把它当作 `chat_id` 使用。普通用户输入如果没有显式传 `chat_id`,会在 `src/channels/cli_chat.rs:119` 生成新的短 ID而不是复用当前 client 的 chat scope。`CreateSession` 又把当前完整 session id 当成新会话的 chat_id。`LoadSession` 解析了传入 session id但随后调用 `GetCurrentDialog`,而后端 `get_current_dialog()` 固定返回 `None`
同时,`send()` 会把所有 `OutboundMessage` 广播给所有 CLI WebSocket client没有按 `msg.chat_id` 或 client 当前会话过滤。这意味着一个客户端的回复可能出现在另一个客户端里。
影响:
- CLI 多轮对话可能落入不同 chat scope。
- 创建/列出/加载会话得到的结果可能不符合 UI 预期。
- 多个 CLI 客户端同时连接时存在串话。
建议:
- 将 client 状态拆成 `chat_id``current_session_id`,不要混用。
- 注册 client 时生成稳定 `chat_id`,后续 `UserInput` 默认复用它。
- `send()``OutboundMessage.chat_id` 精确投递;必要时维护 `chat_id -> clients` 映射。
- `LoadSession` 应直接切换到指定 session或通过 `SwitchDialog` 使用其中的 `dialog_id`
- 为 CLI WebSocket 增加多客户端路由测试。
### 已修复Dialog 控制接口与协议承诺不一致
位置:
- `src/session/session.rs:996-997`
- `src/session/session.rs:1305-1310`
- `src/session/session.rs:1329-1349`
- `src/session/session.rs:1378-1384`
- `src/channels/cli_chat.rs:128-158`
问题:
后端暴露了 create/list/load/rename/archive/delete/clear 等 dialog 操作,但部分行为是空实现或语义错位:
- `/delete` 只创建新 session并没有删除当前 session。
- `get_current_dialog()` 固定返回 `Ok(None)`
- `list_dialogs()` 忽略 `include_archived`,且总是返回 `current_dialog_id = None`
- `archive_dialog()` 是空操作。
- `clear_dialog_history()` 直接返回不可用,但 WebSocket 协议仍暴露 `clear_history`
影响:
用户通过 slash command 和 WebSocket 调用同一类能力时,会得到不一致结果。前端难以基于协议实现可靠状态同步。
建议:
- 明确“archive/clear 是否支持”。不支持就从协议和命令列表移除;支持就实现到底。
- `/delete` 应调用 `delete_dialog(current_session_id)`,再创建一个新的 current session。
- `get_current_dialog()` 应读取 `current_sessions[channel:chat_id]` 并解析为 `UnifiedSessionId`
- `list_dialogs()` 返回真实 current dialog并补上 archived 模型或移除 archived 参数。
### 高优先级:工具文件边界不符合“工作目录内工具”的架构约束
位置:
- `src/tools/mod.rs:56-62`
- `src/tools/path_utils.rs:3-23`
- `src/tools/bash.rs:146-185`
问题:
文件工具默认通过 `FileReadTool::new()``FileWriteTool::new()` 等注册,没有传入 workspace allowlist。`resolve_path()` 对绝对路径直接放行;即使传入 allowlist也只是做 `Path::starts_with()` 的词法判断,没有 canonicalize不能防御 `..`、符号链接等路径逃逸。
`bash` 默认工作目录是 `"."`Gateway 启动时切到 workspace这对相对路径有效但 shell 命令仍然可以访问绝对路径。当前 denylist 只挡少数危险模式,不构成权限边界。
影响:
Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“工作目录内操作”不一致。对于个人助手这可能是有意设计,但如果未来接入外部渠道、多用户或 MCP风险会放大。
建议:
- 工具注册时传入 `workspace_dir`,默认所有文件工具限制在 workspace。
- `resolve_path()` 使用 `std::fs::canonicalize``path_absolutize` 风格逻辑,并处理目标文件不存在时的父目录 canonicalize。
- 写工具禁止跟随危险符号链接,或至少在文档中明确该能力是全文件系统权限。
- shell 工具如果保留,应在配置中显式开关,并区分本地可信模式和渠道暴露模式。
### 中高优先级Session 锁内执行过多异步操作
位置:
- `src/session/session.rs:1001-1018`
- `src/session/session.rs:1604-1711`
问题:
`/compact` 在持有 session mutex 时执行压缩和持久化。agent worker 的 Phase 1 也在持有 session mutex 时执行用户消息落库、memory recall、上下文压缩、session meta 持久化和 agent 创建。其中 `compress_if_needed()` 可能触发 LLM 摘要,属于慢操作。
影响:
- 同一 session 的 slash command、stop、消息排队、状态查询会被慢操作阻塞。
- 当压缩或存储出现抖动时,用户感觉像“卡死”。
- 后续如果在这些慢操作里间接需要 session 状态,容易形成锁顺序问题。
建议:
- 锁内只做内存状态快照和必要的状态标记。
- 将 memory recall、压缩、LLM 摘要放到锁外执行。
- 锁外完成后重新加锁提交结果,并用 generation/version 检测期间是否被 `/stop` 或新任务替换。
### 中优先级Bash 超时不会显式终止子进程
位置:
- `src/tools/bash.rs:150-174`
- `src/tools/bash.rs:180-207`
问题:
`timeout()` 包裹的是 `run_command()` future。超时后 future 被取消,但代码没有持有 child 句柄并显式 `kill()` / `wait()`。对于已经启动的长运行命令或子进程树,可能留下后台进程。
影响:
长任务、服务进程或卡住的 shell 命令会泄漏进程和资源,后续工具调用的行为也会变得不可预测。
建议:
- 使用 `tokio::process::Child``kill_on_drop(true)`
- 超时分支显式 kill child 并 wait。
- 对 shell 子进程树使用进程组隔离,必要时杀整个进程组。
- 对需要持久进程的场景使用 PTY 工具,不混用 bash 的一次性语义。
### 中优先级:文件读取对大二进制文件没有输出上限
位置:
- `src/tools/file_read.rs:121-131`
- `src/tools/file_read.rs:214-229`
问题:
`file_read``std::fs::read()` 读取整个文件。文本路径有 `MAX_CHARS` 截断,但二进制路径会完整 base64 编码后返回,没有大小限制。
影响:
读取大文件会造成内存膨胀、响应膨胀、上下文污染,甚至拖垮进程。
建议:
- 先检查 metadata size超过阈值直接返回提示。
- 二进制文件默认只返回 mime、大小和建议操作需要内容时提供显式 `max_bytes` 参数。
- 对文本读取也改成流式按行读取,而不是整文件读入。
### 中优先级HTTP 私网防护只检查字面 host未做 DNS 解析校验
位置:
- `src/tools/http_request.rs:31-59`
问题:
`http_request` 阻止 localhost、私网 IP 字面量和 `.local`,但普通域名不会解析后检查最终 IP。DNS rebinding 或内网域名解析到私网地址时,当前校验拦不住。
影响:
如果该工具暴露给非完全可信输入,存在 SSRF 风险。
建议:
- 请求前解析域名拒绝私网、loopback、link-local、multicast、unspecified 地址。
- 禁止或限制重定向,重定向后的每个 URL 重新校验。
- 对 `http_request``web_fetch` 复用同一套 URL 安全策略。
### 中优先级:后台任务和主循环缺少监督与优雅关闭
位置:
- `src/bus/mod.rs:51-99`
- `src/gateway/mod.rs:187-244`
- `src/gateway/mod.rs:247-266`
问题:
Gateway 中多个长期任务通过 `tokio::spawn` 启动后没有保存 JoinHandle也没有统一 cancellation token。MessageBus 的 `consume_*()` 在 channel 关闭时使用 `expect()` panic。
影响:
- 某个后台 loop 异常退出后Gateway 不一定能发现。
- 关闭流程只能 stop channel无法系统性停止 scheduler、dispatcher、agent workers、notification publishers。
- bus channel 关闭时更像崩溃,而不是可恢复状态。
建议:
- 引入 runtime supervisor保存 JoinHandle 并集中处理退出原因。
- 用 `CancellationToken` 贯穿 Gateway 子任务。
- `consume_*()` 返回 `Result<Option<T>>`,由调用方决定退出或重启。
### 中低优先级Cron 计算函数没有按入参 `from` 计算 cron 下一次时间
位置:
- `src/scheduler/mod.rs:18-40`
问题:
`next_run_for_schedule(schedule, from)` 的注释说基于 `from` 计算,但 cron 分支创建了 `from_dt` 后没有传给 `cron_schedule`,实际使用的是 `upcoming(Utc)``upcoming(tz)` 的当前时间。
影响:
单元测试或补偿调度传入历史/未来时间时,结果不符合函数契约。线上 reschedule 当前使用 now影响较小但函数语义是错的。
建议:
- 使用 `cron_schedule.after(&from_dt).next()` 或等价 API。
- timezone 分支用 `from_dt.with_timezone(&tz)` 作为 after 起点。
- 增加固定时间输入的单元测试,避免受系统时间影响。
### 中低优先级:存在未接入或半接入代码,增加维护噪音
位置:
- `src/tools/pty.rs`
- `src/tools/mod.rs:1-20`
- `src/tools/mod.rs:49-88`
问题:
仓库里有完整 `pty.rs`,但 `tools/mod.rs` 没有声明 `pub mod pty``create_default_tools()` 也没有注册 PTY 工具。类似情况会让文档、计划和实现状态难以判断。
影响:
维护者会误以为功能已上线。未来改动容易遗漏测试和注册路径。
建议:
- 若 PTY 是要发布的功能:接入模块导出、注册、配置开关、测试和文档。
- 若暂不发布:移动到设计文档或 feature branch避免主干保留死代码。
## 架构评价
### 做得好的地方
- 模块分层方向清楚Channel、Bus、Session、Agent、Provider、Tool、Storage 边界基本可理解。
- AgentLoop 设计为无状态,历史由 SessionManager 管理,这一点利于恢复、压缩和测试。
- Provider 抽象简单直接OpenAI-compatible 与 Anthropic 的差异被限制在 provider 层。
- Storage 集中初始化 schema便于部署单二进制应用。
- Skill、memory、MCP、delegate 这几条扩展线已经形成统一的 ToolRegistry 接入点。
### 主要架构债务
- SessionManager 承担过多职责会话生命周期、命令解析、memory recall、压缩、agent worker、任务取消、send_message 目标解析都在一个 2000 行文件内。
- Channel 和 Session 对 chat_id/session_id/dialog_id 的边界没有类型保护,导致 CLI 层混用字符串。
- Tool 权限模型不够显式:工具是否能访问全文件系统、是否能联网、是否能修改状态主要靠工具自身约定。
- 后台任务生命周期分散gateway loop、agent worker、notification publisher、scheduler、sub-agent task 各自 spawn缺少统一管理。
## 模块级分析
### gateway
`GatewayState::new()` 是清晰的装配中心配置、workspace、storage、memory、bus、session manager、channels、MCP、scheduler 都在这里接线。问题是启动后任务监督不足,且 scheduler 默认 `unwrap_or_default()` 会在省略 `gateway.scheduler` 时启用调度器,这和“省略配置是否代表开启”需要产品层确认。
### channels
Feishu channel 功能较厚,单文件接近 2000 行,建议后续按 API client、message parsing、media handling、outbound rendering 拆分。CLI channel 目前是质量风险最高的 channel核心问题是会话身份混用和广播投递。
### bus
MessageBus 简洁,但当前消费者 API 通过 mutex 包住 receiver 并 `expect()`,更像“单消费者内部队列”。这没问题,但应该把“只能有一个 consumer”写进类型/文档,并把关闭作为正常状态处理。
### session
这是系统核心,也是债务最集中的模块。建议把 `session.rs` 拆成:
- `manager.rs`SessionManager 状态和 dialog 生命周期
- `worker.rs`per-session agent worker 和 cancellation
- `commands.rs`slash command 执行
- `outbound.rs`OutboundMessenger 实现
- `restore.rs`storage 恢复与 tool call chain repair
拆分之前,先补行为测试,尤其是 CLI/WS session lifecycle。
### agent
AgentLoop 的职责相对聚焦:请求模型、执行工具、回填 tool result、循环直到 final response。需要关注的是工具并发的语义`read_only()` 目前是工具自己声明副作用工具不能错标。LoopDetector 有帮助,但属于 runtime guard不应替代工具层的资源限制。
### providers
Provider 层整体可维护。OpenAI/Anthropic 的请求构造逻辑可以继续保留在 provider 内。建议补充请求脱敏策略:当前 debug log 和 `llm_calls` 会持久化完整 request/response可能包含用户隐私、API 返回内容和文件内容。
### tools
工具体系覆盖面很强,但需要明确权限模型。建议新增统一的 `ToolExecutionContext`,包含 workspace、channel、session_id、权限策略、网络策略、输出预算。现在很多策略散落在各工具构造函数里默认值容易失控。
### storage
Storage schema 初始化实用但迁移方式是“CREATE IF NOT EXISTS + ALTER IGNORE”适合早期迭代不适合长期演进。建议引入 schema version 表或 sqlx migrations至少把每次迁移记录下来。
### skills
Skill 加载优先级清晰,内置 skill 打包也实用。需要注意 `SkillsLoader` 使用同步文件系统扫描和 `std::sync::Mutex`,在请求路径频繁 `reload_if_changed()` 时可能造成阻塞。短期可以接受,长期建议缓存刷新放到后台 watcher。
## 建议修复路线
### P0先修会话正确性
1. 修正 CLI `chat_id/current_session_id` 数据模型。
2. 修正 CLI 出站按 client/chat_id 投递。
3. 实现 `get_current_dialog()``list_dialogs()` current 返回。
4. 修正 `/delete``clear_history``archive` 的真实行为或从协议移除。
5. 增加 WebSocket session lifecycle 测试。
### P1收紧工具和资源边界
1. 文件工具默认限制 workspace路径 canonicalize。
2. bash 超时杀进程,必要时引入进程组。
3. file_read 增加文件大小上限和二进制输出上限。
4. HTTP/web 工具增加 DNS 解析后的私网校验和重定向校验。
5. 明确高危工具的配置开关。
### P2降低架构复杂度
1. 拆分 `session.rs``feishu.rs``storage/mod.rs``browser.rs`
2. 引入任务 supervisor 和统一 shutdown token。
3. 引入正式数据库迁移。
4. 增加工具注册快照测试,避免死代码和文档漂移。
## 建议测试补充
- CLI 多客户端并发:两个 WebSocket client 同时发消息,互不串话。
- CLI 不传 chat_id 的连续对话:所有消息应进入同一 session。
- Load/switch/list/delete/clear 的完整 WebSocket 流程。
- `/delete` 后旧 session 软删除、新 session 成为 current。
- 文件路径逃逸:`../`、绝对路径、符号链接、workspace 前缀欺骗。
- bash timeout 后检查子进程不存在。
- cron `next_run_for_schedule()` 使用固定 `from` 的 deterministic 测试。
- HTTP 工具对 DNS 解析到 `127.0.0.1` / `10.0.0.0/8` 的域名拒绝测试。

View File

@ -1,40 +0,0 @@
# 客户端代码整合设计
## 目标
将分散在 `src/cli/``src/client/` 的客户端代码整合到 `src/client/` 目录。
## 变更
### 目录结构
```
src/
├── client/ # 整合后的客户端模块
│ ├── mod.rs # 主程序入口 (run 函数)
│ ├── input.rs # InputHandler + InputCommand (从 cli/input.rs 合并)
│ └── channel.rs # CliChannel (从 cli/channel.rs 合并)
├── cli/ # 删除
└── protocol.rs # 保留
```
### 关键变更
| 变更 | 说明 |
|------|------|
| `InputEvent::Message(String)` | 简化为只携带文本内容,不再使用 `ChatMessage` |
| `cli` 模块删除 | 代码合并到 `client` |
| 解耦 | `client` 不再依赖 `bus::ChatMessage` |
## 实施步骤
1. 创建 `src/client/input.rs` - 从 `cli/input.rs` 合并,修改 `InputEvent::Message``String`
2. 创建 `src/client/channel.rs` - 从 `cli/channel.rs` 直接复制
3. 更新 `src/client/mod.rs` - 更新 import
4. 更新 `src/lib.rs` - 删除 `pub mod cli;`
5. 删除 `src/cli/` 目录
## 验证
- `cargo build` 通过
- 功能保持不变

View File

@ -1,877 +0,0 @@
# Phase 1: Storage 基础 实现计划
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
**Goal:** 创建 `src/storage/` 模块,实现 SQLite 持久化层,为后续 Session 扩展提供 Storage 基础设施。
**Architecture:** 使用 `sqlx` + `sqlite`,通过 `SqlitePool` 实现异步连接池,所有 Storage 操作均为 async`Storage` 内部管理连接池的生命周期。
**Tech Stack:** `sqlx` (sqlite, tokio), `serde`, `chrono` (时间戳), `tokio::time::sleep` (重试退避)
---
## Task 1: 添加依赖
**Files:**
- Modify: `Cargo.toml:36` (在 `[dependencies]` 末尾添加)
**Step 1: 添加 sqlx + sqlite 依赖**
`Cargo.toml` 末尾添加:
```toml
sqlx = { version = "0.8", features = ["sqlite", "tokio", "macros", "chrono"] }
```
**Step 2: 运行 cargo check 验证依赖**
Run: `cargo check 2>&1`
Expected: 无报错,依赖解析成功
**Step 3: Commit**
```bash
git add Cargo.toml
git commit -m "deps: 添加 sqlx + sqlite 依赖"
```
---
## Task 2: 创建 Storage Error 类型
**Files:**
- Create: `src/storage/error.rs`
**Step 1: 编写 StorageError 枚举和测试**
```rust
use thiserror::Error;
#[derive(Error, Debug)]
pub enum StorageError {
#[error("session not found: {0}")]
NotFound(String),
#[error("session already exists: {0}")]
AlreadyExists(String),
#[error("database error: {0}")]
Database(#[from] sqlx::Error),
#[error("serialization error: {0}")]
Serialization(String),
}
```
**Step 2: 验证编译**
Run: `cargo build --lib 2>&1 | head -30`
Expected: 报错 "cannot find module `storage`"(因为模块未创建),这是预期的
**Step 3: Commit**
```bash
git add src/storage/error.rs
git commit -m "feat(storage): 添加 StorageError 类型"
```
---
## Task 3: 创建 Storage 模块骨架
**Files:**
- Create: `src/storage/mod.rs`
- Create: `src/storage/session.rs`
- Create: `src/storage/message.rs`
**Step 1: 创建 `src/storage/mod.rs`**
```rust
pub mod error;
pub mod session;
pub mod message;
pub use error::StorageError;
```
**Step 2: 创建 `src/storage/session.rs`(空壳)**
```rust
// Session CRUD 操作占位符
```
**Step 3: 创建 `src/storage/message.rs`(空壳)**
```rust
// Message CRUD 操作占位符
```
**Step 4: 在 `src/lib.rs` 中添加 storage 模块**
`src/lib.rs` 末尾添加:
```rust
pub mod storage;
```
**Step 5: 验证编译**
Run: `cargo build --lib 2>&1`
Expected: 编译成功(空壳模块)
**Step 6: Commit**
```bash
git add src/storage/ src/lib.rs
git commit -m "feat(storage): 创建 storage 模块骨架"
```
---
## Task 4: 实现 Storage 主结构
**Files:**
- Modify: `src/storage/mod.rs`
**Step 1: 编写 Storage 结构和初始化逻辑**
```rust
use sqlx::{Pool, Sqlite, SqlitePool};
use std::path::Path;
pub struct Storage {
pool: Pool<Sqlite>,
}
impl Storage {
/// 打开或创建数据库
pub async fn new(db_path: &Path) -> Result<Self, StorageError> {
let database_url = format!("sqlite:{}?mode=rwc", db_path.display());
let pool = SqlitePool::connect(&database_url).await?;
let storage = Self { pool };
storage.init_schema().await?;
Ok(storage)
}
/// 初始化数据库 schema
async fn init_schema(&self) -> Result<(), StorageError> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
channel TEXT NOT NULL,
chat_id TEXT NOT NULL,
dialog_id TEXT NOT NULL,
title TEXT NOT NULL DEFAULT '新对话',
created_at INTEGER NOT NULL,
last_active_at INTEGER NOT NULL,
message_count INTEGER DEFAULT 0,
routing_info TEXT,
deleted_at INTEGER,
UNIQUE(channel, chat_id, dialog_id)
)
"#,
)
.execute(&self.pool)
.await?;
sqlx::query(
r#"
CREATE INDEX IF NOT EXISTS idx_sessions_chat
ON sessions(channel, chat_id, deleted_at)
"#,
)
.execute(&self.pool)
.await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS messages (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
seq INTEGER NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
media_refs TEXT,
tool_call_id TEXT,
tool_name TEXT,
tool_calls TEXT,
created_at INTEGER NOT NULL,
FOREIGN KEY (session_id) REFERENCES sessions(id) ON DELETE CASCADE
)
"#,
)
.execute(&self.pool)
.await?;
sqlx::query(
r#"
CREATE INDEX IF NOT EXISTS idx_messages_session_seq
ON messages(session_id, seq)
"#,
)
.execute(&self.pool)
.await?;
Ok(())
}
/// 获取连接池引用(供内部 CRUD 使用)
pub(crate) fn pool(&self) -> &Pool<Sqlite> {
&self.pool
}
}
```
**Step 2: 验证编译**
Run: `cargo build --lib 2>&1`
Expected: 编译成功
**Step 3: Commit**
```bash
git add src/storage/mod.rs
git commit -m "feat(storage): 实现 Storage 主结构和初始化"
```
---
## Task 5: 定义 SessionMeta 和 MessageMeta 数据结构
**Files:**
- Modify: `src/storage/session.rs`
- Modify: `src/storage/message.rs`
**Step 1: 在 `session.rs` 中定义 SessionMeta**
```rust
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionMeta {
pub id: String,
pub channel: String,
pub chat_id: String,
pub dialog_id: String,
pub title: String,
pub created_at: i64,
pub last_active_at: i64,
pub message_count: i64,
pub routing_info: Option<String>,
pub deleted_at: Option<i64>,
}
```
**Step 2: 在 `message.rs` 中定义 MessageMeta**
```rust
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageMeta {
pub id: String,
pub session_id: String,
pub seq: i64,
pub role: String,
pub content: String,
pub media_refs: Option<String>,
pub tool_call_id: Option<String>,
pub tool_name: Option<String>,
pub tool_calls: Option<String>,
pub created_at: i64,
}
```
**Step 3: 验证编译**
Run: `cargo build --lib 2>&1`
Expected: 编译成功
**Step 4: Commit**
```bash
git add src/storage/session.rs src/storage/message.rs
git commit -m "feat(storage): 定义 SessionMeta 和 MessageMeta 数据结构"
```
---
## Task 6: 实现 Session CRUD 操作
**Files:**
- Modify: `src/storage/session.rs`
**Step 1: 编写 upsert_session**
```rust
use sqlx::Row;
use super::SessionMeta;
impl Storage {
pub async fn upsert_session(&self, meta: &SessionMeta) -> Result<(), StorageError> {
sqlx::query(
r#"
INSERT INTO sessions (id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
title = excluded.title,
last_active_at = excluded.last_active_at,
message_count = excluded.message_count,
routing_info = excluded.routing_info,
deleted_at = excluded.deleted_at
"#,
)
.bind(&meta.id)
.bind(&meta.channel)
.bind(&meta.chat_id)
.bind(&meta.dialog_id)
.bind(&meta.title)
.bind(meta.created_at)
.bind(meta.last_active_at)
.bind(meta.message_count)
.bind(&meta.routing_info)
.bind(meta.deleted_at)
.execute(self.pool())
.await?;
Ok(())
}
pub async fn get_session(&self, id: &str) -> Result<SessionMeta, StorageError> {
let row = sqlx::query(
r#"
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at
FROM sessions WHERE id = ? AND deleted_at IS NULL
"#,
)
.bind(id)
.fetch_optional(self.pool())
.await?
.ok_or_else(|| StorageError::NotFound(id.to_string()))?;
Ok(SessionMeta {
id: row.get("id"),
channel: row.get("channel"),
chat_id: row.get("chat_id"),
dialog_id: row.get("dialog_id"),
title: row.get("title"),
created_at: row.get("created_at"),
last_active_at: row.get("last_active_at"),
message_count: row.get("message_count"),
routing_info: row.get("routing_info"),
deleted_at: row.get("deleted_at"),
})
}
pub async fn list_sessions(
&self,
channel: &str,
chat_id: &str,
limit: i64,
) -> Result<Vec<SessionMeta>, StorageError> {
let rows = sqlx::query(
r#"
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at
FROM sessions
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL
ORDER BY last_active_at DESC
LIMIT ?
"#,
)
.bind(channel)
.bind(chat_id)
.bind(limit)
.fetch_all(self.pool())
.await?;
Ok(rows
.into_iter()
.map(|row| SessionMeta {
id: row.get("id"),
channel: row.get("channel"),
chat_id: row.get("chat_id"),
dialog_id: row.get("dialog_id"),
title: row.get("title"),
created_at: row.get("created_at"),
last_active_at: row.get("last_active_at"),
message_count: row.get("message_count"),
routing_info: row.get("routing_info"),
deleted_at: row.get("deleted_at"),
})
.collect())
}
pub async fn touch_session(
&self,
id: &str,
message_count: i64,
last_active_at: i64,
) -> Result<(), StorageError> {
sqlx::query(
r#"
UPDATE sessions SET message_count = ?, last_active_at = ?
WHERE id = ?
"#,
)
.bind(message_count)
.bind(last_active_at)
.bind(id)
.execute(self.pool())
.await?;
Ok(())
}
pub async fn soft_delete_session(&self, id: &str) -> Result<(), StorageError> {
let now = chrono::Utc::now().timestamp_millis();
sqlx::query(
r#"UPDATE sessions SET deleted_at = ? WHERE id = ?"#,
)
.bind(now)
.bind(id)
.execute(self.pool())
.await?;
Ok(())
}
/// 查找 channel:chat_id 下最近活跃且未过期的 session
pub async fn find_active_session(
&self,
channel: &str,
chat_id: &str,
ttl_millis: i64,
) -> Result<Option<SessionMeta>, StorageError> {
let cutoff = chrono::Utc::now().timestamp_millis() - ttl_millis;
let row = sqlx::query(
r#"
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at
FROM sessions
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL AND last_active_at > ?
ORDER BY last_active_at DESC
LIMIT 1
"#,
)
.bind(channel)
.bind(chat_id)
.bind(cutoff)
.fetch_optional(self.pool())
.await?;
match row {
Some(row) => Ok(Some(SessionMeta {
id: row.get("id"),
channel: row.get("channel"),
chat_id: row.get("chat_id"),
dialog_id: row.get("dialog_id"),
title: row.get("title"),
created_at: row.get("created_at"),
last_active_at: row.get("last_active_at"),
message_count: row.get("message_count"),
routing_info: row.get("routing_info"),
deleted_at: row.get("deleted_at"),
})),
None => Ok(None),
}
}
}
```
> 注意:`Storage` 的 CRUD 方法需要能访问 `pool()`,但目前 `pool()``pub(crate)`。在 `mod.rs` 中为 `session.rs` 实现 `Storage` 的 CRUD所以同模块内可访问。
**Step 2: 验证编译**
Run: `cargo build --lib 2>&1`
Expected: 编译成功
**Step 3: Commit**
```bash
git add src/storage/session.rs
git commit -m "feat(storage): 实现 Session CRUD 操作"
```
---
## Task 7: 实现 Message CRUD 操作
**Files:**
- Modify: `src/storage/message.rs`
**Step 1: 编写 Message CRUD**
```rust
use sqlx::Row;
use super::MessageMeta;
impl Storage {
pub async fn append_message(&self, session_id: &str, msg: &MessageMeta) -> Result<i64, StorageError> {
sqlx::query(
r#"
INSERT INTO messages (id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"#,
)
.bind(&msg.id)
.bind(session_id)
.bind(msg.seq)
.bind(&msg.role)
.bind(&msg.content)
.bind(&msg.media_refs)
.bind(&msg.tool_call_id)
.bind(&msg.tool_name)
.bind(&msg.tool_calls)
.bind(msg.created_at)
.execute(self.pool())
.await?;
Ok(msg.seq)
}
pub async fn append_messages(
&self,
session_id: &str,
msgs: &[MessageMeta],
) -> Result<Vec<i64>, StorageError> {
let mut seqs = Vec::with_capacity(msgs.len());
for msg in msgs {
let seq = self.append_message(session_id, msg).await?;
seqs.push(seq);
}
Ok(seqs)
}
pub async fn load_messages(
&self,
session_id: &str,
from_seq: i64,
) -> Result<Vec<MessageMeta>, StorageError> {
let rows = sqlx::query(
r#"
SELECT id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, created_at
FROM messages
WHERE session_id = ? AND seq >= ?
ORDER BY seq ASC
"#,
)
.bind(session_id)
.bind(from_seq)
.fetch_all(self.pool())
.await?;
Ok(rows
.into_iter()
.map(|row| MessageMeta {
id: row.get("id"),
session_id: row.get("session_id"),
seq: row.get("seq"),
role: row.get("role"),
content: row.get("content"),
media_refs: row.get("media_refs"),
tool_call_id: row.get("tool_call_id"),
tool_name: row.get("tool_name"),
tool_calls: row.get("tool_calls"),
created_at: row.get("created_at"),
})
.collect())
}
pub async fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> {
sqlx::query(r#"DELETE FROM messages WHERE session_id = ?"#)
.bind(session_id)
.execute(self.pool())
.await?;
Ok(())
}
}
```
> 注意:同样在 `mod.rs` 中实现,这样 `Storage` 的方法对 `message.rs` 中的 impl 可见。
**Step 2: 验证编译**
Run: `cargo build --lib 2>&1`
Expected: 编译成功
**Step 3: Commit**
```bash
git add src/storage/message.rs
git commit -m "feat(storage): 实现 Message CRUD 操作"
```
---
## Task 8: 实现写入重试逻辑
**Files:**
- Modify: `src/storage/mod.rs`
**Step 1: 在 Storage 中添加带重试的 append_message**
`mod.rs``Storage` impl 块中添加:
```rust
use tokio::time::{sleep, Duration};
impl Storage {
/// 追加消息,带重试逻辑
/// 重试 3 次100/200/300ms 退避),仍失败返回错误
pub async fn append_message_with_retry(
&self,
session_id: &str,
msg: &MessageMeta,
) -> Result<i64, StorageError> {
let delays = [100, 200, 300];
for (i, delay) in delays.iter().enumerate() {
match self.append_message(session_id, msg).await {
Ok(seq) => return Ok(seq),
Err(e) if i < delays.len() - 1 => {
sleep(Duration::from_millis(*delay)).await;
tracing::warn!("Storage write failed, retrying: {}", e);
}
Err(e) => {
tracing::error!("Storage write failed after retries: {}", e);
return Err(e);
}
}
}
unreachable!()
}
}
```
> 注意:需要 `use sqlx::Row;``mod.rs` 中。
**Step 2: 验证编译**
Run: `cargo build --lib 2>&1`
Expected: 编译成功
**Step 3: Commit**
```bash
git add src/storage/mod.rs
git commit -m "feat(storage): 添加写入重试逻辑"
```
---
## Task 9: 编写 Storage 单元测试
**Files:**
- Modify: `src/storage/mod.rs`(添加测试模块)
**Step 1: 编写 Storage 集成测试**
`src/storage/mod.rs` 末尾添加:
```rust
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
use std::path::Path;
async fn create_test_storage() -> (Storage, impl Fn()) {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let storage = Storage::new(&db_path).await.unwrap();
let cleanup = || {
drop(dir);
};
(storage, cleanup)
}
#[tokio::test]
async fn test_upsert_and_get_session() {
let (storage, cleanup) = create_test_storage().await;
defer { cleanup(); }
let meta = SessionMeta {
id: "cli_chat:sid123:dialog1".to_string(),
channel: "cli_chat".to_string(),
chat_id: "sid123".to_string(),
dialog_id: "dialog1".to_string(),
title: "测试会话".to_string(),
created_at: 1000,
last_active_at: 1000,
message_count: 0,
routing_info: Some(r#"{"type":"cli"}"#.to_string()),
deleted_at: None,
};
storage.upsert_session(&meta).await.unwrap();
let loaded = storage.get_session(&meta.id).await.unwrap();
assert_eq!(loaded.title, "测试会话");
assert_eq!(loaded.channel, "cli_chat");
}
#[tokio::test]
async fn test_get_nonexistent_session() {
let (storage, cleanup) = create_test_storage().await;
defer { cleanup(); }
let result = storage.get_session("nonexistent").await;
assert!(result.is_err());
matches!(result.unwrap_err(), StorageError::NotFound(_));
}
#[tokio::test]
async fn test_list_sessions() {
let (storage, cleanup) = create_test_storage().await;
defer { cleanup(); }
for i in 0..5 {
let meta = SessionMeta {
id: format!("cli_chat:sid123:dialog{}", i),
channel: "cli_chat".to_string(),
chat_id: "sid123".to_string(),
dialog_id: format!("dialog{}", i),
title: format!("会话{}", i),
created_at: i as i64 * 1000,
last_active_at: i as i64 * 1000,
message_count: i,
routing_info: None,
deleted_at: None,
};
storage.upsert_session(&meta).await.unwrap();
}
let sessions = storage.list_sessions("cli_chat", "sid123", 10).await.unwrap();
assert_eq!(sessions.len(), 5);
// 按 last_active_at DESC 排序
assert_eq!(sessions[0].dialog_id, "dialog4");
}
#[tokio::test]
async fn test_soft_delete() {
let (storage, cleanup) = create_test_storage().await;
defer { cleanup(); }
let meta = SessionMeta {
id: "cli_chat:sid123:dialog1".to_string(),
channel: "cli_chat".to_string(),
chat_id: "sid123".to_string(),
dialog_id: "dialog1".to_string(),
title: "测试".to_string(),
created_at: 1000,
last_active_at: 1000,
message_count: 0,
routing_info: None,
deleted_at: None,
};
storage.upsert_session(&meta).await.unwrap();
storage.soft_delete_session(&meta.id).await.unwrap();
let result = storage.get_session(&meta.id).await;
assert!(result.is_err());
matches!(result.unwrap_err(), StorageError::NotFound(_));
}
#[tokio::test]
async fn test_append_and_load_messages() {
let (storage, cleanup) = create_test_storage().await;
defer { cleanup(); }
let session_meta = SessionMeta {
id: "cli_chat:sid123:dialog1".to_string(),
channel: "cli_chat".to_string(),
chat_id: "sid123".to_string(),
dialog_id: "dialog1".to_string(),
title: "测试".to_string(),
created_at: 1000,
last_active_at: 1000,
message_count: 0,
routing_info: None,
deleted_at: None,
};
storage.upsert_session(&session_meta).await.unwrap();
let msg = MessageMeta {
id: "msg1".to_string(),
session_id: session_meta.id.clone(),
seq: 1,
role: "user".to_string(),
content: "你好".to_string(),
media_refs: None,
tool_call_id: None,
tool_name: None,
tool_calls: None,
created_at: 1000,
};
let seq = storage.append_message(&session_meta.id, &msg).await.unwrap();
assert_eq!(seq, 1);
let loaded = storage.load_messages(&session_meta.id, 0).await.unwrap();
assert_eq!(loaded.len(), 1);
assert_eq!(loaded[0].content, "你好");
}
#[tokio::test]
async fn test_touch_session() {
let (storage, cleanup) = create_test_storage().await;
defer { cleanup(); }
let meta = SessionMeta {
id: "cli_chat:sid123:dialog1".to_string(),
channel: "cli_chat".to_string(),
chat_id: "sid123".to_string(),
dialog_id: "dialog1".to_string(),
title: "测试".to_string(),
created_at: 1000,
last_active_at: 1000,
message_count: 0,
routing_info: None,
deleted_at: None,
};
storage.upsert_session(&meta).await.unwrap();
storage.touch_session(&meta.id, 5, 2000).await.unwrap();
let loaded = storage.get_session(&meta.id).await.unwrap();
assert_eq!(loaded.message_count, 5);
assert_eq!(loaded.last_active_at, 2000);
}
}
```
> 需要在 `Cargo.toml` 中添加 `tempfile` 依赖(已存在)。`defer` 宏可自己实现一个简单的:`fn defer<F: FnOnce()>(f: F) { f() }`
**Step 2: 运行测试**
Run: `cargo test storage::tests --lib 2>&1`
Expected: 所有 7 个测试 PASS
**Step 3: Commit**
```bash
git add src/storage/mod.rs
git commit -m "test(storage): 编写 Storage 单元测试"
```
---
## 汇总
| Task | 改动文件 | 关键交付物 |
|------|----------|-----------|
| 1 | `Cargo.toml` | sqlx 依赖 |
| 2 | `src/storage/error.rs` | StorageError |
| 3 | `src/storage/{mod.rs,session.rs,message.rs}`, `src/lib.rs` | 模块骨架 |
| 4 | `src/storage/mod.rs` | Storage 结构 + init_schema |
| 5 | `src/storage/session.rs`, `message.rs` | SessionMeta, MessageMeta |
| 6 | `src/storage/session.rs` | Session CRUD |
| 7 | `src/storage/message.rs` | Message CRUD |
| 8 | `src/storage/mod.rs` | append_message_with_retry |
| 9 | `src/storage/mod.rs` | 7 个单元测试 |
**Phase 1 完成后:** Storage 模块可独立使用,具备完整的持久化能力,可安全地集成到 Session 和 SessionManager 中。

View File

@ -1,278 +0,0 @@
# Session 持久化设计方案
## 概述
为 PicoBot 添加 SQLite 持久化层,实现 Session 数据的持久化、完整 Dialog 生命周期管理、消息实时落盘、以及基于 TTL 的自动内存清理。
## 核心概念
```
UnifiedSessionId = {channel}:{chat_id}:{dialog_id}
Session = Dialog两者等价不再分层
```
每个 Session 独立管理自己的消息历史、LLM 配置和路由信息。
## 数据库 Schema
### sessions 表
```sql
CREATE TABLE sessions (
id TEXT PRIMARY KEY,
channel TEXT NOT NULL,
chat_id TEXT NOT NULL,
dialog_id TEXT NOT NULL,
title TEXT NOT NULL DEFAULT '新对话',
created_at INTEGER NOT NULL,
last_active_at INTEGER NOT NULL,
message_count INTEGER DEFAULT 0,
routing_info TEXT,
deleted_at INTEGER,
UNIQUE(channel, chat_id, dialog_id)
);
CREATE INDEX idx_sessions_chat ON sessions(channel, chat_id, deleted_at);
```
### messages 表
```sql
CREATE TABLE messages (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
seq INTEGER NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
media_refs TEXT,
tool_call_id TEXT,
tool_name TEXT,
tool_calls TEXT,
created_at INTEGER NOT NULL,
FOREIGN KEY (session_id) REFERENCES sessions(id) ON DELETE CASCADE
);
CREATE INDEX idx_messages_session_seq ON messages(session_id, seq);
```
## Storage API
### Session 操作
| 方法 | 说明 |
|------|------|
| `new(db_path) -> Storage` | 打开/创建数据库 |
| `upsert_session(meta) -> Result<(), StorageError>` | 插入或更新 session 元数据 |
| `get_session(id) -> Result<SessionMeta, StorageError>` | 获取单个 session |
| `list_sessions(channel, chat_id, limit) -> Result<Vec<SessionMeta>>` | 最近 N 条 |
| `touch_session(id, message_count, last_active_at)` | 更新计数和最后活跃时间 |
| `soft_delete_session(id) -> Result<(), StorageError>` | 软删除 |
### Message 操作
| 方法 | 说明 |
|------|------|
| `append_message(session_id, msg) -> Result<i64, StorageError>` | 追加单条消息,返回 seq |
| `append_messages(session_id, msgs) -> Result<Vec<i64>, StorageError>` | 批量追加 |
| `load_messages(session_id, from_seq) -> Result<Vec<MessageMeta>>` | 从指定 seq 加载 |
| `clear_messages(session_id) -> Result<(), StorageError>` | 清除消息(保留 session |
### 写入失败处理
重试 3 次100/200/300ms 退避),仍失败则发送系统通知告警。
## Session 结构
```rust
pub struct Session {
pub id: UnifiedSessionId,
pub title: String,
pub created_at: i64,
pub last_active_at: i64,
pub message_count: i64, // 用户消息计数
pub total_message_count: i64, // 含系统消息
messages: Vec<ChatMessage>, // 内存消息历史
seq_counter: i64, // 下一个消息的 seq
provider_config: LLMProviderConfig,
provider: Arc<dyn LLMProvider>,
tools: Arc<ToolRegistry>,
compressor: ContextCompressor,
user_tx: mpsc::Sender<WsOutbound>,
storage: Arc<Storage>, // 持久化 sink
routing_info: String, // JSON 路由信息
}
```
### 初始化流程
```
new() 或 from_storage()
注入 storage 引用
创建 provider, tools, compressor
从 Storage 加载 messagesfrom_seq = 0
设置 seq_counter = messages.len() + 1
返回 Session 实例
```
## handle_message 流程
```
handle_message(channel, chat_id, sender_id, content, media)
├── 1. 确定 dialog_id
│ │
│ ├── 显式传入 dialog_id → 使用
│ └── 无 dialog_id
│ ├── 查找 channel:chat_id 下最近活跃且未过期的 session
│ ├── 找到 → 使用该 session
│ └── 未找到 → 创建新 sessiondialog_id = 新随机 ID
├── 2. 获取或创建 Session
│ 有 → 更新 session_timestamps
│ 无 → 从 Storage 恢复 或 创建新 Session
├── 3. 追加用户消息并持久化
│ seq = seq_counter; seq_counter += 1
│ Storage.append_message()(失败重试 → 告警)
│ messages.push(user_msg)
│ message_count += 1
├── 4. 检查 title 自动生成
│ message_count == 10 且 title == 默认值 → LLM 生成 → 更新 title → 写回 Storage
├── 5. 注入 skills_prompt
├── 6. 新 session 注入欢迎消息(系统消息,不计入 message_count
├── 7. 上下文压缩(如需要)
├── 8. 调用 AgentLoop
├── 9. 持久化 Agent 响应
└── 10. 返回响应
```
## Dialog 生命周期命令
| 命令 | 行为 |
|------|------|
| `/new [标题]` | 创建新 dialog新随机 dialog_id新建 Session |
| `/sessions` | 列出 channel:chat_id 下最近 10 条 session按 last_active_at 倒序) |
| `/switch <dialog_id>` | 切换到指定 session从 Storage 恢复或内存命中) |
| `/rename <新标题>` | 重命名当前 session |
| `/delete` | 软删除当前 session内存移除 + Storage 标记 deleted_at |
| `/info` | 显示当前 session 信息 |
| `/compact` | 手动触发上下文压缩 |
## 路由信息
每种 Channel 在创建 Session 时注入路由信息:
```rust
// CLI
routing_info = json!({"type": "cli", "ws_sender_id": "xxx"})
// Feishu
routing_info = json!({"type": "feishu", "open_conversation_id": "oc_xxx", "tenant_key": "xxx"})
```
## Title 自动生成
调用时机:
1. Session 首次创建时(初始 title = "新对话"
2. `message_count` 达到 10 且 title 仍为默认值时,自动更新
生成 Prompt
```
给定以下对话历史生成一个简短的会话标题5-15 个中文字符),
概括这个对话的核心内容或用户的主要需求。只返回一个标题,不要解释。
历史:
{messages}
```
## TTL 清理
- 内存 session 超时 → 释放内存Storage 记录保留
- 用户切换回该 session → 从 Storage 重新加载到内存
- Storage 中的 session 记录通过 `deleted_at` 软删除,不会物理删除
## 文件结构
```
src/
├── storage/
│ ├── mod.rs # Storage 主模块
│ ├── session.rs # Session CRUD
│ ├── message.rs # Message CRUD
│ └── error.rs # StorageError
└── session/
├── mod.rs # 导出 Session, SessionManager
├── session.rs # Session, SessionManager 实现
├── session_id.rs # UnifiedSessionId
├── commands.rs # SessionCommand
├── events.rs # SessionEvent, DialogInfo
└── error.rs # SessionError
```
## 实现顺序
### Phase 1: Storage 基础
1. 添加 `sqlx` + `sqlite` 依赖
2. 实现 `Storage` 结构(连接池、初始化)
3. Session CRUD + Message CRUD
4. 写入重试逻辑
5. 单元测试
### Phase 2: Session 扩展
1. 扩展 `Session` 结构(添加 storage、routing_info、计数字段、seq_counter
2. `from_storage()` 恢复逻辑
3. `add_message` 持久化集成
4. `send_system_notification` 接口
5. Title 自动生成
### Phase 3: SessionManager 完善
1. 注入 `Arc<Storage>`
2. 实现 `list_dialogs()`
3. 实现 `switch_dialog()`
4. 实现 `delete_dialog()` / `rename_dialog()`
5. 后台 TTL 清理任务
6. 集成测试
### Phase 4: 斜杠命令
1. 实现 `/sessions`
2. 实现 `/switch`
3. 实现 `/rename`
4. 实现 `/delete`
5. 端到端测试
## 配置项
```json
{
"session": {
"ttl_hours": 24,
"cleanup_interval_minutes": 60,
"auto_title_after_n_messages": 10,
"storage_retry_delays_ms": [100, 200, 300]
}
}
```
## 与现有代码的冲突点
| 冲突 | 处理方式 |
|------|----------|
| `DialogInfo``archived_at` | 删除该字段,改用 `deleted_at` |
| `SessionCommand::ArchiveDialog` | 删除 |
| `/new` 现有行为 | 改为创建新 session新 dialog_id |
| 现有 `Session` 无 storage/routing_info | 扩展结构,新增 `from_storage()` |
| `SessionManager` 需注入 `Arc<Storage>` | 扩展构造方法 |
| stub 方法 | 实现 |

View File

@ -1,226 +0,0 @@
# PicoBot Memory System Design
Date: 2026-05-07
## 1. Overview
Introduce a memory system that allows PicoBot agents to remember user preferences, project context, facts, and conversation history across sessions. The memory system is **unified with the existing context compression pipeline**: compression automatically produces `timeline` memory entries and advances a `last_consolidated_at` pointer to avoid redundant reprocessing.
### Design Principles
- **Compression is memory** (inspired by nanobot): when old messages are compressed, the summary is persisted — not discarded
- **FTS5 only** (no vector embeddings): keyword search via SQLite FTS5, sufficient for current scale
- **Extend existing infrastructure**: reuse `Storage` connection pool, `ContextCompressor`, `SystemPromptBuilder`
- **YAGNI**: no knowledge graph, no response cache, no namespace isolation, no audit trail
## 2. Core Architecture
```
ContextCompressor (existing) MemoryManager (new)
│ │
│ compress_if_needed() │ store / recall / forget
│ ├─ LLM summary → inject │
│ └─ store(timeline entry) ──────┘
│ └─ advance last_consolidated_at
SystemPromptBuilder ── recall(knowledge, limit=5) ──→ inject into system prompt
AgentLoop ── after_turn ──→ memory_store / memory_recall / memory_forget tools
```
## 3. Memory Categories
| Category | Purpose | Written By | Retrieved By |
|----------|---------|-----------|--------------|
| `knowledge` | Long-term facts, preferences, patterns, insights | Agent via `memory_store` tool | FTS5 → injected into system prompt every turn |
| `timeline` | Compressed conversation summaries | ContextCompressor automatically | FTS5 + time-range queries |
## 4. Storage Schema
### New table: `memories`
Added to the existing `Storage` initialization in `src/storage/mod.rs`:
```sql
CREATE TABLE IF NOT EXISTS memories (
id TEXT PRIMARY KEY,
key TEXT NOT NULL UNIQUE,
content TEXT NOT NULL,
category TEXT NOT NULL DEFAULT 'knowledge',
importance REAL NOT NULL DEFAULT 0.5,
session_id TEXT,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);
CREATE VIRTUAL TABLE IF NOT EXISTS memory_fts USING fts5(
key,
content,
content=memories,
content_rowid=rowid
);
```
### Modified table: `sessions`
```sql
ALTER TABLE sessions ADD COLUMN last_consolidated_at INTEGER;
```
## 5. Unified Compression-Memory Pipeline
### Trigger Conditions
Compression/consolidation fires when **any** of these conditions is met:
| Condition | Value | Rationale |
|-----------|-------|-----------|
| Token budget exceeds 50% threshold | `context_window / 2` | Primary trigger — context is getting full |
| Accumulated N turns without consolidation | 3 (configurable) | Catch-up for short messages that don't hit token threshold |
| Session idle | 10 minutes (configurable) | Important for async channels like Feishu |
### Flow
```
compress_if_needed(history, session_id):
1. Read last_consolidated_at from session
→ Only compress messages after that timestamp
2. If no messages to compress → return history unchanged
3. FTS5 recall(user_input, limit=recall_limit, category=knowledge)
→ Inject relevant facts into system prompt
4. LLM summarization of old messages → [Context Summary]
→ Inject into current conversation
5. Store summary as timeline entry:
key: "ctx_{session_id}_{uuid}"
content: "[YYYY-MM-DD HH:MM] summary text..."
category: timeline
6. UPDATE sessions.last_consolidated_at = now()
7. Return compressed history
```
### timeline Entry Format
Each timeline entry follows nanobot's convention:
```
[2026-05-07 14:30] User asked about Rust async patterns. Discussed tokio::select!,
semaphore-based rate limiting, and backpressure strategies. No code was written.
```
This format is grep-friendly and human-readable.
## 6. Retrieval Strategy
### Automatic Retrieval (every turn)
`SystemPromptBuilder.build_system_prompt()` calls:
```rust
memory.recall(query=user_message, limit=recall_limit, category=knowledge)
```
Results sorted by FTS5 BM25 score, injected as:
```
## Memory Context
- user_prefers_rust: User prefers Rust for all backend projects
- project_picobot_stack: PicoBot uses Rust, axum, sqlx, ratatui, tokio
- user_workflow: User prefers TDD workflow with cargo test --lib
```
### Agent-Initiated Retrieval
Agent uses `memory_recall` tool with optional `category`, `since`, `until` parameters.
### Fallback
If FTS5 returns empty results, fallback to `LIKE '%keyword%'` on `key` and `content` columns.
## 7. Agent Tools
| Tool | Parameters | Description |
|------|-----------|-------------|
| `memory_store` | `key: str`, `content: str`, `category: str`, `importance?: f64` | Write or update a memory entry. Key is semantic identifier (e.g., "user_language_pref") |
| `memory_recall` | `query: str`, `category?: str`, `since?: i64`, `until?: i64`, `limit?: usize` | Search memories by keyword and optional filters |
| `memory_forget` | `key: str` | Delete a memory entry by key |
## 8. Error Handling & Degradation
| Scenario | Strategy |
|----------|----------|
| Consolidation LLM call fails | Log warning, increment failure counter, do NOT block main flow |
| Consecutive failures >= 3 | Degrade: append raw message dump to timeline with `[RAW]` prefix, reset counter |
| FTS5 recall returns empty | Fallback to `LIKE '%keyword%'` query |
| `memory.enabled = false` | ContextCompressor works normally, no memory writes |
| MemoryManager uninitialized | ContextCompressor works with feature-gated memory write path |
## 9. Configuration
```json
{
"memory": {
"enabled": true,
"consolidation_provider": "openai",
"consolidation_model": "gpt-4o-mini",
"recall_limit": 5,
"consolidation_turn_threshold": 3,
"idle_consolidation_minutes": 10,
"timeline_retention_days": 90,
"max_failures_before_degrade": 3
}
}
```
| Key | Type | Default | Description |
|-----|------|---------|-------------|
| `enabled` | bool | `false` | Master switch for memory system |
| `consolidation_provider` | string | — | Provider name for consolidation LLM calls |
| `consolidation_model` | string | — | Model name for consolidation |
| `recall_limit` | usize | `5` | Max knowledge entries injected into system prompt |
| `consolidation_turn_threshold` | usize | `3` | Turns before forced consolidation |
| `idle_consolidation_minutes` | u64 | `10` | Idle time before consolidation trigger |
| `timeline_retention_days` | u64 | `90` | Auto-cleanup age for timeline entries |
| `max_failures_before_degrade` | usize | `3` | Consecutive failures before raw archive fallback |
## 10. New Module Structure
```
src/
├── memory/
│ ├── mod.rs # MemoryManager, MemoryConfig
│ ├── types.rs # MemoryEntry, MemoryCategory, ConsolidationResult
│ └── consolidation.rs # Consolidation prompt + LLM call logic
├── storage/
│ └── memory.rs # SQLite CRUD for memories table + FTS5
├── tools/
│ ├── memory_store.rs # memory_store tool
│ ├── memory_recall.rs # memory_recall tool
│ └── memory_forget.rs # memory_forget tool
```
## 11. Integration Points (Existing Files Modified)
| File | Change |
|------|--------|
| `src/lib.rs` | Add `pub mod memory;` |
| `src/config/mod.rs` | Add `MemoryConfig` struct and deserialization |
| `src/storage/mod.rs` | Add `pub mod memory;`, init `memories` table and FTS5 in `init_schema()` |
| `src/storage/session.rs` | Add `last_consolidated_at` column read/write |
| `src/session/session.rs` | Add `last_consolidated_at: Option<i64>` field to Session |
| `src/agent/context_compressor.rs` | Add `memory: Option<Arc<MemoryManager>>` field, write timeline on compress |
| `src/agent/system_prompt.rs` | Add `memory_context` section via `MemoryManager::recall()` |
| `src/agent/agent_loop.rs` | No changes (tools registered via ToolRegistry) |
| `src/tools/mod.rs` | Register `memory_store`, `memory_recall`, `memory_forget` in `create_default_tools()` |
| `src/gateway/mod.rs` | Initialize `MemoryManager` in `GatewayState::new()`, pass to ContextCompressor |
## 12. Implementation Order
| # | Task | Dependencies |
|---|------|-------------|
| 1 | Types: `MemoryEntry`, `MemoryCategory`, `ConsolidationResult` | — |
| 2 | Config: `MemoryConfig` + deserialization | — |
| 3 | Storage: `memories` table + FTS5 + CRUD + search | #1 |
| 4 | `MemoryManager` API | #1, #2, #3 |
| 5 | Session: `last_consolidated_at` field | — |
| 6 | `ContextCompressor` memory integration | #4, #5 |
| 7 | `SystemPromptBuilder` memory context injection | #4 |
| 8 | Agent tools: `memory_store`, `memory_recall`, `memory_forget` | #4 |
| 9 | `GatewayState` initialization wiring | #4, #5, #6 |
| 10 | Unit tests | #1-#9 |

File diff suppressed because it is too large Load Diff

View File

@ -1,90 +0,0 @@
# 启动增量恢复设计
## 问题
PicoBot 重启后,`Session::from_storage()` 全量加载 `messages` 表,恢复的 history 可能直接超出上下文窗口,首次 LLM 调用即触发压缩,浪费 token。
## 设计
### 核心思路
`last_compressed_message_at` 标记最后压缩时刻。恢复时:
- 加载该标记之后的原始消息
- 用该 session 的 Timeline 条目替代已压缩部分
- `seq_counter` 统一从 SQLite 查 `MAX(seq) + 1`
```
messages 表 memories(timeline)
┌──────────────────────────┐ ┌───────────────────────────┐
│ created_at = T1..T5 │ ← 跳过 │ session = feishu:oc:dialog │
│ (压缩已覆盖用Timeline替代)│ │ created_at 降序 │
├──────────────────────────┤ ├───────────────────────────┤
│ created_at > T6 │ ← 加载 │ 只取最近 3 条 │
└──────────────────────────┘ └───────────────────────────┘
```
### 数据变更
**`sessions` 表加列:**
```sql
last_compressed_message_at INTEGER
```
**`SessionMeta` / `Session` 加字段:** `last_compressed_message_at: Option<i64>`
### Storage 层新增方法
| 方法 | SQL |
|------|-----|
| `get_max_message_seq(session_id)` | `SELECT MAX(seq) FROM messages WHERE session_id = ?` |
| `load_messages_after_timestamp(session_id, after_ts)` | `WHERE created_at > ?` |
| `load_session_timelines(session_id, limit)` | `WHERE session_id = ? AND category = 'timeline' ORDER BY created_at DESC LIMIT ?` |
### 压缩跟踪
`compress_if_needed()` 返回值改为 `CompressionResult { history, created_timelines: bool }`
`compress_once()` 中 LLM 摘要路径才置 `true`Tier 2Tier 1/3 不产生 Timeline。
**记录时机**`handle_message` 正常流、溢出重试流、`/compact` 统一):
```rust
if result.created_timelines {
session.last_compressed_message_at = Some(now());
session.persist_session_meta().await;
}
```
### Session::from_storage() 恢复逻辑
有压缩标记时:
1. `load_session_timelines(limit=4)` → 取 3 条给 LLM第 4 条判"有更多"
2. 有更多 → 插入提示 user 消息
3. 逐条插入 Timeline 为 `[Previous Context]` user 消息
4. `load_messages_after_timestamp(after_ts)` → 原始尾消息
5. `repair_tool_call_chains`
无压缩标记 → 全量加载(现有行为)。
统一:`seq_counter = MAX(seq) + 1`
### 系统提示词
`Session.last_compressed_message_at` 非空时追加:
```
## 历史会话
之前的对话摘要已归档。如需回顾历史上下文,使用 `timeline_recall` 工具搜索。
```
## 改动清单
| # | 文件 | 改动 |
|---|------|------|
| 1 | `storage/session.rs` | `SessionMeta``last_compressed_message_at` |
| 2 | `storage/mod.rs` | DDL migration + upsert/get_session 加列 |
| 3 | `storage/mod.rs` | 新增 `get_max_message_seq`, `load_messages_after_timestamp` |
| 4 | `storage/memory.rs` | 新增 `load_session_timelines` |
| 5 | `agent/context_compressor.rs` | 返回值改为 `CompressionResult``created_timelines` |
| 6 | `session/session.rs` | `Session` 加字段,`persist_session_meta` 加字段 |
| 7 | `session/session.rs` | `from_storage()` 重写恢复逻辑 |
| 8 | `session/session.rs` | `handle_message()` 压缩后记录标记 |
| 9 | `session/session.rs` | `/compact` 命令压缩后记录标记 |
| 10 | `session/session.rs` | `build_system_prompt()` 注入 `last_compressed_message_at` |

View File

@ -1,674 +0,0 @@
# 启动增量恢复 Implementation Plan
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
**Goal:** PicoBot 重启后不再全量加载 messages 表,而是基于 `last_compressed_message_at` 标记增量恢复,用 Timeline 替代已压缩部分。
**Architecture:** 在 `sessions` 表加 `last_compressed_message_at` 列,`compress_if_needed` 返回值增加 `created_timelines` 标志,恢复时按时间戳增量加载消息并用近 3 条 Timeline 前置,`seq_counter` 统一从 SQLite 查 MAX(seq)。
**Tech Stack:** Rust, sqlx (SQLite), tokio
---
### Task 1: SessionMeta 和数据库 DDL 加列
**Files:**
- Modify: `src/storage/session.rs:15`
- Modify: `src/storage/mod.rs:44-45` (DDL), `:172-180` (migration)
- Modify: `src/storage/mod.rs:317-326` (upsert_session SQL + ON CONFLICT)
- Modify: `src/storage/mod.rs:345-369` (get_session SELECT + struct)
- Modify: `src/storage/mod.rs:380-406`, `:454-479`, `:564-588`, `:728`, `:754` (list_sessions 及测试 mock)
**Step 1: 在 `src/storage/session.rs` SessionMeta 加字段**
`last_consolidated_at: Option<i64>` 后加一行:
```rust
pub last_compressed_message_at: Option<i64>,
```
**Step 2: DDL schema 加列**
`src/storage/mod.rs` 的 CREATE TABLE sessions 中 (line 44)`last_consolidated_at INTEGER` 后加逗号和:
```sql
last_compressed_message_at INTEGER
```
**Step 3: migration 加列**
`src/storage/mod.rs` line 182 之后(现有 migration 的 `); .ok();` 之后),添加新 migration
```rust
// Migration: add last_compressed_message_at column if not exists
sqlx::query(
r#"ALTER TABLE sessions ADD COLUMN last_compressed_message_at INTEGER"#,
)
.execute(&self.pool)
.await
.ok();
```
**Step 4: upsert_session SQL 加列**
`src/storage/mod.rs` line 317: INSERT 列列表加 `last_compressed_message_at`VALUES 加 `?`ON CONFLICT DO UPDATE SET 加 `last_compressed_message_at = excluded.last_compressed_message_at`。line 338 后加 `.bind(meta.last_compressed_message_at)`
**Step 5: get_session SELECT 加列**
`src/storage/mod.rs` line 348: SELECT 列加 `last_compressed_message_at`。line 368 后加:
```rust
last_compressed_message_at: row.get("last_compressed_message_at"),
```
**Step 6: 其他 SELECT sessions 的地方list_sessions 多个变体)**
同样补 `last_compressed_message_at` 到 SELECT 列和 struct 构造。以及测试中的 mock SessionMeta 构造line 728, 754 等)。
**Step 7: 编译检查**
```bash
cargo check 2>&1
```
**Step 8: Commit**
```bash
git add src/storage/session.rs src/storage/mod.rs
git commit -m "feat(storage): add last_compressed_message_at column to sessions"
```
---
### Task 2: Storage 新增加载方法
**Files:**
- Modify: `src/storage/mod.rs` (在 load_messages 之后)
- Modify: `src/storage/memory.rs` (在 cleanup_old_timelines 之后)
**Step 1: `get_max_message_seq`**
`src/storage/mod.rs``load_messages` 函数后面添加:
```rust
pub async fn get_max_message_seq(&self, session_id: &str) -> Result<i64, StorageError> {
let row = sqlx::query(
"SELECT COALESCE(MAX(seq), 0) as max_seq FROM messages WHERE session_id = ?",
)
.bind(session_id)
.fetch_one(self.pool())
.await?;
Ok(row.get::<i64, _>("max_seq"))
}
```
**Step 2: `load_messages_after_timestamp`**
```rust
pub async fn load_messages_after_timestamp(
&self,
session_id: &str,
after_ts: i64,
) -> Result<Vec<crate::storage::message::MessageMeta>, StorageError> {
let rows = sqlx::query(
r#"
SELECT id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at
FROM messages
WHERE session_id = ? AND created_at > ?
ORDER BY seq ASC
"#,
)
.bind(session_id)
.bind(after_ts)
.fetch_all(self.pool())
.await?;
Ok(rows.into_iter().map(|row| crate::storage::message::MessageMeta {
id: row.get("id"),
session_id: row.get("session_id"),
seq: row.get("seq"),
role: row.get("role"),
content: row.get("content"),
media_refs: row.get("media_refs"),
tool_call_id: row.get("tool_call_id"),
tool_name: row.get("tool_name"),
tool_calls: row.get("tool_calls"),
source: row.get("source"),
created_at: row.get("created_at"),
}).collect())
}
```
**Step 3: `load_session_timelines`**
`src/storage/memory.rs``cleanup_old_timelines` 之后line 252 的 `}` 之前)添加:
```rust
pub async fn load_session_timelines(
&self,
session_id: &str,
limit: usize,
) -> Result<Vec<MemoryEntry>, StorageError> {
let rows = sqlx::query(
r#"
SELECT id, key, content, category, importance,
session_id, created_at, updated_at
FROM memories
WHERE session_id = ? AND category = 'timeline'
ORDER BY created_at DESC
LIMIT ?
"#,
)
.bind(session_id)
.bind(limit as i64)
.fetch_all(self.pool())
.await?;
parse_memory_rows(&rows)
}
```
**Step 4: 编译检查**
```bash
cargo check 2>&1
```
**Step 5: Commit**
```bash
git add src/storage/mod.rs src/storage/memory.rs
git commit -m "feat(storage): add load_messages_after_timestamp, load_session_timelines, get_max_message_seq"
```
---
### Task 3: ContextCompressor 引入 CompressionResult
**Files:**
- Modify: `src/agent/context_compressor.rs:172-274` (compress_if_needed)
- Modify: `src/agent/context_compressor.rs:320-402` (compress_once)
**Step 1: 定义 CompressionResult**
在 context_compressor.rs 中 `ContextCompressor` struct 定义之后添加:
```rust
pub struct CompressionResult {
pub history: Vec<ChatMessage>,
pub created_timelines: bool,
}
```
**Step 2: 修改 compress_if_needed 签名和返回**
`pub async fn compress_if_needed(&self, mut history: Vec<ChatMessage>) -> Result<Vec<ChatMessage>, AgentError>` 改为:
```rust
pub async fn compress_if_needed(
&self,
mut history: Vec<ChatMessage>,
) -> Result<CompressionResult, AgentError> {
```
内部的 `return Ok(history)` 改为 `return Ok(CompressionResult { history, created_timelines: false })`Tier 1 fast trim 和不需要压缩时)。
**Step 3: 修改 LLM summarization pass 部分**
在压缩循环中维护一个 `created_timelines` 标志:
```rust
let mut created_timelines = false;
for pass in 0..self.config.max_passes {
// ...
match self.compress_once(...).await {
Ok(Some(compressed)) => {
current_history = compressed;
created_timelines = true;
}
// ...
}
}
```
最后返回:
```rust
Ok(CompressionResult { history: current_history, created_timelines })
```
**Step 4: 更新所有 compress_if_needed 调用方**
所有 `compress_if_needed(history)` 改为 `compress_if_needed(history).await?.history`。在 `handle_message``/compact` 中还需要用到 `created_timelines`
**Step 5: 编译检查**
```bash
cargo check 2>&1
```
**Step 6: Commit**
```bash
git add src/agent/context_compressor.rs src/session/session.rs
git commit -m "feat(compressor): return CompressionResult with created_timelines flag"
```
---
### Task 4: Session 结构体和持久化
**Files:**
- Modify: `src/session/session.rs:52-74` (Session struct)
- Modify: `src/session/session.rs:76-121` (Session::new)
- Modify: `src/session/session.rs:298-320` (persist_session_meta)
**Step 1: Session struct 加字段**
`pub last_consolidated_at: Option<i64>` 后加:
```rust
pub last_compressed_message_at: Option<i64>,
```
**Step 2: Session::new 初始化**
`last_consolidated_at: None` 后加:
```rust
last_compressed_message_at: None,
```
**Step 3: persist_session_meta 加字段**
`last_consolidated_at: self.last_consolidated_at` 后加:
```rust
last_compressed_message_at: self.last_compressed_message_at,
```
**Step 4: 编译检查**
```bash
cargo check 2>&1
```
**Step 5: Commit**
```bash
git add src/session/session.rs
git commit -m "feat(session): add last_compressed_message_at field to Session and persist"
```
---
### Task 5: Session::from_storage() 增量恢复
**Files:**
- Modify: `src/session/session.rs:124-189` (from_storage)
**Step 1: 重写 from_storage**
替换现有实现为:
```rust
pub async fn from_storage(
id: UnifiedSessionId,
provider_config: LLMProviderConfig,
tools: Arc<ToolRegistry>,
storage: StdArc<Storage>,
memory_manager: Arc<crate::memory::MemoryManager>,
) -> Result<Self, AgentError> {
let session_meta = storage.get_session(&id.to_string()).await
.map_err(|e| AgentError::Other(format!("failed to load session from storage: {}", e)))?;
let mut provider_box = create_provider(provider_config.clone())
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
provider_box.set_storage(storage.clone());
let provider: Arc<dyn LLMProvider> = Arc::from(provider_box);
let compressor_config = ContextCompressionConfig {
protect_first_n: 2,
..Default::default()
};
let mut compressor = ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config, memory_manager.clone());
compressor.set_session_id(Some(id.to_string()));
// Determine recovery strategy
let mut chat_messages: Vec<ChatMessage> = Vec::new();
if let Some(after_ts) = session_meta.last_compressed_message_at {
// Load last 4 timelines to determine if there are > 3
let timelines = storage
.load_session_timelines(&id.to_string(), 4)
.await
.unwrap_or_default();
let has_more_timelines = timelines.len() > 3;
// Insert hint if more timelines exist
if has_more_timelines {
chat_messages.push(ChatMessage::user(
"[Earlier conversation summaries exist. \
Use `timeline_recall` to search if needed.]"
));
}
// Insert latest 3 timelines as context (reversed: oldest first)
for tl in timelines.iter().take(3).rev() {
chat_messages.push(ChatMessage::user(format!(
"[Previous Context]\n{}", tl.content
)));
}
// Load raw messages after compressed timestamp
let tail = storage
.load_messages_after_timestamp(&id.to_string(), after_ts)
.await
.unwrap_or_default();
let mut tail_msgs: Vec<ChatMessage> = tail.into_iter().map(|m| {
ChatMessage {
id: m.id,
role: m.role,
content: m.content,
media_refs: m.media_refs.map(|refs| serde_json::from_str(&refs).unwrap_or_default()).unwrap_or_default(),
timestamp: m.created_at,
tool_call_id: m.tool_call_id,
tool_name: m.tool_name,
tool_calls: m.tool_calls
.and_then(|tc| serde_json::from_str::<Vec<crate::providers::ToolCall>>(&tc).ok())
.filter(|v| !v.is_empty()),
source: m.source.and_then(|s| serde_json::from_str(&s).ok()),
}
}).collect();
repair_tool_call_chains(&mut tail_msgs);
chat_messages.extend(tail_msgs);
} else {
// No prior compression — load all messages (existing behavior)
let messages = storage.load_messages(&id.to_string(), 0).await
.map_err(|e| AgentError::Other(format!("failed to load messages from storage: {}", e)))?;
chat_messages = messages.into_iter().map(|m| {
ChatMessage {
id: m.id,
role: m.role,
content: m.content,
media_refs: m.media_refs.map(|refs| serde_json::from_str(&refs).unwrap_or_default()).unwrap_or_default(),
timestamp: m.created_at,
tool_call_id: m.tool_call_id,
tool_name: m.tool_name,
tool_calls: m.tool_calls
.and_then(|tc| serde_json::from_str::<Vec<crate::providers::ToolCall>>(&tc).ok())
.filter(|v| !v.is_empty()),
source: m.source.and_then(|s| serde_json::from_str(&s).ok()),
}
}).collect();
repair_tool_call_chains(&mut chat_messages);
}
// seq_counter from actual DB max
let max_seq = storage
.get_max_message_seq(&id.to_string())
.await
.unwrap_or(0);
let seq_counter = max_seq + 1;
let total_message_count = session_meta.message_count;
Ok(Self {
id: id.clone(),
title: session_meta.title,
created_at: session_meta.created_at,
last_active_at: session_meta.last_active_at,
message_count: session_meta.message_count,
total_message_count,
messages: chat_messages,
seq_counter,
provider_config: provider_config.clone(),
provider: provider.clone(),
tools,
compressor,
storage: Some(storage),
routing_info: session_meta.routing_info.unwrap_or_default(),
last_consolidated_at: session_meta.last_consolidated_at,
last_compressed_message_at: session_meta.last_compressed_message_at,
memory_manager,
})
}
```
**Step 2: 编译检查**
```bash
cargo check 2>&1
```
**Step 3: Commit**
```bash
git add src/session/session.rs
git commit -m "feat(session): incremental recovery from storage using compressed timeline"
```
---
### Task 6: 系统提示词加历史会话提示
**Files:**
- Modify: `src/agent/system_prompt.rs:289-304` (MemorySection)
- Modify: `src/agent/system_prompt.rs:16-23` (PromptContext)
- Modify: `src/agent/system_prompt.rs:343-358` (build_system_prompt free function)
- Modify: `src/session/session.rs:411-426` (build_system_prompt)
**Step 1: PromptContext 加 has_compressed_history 字段**
```rust
pub struct PromptContext<'a> {
pub workspace_dir: &'a Path,
pub model_name: &'a str,
pub tools: &'a ToolRegistry,
pub session_id: Option<&'a str>,
pub memory_context: Option<&'a str>,
pub has_compressed_history: bool,
}
```
**Step 2: 加 HistorySection**
在 MemorySection 后面添加:
```rust
pub struct HistorySection;
impl PromptSection for HistorySection {
fn name(&self) -> &str {
"history"
}
fn build(&self, ctx: &PromptContext<'_>) -> String {
if ctx.has_compressed_history {
"## 历史会话\n之前的对话摘要已归档。如需回顾历史上下文使用 `timeline_recall` 工具搜索。".to_string()
} else {
String::new()
}
}
}
```
**Step 3: 注册到 SystemPromptBuilder::with_defaults**
`with_defaults()` 的 sections vec 中 `Box::new(MemorySection)` 后加 `Box::new(HistorySection)`
**Step 4: 更新 build_system_prompt 签名和调用**
```rust
pub fn build_system_prompt(
workspace_dir: &Path,
model_name: &str,
tools: &ToolRegistry,
session_id: Option<&str>,
memory_context: Option<&str>,
has_compressed_history: bool,
) -> String {
let ctx = PromptContext {
workspace_dir,
model_name,
tools,
session_id,
memory_context,
has_compressed_history,
};
SystemPromptBuilder::with_defaults().build(&ctx)
}
```
**Step 5: 更新 Session::build_system_prompt**
```rust
pub fn build_system_prompt(&self, skills_prompt: &str, memory_context: Option<&str>) -> String {
let base_prompt = build_system_prompt(
&self.provider_config.workspace_dir,
&self.provider_config.model_id,
&self.tools,
Some(&self.id.to_string()),
memory_context,
self.last_compressed_message_at.is_some(),
);
if skills_prompt.trim().is_empty() {
base_prompt
} else {
format!("{}\n\n## Skills\n\n{}\n\nUse the `get_skill` tool to load a skill's full content when needed.", base_prompt, skills_prompt)
}
}
```
**Step 6: 更新所有其他 build_system_prompt 调用方**
搜索 `build_system_prompt(` 的所有调用位置,每个都要加 `false` 参数。主要有 `agent/agent_loop.rs` 中的两个调用。
**Step 7: 编译检查**
```bash
cargo check 2>&1
```
**Step 8: Commit**
```bash
git add src/agent/system_prompt.rs src/session/session.rs src/agent/agent_loop.rs
git commit -m "feat(system-prompt): add history section for archived conversation context"
```
---
### Task 7: handle_message 和 /compact 记录压缩标记
**Files:**
- Modify: `src/session/session.rs:1339-1355` (handle_message 压缩后)
- Modify: `src/session/session.rs:1372-1376` (handle_message 溢出重试)
- Modify: `src/session/session.rs:851-872` (/compact 命令)
**Step 1: handle_message 正常流**
`compress_if_needed(history).await?` 之后line 1346改为
```rust
let result = session_guard.compressor
.compress_if_needed(history)
.await?;
if result.created_timelines {
session_guard.last_compressed_message_at = Some(chrono::Utc::now().timestamp_millis());
if let Err(e) = session_guard.persist_session_meta().await {
tracing::warn!(error = %e, "Failed to persist compressed message marker");
}
}
let mut history = result.history;
```
同时删除后面line 1350-1355单独的 `persist_session_meta` 调用(现在已合入上面的逻辑)。
**Step 2: handle_message 溢出重试流**
```rust
let raw = session_guard.get_history().to_vec();
let result = session_guard.compressor.compress_if_needed(raw).await?;
if result.created_timelines {
session_guard.last_compressed_message_at = Some(chrono::Utc::now().timestamp_millis());
let _ = session_guard.persist_session_meta().await;
}
let mut retry = result.history;
retry.insert(0, ChatMessage::system(system_prompt));
agent.process(retry).await?
```
**Step 3: /compact 命令**
```rust
let result = session_guard.compressor
.compress_if_needed(history)
.await?;
let compressed_count = result.history.len();
if result.created_timelines {
session_guard.last_compressed_message_at = Some(chrono::Utc::now().timestamp_millis());
let _ = session_guard.persist_session_meta().await;
}
session_guard.clear_history();
for msg in result.history {
session_guard.add_message(msg, false).await
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
}
```
同时确认 `compress_if_needed` 的 import 正常(已在 scope 中)。
**Step 4: 编译检查**
```bash
cargo check 2>&1
```
**Step 5: Commit**
```bash
git add src/session/session.rs
git commit -m "feat(session): record last_compressed_message_at after compression"
```
---
### Task 8: 全局编译和测试
**Step 1: 全局编译**
```bash
cargo check 2>&1
```
修复所有编译错误,确保全部文件一致。
**Step 2: 运行单元测试**
```bash
cargo test --lib 2>&1
```
**Step 3: 测试通过后 commit**
```bash
git add -A
git commit -m "chore: fix remaining compilation and test issues for incremental recovery"
```
**Step 4: 运行 lint**
```bash
cargo clippy --lib 2>&1 | head -50
```
修复任何 warning。
---
### Task 9: 验证 & 提交设计文档
**Step 1: 最终验证**
```bash
cargo test --lib 2>&1
```
**Step 2: Commit 设计文档**
```bash
git add docs/plans/2026-05-10-incremental-session-recovery-design.md
git commit -m "docs: add incremental session recovery design doc"
```

File diff suppressed because it is too large Load Diff

View File

@ -5,7 +5,7 @@ always: true
--- ---
# About PicoBot # About PicoBot
PicoBot 是一个基于 Rust 的个人 AI 助手支持多渠道飞书、CLI、长记忆、定时任务、Skill 系统等 PicoBot 是一个基于 Rust 的个人 AI 助手运行时,包含本地 Gateway、CLI TUI 客户端、飞书渠道、SQLite 会话持久化、长期记忆、定时任务、Skill 系统、MCP 工具接入和子 Agent 委托能力
## 目录索引 ## 目录索引
@ -13,10 +13,10 @@ PicoBot 是一个基于 Rust 的个人 AI 助手支持多渠道飞书、CL
| 文件 | 内容 | | 文件 | 内容 |
|------|------| |------|------|
| `references/config.md` | 配置字段详解providers、models、agents、gateway、memory、channels、mcp | | `references/config.md` | 配置字段详解providers、models、agents、gateway、client、channels、memory、mcp、browser |
| `references/db-schema.md` | 数据库表结构sessions、messages、memories、scheduled_jobs、llm_calls | | `references/db-schema.md` | 数据库表结构sessions、messages、memories、scheduled_jobs、llm_calls、background_tasks |
| `references/architecture.md` | 核心架构数据流、会话系统、上下文压缩、记忆系统、Skill 优先级机制 | | `references/architecture.md` | 核心架构数据流、会话系统、上下文压缩、记忆系统、Skill 优先级、MCP、子 Agent |
| `references/faq.md` | 常见问题模型切换、渠道添加、Skill 安装、历史查询、定时任务等 | | `references/faq.md` | 常见问题模型切换、渠道添加、Skill 安装、历史查询、定时任务、MCP 等 |
| `references/commands.md` | 常用命令:编译、启动网关、启动客户端、运行测试 | | `references/commands.md` | 常用命令:编译、启动网关、启动客户端、运行测试 |
| `assets/config.example.json` | config.json 完整示例 | | `assets/config.example.json` | config.json 完整示例 |

View File

@ -72,5 +72,15 @@
"timeline_retention_days": 90, "timeline_retention_days": 90,
"max_failures_before_degrade": 3 "max_failures_before_degrade": 3
}, },
"mcp": {
"servers": [],
"tool_timeout_secs": 180
},
"browser": {
"enabled": false,
"webdriver_url": "http://127.0.0.1:9515",
"headless": true,
"chrome_path": null
},
"workspace_dir": "~/.picobot/workspace" "workspace_dir": "~/.picobot/workspace"
} }

View File

@ -17,9 +17,9 @@ Channel → MessageBus → SessionManager → AgentLoop → (tools) → SessionM
| `channels` | 外部集成飞书、CLI仅收发消息 | | `channels` | 外部集成飞书、CLI仅收发消息 |
| `bus` | 异步消息队列,纯队列不路由 | | `bus` | 异步消息队列,纯队列不路由 |
| `session` | 会话生命周期管理、dialog 操作 | | `session` | 会话生命周期管理、dialog 操作 |
| `agent` | LLM 调用循环、工具执行、上下文压缩 | | `agent` | LLM 调用循环、工具执行、上下文压缩、媒体处理、子 Agent |
| `providers` | LLM API 客户端OpenAI 兼容、Anthropic | | `providers` | LLM API 客户端OpenAI 兼容、Anthropic |
| `tools` | Agent 工具bash、文件操作、HTTP、web、get_skill 等) | | `tools` | Agent 工具bash、文件操作、搜索、HTTP、web、browser、memory、delegate 等) |
| `skills` | Skill 加载、管理和 prompt 构建 | | `skills` | Skill 加载、管理和 prompt 构建 |
| `storage` | SQLite 持久化 | | `storage` | SQLite 持久化 |
| `scheduler` | Cron 作业调度 | | `scheduler` | Cron 作业调度 |
@ -37,6 +37,8 @@ Channel → MessageBus → SessionManager → AgentLoop → (tools) → SessionM
- AgentLoop 无状态,接收 dialog 事件调用 LLM、执行工具 - AgentLoop 无状态,接收 dialog 事件调用 LLM、执行工具
- Providers 是纯 HTTP 客户端,无 bus/session/channel 感知 - Providers 是纯 HTTP 客户端,无 bus/session/channel 感知
- Tools 接收原始参数,返回字符串结果 - Tools 接收原始参数,返回字符串结果
- MCP 工具在 Gateway 初始化时连接服务器、发现工具,并包装成普通 Tool 注册到 ToolRegistry
- 子 Agent 由 `delegate` 工具创建,复用 provider 配置和按需过滤后的工具集;后台任务结果通过 MessageBus 发回原会话
## 关键约束 ## 关键约束
@ -45,6 +47,7 @@ Channel → MessageBus → SessionManager → AgentLoop → (tools) → SessionM
- ChannelManager 持有 MessageBus 和所有 channel - ChannelManager 持有 MessageBus 和所有 channel
- OutboundDispatcher 通过 ChannelManager 路由出站消息 - OutboundDispatcher 通过 ChannelManager 路由出站消息
- Config `.env` 加载使用 `unsafe { env::set_var(...) }` - Config `.env` 加载使用 `unsafe { env::set_var(...) }`
- `browser` 工具只有在 `browser.enabled=true` 时注册,依赖 Chrome/Chromium 与 WebDriver
## 上下文压缩 ## 上下文压缩
@ -192,3 +195,48 @@ LLM 对话上下文接近 token 限制 (默认 128K × 70%) 时自动触发压
| 有压缩历史时 | `HistorySection` 提示 LLM 使用 `timeline_recall` | | 有压缩历史时 | `HistorySection` 提示 LLM 使用 `timeline_recall` |
| 压缩完成后 | 摘要自动存储为 Timeline 记忆 | | 压缩完成后 | 摘要自动存储为 Timeline 记忆 |
| 空闲时 | 可配置自动 consolidation`idle_consolidation_minutes` | | 空闲时 | 可配置自动 consolidation`idle_consolidation_minutes` |
---
## MCP 工具集成
Gateway 初始化时读取 `config.mcp.servers`
1. 按服务器配置连接 `stdio``sse``streamable-http` 传输
2. 调用 MCP `list_tools`
3. 将每个 MCP tool 包装为 `McpToolWrapper`
4. 注册到当前 session 的 `ToolRegistry`
`/mcp` 斜杠命令会显示 MCP 服务器连接状态和工具列表。
---
## 子 Agent / delegate
`delegate` 工具用于把独立任务交给子 Agent
| 模式 | 行为 |
|------|------|
| `inline` | 当前轮阻塞等待子 Agent 返回 |
| `background` | 后台运行,完成后通过原 channel/chat 通知 |
| `parallel` | 多个子 Agent 并发执行并聚合结果 |
默认工具集是只读工具:`file_read``file_search``content_search``web_fetch``http_request``calculator`。调用时可通过 `allowed_tools` 显式放开其他工具。后台任务会写入 `background_tasks` 表,默认 24 小时后清理。
---
## 当前斜杠命令
| 命令 | 说明 |
|------|------|
| `/new` | 创建新对话 |
| `/sessions` | 列出最近对话 |
| `/switch <dialog_id>` | 切换到指定对话 |
| `/rename <title>` | 重命名当前对话 |
| `/delete` | 删除当前对话 |
| `/compact` | 手动触发上下文压缩 |
| `/info` | 显示当前对话信息 |
| `/dump` | 保存当前对话为 markdown |
| `/?`, `/help` | 显示帮助 |
| `/mcp` | 显示 MCP 状态 |
| `/stop` | 停止当前任务并清空消息队列 |

View File

@ -14,8 +14,9 @@
"client": {}, // 客户端配置 "client": {}, // 客户端配置
"channels": {}, // 渠道配置 "channels": {}, // 渠道配置
"memory": {}, // 记忆系统配置 "memory": {}, // 记忆系统配置
"workspace_dir": // 工作目录,默认 ~/.picobot/workspace "workspace_dir": "", // 工作目录,默认 ~/.picobot/workspace
"mcp": {} // MCP 服务器配置 "mcp": {}, // MCP 服务器配置
"browser": {} // 可选浏览器自动化配置
} }
``` ```
@ -57,8 +58,17 @@
| `session_ttl_hours` | int | - | 会话过期小时数 | | `session_ttl_hours` | int | - | 会话过期小时数 |
| `session_db_path` | string | - | SQLite 数据库路径,默认在 workspace 下 | | `session_db_path` | string | - | SQLite 数据库路径,默认在 workspace 下 |
| `cleanup_interval_minutes` | int | - | 清理间隔 | | `cleanup_interval_minutes` | int | - | 清理间隔 |
| `max_concurrent_background_tasks` | int | 10 | delegate 后台子任务最大并发数 |
| `scheduler` | object | - | 调度器配置 | | `scheduler` | object | - | 调度器配置 |
### gateway.scheduler 字段
| 字段 | 类型 | 默认 | 说明 |
|------|------|------|------|
| `enabled` | bool | true | 是否启动调度器并注册 cron 工具 |
| `poll_interval_secs` | int | 60 | 检查到期任务的轮询间隔 |
| `max_concurrent` | int | 1 | 最大并发任务数,当前实现预留 |
## memory 字段 ## memory 字段
| 字段 | 类型 | 默认 | 说明 | | 字段 | 类型 | 默认 | 说明 |
@ -94,8 +104,21 @@ MCP 服务器单条配置:
| 字段 | 说明 | | 字段 | 说明 |
|------|------| |------|------|
| `name` | 服务器名称 | | `name` | 服务器名称 |
| `transport` | 传输方式: `Stdio`、`Sse`、`streamable-http` | | `transport` | 传输方式: `stdio`、`sse`、`streamable-http` |
| `command` | 启动命令(Stdio 模式) | | `command` | 启动命令(stdio 模式) |
| `args` | 命令参数 | | `args` | 命令参数 |
| `url` | URLSse / streamable-http 模式) | | `env` | 子进程环境变量 |
| `url` | URLsse / streamable-http 模式) |
| `headers` | HTTP 传输额外请求头 |
| `tool_timeout_secs` | 单独的超时设置 | | `tool_timeout_secs` | 单独的超时设置 |
## browser 字段
浏览器工具默认关闭,开启后注册 `browser` 工具。依赖 Chrome/Chromium 与 chromedriver/WebDriver。
| 字段 | 类型 | 默认 | 说明 |
|------|------|------|------|
| `enabled` | bool | false | 是否启用浏览器工具 |
| `webdriver_url` | string | http://127.0.0.1:9515 | WebDriver 服务地址 |
| `headless` | bool | true | 是否无头运行 |
| `chrome_path` | string | - | 自定义 Chrome/Chromium 路径 |

View File

@ -36,6 +36,28 @@
| `tool_calls` | TEXT | 工具调用参数 JSON | | `tool_calls` | TEXT | 工具调用参数 JSON |
| `source` | TEXT | 消息来源(跨会话消息时标记来源 session_id | | `source` | TEXT | 消息来源(跨会话消息时标记来源 session_id |
| `created_at` | INTEGER | 创建时间unix 秒) | | `created_at` | INTEGER | 创建时间unix 秒) |
| `reasoning_content` | TEXT | provider 返回的推理内容(如有) |
## background_tasks 表
delegate 后台子任务表。`session_id` 不使用数据库外键,因为 session 使用软删除,关联关系由应用层维护。
| 字段 | 类型 | 说明 |
|------|------|------|
| `id` | TEXT PK | 后台任务 ID |
| `session_id` | TEXT | 所属会话 |
| `channel` | TEXT | 回传渠道 |
| `chat_id` | TEXT | 回传目标对话 |
| `prompt` | TEXT | 子任务提示 |
| `allowed_tools` | TEXT | 允许工具 JSON |
| `status` | TEXT | pending / running / completed / failed / cancelled |
| `result` | TEXT | 执行结果 |
| `error` | TEXT | 错误信息 |
| `tool_calls_count` | INTEGER | 工具调用次数 |
| `iterations` | INTEGER | Agent 迭代次数 |
| `started_at` | INTEGER | 开始时间 |
| `finished_at` | INTEGER | 结束时间 |
| `created_at` | INTEGER | 创建时间 |
## memories 表 ## memories 表

View File

@ -124,9 +124,51 @@
--- ---
## file_read / file_write / file_edit / file_search — 文件操作 ## delegate — 子 Agent 委托
工作目录内的文件读写编辑和搜索。详细的参数定义见各工具的 parameters_schema。 创建子 Agent 处理独立任务。
| 参数 | 必填 | 说明 |
|------|------|------|
| `action` | 是 | `run`, `check_task`, `cancel_task`, `list_tasks` |
| `prompt` | run 必填 | 子任务描述 |
| `mode` | 否 | `inline`, `background`, `parallel`,默认 `inline` |
| `allowed_tools` | 否 | 子 Agent 可用工具列表;默认只读工具集 |
| `max_iterations` | 否 | 最大迭代次数,默认 99 |
| `timeout_secs` | 否 | 超时秒数,默认 3600 |
| `tasks` | parallel 必填 | 并行子任务数组 |
| `task_id` | 查询/取消必填 | 后台任务 ID |
默认只读工具集:`file_read``file_search``content_search``web_fetch``http_request``calculator`
---
## browser — 浏览器自动化
仅在 `browser.enabled=true` 时注册。底层使用 WebDriver/Chrome。
| action | 说明 |
|--------|------|
| `open` | 打开 URL |
| `snapshot` | 获取页面结构快照 |
| `click`, `click_at` | 点击元素或坐标 |
| `fill`, `type`, `press` | 输入文本或按键 |
| `get_text`, `get_title`, `get_url` | 读取页面信息 |
| `screenshot` | 截图,可写入文件或返回 base64 |
| `focus`, `hover`, `scroll`, `wait` | 常见交互和等待 |
| `close` | 关闭浏览器会话 |
---
## MCP 工具
如果 `config.mcp.servers` 配置了 MCP 服务器Gateway 启动时会连接服务器、发现工具,并把 MCP 工具包装后注册到 ToolRegistry。使用 `/mcp` 查看当前连接状态和工具列表。
---
## file_read / file_write / file_edit / file_search / content_search — 文件操作和搜索
工作目录内的文件读写编辑、文件名搜索和内容搜索。详细的参数定义见各工具的 parameters_schema。
## bash — 执行命令 ## bash — 执行命令

View File

@ -72,5 +72,15 @@
"timeline_retention_days": 90, "timeline_retention_days": 90,
"max_failures_before_degrade": 3 "max_failures_before_degrade": 3
}, },
"mcp": {
"servers": [],
"tool_timeout_secs": 180
},
"browser": {
"enabled": false,
"webdriver_url": "http://127.0.0.1:9515",
"headless": true,
"chrome_path": null
},
"workspace_dir": "~/.picobot/workspace" "workspace_dir": "~/.picobot/workspace"
} }

View File

@ -4,10 +4,8 @@ use crate::agent::system_prompt::build_system_prompt;
use crate::bus::message::ContentBlock; use crate::bus::message::ContentBlock;
use crate::bus::{ChatMessage, MediaRef}; use crate::bus::{ChatMessage, MediaRef};
use crate::config::LLMProviderConfig; use crate::config::LLMProviderConfig;
use crate::observability::{ use crate::observability::{Observer, ObserverEvent, ToolExecutionOutcome, truncate_args};
truncate_args, Observer, ObserverEvent, ToolExecutionOutcome, use crate::providers::{ChatCompletionRequest, LLMProvider, Message, ToolCall, create_provider};
};
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall};
use crate::tools::ToolRegistry; use crate::tools::ToolRegistry;
use std::collections::VecDeque; use std::collections::VecDeque;
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
@ -228,6 +226,7 @@ pub struct AgentLoop {
pub struct AgentProcessResult { pub struct AgentProcessResult {
pub final_response: ChatMessage, pub final_response: ChatMessage,
pub emitted_messages: Vec<ChatMessage>, pub emitted_messages: Vec<ChatMessage>,
pub total_tokens: Option<u32>,
} }
impl AgentLoop { impl AgentLoop {
@ -255,7 +254,10 @@ impl AgentLoop {
} }
/// Create a new AgentLoop with provider created from config and given tools. /// Create a new AgentLoop with provider created from config and given tools.
pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>) -> Result<Self, AgentError> { pub fn with_tools(
provider_config: LLMProviderConfig,
tools: Arc<ToolRegistry>,
) -> Result<Self, AgentError> {
let max_iterations = provider_config.max_tool_iterations; let max_iterations = provider_config.max_tool_iterations;
let model_name = provider_config.model_id.clone(); let model_name = provider_config.model_id.clone();
let workspace_dir = provider_config.workspace_dir.clone(); let workspace_dir = provider_config.workspace_dir.clone();
@ -278,7 +280,13 @@ impl AgentLoop {
} }
/// Create a new AgentLoop with an existing shared provider. /// Create a new AgentLoop with an existing shared provider.
pub fn with_provider(provider: Arc<dyn LLMProvider>, max_iterations: usize, model_name: String, workspace_dir: PathBuf, input_types: Vec<String>) -> Self { pub fn with_provider(
provider: Arc<dyn LLMProvider>,
max_iterations: usize,
model_name: String,
workspace_dir: PathBuf,
input_types: Vec<String>,
) -> Self {
Self { Self {
provider, provider,
tools: Arc::new(ToolRegistry::new()), tools: Arc::new(ToolRegistry::new()),
@ -340,8 +348,9 @@ impl AgentLoop {
} }
/// Preemptive trim: truncate old tool results in-place when history is /// Preemptive trim: truncate old tool results in-place when history is
/// approaching the context window limit. Only trims tool messages with /// approaching the context window limit. Old results (outside of `keep_recent`
/// content > TRIM_CHARS, preserving the most recent KEEP messages. /// zone) are replaced with a short placeholder; recent results are truncated
/// to `max_chars`.
fn preemptive_trim_old_tool_results( fn preemptive_trim_old_tool_results(
&self, &self,
messages: &mut [ChatMessage], messages: &mut [ChatMessage],
@ -358,11 +367,11 @@ impl AgentLoop {
if messages[i].content.len() <= max_chars { if messages[i].content.len() <= max_chars {
continue; continue;
} }
let removed = messages[i].content.len() - max_chars; let tool_name = messages[i].tool_name.as_deref().unwrap_or("unknown");
let chars = messages[i].content.len();
messages[i].content = format!( messages[i].content = format!(
"{}...\n\n[Output truncated - {} characters removed]", "[Tool output ({}) — {} chars, omitted from context]",
&messages[i].content[..messages[i].content.ceil_char_boundary(max_chars)], tool_name, chars
removed
); );
modified += 1; modified += 1;
} }
@ -377,7 +386,12 @@ impl AgentLoop {
let content = if m.media_refs.is_empty() { let content = if m.media_refs.is_empty() {
vec![ContentBlock::text(&m.content)] vec![ContentBlock::text(&m.content)]
} else { } else {
build_content_blocks(&m.content, &m.media_refs, &self.input_types, &self.media_registry) build_content_blocks(
&m.content,
&m.media_refs,
&self.input_types,
&self.media_registry,
)
}; };
Message { Message {
@ -397,14 +411,28 @@ impl AgentLoop {
/// it loops back to the LLM with the tool results until either: /// it loops back to the LLM with the tool results until either:
/// - The LLM returns no more tool calls (final response) /// - The LLM returns no more tool calls (final response)
/// - Maximum iterations are reached /// - Maximum iterations are reached
pub async fn process(&self, mut messages: Vec<ChatMessage>) -> Result<AgentProcessResult, AgentError> { pub async fn process(
&self,
mut messages: Vec<ChatMessage>,
) -> Result<AgentProcessResult, AgentError> {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
tracing::debug!(history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process"); tracing::debug!(
history_len = messages.len(),
max_iterations = self.max_iterations,
"Starting agent process"
);
// Build and inject system prompt if not present // Build and inject system prompt if not present
let has_system = messages.first().is_some_and(|m| m.role == "system"); let has_system = messages.first().is_some_and(|m| m.role == "system");
if !has_system { if !has_system {
let system_prompt = build_system_prompt(&self.workspace_dir, &self.model_name, &self.tools, None, None, false); let system_prompt = build_system_prompt(
&self.workspace_dir,
&self.model_name,
&self.tools,
None,
None,
false,
);
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
tracing::debug!("System prompt injected:\n{}", system_prompt); tracing::debug!("System prompt injected:\n{}", system_prompt);
messages.insert(0, ChatMessage::system(system_prompt)); messages.insert(0, ChatMessage::system(system_prompt));
@ -413,6 +441,7 @@ impl AgentLoop {
// Track tool calls for loop detection // Track tool calls for loop detection
let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default()); let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default());
let mut emitted_messages = Vec::new(); let mut emitted_messages = Vec::new();
let mut accumulated_tokens: u32 = 0;
for iteration in 0..self.max_iterations { for iteration in 0..self.max_iterations {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
@ -424,9 +453,7 @@ impl AgentLoop {
let estimated = estimate_tokens(&messages); let estimated = estimate_tokens(&messages);
let danger = (self.context_window as f64 * 0.8) as usize; let danger = (self.context_window as f64 * 0.8) as usize;
if estimated > danger { if estimated > danger {
let trimmed = self.preemptive_trim_old_tool_results( let trimmed = self.preemptive_trim_old_tool_results(&mut messages, 2000, 4);
&mut messages, 2000, 4,
);
if trimmed > 0 { if trimmed > 0 {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
tracing::debug!( tracing::debug!(
@ -460,11 +487,12 @@ impl AgentLoop {
}; };
// Call LLM // Call LLM
let response = (*self.provider).chat(request).await let response = (*self.provider).chat(request).await.map_err(|e| {
.map_err(|e| { tracing::error!(error = %e, "LLM request failed");
tracing::error!(error = %e, "LLM request failed"); AgentError::LlmError(e.to_string())
AgentError::LlmError(e.to_string()) })?;
})?;
accumulated_tokens += response.usage.total_tokens;
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
tracing::debug!( tracing::debug!(
@ -482,12 +510,15 @@ impl AgentLoop {
return Ok(AgentProcessResult { return Ok(AgentProcessResult {
final_response: assistant_message, final_response: assistant_message,
emitted_messages, emitted_messages,
total_tokens: Some(accumulated_tokens),
}); });
} }
// Execute tool calls — log and notify immediately // Execute tool calls — log and notify immediately
{ {
let tools_info: Vec<String> = response.tool_calls.iter() let tools_info: Vec<String> = response
.tool_calls
.iter()
.map(|tc| { .map(|tc| {
let args = serde_json::to_string(&tc.arguments).unwrap_or_default(); let args = serde_json::to_string(&tc.arguments).unwrap_or_default();
let s = format!("{}:{}", tc.name, args); let s = format!("{}:{}", tc.name, args);
@ -516,7 +547,9 @@ impl AgentLoop {
// Log function call with name and arguments // Log function call with name and arguments
let args_str = match &tool_call.arguments { let args_str = match &tool_call.arguments {
serde_json::Value::Object(obj) if obj.is_empty() => "{}".to_string(), serde_json::Value::Object(obj) if obj.is_empty() => "{}".to_string(),
other => serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()), other => {
serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string())
}
}; };
tracing::info!(tool = %tool_call.name, args = %args_str, "Calling tool"); tracing::info!(tool = %tool_call.name, args = %args_str, "Calling tool");
@ -556,7 +589,11 @@ impl AgentLoop {
// Loop continues to next iteration with updated messages // Loop continues to next iteration with updated messages
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
tracing::debug!(iteration, message_count = messages.len(), "Tool execution complete, continuing to next iteration"); tracing::debug!(
iteration,
message_count = messages.len(),
"Tool execution complete, continuing to next iteration"
);
} }
// Max iterations reached - ask LLM for a summary based on completed work // Max iterations reached - ask LLM for a summary based on completed work
@ -565,7 +602,7 @@ impl AgentLoop {
// Add a message asking for summary // Add a message asking for summary
let summary_request = ChatMessage::user( let summary_request = ChatMessage::user(
"You have reached the maximum number of tool call iterations. \ "You have reached the maximum number of tool call iterations. \
Please provide your best answer based on the work completed so far." Please provide your best answer based on the work completed so far.",
); );
messages.push(summary_request); messages.push(summary_request);
@ -584,24 +621,32 @@ impl AgentLoop {
match (*self.provider).chat(request).await { match (*self.provider).chat(request).await {
Ok(response) => { Ok(response) => {
accumulated_tokens += response.usage.total_tokens;
let mut assistant_message = ChatMessage::assistant(response.content); let mut assistant_message = ChatMessage::assistant(response.content);
assistant_message.reasoning_content = response.reasoning_content; assistant_message.reasoning_content = response.reasoning_content;
emitted_messages.push(assistant_message.clone()); emitted_messages.push(assistant_message.clone());
Ok(AgentProcessResult { Ok(AgentProcessResult {
final_response: assistant_message, final_response: assistant_message,
emitted_messages, emitted_messages,
total_tokens: Some(accumulated_tokens),
}) })
} }
Err(e) => { Err(e) => {
// Fallback if summary call fails // Fallback if summary call fails
tracing::error!(error = %e, "Failed to get summary from LLM"); tracing::error!(error = %e, "Failed to get summary from LLM");
let final_message = ChatMessage::assistant( let final_message = ChatMessage::assistant(format!(
format!("I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.", self.max_iterations) "I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.",
); self.max_iterations
));
emitted_messages.push(final_message.clone()); emitted_messages.push(final_message.clone());
Ok(AgentProcessResult { Ok(AgentProcessResult {
final_response: final_message, final_response: final_message,
emitted_messages, emitted_messages,
total_tokens: if accumulated_tokens > 0 {
Some(accumulated_tokens)
} else {
None
},
}) })
} }
} }
@ -689,10 +734,7 @@ impl AgentLoop {
} }
// Apply duration // Apply duration
ToolExecutionOutcome { ToolExecutionOutcome { duration, ..result }
duration,
..result
}
} }
/// Internal tool execution without event tracking. /// Internal tool execution without event tracking.
@ -714,18 +756,12 @@ impl AgentLoop {
ToolExecutionOutcome::success(result.output) ToolExecutionOutcome::success(result.output)
} else { } else {
let error = result.error.unwrap_or_default(); let error = result.error.unwrap_or_default();
ToolExecutionOutcome::failure( ToolExecutionOutcome::failure(format!("Error: {}", error), Some(error))
format!("Error: {}", error),
Some(error),
)
} }
} }
Err(e) => { Err(e) => {
tracing::error!(tool = %tool_call.name, error = %e, "Tool execution failed"); tracing::error!(tool = %tool_call.name, error = %e, "Tool execution failed");
ToolExecutionOutcome::failure( ToolExecutionOutcome::failure(format!("Error: {}", e), Some(e.to_string()))
format!("Error: {}", e),
Some(e.to_string()),
)
} }
} }
} }
@ -813,8 +849,14 @@ mod tests {
assert_eq!(provider_message.role, "assistant"); assert_eq!(provider_message.role, "assistant");
assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1); assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1);
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].id, "call_1"); assert_eq!(
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].name, "calculator"); provider_message.tool_calls.as_ref().unwrap()[0].id,
"call_1"
);
assert_eq!(
provider_message.tool_calls.as_ref().unwrap()[0].name,
"calculator"
);
} }
} }

View File

@ -68,6 +68,10 @@ pub struct ContextCompressor {
memory: Arc<MemoryManager>, memory: Arc<MemoryManager>,
/// Current session ID for timeline memory writes. /// Current session ID for timeline memory writes.
session_id: Option<String>, session_id: Option<String>,
/// Message count sent in the last LLM call (used to split known/new history).
last_sent_message_count: Option<usize>,
/// Real total_tokens from the last API response.
last_api_total_tokens: Option<u32>,
} }
/// Result of context compression. /// Result of context compression.
@ -76,6 +80,15 @@ pub struct CompressionResult {
pub created_timelines: bool, pub created_timelines: bool,
} }
/// Token budget state snapshot for diagnostics.
pub struct TokenInfo {
pub context_window: usize,
pub threshold: usize,
pub estimated_tokens: usize,
pub last_api_tokens: Option<u32>,
pub cache_active: bool,
}
impl ContextCompressor { impl ContextCompressor {
/// Create a new compressor with the given provider, context window size, and memory manager. /// Create a new compressor with the given provider, context window size, and memory manager.
pub fn new( pub fn new(
@ -90,6 +103,8 @@ impl ContextCompressor {
provider, provider,
memory, memory,
session_id: None, session_id: None,
last_sent_message_count: None,
last_api_total_tokens: None,
} }
} }
@ -107,6 +122,8 @@ impl ContextCompressor {
provider, provider,
memory, memory,
session_id: None, session_id: None,
last_sent_message_count: None,
last_api_total_tokens: None,
} }
} }
@ -120,39 +137,91 @@ impl ContextCompressor {
self.context_window = window; self.context_window = window;
} }
/// Record the API's reported token usage from the last completed turn.
/// `msg_count`: number of messages sent to LLM in that call.
/// `tokens`: `total_tokens` from the API response.
pub fn set_last_api_info(&mut self, msg_count: usize, tokens: Option<u32>) {
self.last_sent_message_count = Some(msg_count);
self.last_api_total_tokens = tokens;
}
/// Invalidate the cached API token info — called after compression modifies messages.
fn invalidate_token_cache(&mut self) {
self.last_sent_message_count = None;
self.last_api_total_tokens = None;
}
/// Hybrid token estimation: API-reported tokens for known history +
/// char/4 estimate for new messages since last API call.
fn token_estimate_with_history(&self, messages: &[ChatMessage]) -> usize {
match (self.last_api_total_tokens, self.last_sent_message_count) {
(Some(known), Some(known_count)) if messages.len() > known_count => {
let delta = &messages[known_count..];
known as usize + estimate_tokens(delta)
}
(Some(known), _) => known as usize,
_ => estimate_tokens(messages),
}
}
/// Always true — memory is always available (memory system is always on). /// Always true — memory is always available (memory system is always on).
pub fn has_memory(&self) -> bool { pub fn has_memory(&self) -> bool {
true true
} }
/// Get a snapshot of the current token budget state for diagnostics.
pub fn token_info(&self, messages: &[ChatMessage]) -> TokenInfo {
TokenInfo {
context_window: self.context_window,
threshold: self.threshold(),
estimated_tokens: self.token_estimate_with_history(messages),
last_api_tokens: self.last_api_total_tokens,
cache_active: self.last_api_total_tokens.is_some(),
}
}
/// Get the compression threshold in tokens. /// Get the compression threshold in tokens.
pub fn threshold(&self) -> usize { pub fn threshold(&self) -> usize {
(self.context_window as f64 * self.threshold_ratio) as usize (self.context_window as f64 * self.threshold_ratio) as usize
} }
/// Fast-path: trim oversized tool results without LLM call. /// Fast-path: trim oversized tool results without LLM call.
/// Old tool results (outside of `protect_tail` zone) are replaced with a
/// concise placeholder; recent results are truncated to `tool_result_trim_chars`.
/// Returns the number of messages modified. /// Returns the number of messages modified.
fn fast_trim_tool_results(&self, messages: &mut [ChatMessage]) -> usize { fn fast_trim_tool_results(&self, messages: &mut [ChatMessage], protect_tail: usize) -> usize {
let limit = self.config.tool_result_trim_chars; let limit = self.config.tool_result_trim_chars;
let tail_start = messages.len().saturating_sub(protect_tail);
let mut modified = 0; let mut modified = 0;
for msg in messages.iter_mut() { for (i, msg) in messages.iter_mut().enumerate() {
if msg.role == "tool" && msg.content.len() > limit { if msg.role != "tool" || msg.content.len() <= limit {
continue;
}
if i < tail_start {
let tool_name = msg.tool_name.as_deref().unwrap_or("unknown");
let chars = msg.content.len();
msg.content = format!(
"[Tool output ({}) — {} chars, omitted from context]",
tool_name, chars
);
} else {
let removed = msg.content.len() - limit; let removed = msg.content.len() - limit;
msg.content = format!( msg.content = format!(
"{}...\n\n[Output truncated - {} characters removed]", "{}...\n\n[Output truncated - {} characters removed]",
&msg.content[..msg.content.ceil_char_boundary(limit)], &msg.content[..msg.content.ceil_char_boundary(limit)],
removed removed
); );
modified += 1;
} }
modified += 1;
} }
modified modified
} }
/// Remove orphan tool results whose declaring tool_calls have been compressed away. /// Repair tool call chains after compression.
/// Scans for tool messages with no preceding assistant tool_call, and removes them. /// Phase 1: remove orphan tool results whose declaring tool_calls are missing.
/// Phase 2: strip tool_calls from assistants whose results are missing.
pub fn repair_tool_pairs(messages: &mut Vec<ChatMessage>) { pub fn repair_tool_pairs(messages: &mut Vec<ChatMessage>) {
let mut declared: std::collections::HashSet<String> = std::collections::HashSet::new(); let mut declared: std::collections::HashSet<String> = std::collections::HashSet::new();
let mut i = 0; let mut i = 0;
@ -165,23 +234,58 @@ impl ContextCompressor {
} }
} else if messages[i].role == "tool" } else if messages[i].role == "tool"
&& let Some(ref tid) = messages[i].tool_call_id && let Some(ref tid) = messages[i].tool_call_id
&& !declared.contains(tid.as_str()) { && !declared.contains(tid.as_str())
messages.remove(i); {
continue; messages.remove(i);
} continue;
}
i += 1; i += 1;
} }
let broken: Vec<usize> = messages
.iter()
.enumerate()
.filter_map(|(idx, msg)| {
if msg.role == "assistant"
&& let Some(ref tcs) = msg.tool_calls
&& !tcs.is_empty()
{
let all_present = tcs.iter().all(|tc| {
messages.iter().any(|m| {
m.role == "tool" && m.tool_call_id.as_deref() == Some(tc.id.as_str())
})
});
if !all_present { Some(idx) } else { None }
} else {
None
}
})
.collect();
for idx in broken {
let msg = &mut messages[idx];
let tcs = msg.tool_calls.take().unwrap_or_default();
let names: Vec<&str> = tcs.iter().map(|tc| tc.name.as_str()).collect();
msg.content = format!(
"{}\n\n[Tool calls ({}) — results are no longer available]",
msg.content,
names.join(", ")
);
}
} }
/// Main entry point - compresses history if over threshold. /// Main entry point - compresses history if over threshold.
pub async fn compress_if_needed( pub async fn compress_if_needed(
&self, &mut self,
mut history: Vec<ChatMessage>, mut history: Vec<ChatMessage>,
) -> Result<CompressionResult, AgentError> { ) -> Result<CompressionResult, AgentError> {
// Check if compression is needed // Check if compression is needed
let tokens = estimate_tokens(&history); let tokens = self.token_estimate_with_history(&history);
if tokens <= self.threshold() { if tokens <= self.threshold() {
return Ok(CompressionResult { history, created_timelines: false }); return Ok(CompressionResult {
history,
created_timelines: false,
});
} }
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
@ -193,8 +297,8 @@ impl ContextCompressor {
); );
// Fast trim pass first — modify history in place // Fast trim pass first — modify history in place
let trimmed = self.fast_trim_tool_results(&mut history); let trimmed = self.fast_trim_tool_results(&mut history, self.config.protect_last_n);
let tokens_after = estimate_tokens(&history); let tokens_after = self.token_estimate_with_history(&history);
if trimmed > 0 { if trimmed > 0 {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
tracing::debug!( tracing::debug!(
@ -204,24 +308,24 @@ impl ContextCompressor {
); );
} }
if tokens_after <= self.threshold() { if tokens_after <= self.threshold() {
return Ok(CompressionResult { history, created_timelines: false }); self.invalidate_token_cache();
return Ok(CompressionResult {
history,
created_timelines: false,
});
} }
// LLM summarization pass // LLM summarization pass
let mut current_history = history; let mut current_history = history;
let mut created_timelines = false; let mut created_timelines = false;
for pass in 0..self.config.max_passes { for pass in 0..self.config.max_passes {
let tokens = estimate_tokens(&current_history); let tokens = self.token_estimate_with_history(&current_history);
if tokens <= self.threshold() { if tokens <= self.threshold() {
break; break;
} }
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
tracing::debug!( tracing::debug!(pass = pass + 1, tokens = tokens, "Compression pass");
pass = pass + 1,
tokens = tokens,
"Compression pass"
);
match self.compress_once(&current_history).await { match self.compress_once(&current_history).await {
Ok(Some(compressed)) => { Ok(Some(compressed)) => {
@ -241,15 +345,52 @@ impl ContextCompressor {
// Hard safety net: if still dangerously high after all passes, // Hard safety net: if still dangerously high after all passes,
// fall back to head+tail truncation so the LLM call doesn't overflow. // fall back to head+tail truncation so the LLM call doesn't overflow.
let final_tokens = estimate_tokens(&current_history); let final_tokens = self.token_estimate_with_history(&current_history);
let danger_threshold = (self.context_window as f64 * 0.9) as usize; let danger_threshold = (self.context_window as f64 * 0.9) as usize;
if final_tokens > danger_threshold if final_tokens > danger_threshold
&& current_history.len() > self.config.protect_first_n + self.config.protect_last_n && current_history.len() > self.config.protect_first_n + self.config.protect_last_n
{ {
let mut tail_start = current_history.len() - self.config.protect_last_n;
// Align tail_start backwards to preserve tool chain boundaries:
// if an assistant with tool_calls has results spanning the cut,
// include the assistant in the tail.
if tail_start > 0 && tail_start < current_history.len() {
let mut scan = tail_start.saturating_sub(1);
loop {
let m = &current_history[scan];
if m.role == "assistant" {
if let Some(tcs) = &m.tool_calls
&& !tcs.is_empty()
{
let has_post = current_history[scan + 1..]
.iter()
.filter(|r| r.role == "tool")
.any(|r| {
tcs.iter()
.any(|tc| r.tool_call_id.as_deref() == Some(tc.id.as_str()))
});
if has_post {
tail_start = scan;
}
}
break;
}
if scan == 0 {
break;
}
scan -= 1;
}
}
// Skip orphan tool messages at the new head-tail boundary
while tail_start < current_history.len() && current_history[tail_start].role == "tool" {
tail_start += 1;
}
let head: Vec<_> = current_history[..self.config.protect_first_n].to_vec(); let head: Vec<_> = current_history[..self.config.protect_first_n].to_vec();
let tail_start = current_history.len() - self.config.protect_last_n;
let tail: Vec<_> = current_history[tail_start..].to_vec(); let tail: Vec<_> = current_history[tail_start..].to_vec();
let dropped = current_history.len() - self.config.protect_first_n - self.config.protect_last_n; let dropped = current_history.len() - self.config.protect_first_n - tail.len();
let mut truncated = head; let mut truncated = head;
truncated.push(ChatMessage::user(format!( truncated.push(ChatMessage::user(format!(
@ -259,6 +400,26 @@ impl ContextCompressor {
))); )));
truncated.extend(tail); truncated.extend(tail);
// Strip tool_calls from any assistant in the head whose results
// were dropped (previously in the middle section).
for msg in &mut truncated[..self.config.protect_first_n] {
if msg.role == "assistant" {
if let Some(ref tcs) = msg.tool_calls
&& !tcs.is_empty()
{
let names: Vec<&str> = tcs.iter().map(|tc| tc.name.as_str()).collect();
msg.content = format!(
"{}\n\n[Tool calls ({}) — results dropped during truncation]",
msg.content,
names.join(", ")
);
msg.tool_calls = None;
}
}
}
Self::repair_tool_pairs(&mut truncated);
tracing::warn!( tracing::warn!(
final_tokens = final_tokens, final_tokens = final_tokens,
danger = danger_threshold, danger = danger_threshold,
@ -269,14 +430,21 @@ impl ContextCompressor {
current_history = truncated; current_history = truncated;
} }
if created_timelines {
self.invalidate_token_cache();
}
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
tracing::debug!( tracing::debug!(
final_tokens = estimate_tokens(&current_history), final_tokens = self.token_estimate_with_history(&current_history),
final_msg_count = current_history.len(), final_msg_count = current_history.len(),
"Context compression completed" "Context compression completed"
); );
Ok(CompressionResult { history: current_history, created_timelines }) Ok(CompressionResult {
history: current_history,
created_timelines,
})
} }
/// Try to extract the actual context token limit from an LLM error message. /// Try to extract the actual context token limit from an LLM error message.
@ -299,20 +467,21 @@ impl ContextCompressor {
// Look for a number in the vicinity (up to 10 chars after marker) // Look for a number in the vicinity (up to 10 chars after marker)
if let Some(num_str) = find_number_nearby(after, 50) if let Some(num_str) = find_number_nearby(after, 50)
&& let Ok(n) = num_str.parse::<usize>() && let Ok(n) = num_str.parse::<usize>()
&& (1024..=10_000_000).contains(&n) { && (1024..=10_000_000).contains(&n)
return Some(n); {
} return Some(n);
}
} }
} }
// Also try: "XXXX token context" or "XXXX limit" // Also try: "XXXX token context" or "XXXX limit"
if let Some(num_str) = find_number_nearby(&lower, lower.len()) if let Some(num_str) = find_number_nearby(&lower, lower.len())
&& let Ok(n) = num_str.parse::<usize>() && let Ok(n) = num_str.parse::<usize>()
&& (1024..=10_000_000).contains(&n) && (1024..=10_000_000).contains(&n)
&& (lower.contains("token") || lower.contains("context") || lower.contains("limit")) && (lower.contains("token") || lower.contains("context") || lower.contains("limit"))
{ {
return Some(n); return Some(n);
} }
None None
} }
@ -361,19 +530,26 @@ impl ContextCompressor {
// Persist compressed summary as timeline memory entry // Persist compressed summary as timeline memory entry
let ts = chrono::Utc::now().format("%Y-%m-%d %H:%M").to_string(); let ts = chrono::Utc::now().format("%Y-%m-%d %H:%M").to_string();
let timeline_content = format!("[{}] Compressed {} conversation segments:\n{}", let timeline_content = format!(
ts, between.len(), summary); "[{}] Compressed {} conversation segments:\n{}",
ts,
between.len(),
summary
);
let key = format!("ctx_compressed_{}", uuid::Uuid::new_v4()); let key = format!("ctx_compressed_{}", uuid::Uuid::new_v4());
let mm = self.memory.clone(); let mm = self.memory.clone();
let sid = self.session_id.clone(); let sid = self.session_id.clone();
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = mm.store( if let Err(e) = mm
&key, .store(
&timeline_content, &key,
crate::memory::MemoryCategory::Timeline, &timeline_content,
sid.as_deref(), crate::memory::MemoryCategory::Timeline,
Some(0.3), sid.as_deref(),
).await { Some(0.3),
)
.await
{
tracing::warn!(error = %e, "Failed to store compressed context as timeline"); tracing::warn!(error = %e, "Failed to store compressed context as timeline");
} }
}); });
@ -404,10 +580,7 @@ impl ContextCompressor {
} }
/// Summarize a segment of messages using LLM. /// Summarize a segment of messages using LLM.
async fn summarize_segment( async fn summarize_segment(&self, messages: &[ChatMessage]) -> Result<String, AgentError> {
&self,
messages: &[ChatMessage],
) -> Result<String, AgentError> {
if messages.is_empty() { if messages.is_empty() {
return Ok(String::new()); return Ok(String::new());
} }
@ -421,7 +594,8 @@ impl ContextCompressor {
"tool" => "Tool", "tool" => "Tool",
_ => m.role.as_str(), _ => m.role.as_str(),
}; };
let name = m.tool_name let name = m
.tool_name
.as_ref() .as_ref()
.map(|n| format!(" ({})", n)) .map(|n| format!(" ({})", n))
.unwrap_or_default(); .unwrap_or_default();
@ -466,7 +640,10 @@ Be concise, aim for {} characters or less.
); );
let request = ChatCompletionRequest { let request = ChatCompletionRequest {
messages: vec![Message::system("You are a helpful assistant."), Message::user(&prompt)], messages: vec![
Message::system("You are a helpful assistant."),
Message::user(&prompt),
],
temperature: Some(0.3), temperature: Some(0.3),
max_tokens: Some(1000), max_tokens: Some(1000),
tools: None, tools: None,
@ -538,13 +715,23 @@ mod tests {
content: "[summarized]".into(), content: "[summarized]".into(),
reasoning_content: None, reasoning_content: None,
tool_calls: vec![], tool_calls: vec![],
usage: Usage { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }, usage: Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
},
}) })
} }
fn ptype(&self) -> &str { "mock" } fn ptype(&self) -> &str {
fn name(&self) -> &str { "mock" } "mock"
fn model_id(&self) -> &str { "mock" } }
fn name(&self) -> &str {
"mock"
}
fn model_id(&self) -> &str {
"mock"
}
} }
fn mock_summarizer() -> Arc<dyn LLMProvider> { fn mock_summarizer() -> Arc<dyn LLMProvider> {
@ -556,11 +743,13 @@ mod tests {
MM.get_or_init(|| { MM.get_or_init(|| {
let rt = tokio::runtime::Runtime::new().unwrap(); let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async { rt.block_on(async {
let tmp = std::env::temp_dir().join(format!("picobot_ctx_test_{}.db", std::process::id())); let tmp = std::env::temp_dir()
.join(format!("picobot_ctx_test_{}.db", std::process::id()));
let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap()); let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap());
Arc::new(MemoryManager::new(storage, "test".into(), "test".into())) Arc::new(MemoryManager::new(storage, "test".into(), "test".into()))
}) })
}).clone() })
.clone()
} }
#[test] #[test]
@ -576,7 +765,11 @@ mod tests {
// "Hi there!" (8) -> ceil(8/4)+4 = 2+4 = 6 // "Hi there!" (8) -> ceil(8/4)+4 = 2+4 = 6
// "How are you?" (11) -> ceil(11/4)+4 = 3+4 = 7 // "How are you?" (11) -> ceil(11/4)+4 = 3+4 = 7
// raw = 19, with 1.2x = ~23 // raw = 19, with 1.2x = ~23
assert!(tokens > 18 && tokens < 30, "Expected ~23 tokens, got {}", tokens); assert!(
tokens > 18 && tokens < 30,
"Expected ~23 tokens, got {}",
tokens
);
} }
#[test] #[test]
@ -585,14 +778,15 @@ mod tests {
tool_result_trim_chars: 50, tool_result_trim_chars: 50,
..Default::default() ..Default::default()
}; };
let compressor = ContextCompressor::with_config(mock_provider(), 100_000, config, test_memory_manager()); let compressor =
ContextCompressor::with_config(mock_provider(), 100_000, config, test_memory_manager());
let mut messages = vec![ let mut messages = vec![
ChatMessage::user("Hello"), ChatMessage::user("Hello"),
ChatMessage::tool("call1", "bash", &"x".repeat(200)), ChatMessage::tool("call1", "bash", &"x".repeat(200)),
]; ];
let modified = compressor.fast_trim_tool_results(&mut messages); let modified = compressor.fast_trim_tool_results(&mut messages, 2);
assert_eq!(modified, 1); assert_eq!(modified, 1);
assert!(messages[1].content.len() < 100); assert!(messages[1].content.len() < 100);
} }
@ -619,14 +813,18 @@ mod tests {
max_passes: 0, max_passes: 0,
..Default::default() ..Default::default()
}; };
let compressor = ContextCompressor::with_config(mock_provider(), 200, config, mm); let mut compressor = ContextCompressor::with_config(mock_provider(), 200, config, mm);
let messages = vec![ let messages = vec![
ChatMessage::user("Hi"), ChatMessage::user("Hi"),
ChatMessage::tool("call1", "bash", &"x".repeat(3000)), ChatMessage::tool("call1", "bash", &"x".repeat(3000)),
]; ];
let result = compressor.compress_if_needed(messages).await.unwrap().history; let result = compressor
.compress_if_needed(messages)
.await
.unwrap()
.history;
let tool_msg = result.iter().find(|m| m.role == "tool").unwrap(); let tool_msg = result.iter().find(|m| m.role == "tool").unwrap();
assert!( assert!(
@ -650,18 +848,19 @@ mod tests {
// - B2B (L275): last user message lost when it is the final history message // - B2B (L275): last user message lost when it is the final history message
// //
// context_window=200 → threshold=100. Large tool outputs force LLM summarization. // context_window=200 → threshold=100. Large tool outputs force LLM summarization.
let tmp = std::env::temp_dir().join(format!("picobot_ctx_boundary_{}.db", std::process::id())); let tmp =
std::env::temp_dir().join(format!("picobot_ctx_boundary_{}.db", std::process::id()));
let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap()); let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap());
let mm = Arc::new(MemoryManager::new(storage, "test".into(), "test".into())); let mm = Arc::new(MemoryManager::new(storage, "test".into(), "test".into()));
let config = ContextCompressionConfig { let config = ContextCompressionConfig {
tool_result_trim_chars: 2000, tool_result_trim_chars: 2000,
protect_first_n: 1, // system/protected → B2A: first user (after skip) duplicated protect_first_n: 1, // system/protected → B2A: first user (after skip) duplicated
protect_last_n: 2, protect_last_n: 2,
max_passes: 1, max_passes: 1,
..Default::default() ..Default::default()
}; };
let compressor = ContextCompressor::with_config(mock_summarizer(), 200, config, mm); let mut compressor = ContextCompressor::with_config(mock_summarizer(), 200, config, mm);
// History: 9 messages, last message is user Q4. // History: 9 messages, last message is user Q4.
// user_indices (skip 1) = [1, 3, 6, 8] // user_indices (skip 1) = [1, 3, 6, 8]
@ -670,25 +869,43 @@ mod tests {
let big = "x".repeat(3000); let big = "x".repeat(3000);
let messages = vec![ let messages = vec![
ChatMessage::system("You are a helper."), // 0: protected ChatMessage::system("You are a helper."), // 0: protected
ChatMessage::user("Q1"), // 1: first user ChatMessage::user("Q1"), // 1: first user
ChatMessage::tool("t1", "bash", &big), // 2 ChatMessage::tool("t1", "bash", &big), // 2
ChatMessage::user("Q2"), // 3 ChatMessage::user("Q2"), // 3
ChatMessage::assistant("thinking"), // 4 ChatMessage::assistant("thinking"), // 4
ChatMessage::tool("t2", "bash", &big), // 5 ChatMessage::tool("t2", "bash", &big), // 5
ChatMessage::user("Q3"), // 6 ChatMessage::user("Q3"), // 6
ChatMessage::assistant("thinking"), // 7 ChatMessage::assistant("thinking"), // 7
ChatMessage::user("Q4"), // 8: LAST, is user → B2B triggers ChatMessage::user("Q4"), // 8: LAST, is user → B2B triggers
]; ];
let result = compressor.compress_if_needed(messages).await.unwrap().history; let result = compressor
.compress_if_needed(messages)
.await
.unwrap()
.history;
// B2A: "Q1" must appear exactly once // B2A: "Q1" must appear exactly once
let q1_count = result.iter().filter(|m| m.role == "user" && m.content == "Q1").count(); let q1_count = result
assert_eq!(q1_count, 1, "Q1 should appear exactly once, got {}", q1_count); .iter()
.filter(|m| m.role == "user" && m.content == "Q1")
.count();
assert_eq!(
q1_count, 1,
"Q1 should appear exactly once, got {}",
q1_count
);
// B2B: "Q4" must NOT be lost // B2B: "Q4" must NOT be lost
let q4_count = result.iter().filter(|m| m.role == "user" && m.content == "Q4").count(); let q4_count = result
assert_eq!(q4_count, 1, "Q4 should appear exactly once (not lost), got {}", q4_count); .iter()
.filter(|m| m.role == "user" && m.content == "Q4")
.count();
assert_eq!(
q4_count, 1,
"Q4 should appear exactly once (not lost), got {}",
q4_count
);
let _ = std::fs::remove_file(&tmp); let _ = std::fs::remove_file(&tmp);
} }
@ -702,16 +919,16 @@ mod tests {
let mm = Arc::new(MemoryManager::new(storage, "test".into(), "test".into())); let mm = Arc::new(MemoryManager::new(storage, "test".into(), "test".into()));
let config = ContextCompressionConfig { let config = ContextCompressionConfig {
tool_result_trim_chars: 500, // trim reduces but not enough tool_result_trim_chars: 500, // trim reduces but not enough
protect_first_n: 1, protect_first_n: 1,
protect_last_n: 2, protect_last_n: 2,
max_passes: 0, // no LLM summarization → will exceed danger max_passes: 0, // no LLM summarization → will exceed danger
..Default::default() ..Default::default()
}; };
// context_window=100, danger_threshold=90. // context_window=100, danger_threshold=90.
// Each trimmed tool (~500 chars): ceil(500/4)+4 = 129 raw. 3 tools = 387. // Each trimmed tool (~500 chars): ceil(500/4)+4 = 129 raw. 3 tools = 387.
// Plus users (~5 each) + system (~15) = ~417 raw * 1.2 = 500 > 90. // Plus users (~5 each) + system (~15) = ~417 raw * 1.2 = 500 > 90.
let compressor = ContextCompressor::with_config(mock_provider(), 100, config, mm); let mut compressor = ContextCompressor::with_config(mock_provider(), 100, config, mm);
let big = "x".repeat(3000); let big = "x".repeat(3000);
let messages = vec![ let messages = vec![
@ -724,13 +941,23 @@ mod tests {
ChatMessage::tool("t3", "bash", &big), ChatMessage::tool("t3", "bash", &big),
]; ];
let result = compressor.compress_if_needed(messages).await.unwrap().history; let result = compressor
.compress_if_needed(messages)
.await
.unwrap()
.history;
// After hard truncation: head (1) + trunc_note (1) + tail (2) = 4 messages // After hard truncation: head (1) + trunc_note (1) + tail (2) = 4 messages
assert!(result.len() < 7, "expected truncation reduction, got {} messages", result.len()); assert!(
result.len() < 7,
"expected truncation reduction, got {} messages",
result.len()
);
// Truncation notice should be present // Truncation notice should be present
let has_notice = result.iter().any(|m| m.content.contains("Context truncation")); let has_notice = result
.iter()
.any(|m| m.content.contains("Context truncation"));
assert!(has_notice, "hard truncation notice missing"); assert!(has_notice, "hard truncation notice missing");
let _ = std::fs::remove_file(&tmp); let _ = std::fs::remove_file(&tmp);
@ -745,9 +972,9 @@ mod tests {
let mut messages = vec![ let mut messages = vec![
ChatMessage::user("Q1"), ChatMessage::user("Q1"),
ChatMessage::user("[Context Summary]\n\nsummary of previous turn"), ChatMessage::user("[Context Summary]\n\nsummary of previous turn"),
ChatMessage::tool("tc1", "bash", "orphan result"), // orphan — tc1 never declared ChatMessage::tool("tc1", "bash", "orphan result"), // orphan — tc1 never declared
ChatMessage::assistant("done"), // declares tc2 ChatMessage::assistant("done"), // declares tc2
ChatMessage::tool("tc2", "bash", "legitimate result"), // legit ChatMessage::tool("tc2", "bash", "legitimate result"), // legit
]; ];
// Set tool_call_id on tool messages and tool_calls on assistant // Set tool_call_id on tool messages and tool_calls on assistant
messages[2].tool_call_id = Some("tc1".into()); messages[2].tool_call_id = Some("tc1".into());
@ -762,8 +989,16 @@ mod tests {
// orphan should be removed; legitimate should stay // orphan should be removed; legitimate should stay
assert_eq!(messages.len(), 4); assert_eq!(messages.len(), 4);
assert!(messages.iter().all(|m| m.tool_call_id != Some("tc1".into()))); assert!(
assert!(messages.iter().any(|m| m.tool_call_id == Some("tc2".into()))); messages
.iter()
.all(|m| m.tool_call_id != Some("tc1".into()))
);
assert!(
messages
.iter()
.any(|m| m.tool_call_id == Some("tc2".into()))
);
} }
#[test] #[test]

View File

@ -49,7 +49,7 @@ impl MediaHandler for ImageHandler {
} }
fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> { fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> {
use base64::{engine::general_purpose::STANDARD, Engine as _}; use base64::{Engine as _, engine::general_purpose::STANDARD};
let mut file = std::fs::File::open(path)?; let mut file = std::fs::File::open(path)?;
let mut buffer = Vec::new(); let mut buffer = Vec::new();

View File

@ -1,8 +1,16 @@
pub mod agent_loop; pub mod agent_loop;
pub mod context_compressor; pub mod context_compressor;
pub mod media_handler; pub mod media_handler;
pub mod sub_agent;
pub mod system_prompt; pub mod system_prompt;
pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult}; pub use agent_loop::{AgentError, AgentLoop, AgentProcessResult};
pub use context_compressor::{ContextCompressor, estimate_tokens}; pub use context_compressor::{ContextCompressor, estimate_tokens};
pub use system_prompt::{build_system_prompt, PromptContext, PromptSection, SystemPromptBuilder}; pub use sub_agent::{
DelegateContext, ExecutionMode, SubAgentConfig, SubAgentError, SubAgentManager, SubAgentResult,
TaskNotification, TaskStatus,
};
pub use system_prompt::{
PromptContext, PromptSection, SystemPromptBuilder, build_sub_agent_system_prompt,
build_system_prompt,
};

623
src/agent/sub_agent.rs Normal file
View File

@ -0,0 +1,623 @@
use std::collections::HashSet;
use std::sync::Arc;
use std::time::Instant;
use dashmap::DashMap;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use crate::agent::AgentError;
use crate::agent::AgentLoop;
use crate::agent::system_prompt::build_sub_agent_system_prompt;
use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig;
use crate::providers::{LLMProvider, create_provider};
use crate::skills::SkillsLoader;
use crate::tools::ToolRegistry;
tokio::task_local! {
pub(crate) static DELEGATE_CONTEXT: DelegateContext;
}
/// Read the delegate context from the current task. Returns an error if not set.
pub fn get_delegate_context() -> Result<DelegateContext, String> {
DELEGATE_CONTEXT
.try_with(|ctx| ctx.clone())
.map_err(|_| "DELEGATE_CONTEXT not set".to_string())
}
const DEFAULT_MAX_ITERATIONS: usize = 99;
const DEFAULT_TIMEOUT_SECS: u64 = 3600;
const MAX_INLINE_RESULT_CHARS: usize = 8000;
const DEFAULT_READONLY_TOOLS: &[&str] = &[
"file_read",
"file_search",
"content_search",
"web_fetch",
"http_request",
"calculator",
];
#[derive(Debug, Clone)]
pub struct SubAgentConfig {
pub prompt: String,
pub mode: ExecutionMode,
pub allowed_tools: Option<Vec<String>>,
pub max_iterations: Option<usize>,
pub timeout_secs: Option<u64>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ExecutionMode {
Inline,
Background,
Parallel,
}
#[derive(Debug, Clone)]
pub struct SubAgentResult {
pub task_id: String,
pub content: String,
pub content_truncated: bool,
pub status: TaskStatus,
pub tool_calls_count: usize,
pub iterations: usize,
pub duration_ms: u64,
}
#[derive(Debug, Clone)]
pub enum TaskStatus {
Completed,
Failed(String),
Cancelled,
TimedOut,
}
#[derive(Debug, Clone)]
pub struct TaskNotification {
pub task_id: String,
pub session_id: String,
pub channel: String,
pub chat_id: String,
pub status: TaskStatus,
pub result_summary: String,
}
#[derive(Debug, Clone)]
pub struct DelegateContext {
pub session_id: String,
pub channel: String,
pub chat_id: String,
}
#[derive(Debug)]
pub enum SubAgentError {
TooManyTasks(usize),
ProviderCreation(String),
Storage(String),
Other(String),
}
impl std::fmt::Display for SubAgentError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TooManyTasks(max) => write!(f, "后台任务已达上限({}),请稍后重试", max),
Self::ProviderCreation(e) => write!(f, "provider creation failed: {}", e),
Self::Storage(e) => write!(f, "storage error: {}", e),
Self::Other(e) => write!(f, "{}", e),
}
}
}
impl std::error::Error for SubAgentError {}
pub struct SubAgentManager {
provider_config: LLMProviderConfig,
full_tools: Arc<ToolRegistry>,
storage: Option<Arc<crate::storage::Storage>>,
active_tasks: Arc<DashMap<String, CancellationToken>>,
notify_tx: tokio::sync::mpsc::UnboundedSender<TaskNotification>,
max_concurrent_background_tasks: usize,
skills_loader: Option<Arc<SkillsLoader>>,
}
impl SubAgentManager {
pub fn new(
provider_config: LLMProviderConfig,
full_tools: Arc<ToolRegistry>,
storage: Option<Arc<crate::storage::Storage>>,
notify_tx: tokio::sync::mpsc::UnboundedSender<TaskNotification>,
max_concurrent_background_tasks: usize,
skills_loader: Option<Arc<SkillsLoader>>,
) -> Self {
Self {
provider_config,
full_tools,
storage,
active_tasks: Arc::new(DashMap::new()),
notify_tx,
max_concurrent_background_tasks,
skills_loader,
}
}
pub fn filter_tools(&self, allowed: &Option<Vec<String>>) -> Arc<ToolRegistry> {
let allowed_set: HashSet<&str> = match allowed {
Some(list) => list.iter().map(|s| s.as_str()).collect(),
None => DEFAULT_READONLY_TOOLS.iter().copied().collect(),
};
let filtered = ToolRegistry::new();
for (name, tool) in self.full_tools.iter() {
if allowed_set.contains(name.as_str()) && name != "delegate" {
filtered.register_raw(name, tool);
}
}
Arc::new(filtered)
}
fn get_skills_prompt(&self, tools: &ToolRegistry) -> Option<String> {
let has_get_skill = tools.iter().iter().any(|(name, _)| name == "get_skill");
if has_get_skill {
if let Some(ref loader) = self.skills_loader {
let prompt = loader.build_skills_prompt();
if !prompt.is_empty() {
return Some(prompt);
}
}
}
None
}
pub fn build_sub_agent(
&self,
config: &SubAgentConfig,
tools: Arc<ToolRegistry>,
) -> Result<AgentLoop, AgentError> {
let mut provider = create_provider(self.provider_config.clone())
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
if let Some(ref s) = self.storage {
provider.set_storage(s.clone());
}
let provider: Arc<dyn LLMProvider> = Arc::from(provider);
let max_iterations = config.max_iterations.unwrap_or(DEFAULT_MAX_ITERATIONS);
let workspace_dir = self.provider_config.workspace_dir.clone();
let model_name = self.provider_config.model_id.clone();
let input_types = self.provider_config.input_types.clone();
let agent = AgentLoop::with_provider_and_tools(
provider,
tools,
max_iterations,
model_name,
workspace_dir,
input_types,
)
.with_context_window(self.provider_config.token_limit);
Ok(agent)
}
pub async fn run_inline(
&self,
config: SubAgentConfig,
) -> Result<SubAgentResult, SubAgentError> {
let task_id = generate_task_id();
let tools = self.filter_tools(&config.allowed_tools);
let timeout_secs = config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
let timeout_human = format_duration(timeout_secs);
let http_get_only = config.allowed_tools.is_none()
|| config
.allowed_tools
.as_ref()
.is_some_and(|v| v.iter().any(|t| t == "http_request"));
let skills_prompt = self.get_skills_prompt(&tools);
let system_prompt = build_sub_agent_system_prompt(
&config.prompt,
&timeout_human,
&tools,
&self.provider_config.workspace_dir,
&self.provider_config.model_id,
skills_prompt,
http_get_only,
);
let agent = self
.build_sub_agent(&config, tools)
.map_err(|e| SubAgentError::ProviderCreation(e.to_string()))?;
let history = vec![
ChatMessage::system(system_prompt),
ChatMessage::user(&config.prompt),
];
let start = Instant::now();
let result = tokio::time::timeout(
std::time::Duration::from_secs(timeout_secs),
agent.process(history),
)
.await;
let duration_ms = start.elapsed().as_millis() as u64;
match result {
Ok(Ok(agent_result)) => {
let (content, truncated) =
truncate_sub_agent_result(&agent_result.final_response.content);
let tool_calls_count = agent_result
.emitted_messages
.iter()
.filter(|m| m.tool_calls.is_some())
.count();
let iterations = agent_result
.emitted_messages
.iter()
.filter(|m| m.role == "assistant" && m.tool_calls.is_some())
.count();
Ok(SubAgentResult {
task_id,
content,
content_truncated: truncated,
status: TaskStatus::Completed,
tool_calls_count,
iterations,
duration_ms,
})
}
Ok(Err(e)) => Ok(SubAgentResult {
task_id,
content: String::new(),
content_truncated: false,
status: TaskStatus::Failed(e.to_string()),
tool_calls_count: 0,
iterations: 0,
duration_ms,
}),
Err(_elapsed) => Ok(SubAgentResult {
task_id,
content: String::new(),
content_truncated: false,
status: TaskStatus::TimedOut,
tool_calls_count: 0,
iterations: 0,
duration_ms,
}),
}
}
pub async fn run_parallel(
&self,
configs: Vec<SubAgentConfig>,
) -> Result<Vec<SubAgentResult>, SubAgentError> {
let futures: Vec<_> = configs
.into_iter()
.map(|config| {
let mgr = self; // &self borrow, all tasks share the same manager
async move { mgr.run_inline(config).await }
})
.collect();
let results = futures_util::future::join_all(futures).await;
Ok(results.into_iter().collect::<Result<Vec<_>, _>>()?)
}
pub async fn run_background(
&self,
config: SubAgentConfig,
ctx: DelegateContext,
) -> Result<String, SubAgentError> {
if self.active_tasks.len() >= self.max_concurrent_background_tasks {
return Err(SubAgentError::TooManyTasks(
self.max_concurrent_background_tasks,
));
}
let task_id = generate_task_id();
let cancel_token = CancellationToken::new();
self.active_tasks
.insert(task_id.clone(), cancel_token.clone());
// Write DB: pending
if let Some(ref storage) = self.storage {
let allowed_tools_json = config
.allowed_tools
.as_ref()
.and_then(|v| serde_json::to_string(v).ok());
let record = crate::storage::BackgroundTask {
id: task_id.clone(),
session_id: ctx.session_id.clone(),
channel: ctx.channel.clone(),
chat_id: ctx.chat_id.clone(),
prompt: config.prompt.clone(),
allowed_tools: allowed_tools_json,
status: "pending".to_string(),
result: None,
error: None,
tool_calls_count: 0,
iterations: 0,
started_at: None,
finished_at: None,
created_at: chrono::Utc::now().timestamp_millis(),
};
storage
.create_background_task(&record)
.await
.map_err(|e| SubAgentError::Storage(e.to_string()))?;
}
let tools = self.filter_tools(&config.allowed_tools);
let timeout_secs = config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
let timeout_human = format_duration(timeout_secs);
let http_get_only = config.allowed_tools.is_none()
|| config
.allowed_tools
.as_ref()
.is_some_and(|v| v.iter().any(|t| t == "http_request"));
let skills_prompt = self.get_skills_prompt(&tools);
let system_prompt = build_sub_agent_system_prompt(
&config.prompt,
&timeout_human,
&tools,
&self.provider_config.workspace_dir,
&self.provider_config.model_id,
skills_prompt,
http_get_only,
);
let provider_config = self.provider_config.clone();
let storage = self.storage.clone();
let notify_tx = self.notify_tx.clone();
let active_tasks = Arc::clone(&self.active_tasks);
let tid = task_id.clone();
let sess_id = ctx.session_id.clone();
let ch = ctx.channel.clone();
let cid = ctx.chat_id.clone();
let prompt = config.prompt.clone();
tokio::spawn(async move {
let started_at = chrono::Utc::now().timestamp_millis();
// Update DB: running
if let Some(ref s) = storage {
let _ = s
.update_background_task_status(
&tid,
"running",
None,
None,
Some(started_at),
None,
)
.await;
}
let mut provider = create_provider(provider_config.clone()).ok();
if let Some(ref mut p) = provider {
if let Some(ref s) = storage {
p.set_storage(s.clone());
}
}
let provider_result: Option<Arc<dyn LLMProvider>> = provider.map(|p| Arc::from(p));
let result = match provider_result {
Some(provider) => {
let agent = AgentLoop::with_provider_and_tools(
provider,
tools,
DEFAULT_MAX_ITERATIONS,
provider_config.model_id.clone(),
provider_config.workspace_dir.clone(),
provider_config.input_types.clone(),
)
.with_context_window(provider_config.token_limit);
let history = vec![
ChatMessage::system(system_prompt),
ChatMessage::user(&prompt),
];
tokio::select! {
r = tokio::time::timeout(
std::time::Duration::from_secs(timeout_secs),
agent.process(history),
) => {
match r {
Ok(Ok(agent_result)) => SubAgentResult {
task_id: tid.clone(),
content: agent_result.final_response.content,
content_truncated: false,
status: TaskStatus::Completed,
tool_calls_count: 0,
iterations: 0,
duration_ms: 0,
},
Ok(Err(e)) => SubAgentResult {
task_id: tid.clone(),
content: String::new(),
content_truncated: false,
status: TaskStatus::Failed(e.to_string()),
tool_calls_count: 0,
iterations: 0,
duration_ms: 0,
},
Err(_) => SubAgentResult {
task_id: tid.clone(),
content: String::new(),
content_truncated: false,
status: TaskStatus::TimedOut,
tool_calls_count: 0,
iterations: 0,
duration_ms: 0,
},
}
}
_ = cancel_token.cancelled() => SubAgentResult {
task_id: tid.clone(),
content: String::new(),
content_truncated: false,
status: TaskStatus::Cancelled,
tool_calls_count: 0,
iterations: 0,
duration_ms: 0,
},
}
}
None => SubAgentResult {
task_id: tid.clone(),
content: String::new(),
content_truncated: false,
status: TaskStatus::Failed("provider creation failed".into()),
tool_calls_count: 0,
iterations: 0,
duration_ms: 0,
},
};
let finished_at = chrono::Utc::now().timestamp_millis();
let duration_ms = (finished_at - started_at) as u64;
let (status_str, error_val) = match &result.status {
TaskStatus::Completed => ("completed".to_string(), None),
TaskStatus::Failed(e) => ("failed".to_string(), Some(e.clone())),
TaskStatus::Cancelled => ("cancelled".to_string(), None),
TaskStatus::TimedOut => ("failed".to_string(), Some("timeout".to_string())),
};
if let Some(ref s) = storage {
let _ = s
.update_background_task_status(
&tid,
&status_str,
Some(&result.content),
error_val.as_deref(),
Some(started_at),
Some(finished_at),
)
.await;
}
let _ = notify_tx.send(TaskNotification {
task_id: tid.clone(),
session_id: sess_id,
channel: ch,
chat_id: cid,
status: result.status,
result_summary: summarize_for_notification(&result.content, duration_ms),
});
active_tasks.remove(&tid);
});
Ok(task_id)
}
pub async fn cancel_task(&self, task_id: &str) -> Result<bool, SubAgentError> {
if let Some((_, token)) = self.active_tasks.remove(task_id) {
token.cancel();
if let Some(ref s) = self.storage {
s.update_background_task_status(
task_id,
"cancelled",
None,
None,
None,
Some(chrono::Utc::now().timestamp_millis()),
)
.await
.map_err(|e| SubAgentError::Storage(e.to_string()))?;
}
Ok(true)
} else if let Some(ref s) = self.storage {
match s.get_background_task(task_id).await {
Ok(task) => match task.status.as_str() {
"pending" | "running" => {
tracing::warn!(task_id, "task in DB but not in active_tasks");
Ok(false)
}
_ => Ok(false),
},
Err(_) => Ok(false),
}
} else {
Ok(false)
}
}
pub async fn check_task(&self, task_id: &str) -> Option<crate::storage::BackgroundTask> {
if let Some(ref s) = self.storage {
s.get_background_task(task_id).await.ok()
} else {
None
}
}
pub async fn list_tasks(&self, session_id: &str) -> Vec<crate::storage::BackgroundTask> {
if let Some(ref s) = self.storage {
s.list_background_tasks(session_id)
.await
.unwrap_or_default()
} else {
vec![]
}
}
pub async fn cancel_by_session(&self, session_id: &str) {
// Cancel all running tasks for a session by checking DB
if let Some(ref s) = self.storage {
if let Ok(tasks) = s.list_background_tasks(session_id).await {
for task in &tasks {
if task.status == "pending" || task.status == "running" {
let _ = self.cancel_task(&task.id).await;
}
}
}
}
}
pub fn active_task_count(&self) -> usize {
self.active_tasks.len()
}
}
fn generate_task_id() -> String {
Uuid::new_v4().to_string()[..8].to_string()
}
fn format_duration(seconds: u64) -> String {
if seconds < 60 {
format!("{}s", seconds)
} else if seconds < 3600 {
format!("{}m", seconds / 60)
} else {
format!("{}h", seconds / 3600)
}
}
fn truncate_sub_agent_result(content: &str) -> (String, bool) {
if content.len() <= MAX_INLINE_RESULT_CHARS {
(content.to_string(), false)
} else {
let truncate_at = content.floor_char_boundary(MAX_INLINE_RESULT_CHARS);
(
format!(
"{}\n\n[... 结果已截断,共 {} 字符,完整结果请使用 check_task 查看 ...]",
&content[..truncate_at],
content.len()
),
true,
)
}
}
fn summarize_for_notification(content: &str, _duration_ms: u64) -> String {
const MAX_SUMMARY_BYTES: usize = 500;
if content.len() <= MAX_SUMMARY_BYTES {
content.to_string()
} else {
let truncate_at = content.floor_char_boundary(MAX_SUMMARY_BYTES);
format!("{}...", &content[..truncate_at])
}
}

View File

@ -3,11 +3,7 @@
//! This module provides a modular framework for building system prompts //! This module provides a modular framework for building system prompts
//! using the SystemPromptBuilder pattern. //! using the SystemPromptBuilder pattern.
//! //!
//! Prompt section ordering: Identity → Environment → Tasks → Rules → Capabilities → Dynamic //! Prompt section ordering: Identity → Environment → Tasks → Rules → Capabilities → Dynamic → Delegation
//!
//! Configuration files loaded from ~/.picobot/:
//! - AGENTS.md — agent identity and behavior
//! - USER.md — user preferences and profile
use crate::tools::ToolRegistry; use crate::tools::ToolRegistry;
use std::path::Path; use std::path::Path;
@ -55,10 +51,35 @@ impl SystemPromptBuilder {
Box::new(CrossChannelSection), Box::new(CrossChannelSection),
Box::new(MemorySection), Box::new(MemorySection),
Box::new(HistorySection), Box::new(HistorySection),
Box::new(DelegationSection),
], ],
} }
} }
/// Create a builder with sub-agent specific sections.
pub fn with_sub_agent_defaults(
task: &str,
timeout: &str,
skills_prompt: Option<String>,
http_get_only: bool,
) -> Self {
let mut sections: Vec<Box<dyn PromptSection>> = vec![
Box::new(SubAgentIdentitySection {
task: task.to_string(),
timeout: timeout.to_string(),
}),
Box::new(ToolHonestySection),
Box::new(SafetySection),
Box::new(SubAgentToolsSection { http_get_only }),
Box::new(WorkspaceSection),
Box::new(DateTimeSection),
];
if let Some(sp) = skills_prompt {
sections.push(Box::new(SubAgentSkillsSection { skills_prompt: sp }));
}
Self { sections }
}
/// Add a custom section to the builder. /// Add a custom section to the builder.
pub fn add_section(mut self, section: Box<dyn PromptSection>) -> Self { pub fn add_section(mut self, section: Box<dyn PromptSection>) -> Self {
self.sections.push(section); self.sections.push(section);
@ -175,10 +196,10 @@ impl PromptSection for UserProfileSection {
if let Some(user_config_dir) = get_user_config_dir() if let Some(user_config_dir) = get_user_config_dir()
&& let Some(content) = && let Some(content) =
load_file_from_dir(&user_config_dir, "USER.md", BOOTSTRAP_MAX_CHARS) load_file_from_dir(&user_config_dir, "USER.md", BOOTSTRAP_MAX_CHARS)
{ {
output.push_str(&content); output.push_str(&content);
return output; return output;
} }
// No USER.md found, return empty // No USER.md found, return empty
String::new() String::new()
@ -199,10 +220,10 @@ impl PromptSection for AgentProfileSection {
if let Some(user_config_dir) = get_user_config_dir() if let Some(user_config_dir) = get_user_config_dir()
&& let Some(content) = && let Some(content) =
load_file_from_dir(&user_config_dir, "AGENTS.md", BOOTSTRAP_MAX_CHARS) load_file_from_dir(&user_config_dir, "AGENTS.md", BOOTSTRAP_MAX_CHARS)
{ {
output.push_str(&content); output.push_str(&content);
return output; return output;
} }
String::new() String::new()
} }
@ -353,6 +374,120 @@ impl PromptSection for HistorySection {
} }
} }
/// Sub-agent delegation principles.
pub struct DelegationSection;
impl PromptSection for DelegationSection {
fn name(&self) -> &str {
"delegation"
}
fn build(&self, _ctx: &PromptContext<'_>) -> String {
"## 子 Agent 委托原则\n\n\
使 delegate Agent\n\
\n\
### \n\
- 使 mode=\"parallel\"\n\
- 使 mode=\"background\"\n\
- \n\
\n\
### \n\
- **** Agent \n\
- **** file_readfile_searchweb_fetch bashfile_writefile_edit\n\
- **** delegate Agent\n\
- **** Agent \n\
\n\
### Skill \n\
- skill allowed_tools get_skill\n\
- prompt Agent 使 get_skill \n\
- \"使用 get_skill action='get' skill_name='pdf' 加载 PDF 处理技能后完成任务\"\n\
\n\
### \n\
- prompt \n\
- prompt \"跳过 .tmp 文件\"\n\
- \n\
\n\
### \n\
- 使 mode=\"parallel\",任务定义在 tasks 数组中\n\
- \n\
- 5 \n\
\n\
### \n\
- 30s 使 mode=\"background\"\n\
- ".to_string()
}
}
// === Sub-Agent Prompt Sections ===
/// Sub-agent identity and task instructions.
pub struct SubAgentIdentitySection {
pub task: String,
pub timeout: String,
}
impl PromptSection for SubAgentIdentitySection {
fn name(&self) -> &str {
"sub_agent_identity"
}
fn build(&self, _ctx: &PromptContext<'_>) -> String {
format!(
"## 子 Agent\n\n\
Agent Agent Agent\n\
\n\
## \n\n\
{}\n\
\n\
## \n\
- \n\
- 使\n\
- 使 delegate \n\
- \n\
- \n\
- {}",
self.task, self.timeout,
)
}
}
/// Sub-agent available tools description.
pub struct SubAgentToolsSection {
pub http_get_only: bool,
}
impl PromptSection for SubAgentToolsSection {
fn name(&self) -> &str {
"sub_agent_tools"
}
fn build(&self, ctx: &PromptContext<'_>) -> String {
let mut s = String::from("## 可用工具\n\n");
s.push_str(&ctx.tools.describe_for_prompt());
if self.http_get_only {
s.push_str(
"\n\n**注意**:使用 http_request 时只允许 GET 方法,禁止 POST、PUT、DELETE 等。",
);
}
s
}
}
/// Sub-agent skills information, injected when get_skill tool is available.
pub struct SubAgentSkillsSection {
pub skills_prompt: String,
}
impl PromptSection for SubAgentSkillsSection {
fn name(&self) -> &str {
"sub_agent_skills"
}
fn build(&self, _ctx: &PromptContext<'_>) -> String {
self.skills_prompt.clone()
}
}
// === Helper Functions === // === Helper Functions ===
/// Get user config directory (~/.picobot/). /// Get user config directory (~/.picobot/).
@ -409,6 +544,28 @@ pub fn build_system_prompt(
SystemPromptBuilder::with_defaults().build(&ctx) SystemPromptBuilder::with_defaults().build(&ctx)
} }
/// Build a system prompt for a sub-agent with all relevant operational sections.
pub fn build_sub_agent_system_prompt(
task: &str,
timeout_human: &str,
tools: &ToolRegistry,
workspace_dir: &Path,
model_name: &str,
skills_prompt: Option<String>,
http_get_only: bool,
) -> String {
let ctx = PromptContext {
workspace_dir,
model_name,
tools,
session_id: None,
memory_context: None,
has_compressed_history: false,
};
SystemPromptBuilder::with_sub_agent_defaults(task, timeout_human, skills_prompt, http_get_only)
.build(&ctx)
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View File

@ -1,8 +1,8 @@
use std::sync::Arc; use std::sync::Arc;
use crate::bus::{MessageBus, OutboundMessage}; use crate::bus::{MessageBus, OutboundMessage};
use crate::channels::base::{Channel, ChannelError};
use crate::channels::ChannelManager; use crate::channels::ChannelManager;
use crate::channels::base::{Channel, ChannelError};
/// OutboundDispatcher consumes outbound messages from the MessageBus /// OutboundDispatcher consumes outbound messages from the MessageBus
/// and dispatches them to the appropriate Channel /// and dispatches them to the appropriate Channel

View File

@ -1,5 +1,5 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::providers::ToolCall; use crate::providers::ToolCall;
@ -23,7 +23,9 @@ pub struct ImageUrlBlock {
impl ContentBlock { impl ContentBlock {
pub fn text(content: impl Into<String>) -> Self { pub fn text(content: impl Into<String>) -> Self {
Self::Text { text: content.into() } Self::Text {
text: content.into(),
}
} }
pub fn image_url(url: impl Into<String>) -> Self { pub fn image_url(url: impl Into<String>) -> Self {
@ -49,10 +51,10 @@ pub struct MediaRef {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct MediaItem { pub struct MediaItem {
pub path: String, // Local file path pub path: String, // Local file path
pub media_type: String, // "image", "audio", "file", "video" pub media_type: String, // "image", "audio", "file", "video"
pub mime_type: Option<String>, pub mime_type: Option<String>,
pub original_key: Option<String>, // Feishu file_key for download pub original_key: Option<String>, // Feishu file_key for download
} }
impl MediaItem { impl MediaItem {
@ -161,7 +163,10 @@ impl ChatMessage {
} }
} }
pub fn assistant_with_tool_calls(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self { pub fn assistant_with_tool_calls(
content: impl Into<String>,
tool_calls: Vec<ToolCall>,
) -> Self {
Self { Self {
id: uuid::Uuid::new_v4().to_string(), id: uuid::Uuid::new_v4().to_string(),
role: "assistant".to_string(), role: "assistant".to_string(),
@ -206,7 +211,11 @@ impl ChatMessage {
} }
} }
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self { pub fn tool(
tool_call_id: impl Into<String>,
tool_name: impl Into<String>,
content: impl Into<String>,
) -> Self {
Self { Self {
id: uuid::Uuid::new_v4().to_string(), id: uuid::Uuid::new_v4().to_string(),
role: "tool".to_string(), role: "tool".to_string(),

View File

@ -2,10 +2,13 @@ pub mod dispatcher;
pub mod message; pub mod message;
pub use dispatcher::OutboundDispatcher; pub use dispatcher::OutboundDispatcher;
pub use message::{ChatMessage, ContentBlock, ControlMessage, InboundMessage, MediaItem, MediaRef, MessageSource, OutboundMessage, SourceKind}; pub use message::{
ChatMessage, ContentBlock, ControlMessage, InboundMessage, MediaItem, MediaRef, MessageSource,
OutboundMessage, SourceKind,
};
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::{mpsc, Mutex}; use tokio::sync::{Mutex, mpsc};
// ============================================================================ // ============================================================================
// MessageBus - Async message queue for Channel <-> Agent communication // MessageBus - Async message queue for Channel <-> Agent communication
@ -49,7 +52,8 @@ impl MessageBus {
/// Consume an inbound message (Agent -> Bus) /// Consume an inbound message (Agent -> Bus)
pub async fn consume_inbound(&self) -> InboundMessage { pub async fn consume_inbound(&self) -> InboundMessage {
let msg = self.inbound_rx let msg = self
.inbound_rx
.lock() .lock()
.await .await
.recv() .recv()

View File

@ -1,10 +1,10 @@
use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use tokio::sync::{mpsc, Mutex}; use std::sync::Arc;
use tokio::sync::{Mutex, mpsc};
use crate::bus::{ControlMessage, InboundMessage, MessageBus, OutboundMessage}; use crate::bus::{ControlMessage, InboundMessage, MessageBus, OutboundMessage};
use crate::protocol::{SlashCommandInfo, WsInbound, WsOutbound, parse_inbound};
use crate::session::{SessionCommand, SessionEvent, UnifiedSessionId}; use crate::session::{SessionCommand, SessionEvent, UnifiedSessionId};
use crate::protocol::{parse_inbound, WsInbound, WsOutbound, SlashCommandInfo};
use super::base::{Channel, ChannelError}; use super::base::{Channel, ChannelError};
@ -14,6 +14,7 @@ use super::base::{Channel, ChannelError};
pub(crate) struct Client { pub(crate) struct Client {
sender: mpsc::Sender<WsOutbound>, sender: mpsc::Sender<WsOutbound>,
chat_id: String,
current_session_id: Mutex<Option<String>>, current_session_id: Mutex<Option<String>>,
} }
@ -41,23 +42,28 @@ impl CliChatChannel {
} }
/// Register a new client connection, returns (session_id, client) /// Register a new client connection, returns (session_id, client)
pub(crate) async fn register_client(&self, sender: mpsc::Sender<WsOutbound>) -> (String, Arc<Client>) { pub(crate) async fn register_client(
// Generate connection ID (used as chat_id) - use short ID &self,
let connection_id = crate::util::short_id(); sender: mpsc::Sender<WsOutbound>,
) -> (String, Arc<Client>) {
// Each WebSocket connection gets a stable chat scope. All user input and
// dialog controls for this client stay inside that scope unless the
// protocol explicitly carries a full session id.
let chat_id = crate::util::short_id();
let client = Arc::new(Client { let client = Arc::new(Client {
sender, sender,
chat_id: chat_id.clone(),
current_session_id: Mutex::new(None), current_session_id: Mutex::new(None),
}); });
self.clients.lock().await.push(client.clone()); self.clients.lock().await.push(client.clone());
// Create initial session via control message // Create initial session via control message
let session_id = match self.create_session_via_control(&connection_id, None).await { let session_id = match self.create_session_via_control(&chat_id, None).await {
Ok(id) => id, Ok((id, _title)) => id,
Err(e) => { Err(e) => {
tracing::error!(error = %e, "Failed to create initial session"); tracing::error!(error = %e, "Failed to create initial session");
// Fall back to old format for backward compatibility UnifiedSessionId::new("cli_chat", &chat_id, &crate::util::short_id()).to_string()
connection_id.clone()
} }
}; };
@ -73,21 +79,19 @@ impl CliChatChannel {
/// Handle an inbound message from a client /// Handle an inbound message from a client
pub(crate) async fn handle_inbound(&self, client: Arc<Client>, raw_msg: &str) { pub(crate) async fn handle_inbound(&self, client: Arc<Client>, raw_msg: &str) {
match parse_inbound(raw_msg) { match parse_inbound(raw_msg) {
Ok(inbound) => { Ok(inbound) => match self.handle_ws_inbound(client.clone(), inbound).await {
match self.handle_ws_inbound(client.clone(), inbound).await { Ok(()) => {}
Ok(()) => {} Err(e) => {
Err(e) => { tracing::warn!(error = %e, "Failed to handle inbound message");
tracing::warn!(error = %e, "Failed to handle inbound message"); let _ = client
let _ = client .sender
.sender .send(WsOutbound::Error {
.send(WsOutbound::Error { code: "INTERNAL_ERROR".to_string(),
code: "INTERNAL_ERROR".to_string(), message: e.to_string(),
message: e.to_string(), })
}) .await;
.await;
}
} }
} },
Err(e) => { Err(e) => {
tracing::warn!(error = %e, "Failed to parse inbound message"); tracing::warn!(error = %e, "Failed to parse inbound message");
let _ = client let _ = client
@ -101,22 +105,30 @@ impl CliChatChannel {
} }
} }
async fn handle_ws_inbound(&self, client: Arc<Client>, inbound: WsInbound) -> Result<(), ChannelError> { async fn handle_ws_inbound(
&self,
client: Arc<Client>,
inbound: WsInbound,
) -> Result<(), ChannelError> {
let bus = { let bus = {
let guard = self.bus.lock().unwrap(); let guard = self.bus.lock().unwrap();
guard.clone().ok_or_else(|| ChannelError::Other("Channel not started".to_string()))? guard
.clone()
.ok_or_else(|| ChannelError::Other("Channel not started".to_string()))?
}; };
let mut current_session_guard = client.current_session_id.lock().await; let mut current_session_guard = client.current_session_id.lock().await;
match inbound { match inbound {
WsInbound::UserInput { content, chat_id, .. } => { WsInbound::UserInput {
content, chat_id, ..
} => {
// All messages (including slash commands) go through the normal inbound flow // All messages (including slash commands) go through the normal inbound flow
// SessionManager handles session creation/reuse internally // SessionManager handles session creation/reuse internally
let msg = InboundMessage { let msg = InboundMessage {
channel: self.name().to_string(), channel: self.name().to_string(),
sender_id: "cli".to_string(), sender_id: "cli".to_string(),
chat_id: chat_id.unwrap_or_else(crate::util::short_id), chat_id: chat_id.unwrap_or_else(|| client.chat_id.clone()),
content, content,
timestamp: crate::bus::message::current_timestamp(), timestamp: crate::bus::message::current_timestamp(),
media: Vec::new(), media: Vec::new(),
@ -125,19 +137,56 @@ impl CliChatChannel {
}; };
bus.publish_inbound(msg).await?; bus.publish_inbound(msg).await?;
} }
WsInbound::ClearHistory { chat_id, session_id } => { WsInbound::ClearHistory {
let target = session_id chat_id,
.or(chat_id) session_id,
.or(current_session_guard.clone()) } => {
.ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
let (reply_tx, mut reply_rx) = mpsc::channel(1); let (reply_tx, mut reply_rx) = mpsc::channel(1);
let session_id = UnifiedSessionId::parse(&target) let session_id = if let Some(session_id) = session_id {
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?; UnifiedSessionId::parse(&session_id).ok_or_else(|| {
ChannelError::Other("Invalid session ID format".to_string())
})?
} else if let Some(chat_id) = chat_id {
let (current_tx, mut current_rx) = mpsc::channel(1);
bus.publish_control(ControlMessage {
op: SessionCommand::GetCurrentDialog {
channel: "cli_chat".to_string(),
chat_id,
},
reply_tx: current_tx,
})
.await?;
match current_rx.recv().await {
Some(Ok(SessionEvent::CurrentDialog {
session_id: Some(session_id),
})) => session_id,
Some(Ok(SessionEvent::CurrentDialog { session_id: None })) => {
return Err(ChannelError::Other("No active session".to_string()));
}
Some(Ok(_)) => {
return Err(ChannelError::Other(
"Unexpected response type".to_string(),
));
}
Some(Err(e)) => return Err(e),
None => {
return Err(ChannelError::Other("Control channel closed".to_string()));
}
}
} else {
let target = current_session_guard
.clone()
.ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
UnifiedSessionId::parse(&target).ok_or_else(|| {
ChannelError::Other("Invalid session ID format".to_string())
})?
};
let target = session_id.to_string();
bus.publish_control(ControlMessage { bus.publish_control(ControlMessage {
op: SessionCommand::ClearHistory { session_id }, op: SessionCommand::ClearHistory { session_id },
reply_tx, reply_tx,
}).await?; })
.await?;
match reply_rx.recv().await { match reply_rx.recv().await {
Some(Ok(SessionEvent::HistoryCleared { .. })) => { Some(Ok(SessionEvent::HistoryCleared { .. })) => {
@ -158,24 +207,21 @@ impl CliChatChannel {
} }
} }
WsInbound::CreateSession { title } => { WsInbound::CreateSession { title } => {
// Use current session's chat_id if available, otherwise generate new one let (new_id, created_title) = self
let chat_id = current_session_guard.clone() .create_session_via_control(&client.chat_id, title.as_deref())
.unwrap_or_else(crate::util::short_id); .await?;
let new_id = self.create_session_via_control(&chat_id, title.as_deref()).await?;
*current_session_guard = Some(new_id.clone()); *current_session_guard = Some(new_id.clone());
let _ = client let _ = client
.sender .sender
.send(WsOutbound::SessionCreated { .send(WsOutbound::SessionCreated {
session_id: new_id, session_id: new_id,
title: title.unwrap_or_default(), title: created_title,
}) })
.await; .await;
} }
WsInbound::ListSessions { include_archived } => { WsInbound::ListSessions { include_archived } => {
// List dialogs for the current chat // List dialogs for the current chat
let chat_id = current_session_guard.clone() let chat_id = client.chat_id.clone();
.unwrap_or_else(|| "".to_string());
let chat_id_for_response = chat_id.clone();
let (reply_tx, mut reply_rx) = mpsc::channel(1); let (reply_tx, mut reply_rx) = mpsc::channel(1);
bus.publish_control(ControlMessage { bus.publish_control(ControlMessage {
op: SessionCommand::ListDialogs { op: SessionCommand::ListDialogs {
@ -184,13 +230,18 @@ impl CliChatChannel {
include_archived, include_archived,
}, },
reply_tx, reply_tx,
}).await?; })
.await?;
match reply_rx.recv().await { match reply_rx.recv().await {
Some(Ok(SessionEvent::DialogList { dialogs, current_dialog_id })) => { Some(Ok(SessionEvent::DialogList {
dialogs,
current_dialog_id,
})) => {
// Convert DialogInfo to SessionSummary for backward compatibility // Convert DialogInfo to SessionSummary for backward compatibility
let sessions: Vec<crate::protocol::SessionSummary> = dialogs.into_iter().map(|d| { let sessions: Vec<crate::protocol::SessionSummary> = dialogs
crate::protocol::SessionSummary { .into_iter()
.map(|d| crate::protocol::SessionSummary {
session_id: d.session_id.to_string(), session_id: d.session_id.to_string(),
title: d.title, title: d.title,
channel_name: d.session_id.channel.clone(), channel_name: d.session_id.channel.clone(),
@ -198,11 +249,14 @@ impl CliChatChannel {
message_count: d.message_count, message_count: d.message_count,
last_active_at: d.last_active_at, last_active_at: d.last_active_at,
archived_at: d.archived_at, archived_at: d.archived_at,
} })
}).collect(); .collect();
let current_session_id = current_dialog_id.map(|did| { let current_session_id = current_dialog_id.map(|did| {
UnifiedSessionId::new("cli_chat", chat_id_for_response.clone(), did).to_string() UnifiedSessionId::new("cli_chat", &client.chat_id, &did).to_string()
}); });
if let Some(ref session_id) = current_session_id {
*current_session_guard = Some(session_id.clone());
}
let _ = client let _ = client
.sender .sender
.send(WsOutbound::SessionList { .send(WsOutbound::SessionList {
@ -223,39 +277,35 @@ impl CliChatChannel {
} }
} }
WsInbound::LoadSession { session_id } => { WsInbound::LoadSession { session_id } => {
// LoadSession: parse the session_id and get current dialog info
let (reply_tx, mut reply_rx) = mpsc::channel(1); let (reply_tx, mut reply_rx) = mpsc::channel(1);
let unified_id = UnifiedSessionId::parse(&session_id) let unified_id = UnifiedSessionId::parse(&session_id)
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?; .ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
if unified_id.channel != "cli_chat" || unified_id.chat_id != client.chat_id {
return Err(ChannelError::Other(
"Session does not belong to this client".to_string(),
));
}
bus.publish_control(ControlMessage { bus.publish_control(ControlMessage {
op: SessionCommand::GetCurrentDialog { op: SessionCommand::SwitchDialog {
channel: unified_id.channel.clone(), channel: unified_id.channel.clone(),
chat_id: unified_id.chat_id.clone(), chat_id: unified_id.chat_id.clone(),
dialog_id: unified_id.dialog_id.clone(),
}, },
reply_tx, reply_tx,
}).await?; })
.await?;
match reply_rx.recv().await { match reply_rx.recv().await {
Some(Ok(SessionEvent::CurrentDialog { session_id: current_session_id_opt })) => { Some(Ok(SessionEvent::DialogSwitched { session_id })) => {
if let Some(current_session_id) = current_session_id_opt { *current_session_guard = Some(session_id.to_string());
*current_session_guard = Some(current_session_id.to_string()); let _ = client
let _ = client .sender
.sender .send(WsOutbound::SessionLoaded {
.send(WsOutbound::SessionLoaded { session_id: session_id.to_string(),
session_id: current_session_id.to_string(), title: "Session".to_string(),
title: "Session".to_string(), // TODO: get actual title message_count: 0,
message_count: 0, // TODO: get actual count })
}) .await;
.await;
} else {
let _ = client
.sender
.send(WsOutbound::Error {
code: "NO_CURRENT_DIALOG".to_string(),
message: "No current dialog".to_string(),
})
.await;
}
} }
Some(Ok(_)) => { Some(Ok(_)) => {
// Unexpected response type // Unexpected response type
@ -275,23 +325,30 @@ impl CliChatChannel {
} }
} }
WsInbound::RenameSession { session_id, title } => { WsInbound::RenameSession { session_id, title } => {
let target = session_id.or(current_session_guard.clone()).ok_or_else(|| { let target = session_id
ChannelError::Other("No active session".to_string()) .or(current_session_guard.clone())
})?; .ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
let (reply_tx, mut reply_rx) = mpsc::channel(1); let (reply_tx, mut reply_rx) = mpsc::channel(1);
let unified_id = UnifiedSessionId::parse(&target) let unified_id = UnifiedSessionId::parse(&target)
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?; .ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
bus.publish_control(ControlMessage { bus.publish_control(ControlMessage {
op: SessionCommand::RenameDialog { session_id: unified_id, title: title.clone() }, op: SessionCommand::RenameDialog {
session_id: unified_id,
title: title.clone(),
},
reply_tx, reply_tx,
}).await?; })
.await?;
match reply_rx.recv().await { match reply_rx.recv().await {
Some(Ok(SessionEvent::DialogRenamed { session_id, title })) => { Some(Ok(SessionEvent::DialogRenamed { session_id, title })) => {
let _ = client let _ = client
.sender .sender
.send(WsOutbound::SessionRenamed { session_id: session_id.to_string(), title }) .send(WsOutbound::SessionRenamed {
session_id: session_id.to_string(),
title,
})
.await; .await;
} }
Some(Ok(_)) => { Some(Ok(_)) => {
@ -306,24 +363,43 @@ impl CliChatChannel {
} }
} }
WsInbound::ArchiveSession { session_id } => { WsInbound::ArchiveSession { session_id } => {
let target = session_id.or(current_session_guard.clone()).ok_or_else(|| { let target = session_id
ChannelError::Other("No active session".to_string()) .or(current_session_guard.clone())
})?; .ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
let was_current = current_session_guard.as_deref() == Some(&target);
let (reply_tx, mut reply_rx) = mpsc::channel(1); let (reply_tx, mut reply_rx) = mpsc::channel(1);
let unified_id = UnifiedSessionId::parse(&target) let unified_id = UnifiedSessionId::parse(&target)
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?; .ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
bus.publish_control(ControlMessage { bus.publish_control(ControlMessage {
op: SessionCommand::ArchiveDialog { session_id: unified_id }, op: SessionCommand::ArchiveDialog {
session_id: unified_id,
},
reply_tx, reply_tx,
}).await?; })
.await?;
match reply_rx.recv().await { match reply_rx.recv().await {
Some(Ok(SessionEvent::DialogArchived { session_id })) => { Some(Ok(SessionEvent::DialogArchived { session_id })) => {
let _ = client let _ = client
.sender .sender
.send(WsOutbound::SessionArchived { session_id: session_id.to_string() }) .send(WsOutbound::SessionArchived {
session_id: session_id.to_string(),
})
.await; .await;
if was_current {
let (new_id, title) = self
.create_session_via_control(&client.chat_id, None)
.await?;
*current_session_guard = Some(new_id.clone());
let _ = client
.sender
.send(WsOutbound::SessionCreated {
session_id: new_id,
title,
})
.await;
}
} }
Some(Ok(_)) => { Some(Ok(_)) => {
// Unexpected response type // Unexpected response type
@ -337,35 +413,42 @@ impl CliChatChannel {
} }
} }
WsInbound::DeleteSession { session_id } => { WsInbound::DeleteSession { session_id } => {
let target = session_id.or(current_session_guard.clone()).ok_or_else(|| { let target = session_id
ChannelError::Other("No active session".to_string()) .or(current_session_guard.clone())
})?; .ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
let (reply_tx, mut reply_rx) = mpsc::channel(1); let (reply_tx, mut reply_rx) = mpsc::channel(1);
let unified_id = UnifiedSessionId::parse(&target) let unified_id = UnifiedSessionId::parse(&target)
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?; .ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
bus.publish_control(ControlMessage { bus.publish_control(ControlMessage {
op: SessionCommand::DeleteDialog { session_id: unified_id }, op: SessionCommand::DeleteDialog {
session_id: unified_id,
},
reply_tx, reply_tx,
}).await?; })
.await?;
match reply_rx.recv().await { match reply_rx.recv().await {
Some(Ok(SessionEvent::DialogDeleted { session_id })) => { Some(Ok(SessionEvent::DialogDeleted { session_id })) => {
let _ = client let _ = client
.sender .sender
.send(WsOutbound::SessionDeleted { session_id: session_id.to_string() }) .send(WsOutbound::SessionDeleted {
session_id: session_id.to_string(),
})
.await; .await;
// If deleting current session, create a new one // If deleting current session, create a new one
if current_session_guard.as_deref() == Some(&target) { if current_session_guard.as_deref() == Some(&target) {
drop(reply_rx); drop(reply_rx);
if let Ok(new_id) = self.create_session_via_control(&target, None).await { if let Ok((new_id, title)) =
self.create_session_via_control(&client.chat_id, None).await
{
*current_session_guard = Some(new_id.clone()); *current_session_guard = Some(new_id.clone());
let _ = client let _ = client
.sender .sender
.send(WsOutbound::SessionCreated { .send(WsOutbound::SessionCreated {
session_id: new_id, session_id: new_id,
title: String::new(), title,
}) })
.await; .await;
} }
@ -388,32 +471,45 @@ impl CliChatChannel {
bus.publish_control(ControlMessage { bus.publish_control(ControlMessage {
op: SessionCommand::GetSlashCommands { op: SessionCommand::GetSlashCommands {
channel: "cli_chat".to_string(), channel: "cli_chat".to_string(),
chat_id: "".to_string(), chat_id: client.chat_id.clone(),
}, },
reply_tx, reply_tx,
}).await?; })
.await?;
if let Some(result) = reply_rx.recv().await { if let Some(result) = reply_rx.recv().await {
match result { match result {
Ok(SessionEvent::SlashCommandsList { commands }) => { Ok(SessionEvent::SlashCommandsList { commands }) => {
// Convert to SlashCommand to SlashCommandInfo // Convert to SlashCommand to SlashCommandInfo
let command_infos: Vec<SlashCommandInfo> = commands.into_iter().map(|cmd| { let command_infos: Vec<SlashCommandInfo> = commands
SlashCommandInfo { .into_iter()
.map(|cmd| SlashCommandInfo {
name: cmd.name.to_string(), name: cmd.name.to_string(),
description: cmd.description.to_string(), description: cmd.description.to_string(),
aliases: cmd.aliases.iter().map(|&a| a.to_string()).collect(), aliases: cmd.aliases.iter().map(|&a| a.to_string()).collect(),
} })
}).collect(); .collect();
let _ = client.sender.send(WsOutbound::SlashCommandsList { commands: command_infos }).await; let _ = client
.sender
.send(WsOutbound::SlashCommandsList {
commands: command_infos,
})
.await;
} }
Ok(SessionEvent::Error { code, message }) => { Ok(SessionEvent::Error { code, message }) => {
let _ = client.sender.send(WsOutbound::Error { code, message }).await; let _ = client
.sender
.send(WsOutbound::Error { code, message })
.await;
} }
Err(e) => { Err(e) => {
let _ = client.sender.send(WsOutbound::Error { let _ = client
code: "GET_COMMANDS_ERROR".to_string(), .sender
message: e.to_string() .send(WsOutbound::Error {
}).await; code: "GET_COMMANDS_ERROR".to_string(),
message: e.to_string(),
})
.await;
} }
_ => {} _ => {}
} }
@ -427,29 +523,34 @@ impl CliChatChannel {
} }
/// Create a session via control message and return the session_id /// Create a session via control message and return the session_id
async fn create_session_via_control(&self, connection_id: &str, title: Option<&str>) -> Result<String, ChannelError> { async fn create_session_via_control(
&self,
chat_id: &str,
title: Option<&str>,
) -> Result<(String, String), ChannelError> {
let bus = { let bus = {
let guard = self.bus.lock().unwrap(); let guard = self.bus.lock().unwrap();
guard.clone().ok_or_else(|| ChannelError::Other("Channel not started".to_string()))? guard
.clone()
.ok_or_else(|| ChannelError::Other("Channel not started".to_string()))?
}; };
let (reply_tx, mut reply_rx) = mpsc::channel(1); let (reply_tx, mut reply_rx) = mpsc::channel(1);
bus.publish_control(ControlMessage { bus.publish_control(ControlMessage {
op: SessionCommand::CreateDialog { op: SessionCommand::CreateDialog {
channel: "cli_chat".to_string(), channel: "cli_chat".to_string(),
chat_id: connection_id.to_string(), chat_id: chat_id.to_string(),
title: title.map(String::from), title: title.map(String::from),
}, },
reply_tx, reply_tx,
}).await?; })
.await?;
match reply_rx.recv().await { match reply_rx.recv().await {
Some(Ok(SessionEvent::DialogCreated { session_id, .. })) => { Some(Ok(SessionEvent::DialogCreated { session_id, title })) => {
Ok(session_id.to_string()) Ok((session_id.to_string(), title))
}
Some(Ok(_)) => {
Err(ChannelError::Other("Unexpected response type".to_string()))
} }
Some(Ok(_)) => Err(ChannelError::Other("Unexpected response type".to_string())),
Some(Err(e)) => Err(e), Some(Err(e)) => Err(e),
None => Err(ChannelError::Other("Control channel closed".to_string())), None => Err(ChannelError::Other("Control channel closed".to_string())),
} }
@ -479,7 +580,11 @@ impl Channel for CliChatChannel {
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> { async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
let clients = self.clients.lock().await.clone(); let clients = self.clients.lock().await.clone();
for client in clients { for client in clients {
let outbound = if msg.metadata.get("_type").map(|v| v.as_str()) == Some("notification") { if client.chat_id != msg.chat_id {
continue;
}
let outbound = if msg.metadata.get("_type").map(|v| v.as_str()) == Some("notification")
{
WsOutbound::SystemNotification { WsOutbound::SystemNotification {
content: msg.content.clone(), content: msg.content.clone(),
} }

File diff suppressed because it is too large Load Diff

View File

@ -24,7 +24,10 @@ impl ChannelManager {
} }
} }
pub fn with_bus(cli_chat_channel: Arc<crate::channels::CliChatChannel>, bus: Arc<MessageBus>) -> Self { pub fn with_bus(
cli_chat_channel: Arc<crate::channels::CliChatChannel>,
bus: Arc<MessageBus>,
) -> Self {
Self { Self {
channels: Arc::new(RwLock::new(HashMap::new())), channels: Arc::new(RwLock::new(HashMap::new())),
cli_chat_channel, cli_chat_channel,
@ -39,7 +42,10 @@ impl ChannelManager {
/// Register a channel with the manager /// Register a channel with the manager
pub async fn register_channel(&self, name: &str, channel: Arc<dyn Channel + Send + Sync>) { pub async fn register_channel(&self, name: &str, channel: Arc<dyn Channel + Send + Sync>) {
self.channels.write().await.insert(name.to_string(), channel); self.channels
.write()
.await
.insert(name.to_string(), channel);
} }
/// Get CLI chat channel /// Get CLI chat channel
@ -56,14 +62,19 @@ impl ChannelManager {
// Initialize Feishu channel if enabled // Initialize Feishu channel if enabled
if let Some(feishu_config) = config.channels.get("feishu") { if let Some(feishu_config) = config.channels.get("feishu") {
if feishu_config.enabled { if feishu_config.enabled {
let channel = FeishuChannel::new(feishu_config.clone(), &workspace_dir) let channel =
.map_err(|e| ChannelError::Other(format!("Failed to create Feishu channel: {}", e)))?; FeishuChannel::new(feishu_config.clone(), &workspace_dir).map_err(|e| {
ChannelError::Other(format!("Failed to create Feishu channel: {}", e))
})?;
self.channels self.channels
.write() .write()
.await .await
.insert("feishu".to_string(), Arc::new(channel)); .insert("feishu".to_string(), Arc::new(channel));
tracing::info!("Feishu channel registered (media_dir: {}/media/feishu)", workspace_dir.display()); tracing::info!(
"Feishu channel registered (media_dir: {}/media/feishu)",
workspace_dir.display()
);
} else { } else {
tracing::info!("Feishu channel disabled in config"); tracing::info!("Feishu channel disabled in config");
} }
@ -118,7 +129,10 @@ impl ChannelManager {
if let Some(channel) = self.get_channel(channel_name).await { if let Some(channel) = self.get_channel(channel_name).await {
channel.send(msg).await channel.send(msg).await
} else { } else {
Err(ChannelError::Other(format!("Channel not found: {}", channel_name))) Err(ChannelError::Other(format!(
"Channel not found: {}",
channel_name
)))
} }
} }
} }

View File

@ -1,11 +1,11 @@
pub mod base; pub mod base;
pub mod feishu;
pub mod cli_chat; pub mod cli_chat;
pub mod feishu;
pub mod manager; pub mod manager;
pub mod slash_command; pub mod slash_command;
pub use base::{Channel, ChannelError}; pub use base::{Channel, ChannelError};
pub use manager::ChannelManager;
pub use feishu::FeishuChannel;
pub use cli_chat::CliChatChannel; pub use cli_chat::CliChatChannel;
pub use slash_command::{parse_slash_command, command_matches}; pub use feishu::FeishuChannel;
pub use manager::ChannelManager;
pub use slash_command::{command_matches, parse_slash_command};

View File

@ -16,7 +16,9 @@ pub fn parse_slash_command(content: &str) -> Option<(&str, &str)> {
/// 检查内容是否匹配指定命令 /// 检查内容是否匹配指定命令
pub fn command_matches(content: &str, aliases: &[&str]) -> bool { pub fn command_matches(content: &str, aliases: &[&str]) -> bool {
let trimmed = content.trim(); let trimmed = content.trim();
aliases.iter().any(|&alias| trimmed == alias || trimmed.starts_with(&format!("{} ", alias))) aliases
.iter()
.any(|&alias| trimmed == alias || trimmed.starts_with(&format!("{} ", alias)))
} }
#[cfg(test)] #[cfg(test)]
@ -27,7 +29,10 @@ mod tests {
fn test_parse_slash_command() { fn test_parse_slash_command() {
assert_eq!(parse_slash_command("/reset"), Some(("reset", ""))); assert_eq!(parse_slash_command("/reset"), Some(("reset", "")));
assert_eq!(parse_slash_command("/reset arg"), Some(("reset", "arg"))); assert_eq!(parse_slash_command("/reset arg"), Some(("reset", "arg")));
assert_eq!(parse_slash_command("/new hello world"), Some(("new", "hello world"))); assert_eq!(
parse_slash_command("/new hello world"),
Some(("new", "hello world"))
);
assert_eq!(parse_slash_command("/??"), Some(("??", ""))); assert_eq!(parse_slash_command("/??"), Some(("??", "")));
assert_eq!(parse_slash_command("/? arg"), Some(("?", "arg"))); assert_eq!(parse_slash_command("/? arg"), Some(("?", "arg")));
assert_eq!(parse_slash_command("/?"), Some(("?", ""))); assert_eq!(parse_slash_command("/?"), Some(("?", "")));

View File

@ -8,10 +8,10 @@ use crate::client::tui::ui::render_ui;
use crossterm::{ use crossterm::{
event::{self, Event}, event::{self, Event},
execute, execute,
terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen}, terminal::{EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, enable_raw_mode},
}; };
use futures_util::{SinkExt, StreamExt}; use futures_util::{SinkExt, StreamExt};
use ratatui::{prelude::CrosstermBackend, Terminal}; use ratatui::{Terminal, prelude::CrosstermBackend};
use std::io; use std::io;
use tokio_tungstenite::{connect_async, tungstenite::Message}; use tokio_tungstenite::{connect_async, tungstenite::Message};
@ -104,7 +104,10 @@ async fn handle_ws_message(app: &mut App, outbound: WsOutbound) {
WsOutbound::SessionCreated { session_id, .. } => { WsOutbound::SessionCreated { session_id, .. } => {
app.set_current_session(Some(session_id)); app.set_current_session(Some(session_id));
} }
WsOutbound::SessionList { sessions, current_session_id } => { WsOutbound::SessionList {
sessions,
current_session_id,
} => {
app.set_sessions(sessions); app.set_sessions(sessions);
if let Some(id) = current_session_id { if let Some(id) = current_session_id {
app.set_current_session(Some(id)); app.set_current_session(Some(id));

View File

@ -1,10 +1,10 @@
use crate::client::tui::app::{App, MessageRole}; use crate::client::tui::app::{App, MessageRole};
use ratatui::{ use ratatui::{
Frame,
layout::Rect, layout::Rect,
style::{Color, Modifier, Style}, style::{Color, Modifier, Style},
text::Line, text::Line,
widgets::{Block, Borders, List, ListItem}, widgets::{Block, Borders, List, ListItem},
Frame,
}; };
pub fn render(f: &mut Frame, area: Rect, app: &App) { pub fn render(f: &mut Frame, area: Rect, app: &App) {

View File

@ -1,10 +1,10 @@
use crate::client::tui::app::App; use crate::client::tui::app::App;
use ratatui::{ use ratatui::{
Frame,
layout::Rect, layout::Rect,
style::{Color, Modifier, Style}, style::{Color, Modifier, Style},
text::{Line, Span}, text::{Line, Span},
widgets::{Block, Borders, List, ListItem}, widgets::{Block, Borders, List, ListItem},
Frame,
}; };
pub fn render(f: &mut Frame, area: Rect, app: &App) { pub fn render(f: &mut Frame, area: Rect, app: &App) {

View File

@ -1,8 +1,8 @@
use ratatui::{ use ratatui::{
Frame,
layout::Rect, layout::Rect,
style::{Color, Modifier, Style}, style::{Color, Modifier, Style},
widgets::{Block, Borders, Clear, List, ListItem}, widgets::{Block, Borders, Clear, List, ListItem},
Frame,
}; };
pub fn render(f: &mut Frame, area: Rect) { pub fn render(f: &mut Frame, area: Rect) {

View File

@ -1,9 +1,9 @@
use crate::client::tui::app::App; use crate::client::tui::app::App;
use ratatui::{ use ratatui::{
Frame,
layout::Rect, layout::Rect,
style::{Color, Style}, style::{Color, Style},
widgets::{Block, Borders, Paragraph}, widgets::{Block, Borders, Paragraph},
Frame,
}; };
pub fn render(f: &mut Frame, area: Rect, app: &App) { pub fn render(f: &mut Frame, area: Rect, app: &App) {

View File

@ -1,9 +1,9 @@
use crate::client::tui::app::App; use crate::client::tui::app::App;
use ratatui::{ use ratatui::{
Frame,
layout::Rect, layout::Rect,
style::{Color, Modifier, Style}, style::{Color, Modifier, Style},
widgets::{Block, Borders, List, ListItem}, widgets::{Block, Borders, List, ListItem},
Frame,
}; };
pub fn render(f: &mut Frame, area: Rect, app: &App) { pub fn render(f: &mut Frame, area: Rect, app: &App) {
@ -11,9 +11,7 @@ pub fn render(f: &mut Frame, area: Rect, app: &App) {
.sessions .sessions
.iter() .iter()
.map(|session| { .map(|session| {
let is_current = app let is_current = app.current_session_id.as_ref() == Some(&session.session_id);
.current_session_id
.as_ref() == Some(&session.session_id);
let archived = session.archived_at.is_some(); let archived = session.archived_at.is_some();
let mut content = if is_current { let mut content = if is_current {

View File

@ -1,15 +1,18 @@
use crate::client::tui::app::App; use crate::client::tui::app::App;
use ratatui::{ use ratatui::{
Frame,
layout::Rect, layout::Rect,
style::{Color, Modifier, Style}, style::{Color, Modifier, Style},
widgets::{Block, Borders, Paragraph}, widgets::{Block, Borders, Paragraph},
Frame,
}; };
pub fn render(f: &mut Frame, area: Rect, app: &App) { pub fn render(f: &mut Frame, area: Rect, app: &App) {
let (title, style) = if app.pending_quit { let (title, style) = if app.pending_quit {
let msg = if let Some(session_id) = &app.current_session_id { let msg = if let Some(session_id) = &app.current_session_id {
format!("PicoBot | Session: {} | Press Ctrl+C again to quit", session_id) format!(
"PicoBot | Session: {} | Press Ctrl+C again to quit",
session_id
)
} else { } else {
"PicoBot | Press Ctrl+C again to quit".to_string() "PicoBot | Press Ctrl+C again to quit".to_string()
}; };

View File

@ -1,6 +1,6 @@
use crate::client::tui::app::{App, MessageRole}; use crate::client::tui::app::{App, MessageRole};
use crate::protocol::serialize_inbound;
use crate::protocol::WsInbound; use crate::protocol::WsInbound;
use crate::protocol::serialize_inbound;
use crossterm::event::{KeyCode, KeyEvent}; use crossterm::event::{KeyCode, KeyEvent};
use futures_util::SinkExt; use futures_util::SinkExt;
@ -48,7 +48,10 @@ pub async fn handle_key_event(app: &mut App, key: KeyEvent) {
async fn handle_normal_input(app: &mut App, key: KeyEvent) { async fn handle_normal_input(app: &mut App, key: KeyEvent) {
// Handle Ctrl+C for quit (double press to exit) // Handle Ctrl+C for quit (double press to exit)
let is_ctrl_c = key.code == KeyCode::Char('c') && key.modifiers.contains(crossterm::event::KeyModifiers::CONTROL); let is_ctrl_c = key.code == KeyCode::Char('c')
&& key
.modifiers
.contains(crossterm::event::KeyModifiers::CONTROL);
if is_ctrl_c { if is_ctrl_c {
if app.handle_ctrl_c_for_quit() { if app.handle_ctrl_c_for_quit() {
return; return;
@ -63,9 +66,11 @@ async fn handle_normal_input(app: &mut App, key: KeyEvent) {
} }
KeyCode::Char(c) => { KeyCode::Char(c) => {
app.input_insert_char(c); app.input_insert_char(c);
// Show command menu when input starts with / // Show command menu when input starts with /
if !app.show_command_menu && (app.input == "/" || (app.input.len() > 1 && app.input.starts_with('/'))) { if !app.show_command_menu
&& (app.input == "/" || (app.input.len() > 1 && app.input.starts_with('/')))
{
app.show_command_menu = true; app.show_command_menu = true;
app.selected_command_idx = 0; app.selected_command_idx = 0;
} else if app.show_command_menu && !app.input.starts_with('/') { } else if app.show_command_menu && !app.input.starts_with('/') {
@ -74,7 +79,7 @@ async fn handle_normal_input(app: &mut App, key: KeyEvent) {
} }
KeyCode::Backspace => { KeyCode::Backspace => {
app.input_delete_char(); app.input_delete_char();
// Hide menu if input no longer starts with / // Hide menu if input no longer starts with /
if app.show_command_menu && !app.input.starts_with('/') { if app.show_command_menu && !app.input.starts_with('/') {
app.show_command_menu = false; app.show_command_menu = false;
@ -121,7 +126,9 @@ async fn process_input(app: &mut App, input: String) {
sender_id: None, sender_id: None,
}; };
if let Ok(text) = serialize_inbound(&inbound) { if let Ok(text) = serialize_inbound(&inbound) {
let _ = sender.send(tokio_tungstenite::tungstenite::Message::Text(text.into())).await; let _ = sender
.send(tokio_tungstenite::tungstenite::Message::Text(text.into()))
.await;
} }
} }
} }

View File

@ -1,8 +1,8 @@
use crate::client::tui::app::App; use crate::client::tui::app::App;
use crate::client::tui::components::*; use crate::client::tui::components::*;
use ratatui::{ use ratatui::{
layout::{Constraint, Direction, Layout, Rect},
Frame, Frame,
layout::{Constraint, Direction, Layout, Rect},
}; };
pub fn render_ui(f: &mut Frame, app: &App) { pub fn render_ui(f: &mut Frame, app: &App) {

View File

@ -152,10 +152,26 @@ pub struct GatewayConfig {
pub cleanup_interval_minutes: Option<u64>, pub cleanup_interval_minutes: Option<u64>,
#[serde(default, rename = "session_db_path")] #[serde(default, rename = "session_db_path")]
pub session_db_path: Option<String>, pub session_db_path: Option<String>,
#[serde(default, rename = "max_concurrent_background_tasks")]
pub max_concurrent_background_tasks: usize,
#[serde(default)] #[serde(default)]
pub scheduler: Option<SchedulerConfig>, pub scheduler: Option<SchedulerConfig>,
} }
impl Default for GatewayConfig {
fn default() -> Self {
Self {
host: default_gateway_host(),
port: default_gateway_port(),
session_ttl_hours: None,
cleanup_interval_minutes: None,
session_db_path: None,
max_concurrent_background_tasks: 10,
scheduler: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SchedulerConfig { pub struct SchedulerConfig {
/// Whether the scheduler is enabled /// Whether the scheduler is enabled
@ -209,19 +225,6 @@ fn default_gateway_url() -> String {
"ws://127.0.0.1:19876/ws".to_string() "ws://127.0.0.1:19876/ws".to_string()
} }
impl Default for GatewayConfig {
fn default() -> Self {
Self {
host: default_gateway_host(),
port: default_gateway_port(),
session_ttl_hours: None,
cleanup_interval_minutes: None,
session_db_path: None,
scheduler: None,
}
}
}
impl Default for ClientConfig { impl Default for ClientConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
@ -270,12 +273,16 @@ impl Default for MemoryConfig {
impl MemoryConfig { impl MemoryConfig {
/// Resolve consolidation provider name, falling back to the main agent's provider. /// Resolve consolidation provider name, falling back to the main agent's provider.
pub fn resolve_consolidation_provider(&self, default: &str) -> String { pub fn resolve_consolidation_provider(&self, default: &str) -> String {
self.consolidation_provider.clone().unwrap_or_else(|| default.to_string()) self.consolidation_provider
.clone()
.unwrap_or_else(|| default.to_string())
} }
/// Resolve consolidation model name, falling back to the main agent's model. /// Resolve consolidation model name, falling back to the main agent's model.
pub fn resolve_consolidation_model(&self, default: &str) -> String { pub fn resolve_consolidation_model(&self, default: &str) -> String {
self.consolidation_model.clone().unwrap_or_else(|| default.to_string()) self.consolidation_model
.clone()
.unwrap_or_else(|| default.to_string())
} }
} }
@ -363,10 +370,18 @@ impl Default for BrowserConfig {
} }
} }
fn default_recall_limit() -> usize { 5 } fn default_recall_limit() -> usize {
fn default_idle_consolidation_minutes() -> u64 { 10 } 5
fn default_timeline_retention_days() -> u64 { 90 } }
fn default_max_failures_before_degrade() -> usize { 3 } fn default_idle_consolidation_minutes() -> u64 {
10
}
fn default_timeline_retention_days() -> u64 {
90
}
fn default_max_failures_before_degrade() -> usize {
3
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct LLMProviderConfig { pub struct LLMProviderConfig {
@ -466,7 +481,11 @@ pub enum ConfigError {
impl std::fmt::Display for ConfigError { impl std::fmt::Display for ConfigError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
ConfigError::ConfigNotFound(path) => write!(f, "Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json", path), ConfigError::ConfigNotFound(path) => write!(
f,
"Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json",
path
),
ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name), ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name),
ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name), ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name),
ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name), ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name),

View File

@ -1,19 +1,19 @@
pub mod http; pub mod http;
pub mod ws; pub mod ws;
use axum::{Router, routing};
use std::sync::Arc; use std::sync::Arc;
use axum::{routing, Router};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use crate::bus::{ControlMessage, MessageBus, OutboundDispatcher}; use crate::bus::{ControlMessage, MessageBus, OutboundDispatcher};
use crate::channels::{ChannelManager, CliChatChannel};
use crate::channels::base::{Channel, ChannelError}; use crate::channels::base::{Channel, ChannelError};
use crate::config::{Config, expand_path, ensure_workspace_dir}; use crate::channels::{ChannelManager, CliChatChannel};
use crate::config::{Config, ensure_workspace_dir, expand_path};
use crate::logging; use crate::logging;
use crate::mcp; use crate::mcp;
use crate::memory::MemoryManager; use crate::memory::MemoryManager;
use crate::session::SessionManager;
use crate::scheduler::Scheduler; use crate::scheduler::Scheduler;
use crate::session::SessionManager;
pub struct GatewayState { pub struct GatewayState {
pub config: Config, pub config: Config,
@ -32,8 +32,13 @@ impl GatewayState {
let workspace_path = ensure_workspace_dir(&workspace_path)?; let workspace_path = ensure_workspace_dir(&workspace_path)?;
// Switch current working directory to workspace // Switch current working directory to workspace
std::env::set_current_dir(&workspace_path) std::env::set_current_dir(&workspace_path).map_err(|e| {
.map_err(|e| format!("Failed to switch to workspace directory {}: {}", workspace_path.display(), e))?; format!(
"Failed to switch to workspace directory {}: {}",
workspace_path.display(),
e
)
})?;
tracing::info!("Using workspace directory: {}", workspace_path.display()); tracing::info!("Using workspace directory: {}", workspace_path.display());
@ -52,8 +57,9 @@ impl GatewayState {
workspace_path.join("picobot.db") workspace_path.join("picobot.db")
}; };
let storage = Arc::new( let storage = Arc::new(
crate::storage::Storage::new(&db_path).await crate::storage::Storage::new(&db_path)
.map_err(|e| format!("failed to initialize session storage: {}", e))? .await
.map_err(|e| format!("failed to initialize session storage: {}", e))?,
); );
tracing::info!("Session storage: {}", db_path.display()); tracing::info!("Session storage: {}", db_path.display());
@ -91,13 +97,16 @@ impl GatewayState {
bus.clone(), bus.clone(),
memory_manager, memory_manager,
browser_config, browser_config,
config.gateway.max_concurrent_background_tasks,
)?; )?;
let session_manager = Arc::new(session_manager); let session_manager = Arc::new(session_manager);
// Create ChannelManager and init channels // Create ChannelManager and init channels
let cli_chat_channel = Arc::new(CliChatChannel::new()); let cli_chat_channel = Arc::new(CliChatChannel::new());
let channel_manager = ChannelManager::with_bus(cli_chat_channel, bus); let channel_manager = ChannelManager::with_bus(cli_chat_channel, bus);
channel_manager.init(&config, workspace_path.clone()).await channel_manager
.init(&config, workspace_path.clone())
.await
.map_err(|e| format!("Failed to init channels: {}", e))?; .map_err(|e| format!("Failed to init channels: {}", e))?;
// Register send_message tool with available channel names // Register send_message tool with available channel names
@ -106,9 +115,12 @@ impl GatewayState {
session_manager.register_outbound_tool(available_channels); session_manager.register_outbound_tool(available_channels);
// Register chat_manager tool // Register chat_manager tool
session_manager.tools().register( session_manager
crate::tools::ChatManagerTool::new(storage.clone(), valid_channels.clone()), .tools()
); .register(crate::tools::ChatManagerTool::new(
storage.clone(),
valid_channels.clone(),
));
// Initialize MCP servers — connect and register discovered tools // Initialize MCP servers — connect and register discovered tools
if !config.mcp.servers.is_empty() { if !config.mcp.servers.is_empty() {
@ -129,24 +141,27 @@ impl GatewayState {
let scheduler_config = config.gateway.scheduler.clone().unwrap_or_default(); let scheduler_config = config.gateway.scheduler.clone().unwrap_or_default();
if scheduler_config.enabled { if scheduler_config.enabled {
// Register cron tools // Register cron tools
session_manager.tools().register( session_manager
crate::tools::cron::CronAddTool::new(storage.clone(), valid_channels), .tools()
); .register(crate::tools::cron::CronAddTool::new(
session_manager.tools().register( storage.clone(),
crate::tools::cron::CronListTool::new(storage.clone()), valid_channels,
); ));
session_manager.tools().register( session_manager
crate::tools::cron::CronRemoveTool::new(storage.clone()), .tools()
); .register(crate::tools::cron::CronListTool::new(storage.clone()));
session_manager.tools().register( session_manager
crate::tools::cron::CronEnableTool::new(storage.clone()), .tools()
); .register(crate::tools::cron::CronRemoveTool::new(storage.clone()));
session_manager.tools().register( session_manager
crate::tools::cron::CronDisableTool::new(storage.clone()), .tools()
); .register(crate::tools::cron::CronEnableTool::new(storage.clone()));
session_manager.tools().register( session_manager
crate::tools::cron::CronUpdateTool::new(storage.clone()), .tools()
); .register(crate::tools::cron::CronDisableTool::new(storage.clone()));
session_manager
.tools()
.register(crate::tools::cron::CronUpdateTool::new(storage.clone()));
tracing::info!("Cron tools registered"); tracing::info!("Cron tools registered");
} }
@ -267,71 +282,103 @@ impl GatewayState {
} }
/// Handle control messages (session management operations) /// Handle control messages (session management operations)
async fn handle_control_message( async fn handle_control_message(session_manager: &SessionManager, msg: ControlMessage) {
session_manager: &SessionManager,
msg: ControlMessage,
) {
use crate::session::{SessionCommand::*, SessionEvent}; use crate::session::{SessionCommand::*, SessionEvent};
let reply_tx = msg.reply_tx; let reply_tx = msg.reply_tx;
let result: Result<SessionEvent, ChannelError> = match msg.op { let result: Result<SessionEvent, ChannelError> = match msg.op {
CreateDialog { channel, chat_id, title } => { CreateDialog {
session_manager.create_dialog(&channel, &chat_id, title.as_deref()).await channel,
.map(|(session_id, title)| SessionEvent::DialogCreated { session_id, title }) chat_id,
.map_err(|e| ChannelError::Other(e.to_string())) title,
} } => session_manager
ListDialogs { channel, chat_id, include_archived } => { .create_dialog(&channel, &chat_id, title.as_deref())
session_manager.list_dialogs(&channel, &chat_id, include_archived).await .await
.map(|(dialogs, current_dialog_id)| SessionEvent::DialogList { dialogs, current_dialog_id }) .map(|(session_id, title)| SessionEvent::DialogCreated { session_id, title })
.map_err(|e| ChannelError::Other(e.to_string())) .map_err(|e| ChannelError::Other(e.to_string())),
} ListDialogs {
GetCurrentDialog { channel, chat_id } => { channel,
session_manager.get_current_dialog(&channel, &chat_id).await chat_id,
.map(|session_id| SessionEvent::CurrentDialog { session_id }) include_archived,
.map_err(|e| ChannelError::Other(e.to_string())) } => session_manager
} .list_dialogs(&channel, &chat_id, include_archived)
SwitchDialog { channel, chat_id, dialog_id } => { .await
session_manager.switch_dialog(&channel, &chat_id, &dialog_id).await .map(|(dialogs, current_dialog_id)| SessionEvent::DialogList {
.map(|session_id| SessionEvent::DialogSwitched { session_id }) dialogs,
.map_err(|e| ChannelError::Other(e.to_string())) current_dialog_id,
} })
RenameDialog { session_id, title } => { .map_err(|e| ChannelError::Other(e.to_string())),
session_manager.rename_dialog(&session_id, &title).await GetCurrentDialog { channel, chat_id } => session_manager
.map(|()| SessionEvent::DialogRenamed { session_id, title }) .get_current_dialog(&channel, &chat_id)
.map_err(|e| ChannelError::Other(e.to_string())) .await
} .map(|session_id| SessionEvent::CurrentDialog { session_id })
ArchiveDialog { session_id } => { .map_err(|e| ChannelError::Other(e.to_string())),
session_manager.archive_dialog(&session_id) SwitchDialog {
.map(|()| SessionEvent::DialogArchived { session_id }) channel,
.map_err(|e| ChannelError::Other(e.to_string())) chat_id,
} dialog_id,
DeleteDialog { session_id } => { } => session_manager
session_manager.delete_dialog(&session_id).await .switch_dialog(&channel, &chat_id, &dialog_id)
.map(|()| SessionEvent::DialogDeleted { session_id }) .await
.map_err(|e| ChannelError::Other(e.to_string())) .map(|session_id| SessionEvent::DialogSwitched { session_id })
} .map_err(|e| ChannelError::Other(e.to_string())),
ClearHistory { session_id } => { RenameDialog { session_id, title } => session_manager
session_manager.clear_dialog_history(&session_id) .rename_dialog(&session_id, &title)
.map(|()| SessionEvent::HistoryCleared { session_id }) .await
.map_err(|e| ChannelError::Other(e.to_string())) .map(|()| SessionEvent::DialogRenamed { session_id, title })
} .map_err(|e| ChannelError::Other(e.to_string())),
GetSlashCommands { channel: _, chat_id: _ } => { ArchiveDialog { session_id } => session_manager
.archive_dialog(&session_id)
.await
.map(|()| SessionEvent::DialogArchived { session_id })
.map_err(|e| ChannelError::Other(e.to_string())),
DeleteDialog { session_id } => session_manager
.delete_dialog(&session_id)
.await
.map(|()| SessionEvent::DialogDeleted { session_id })
.map_err(|e| ChannelError::Other(e.to_string())),
ClearHistory { session_id } => session_manager
.clear_dialog_history(&session_id)
.await
.map(|()| SessionEvent::HistoryCleared { session_id })
.map_err(|e| ChannelError::Other(e.to_string())),
GetSlashCommands {
channel: _,
chat_id: _,
} => {
let commands = session_manager.get_slash_commands().to_vec(); let commands = session_manager.get_slash_commands().to_vec();
Ok(SessionEvent::SlashCommandsList { commands }) Ok(SessionEvent::SlashCommandsList { commands })
} }
ExecuteSlashCommand { command, args, channel, chat_id, current_session_id } => { ExecuteSlashCommand {
session_manager.execute_slash_command(&command, args.as_deref(), &channel, &chat_id, current_session_id.as_ref()) command,
.await args,
.map(|(new_id, msg)| SessionEvent::SlashCommandExecuted { new_session_id: new_id, message: msg }) channel,
.map_err(|e| ChannelError::Other(e.to_string())) chat_id,
} current_session_id,
} => session_manager
.execute_slash_command(
&command,
args.as_deref(),
&channel,
&chat_id,
current_session_id.as_ref(),
)
.await
.map(|(new_id, msg)| SessionEvent::SlashCommandExecuted {
new_session_id: new_id,
message: msg,
})
.map_err(|e| ChannelError::Other(e.to_string())),
}; };
let _ = reply_tx.send(result).await; let _ = reply_tx.send(result).await;
} }
} }
pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn std::error::Error>> { pub async fn run(
host: Option<String>,
port: Option<u16>,
) -> Result<(), Box<dyn std::error::Error>> {
// Initialize logging // Initialize logging
logging::init_logging(); logging::init_logging();
tracing::info!("Starting PicoBot Gateway"); tracing::info!("Starting PicoBot Gateway");

View File

@ -1,12 +1,12 @@
use std::sync::Arc; use super::GatewayState;
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage}; use crate::protocol::WsOutbound;
use crate::protocol::serialize_outbound;
use axum::extract::State; use axum::extract::State;
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
use axum::response::Response; use axum::response::Response;
use futures_util::{SinkExt, StreamExt}; use futures_util::{SinkExt, StreamExt};
use std::sync::Arc;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use crate::protocol::serialize_outbound;
use crate::protocol::WsOutbound;
use super::GatewayState;
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response { pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
ws.on_upgrade(|socket| async move { ws.on_upgrade(|socket| async move {
@ -25,9 +25,11 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
let (session_id, client) = cli_chat_channel.register_client(sender.clone()).await; let (session_id, client) = cli_chat_channel.register_client(sender.clone()).await;
// Send session established message // Send session established message
let _ = sender.send(WsOutbound::SessionEstablished { let _ = sender
session_id: session_id.clone(), .send(WsOutbound::SessionEstablished {
}).await; session_id: session_id.clone(),
})
.await;
tracing::info!(session_id = %session_id, "CLI session established"); tracing::info!(session_id = %session_id, "CLI session established");
@ -37,9 +39,10 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
tokio::spawn(async move { tokio::spawn(async move {
while let Some(msg) = receiver.recv().await { while let Some(msg) = receiver.recv().await {
if let Ok(text) = serialize_outbound(&msg) if let Ok(text) = serialize_outbound(&msg)
&& ws_sender.send(WsMessage::Text(text.into())).await.is_err() { && ws_sender.send(WsMessage::Text(text.into())).await.is_err()
break; {
} break;
}
} }
}); });

View File

@ -1,17 +1,17 @@
pub mod config;
pub mod providers;
pub mod bus;
pub mod agent; pub mod agent;
pub mod gateway; pub mod bus;
pub mod session;
pub mod client;
pub mod protocol;
pub mod channels; pub mod channels;
pub mod client;
pub mod config;
pub mod gateway;
pub mod logging; pub mod logging;
pub mod mcp; pub mod mcp;
pub mod memory; pub mod memory;
pub mod observability; pub mod observability;
pub mod protocol;
pub mod providers;
pub mod scheduler; pub mod scheduler;
pub mod session;
pub mod skills; pub mod skills;
pub mod storage; pub mod storage;
pub mod tools; pub mod tools;

View File

@ -1,11 +1,7 @@
use std::path::PathBuf; use std::path::PathBuf;
use tracing_appender::rolling::{RollingFileAppender, Rotation}; use tracing_appender::rolling::{RollingFileAppender, Rotation};
use tracing_subscriber::{ use tracing_subscriber::{
fmt, EnvFilter, fmt, fmt::time::LocalTime, layer::SubscriberExt, util::SubscriberInitExt,
layer::SubscriberExt,
util::SubscriberInitExt,
fmt::time::LocalTime,
EnvFilter,
}; };
/// Get the default log directory path: ~/.picobot/logs /// Get the default log directory path: ~/.picobot/logs
@ -27,20 +23,20 @@ pub fn init_logging() {
// Create log directory if it doesn't exist // Create log directory if it doesn't exist
if !log_dir.exists() if !log_dir.exists()
&& let Err(e) = std::fs::create_dir_all(&log_dir) { && let Err(e) = std::fs::create_dir_all(&log_dir)
eprintln!("Warning: Failed to create log directory {}: {}", log_dir.display(), e); {
} eprintln!(
"Warning: Failed to create log directory {}: {}",
log_dir.display(),
e
);
}
// Create file appender with daily rotation // Create file appender with daily rotation
let file_appender = RollingFileAppender::new( let file_appender = RollingFileAppender::new(Rotation::DAILY, &log_dir, "picobot.log");
Rotation::DAILY,
&log_dir,
"picobot.log",
);
// Build subscriber with both console and file output // Build subscriber with both console and file output
let env_filter = EnvFilter::try_from_default_env() let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
.unwrap_or_else(|_| EnvFilter::new("info"));
let file_layer = fmt::layer() let file_layer = fmt::layer()
.with_writer(file_appender) .with_writer(file_appender)
@ -66,8 +62,7 @@ pub fn init_logging() {
/// Initialize logging without file output (console only) /// Initialize logging without file output (console only)
pub fn init_logging_console_only() { pub fn init_logging_console_only() {
let env_filter = EnvFilter::try_from_default_env() let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
.unwrap_or_else(|_| EnvFilter::new("info"));
let console_layer = fmt::layer() let console_layer = fmt::layer()
.with_timer(LocalTime::rfc_3339()) .with_timer(LocalTime::rfc_3339())

View File

@ -1,8 +1,9 @@
use clap::{Parser, CommandFactory}; use clap::{CommandFactory, Parser};
#[derive(Parser)] #[derive(Parser)]
#[command(name = "picobot")] #[command(name = "picobot")]
#[command(about = "A CLI chatbot", long_about = None)] #[command(about = "A CLI chatbot", long_about = None)]
#[command(version = "1.1.0")]
enum Command { enum Command {
/// Connect to gateway /// Connect to gateway
Chat { Chat {

View File

@ -92,24 +92,19 @@ fn extract_text(result: &rmcp::model::CallToolResult) -> String {
parts.push(text.text.clone()); parts.push(text.text.clone());
} }
RawContent::Image(image) => { RawContent::Image(image) => {
parts.push(format!( parts.push(format!("[image: {}]", image.mime_type,));
"[image: {}]",
image.mime_type,
));
} }
RawContent::Resource(resource) => { RawContent::Resource(resource) => match &resource.resource {
match &resource.resource { rmcp::model::ResourceContents::TextResourceContents { text, .. } => {
rmcp::model::ResourceContents::TextResourceContents { text, .. } => { parts.push(format!(
parts.push(format!( "[resource text: {}]",
"[resource text: {}]", text.chars().take(200).collect::<String>(),
text.chars().take(200).collect::<String>(), ));
));
}
rmcp::model::ResourceContents::BlobResourceContents { uri, .. } => {
parts.push(format!("[resource blob: {}]", uri));
}
} }
} rmcp::model::ResourceContents::BlobResourceContents { uri, .. } => {
parts.push(format!("[resource blob: {}]", uri));
}
},
_ => { _ => {
parts.push("[unsupported content]".to_string()); parts.push("[unsupported content]".to_string());
} }
@ -225,8 +220,8 @@ async fn connect_server(config: &McpServerConfig) -> anyhow::Result<McpConnectio
cmd.env(k, v); cmd.env(k, v);
} }
let service = () let service =
.serve( ().serve(
TokioChildProcess::new(cmd).context("failed to create stdio MCP transport")?, TokioChildProcess::new(cmd).context("failed to create stdio MCP transport")?,
) )
.await .await
@ -261,14 +256,14 @@ async fn connect_server(config: &McpServerConfig) -> anyhow::Result<McpConnectio
} else { } else {
StreamableHttpClientTransport::from_config( StreamableHttpClientTransport::from_config(
StreamableHttpClientTransportConfig::with_uri(url.to_string()) StreamableHttpClientTransportConfig::with_uri(url.to_string())
.custom_headers(headers_map) .custom_headers(headers_map),
) )
}; };
let service = () let service =
.serve(transport) ().serve(transport)
.await .await
.context("failed to connect to HTTP/SSE MCP server")?; .context("failed to connect to HTTP/SSE MCP server")?;
let peer = service.peer().clone(); let peer = service.peer().clone();

View File

@ -102,7 +102,11 @@ mod tests {
let dir = tempdir().unwrap(); let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db"); let db_path = dir.path().join("test.db");
let storage = Arc::new(Storage::new(&db_path).await.unwrap()); let storage = Arc::new(Storage::new(&db_path).await.unwrap());
let mm = Arc::new(MemoryManager::new(storage, "default".into(), "default".into())); let mm = Arc::new(MemoryManager::new(
storage,
"default".into(),
"default".into(),
));
(mm, dir) (mm, dir)
} }
@ -131,15 +135,9 @@ mod tests {
async fn test_upsert_overwrites() { async fn test_upsert_overwrites() {
let (mm, _dir) = setup_memory_manager().await; let (mm, _dir) = setup_memory_manager().await;
mm.store( mm.store("dup_key", "original", MemoryCategory::Knowledge, None, None)
"dup_key", .await
"original", .unwrap();
MemoryCategory::Knowledge,
None,
None,
)
.await
.unwrap();
mm.store( mm.store(
"dup_key", "dup_key",
"updated", "updated",
@ -247,7 +245,12 @@ mod tests {
// Recall scoped to session A — should get only tl_a // Recall scoped to session A — should get only tl_a
let scoped = mm let scoped = mm
.recall("summary", 10, Some(MemoryCategory::Timeline), Some("chan:chat:dialog_a")) .recall(
"summary",
10,
Some(MemoryCategory::Timeline),
Some("chan:chat:dialog_a"),
)
.await .await
.unwrap(); .unwrap();
assert_eq!(scoped.len(), 1); assert_eq!(scoped.len(), 1);

View File

@ -20,10 +20,7 @@ pub enum ObserverEvent {
success: bool, success: bool,
}, },
/// Emitted when the agent starts processing. /// Emitted when the agent starts processing.
AgentStart { AgentStart { provider: String, model: String },
provider: String,
model: String,
},
/// Emitted when the agent finishes processing. /// Emitted when the agent finishes processing.
AgentEnd { AgentEnd {
provider: String, provider: String,
@ -94,7 +91,11 @@ impl ToolExecutionOutcome {
} }
/// Create a failed outcome with duration. /// Create a failed outcome with duration.
pub fn failure_with_duration(output: String, error_reason: Option<String>, duration: Duration) -> Self { pub fn failure_with_duration(
output: String,
error_reason: Option<String>,
duration: Duration,
) -> Self {
Self { Self {
output, output,
success: false, success: false,

View File

@ -4,23 +4,24 @@ use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::time::Duration; use std::time::Duration;
use crate::bus::message::ContentBlock;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
use super::traits::Usage; use super::traits::Usage;
use std::sync::Arc; use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
use crate::bus::message::ContentBlock;
use crate::storage::Storage; use crate::storage::Storage;
use std::sync::Arc;
const LLM_REQUEST_TIMEOUT_SECS: u64 = 300; const LLM_REQUEST_TIMEOUT_SECS: u64 = 300;
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> { fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> {
blocks.iter().map(|b| match b { blocks
ContentBlock::Text { text } => { .iter()
serde_json::json!({ "type": "text", "text": text }) .map(|b| match b {
} ContentBlock::Text { text } => {
ContentBlock::ImageUrl { image_url } => { serde_json::json!({ "type": "text", "text": text })
convert_image_url_to_anthropic(&image_url.url) }
} ContentBlock::ImageUrl { image_url } => convert_image_url_to_anthropic(&image_url.url),
}).collect() })
.collect()
} }
fn convert_image_url_to_anthropic(url: &str) -> serde_json::Value { fn convert_image_url_to_anthropic(url: &str) -> serde_json::Value {
@ -197,8 +198,13 @@ impl LLMProvider for AnthropicProvider {
}; };
let content = if let Some(ref tc_id) = m.tool_call_id { let content = if let Some(ref tc_id) = m.tool_call_id {
// Tool result: wrap as tool_result content block // Tool result: wrap as tool_result content block
let output = m.content.iter() let output = m
.filter_map(|b| match b { ContentBlock::Text { text } => Some(text.as_str()), _ => None }) .content
.iter()
.filter_map(|b| match b {
ContentBlock::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(""); .join("");
vec![serde_json::json!({ vec![serde_json::json!({
@ -244,19 +250,18 @@ impl LLMProvider for AnthropicProvider {
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default(); let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
tracing::debug!(req_body = %req_body_str, "LLM request"); tracing::debug!(req_body = %req_body_str, "LLM request");
let resp = req_builder.json(&body).send().await let resp = req_builder.json(&body).send().await.inspect_err(|e| {
.inspect_err(|e| { let is_timeout = e.is_timeout();
let is_timeout = e.is_timeout(); tracing::error!(
tracing::error!( provider = %self.name,
provider = %self.name, model = %self.model_id,
model = %self.model_id, url = %url,
url = %url, timeout = is_timeout,
timeout = is_timeout, error = %e,
error = %e, elapsed_ms = %start.elapsed().as_millis(),
elapsed_ms = %start.elapsed().as_millis(), "LLM API request failed"
"LLM API request failed" );
); })?;
})?;
let status = resp.status(); let status = resp.status();
let body_text = resp.text().await?; let body_text = resp.text().await?;
@ -281,32 +286,38 @@ impl LLMProvider for AnthropicProvider {
"LLM API returned error" "LLM API returned error"
); );
if let Some(ref storage) = self.storage { if let Some(ref storage) = self.storage {
let _ = storage.append_llm_call( let _ = storage
&self.name, &self.model_id, &req_body_str, .append_llm_call(
Some(&body_text), Some(&error_msg), &self.name,
start.elapsed().as_millis() as u64, &self.model_id,
).await; &req_body_str,
Some(&body_text),
Some(&error_msg),
start.elapsed().as_millis() as u64,
)
.await;
} }
return Err(format!("API error ({}): {}", status.as_u16(), error_msg).into()); return Err(format!("API error ({}): {}", status.as_u16(), error_msg).into());
} }
let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text) let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text).map_err(|e| {
.map_err(|e| { let err_msg = format!("decode error: {} | body: {}", e, &body_text);
let err_msg = format!("decode error: {} | body: {}", e, &body_text); if let Some(ref storage) = self.storage {
if let Some(ref storage) = self.storage { let name = self.name.clone();
let name = self.name.clone(); let model = self.model_id.clone();
let model = self.model_id.clone(); let req = req_body_str.clone();
let req = req_body_str.clone(); let resp_body = body_text.clone();
let resp_body = body_text.clone(); let dur = start.elapsed().as_millis() as u64;
let dur = start.elapsed().as_millis() as u64; let err = err_msg.clone();
let err = err_msg.clone(); let s = storage.clone();
let s = storage.clone(); tokio::spawn(async move {
tokio::spawn(async move { let _ = s
let _ = s.append_llm_call(&name, &model, &req, Some(&resp_body), Some(&err), dur).await; .append_llm_call(&name, &model, &req, Some(&resp_body), Some(&err), dur)
}); .await;
} });
err_msg }
})?; err_msg
})?;
let mut content = String::new(); let mut content = String::new();
let mut reasoning = None; let mut reasoning = None;
@ -343,21 +354,35 @@ impl LLMProvider for AnthropicProvider {
reasoning_content: reasoning, reasoning_content: reasoning,
tool_calls, tool_calls,
usage: Usage { usage: Usage {
prompt_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens).unwrap_or(0), prompt_tokens: anthropic_resp
completion_tokens: anthropic_resp.usage.as_ref().map(|u| u.output_tokens).unwrap_or(0), .usage
total_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens + u.output_tokens).unwrap_or(0), .as_ref()
.map(|u| u.input_tokens)
.unwrap_or(0),
completion_tokens: anthropic_resp
.usage
.as_ref()
.map(|u| u.output_tokens)
.unwrap_or(0),
total_tokens: anthropic_resp
.usage
.as_ref()
.map(|u| u.input_tokens + u.output_tokens)
.unwrap_or(0),
}, },
}; };
if let Some(ref storage) = self.storage { if let Some(ref storage) = self.storage {
let _ = storage.append_llm_call( let _ = storage
&self.name, .append_llm_call(
&self.model_id, &self.name,
&req_body_str, &self.model_id,
Some(&body_text), &req_body_str,
None, Some(&body_text),
start.elapsed().as_millis() as u64, None,
).await; start.elapsed().as_millis() as u64,
)
.await;
} }
Ok(response) Ok(response)

View File

@ -1,12 +1,15 @@
pub mod traits;
pub mod openai;
pub mod anthropic; pub mod anthropic;
pub mod openai;
pub mod traits;
pub use self::openai::OpenAIProvider;
pub use self::anthropic::AnthropicProvider; pub use self::anthropic::AnthropicProvider;
pub use self::openai::OpenAIProvider;
use crate::config::LLMProviderConfig; use crate::config::LLMProviderConfig;
pub use traits::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall, ToolFunction, Usage}; pub use traits::{
ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall,
ToolFunction, Usage,
};
pub fn create_provider(config: LLMProviderConfig) -> Result<Box<dyn LLMProvider>, ProviderError> { pub fn create_provider(config: LLMProviderConfig) -> Result<Box<dyn LLMProvider>, ProviderError> {
match config.provider_type.as_str() { match config.provider_type.as_str() {

View File

@ -1,29 +1,35 @@
use async_trait::async_trait; use async_trait::async_trait;
use reqwest::Client; use reqwest::Client;
use serde::Deserialize; use serde::Deserialize;
use serde_json::{json, Value}; use serde_json::{Value, json};
use std::collections::HashMap; use std::collections::HashMap;
use std::time::Duration; use std::time::Duration;
use crate::bus::message::ContentBlock;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
use super::traits::Usage; use super::traits::Usage;
use std::sync::Arc; use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
use crate::bus::message::ContentBlock;
use crate::storage::Storage; use crate::storage::Storage;
use std::sync::Arc;
const LLM_REQUEST_TIMEOUT_SECS: u64 = 300; const LLM_REQUEST_TIMEOUT_SECS: u64 = 300;
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value { fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
if blocks.len() == 1 if blocks.len() == 1
&& let ContentBlock::Text { text } = &blocks[0] { && let ContentBlock::Text { text } = &blocks[0]
return Value::String(text.clone()); {
} return Value::String(text.clone());
Value::Array(blocks.iter().map(|b| match b { }
ContentBlock::Text { text } => json!({ "type": "text", "text": text }), Value::Array(
ContentBlock::ImageUrl { image_url } => { blocks
json!({ "type": "image_url", "image_url": { "url": image_url.url } }) .iter()
} .map(|b| match b {
}).collect()) ContentBlock::Text { text } => json!({ "type": "text", "text": text }),
ContentBlock::ImageUrl { image_url } => {
json!({ "type": "image_url", "image_url": { "url": image_url.url } })
}
})
.collect(),
)
} }
pub struct OpenAIProvider { pub struct OpenAIProvider {
@ -201,10 +207,14 @@ impl LLMProvider for OpenAIProvider {
if let Some(content) = msg.get("content").and_then(|c| c.as_array()) { if let Some(content) = msg.get("content").and_then(|c| c.as_array()) {
for (j, item) in content.iter().enumerate() { for (j, item) in content.iter().enumerate() {
if item.get("type").and_then(|t| t.as_str()) == Some("image_url") if item.get("type").and_then(|t| t.as_str()) == Some("image_url")
&& let Some(url_str) = item.get("image_url").and_then(|u| u.get("url")).and_then(|v| v.as_str()) { && let Some(url_str) = item
let prefix: String = url_str.chars().take(20).collect(); .get("image_url")
tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)"); .and_then(|u| u.get("url"))
} .and_then(|v| v.as_str())
{
let prefix: String = url_str.chars().take(20).collect();
tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)");
}
} }
} }
} }
@ -224,19 +234,18 @@ impl LLMProvider for OpenAIProvider {
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default(); let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
tracing::debug!(req_body = %req_body_str, "LLM request"); tracing::debug!(req_body = %req_body_str, "LLM request");
let resp = req_builder.json(&body).send().await let resp = req_builder.json(&body).send().await.inspect_err(|e| {
.inspect_err(|e| { let is_timeout = e.is_timeout();
let is_timeout = e.is_timeout(); tracing::error!(
tracing::error!( provider = %self.name,
provider = %self.name, model = %self.model_id,
model = %self.model_id, url = %url,
url = %url, timeout = is_timeout,
timeout = is_timeout, error = %e,
error = %e, elapsed_ms = %start.elapsed().as_millis(),
elapsed_ms = %start.elapsed().as_millis(), "LLM API request failed"
"LLM API request failed" );
); })?;
})?;
let status = resp.status(); let status = resp.status();
let text = resp.text().await?; let text = resp.text().await?;
@ -253,37 +262,48 @@ impl LLMProvider for OpenAIProvider {
"LLM API returned error" "LLM API returned error"
); );
if let Some(ref storage) = self.storage if let Some(ref storage) = self.storage
&& let Err(e) = storage.append_llm_call( && let Err(e) = storage
&self.name, &self.model_id, &req_body_str, .append_llm_call(
Some(&text), Some(&error), &self.name,
start.elapsed().as_millis() as u64, &self.model_id,
).await { &req_body_str,
tracing::warn!("failed to persist LLM call: {}", e); Some(&text),
} Some(&error),
start.elapsed().as_millis() as u64,
)
.await
{
tracing::warn!("failed to persist LLM call: {}", e);
}
return Err(error.into()); return Err(error.into());
} }
let openai_resp: OpenAIResponse = serde_json::from_str(&text) let openai_resp: OpenAIResponse = serde_json::from_str(&text).map_err(|e| {
.map_err(|e| { let err_msg = format!("decode error: {} | body: {}", e, &text);
let err_msg = format!("decode error: {} | body: {}", e, &text); if let Some(ref storage) = self.storage {
if let Some(ref storage) = self.storage { let name = self.name.clone();
let name = self.name.clone(); let model = self.model_id.clone();
let model = self.model_id.clone(); let req = req_body_str.clone();
let req = req_body_str.clone(); let resp = text.clone();
let resp = text.clone(); let dur = start.elapsed().as_millis() as u64;
let dur = start.elapsed().as_millis() as u64; let err = err_msg.clone();
let err = err_msg.clone(); let s = storage.clone();
let s = storage.clone(); tokio::spawn(async move {
tokio::spawn(async move { if let Err(e) = s
if let Err(e) = s.append_llm_call(&name, &model, &req, Some(&resp), Some(&err), dur).await { .append_llm_call(&name, &model, &req, Some(&resp), Some(&err), dur)
tracing::warn!("failed to persist LLM call (decode error): {}", e); .await
} {
}); tracing::warn!("failed to persist LLM call (decode error): {}", e);
} }
err_msg });
})?; }
err_msg
})?;
let first_choice = openai_resp.choices.into_iter().next() let first_choice = openai_resp
.choices
.into_iter()
.next()
.ok_or("no choices in response")?; .ok_or("no choices in response")?;
let content = first_choice let content = first_choice
@ -300,7 +320,8 @@ impl LLMProvider for OpenAIProvider {
.map(|tc| ToolCall { .map(|tc| ToolCall {
id: tc.id.clone(), id: tc.id.clone(),
name: tc.function.name.clone(), name: tc.function.name.clone(),
arguments: serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null), arguments: serde_json::from_str(&tc.function.arguments)
.unwrap_or(serde_json::Value::Null),
}) })
.collect(); .collect();
@ -318,13 +339,19 @@ impl LLMProvider for OpenAIProvider {
}; };
if let Some(ref storage) = self.storage if let Some(ref storage) = self.storage
&& let Err(e) = storage.append_llm_call( && let Err(e) = storage
&self.name, &self.model_id, &req_body_str, .append_llm_call(
Some(&text), None, &self.name,
start.elapsed().as_millis() as u64, &self.model_id,
).await { &req_body_str,
tracing::warn!("failed to persist LLM call: {}", e); Some(&text),
} None,
start.elapsed().as_millis() as u64,
)
.await
{
tracing::warn!("failed to persist LLM call: {}", e);
}
Ok(response) Ok(response)
} }
@ -386,6 +413,9 @@ mod tests {
assert_eq!(tool_calls[0]["id"], "call_1"); assert_eq!(tool_calls[0]["id"], "call_1");
assert_eq!(tool_calls[0]["type"], "function"); assert_eq!(tool_calls[0]["type"], "function");
assert_eq!(tool_calls[0]["function"]["name"], "calculator"); assert_eq!(tool_calls[0]["function"]["name"], "calculator");
assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}"); assert_eq!(
tool_calls[0]["function"]["arguments"],
"{\"expression\":\"1+1\"}"
);
} }
} }

View File

@ -1,6 +1,6 @@
use crate::bus::message::ContentBlock;
use async_trait::async_trait; use async_trait::async_trait;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::bus::message::ContentBlock;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message { pub struct Message {
@ -61,7 +61,11 @@ impl Message {
} }
} }
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self { pub fn tool(
tool_call_id: impl Into<String>,
tool_name: impl Into<String>,
content: impl Into<String>,
) -> Self {
Self { Self {
role: "tool".to_string(), role: "tool".to_string(),
content: vec![ContentBlock::text(content)], content: vec![ContentBlock::text(content)],

View File

@ -5,11 +5,11 @@ use std::time::Instant;
use tokio::time; use tokio::time;
use crate::config::SchedulerConfig; use crate::config::SchedulerConfig;
use crate::session::session::HandleResult;
use crate::session::SessionManager; use crate::session::SessionManager;
use crate::session::session::HandleResult;
use crate::storage::JobRun;
use crate::storage::ScheduledJob; use crate::storage::ScheduledJob;
use crate::storage::Storage; use crate::storage::Storage;
use crate::storage::JobRun;
pub use types::Schedule; pub use types::Schedule;
@ -89,7 +89,11 @@ impl Scheduler {
let now = now_ms(); let now = now_ms();
let due = match self.storage.due_scheduled_jobs(now, self.config.max_concurrent).await { let due = match self
.storage
.due_scheduled_jobs(now, self.config.max_concurrent)
.await
{
Ok(jobs) => jobs, Ok(jobs) => jobs,
Err(e) => { Err(e) => {
tracing::error!("scheduler: failed to query due jobs: {}", e); tracing::error!("scheduler: failed to query due jobs: {}", e);
@ -107,7 +111,11 @@ impl Scheduler {
let start = Instant::now(); let start = Instant::now();
let started_at = now_ms(); let started_at = now_ms();
if let Err(e) = self.storage.touch_scheduled_job_last_run(&job.id, started_at).await { if let Err(e) = self
.storage
.touch_scheduled_job_last_run(&job.id, started_at)
.await
{
tracing::error!(job_id = %job.id, "scheduler: failed to touch last_run_at: {}", e); tracing::error!(job_id = %job.id, "scheduler: failed to touch last_run_at: {}", e);
continue; continue;
} }
@ -135,7 +143,10 @@ impl Scheduler {
match result { match result {
Ok(HandleResult::AgentResponse(output)) => { Ok(HandleResult::AgentResponse(output)) => {
let output_truncated = if output.len() > 8000 { let output_truncated = if output.len() > 8000 {
format!("{}...[truncated]", &output[..output.ceil_char_boundary(8000)]) format!(
"{}...[truncated]",
&output[..output.ceil_char_boundary(8000)]
)
} else { } else {
output.clone() output.clone()
}; };
@ -155,7 +166,11 @@ impl Scheduler {
tracing::error!(job_id = %job.id, "scheduler: failed to record run: {}", e); tracing::error!(job_id = %job.id, "scheduler: failed to record run: {}", e);
} }
if let Err(e) = self.storage.set_scheduled_job_last_status(&job.id, "ok", None).await { if let Err(e) = self
.storage
.set_scheduled_job_last_status(&job.id, "ok", None)
.await
{
tracing::error!(job_id = %job.id, "scheduler: failed to set last_status: {}", e); tracing::error!(job_id = %job.id, "scheduler: failed to set last_status: {}", e);
} }
@ -199,9 +214,11 @@ impl Scheduler {
tracing::error!(job_id = %job.id, "scheduler: failed to record error run: {}", e2); tracing::error!(job_id = %job.id, "scheduler: failed to record error run: {}", e2);
} }
if let Err(e2) = self.storage.set_scheduled_job_last_status( if let Err(e2) = self
&job.id, "error", Some(&error_str), .storage
).await { .set_scheduled_job_last_status(&job.id, "error", Some(&error_str))
.await
{
tracing::error!(job_id = %job.id, "scheduler: failed to set error status: {}", e2); tracing::error!(job_id = %job.id, "scheduler: failed to set error status: {}", e2);
} }
@ -231,17 +248,23 @@ impl Scheduler {
self.storage.remove_scheduled_job(&job.id).await?; self.storage.remove_scheduled_job(&job.id).await?;
tracing::info!(job_id = %job.id, "scheduler: one-shot job deleted after run"); tracing::info!(job_id = %job.id, "scheduler: one-shot job deleted after run");
} else { } else {
self.storage.set_scheduled_job_enabled(&job.id, false).await?; self.storage
.set_scheduled_job_enabled(&job.id, false)
.await?;
tracing::info!(job_id = %job.id, "scheduler: one-shot job disabled after run"); tracing::info!(job_id = %job.id, "scheduler: one-shot job disabled after run");
} }
} }
Schedule::Every { .. } | Schedule::Cron { .. } => { Schedule::Every { .. } | Schedule::Cron { .. } => {
if let Some(next) = next_run_for_schedule(&job.schedule, now) { if let Some(next) = next_run_for_schedule(&job.schedule, now) {
self.storage.set_scheduled_job_next_run(&job.id, next).await?; self.storage
.set_scheduled_job_next_run(&job.id, next)
.await?;
tracing::info!(job_id = %job.id, next_run_at = %next, "scheduler: job rescheduled"); tracing::info!(job_id = %job.id, next_run_at = %next, "scheduler: job rescheduled");
} else { } else {
tracing::error!(job_id = %job.id, "scheduler: could not compute next run -- disabling job"); tracing::error!(job_id = %job.id, "scheduler: could not compute next run -- disabling job");
self.storage.set_scheduled_job_enabled(&job.id, false).await?; self.storage
.set_scheduled_job_enabled(&job.id, false)
.await?;
} }
} }
} }

View File

@ -22,32 +22,20 @@ pub enum SessionCommand {
dialog_id: String, dialog_id: String,
}, },
/// Get the current dialog for a chat /// Get the current dialog for a chat
GetCurrentDialog { GetCurrentDialog { channel: String, chat_id: String },
channel: String,
chat_id: String,
},
/// Rename a dialog /// Rename a dialog
RenameDialog { RenameDialog {
session_id: UnifiedSessionId, session_id: UnifiedSessionId,
title: String, title: String,
}, },
/// Archive a dialog /// Archive a dialog
ArchiveDialog { ArchiveDialog { session_id: UnifiedSessionId },
session_id: UnifiedSessionId,
},
/// Delete a dialog /// Delete a dialog
DeleteDialog { DeleteDialog { session_id: UnifiedSessionId },
session_id: UnifiedSessionId,
},
/// Clear dialog history /// Clear dialog history
ClearHistory { ClearHistory { session_id: UnifiedSessionId },
session_id: UnifiedSessionId,
},
/// Get list of available slash commands /// Get list of available slash commands
GetSlashCommands { GetSlashCommands { channel: String, chat_id: String },
channel: String,
chat_id: String,
},
/// Execute a slash command /// Execute a slash command
ExecuteSlashCommand { ExecuteSlashCommand {
command: String, command: String,
@ -60,7 +48,11 @@ pub enum SessionCommand {
impl SessionCommand { impl SessionCommand {
/// Create a CreateDialog command /// Create a CreateDialog command
pub fn create_dialog(channel: impl Into<String>, chat_id: impl Into<String>, title: Option<String>) -> Self { pub fn create_dialog(
channel: impl Into<String>,
chat_id: impl Into<String>,
title: Option<String>,
) -> Self {
Self::CreateDialog { Self::CreateDialog {
channel: channel.into(), channel: channel.into(),
chat_id: chat_id.into(), chat_id: chat_id.into(),
@ -69,7 +61,11 @@ impl SessionCommand {
} }
/// Create a ListDialogs command /// Create a ListDialogs command
pub fn list_dialogs(channel: impl Into<String>, chat_id: impl Into<String>, include_archived: bool) -> Self { pub fn list_dialogs(
channel: impl Into<String>,
chat_id: impl Into<String>,
include_archived: bool,
) -> Self {
Self::ListDialogs { Self::ListDialogs {
channel: channel.into(), channel: channel.into(),
chat_id: chat_id.into(), chat_id: chat_id.into(),

View File

@ -1,5 +1,5 @@
use super::session_id::UnifiedSessionId;
use super::session::SlashCommand; use super::session::SlashCommand;
use super::session_id::UnifiedSessionId;
/// Dialog information returned by SessionManager /// Dialog information returned by SessionManager
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -30,30 +30,20 @@ pub enum SessionEvent {
session_id: Option<UnifiedSessionId>, session_id: Option<UnifiedSessionId>,
}, },
/// Dialog switched successfully /// Dialog switched successfully
DialogSwitched { DialogSwitched { session_id: UnifiedSessionId },
session_id: UnifiedSessionId,
},
/// Dialog renamed /// Dialog renamed
DialogRenamed { DialogRenamed {
session_id: UnifiedSessionId, session_id: UnifiedSessionId,
title: String, title: String,
}, },
/// Dialog archived /// Dialog archived
DialogArchived { DialogArchived { session_id: UnifiedSessionId },
session_id: UnifiedSessionId,
},
/// Dialog deleted /// Dialog deleted
DialogDeleted { DialogDeleted { session_id: UnifiedSessionId },
session_id: UnifiedSessionId,
},
/// Dialog history cleared /// Dialog history cleared
HistoryCleared { HistoryCleared { session_id: UnifiedSessionId },
session_id: UnifiedSessionId,
},
/// List of available slash commands /// List of available slash commands
SlashCommandsList { SlashCommandsList { commands: Vec<SlashCommand> },
commands: Vec<SlashCommand>,
},
/// Slash command executed successfully /// Slash command executed successfully
SlashCommandExecuted { SlashCommandExecuted {
new_session_id: Option<UnifiedSessionId>, new_session_id: Option<UnifiedSessionId>,
@ -70,8 +60,5 @@ pub enum SessionEvent {
message_count: usize, message_count: usize,
}, },
/// Error occurred /// Error occurred
Error { Error { code: String, message: String },
code: String,
message: String,
},
} }

View File

@ -1,11 +1,11 @@
pub mod error;
pub mod commands; pub mod commands;
pub mod error;
pub mod events; pub mod events;
pub mod session; pub mod session;
pub mod session_id; pub mod session_id;
pub use error::SessionError;
pub use commands::SessionCommand; pub use commands::SessionCommand;
pub use events::{SessionEvent, DialogInfo}; pub use error::SessionError;
pub use session::{Session, SessionManager, SlashCommand, SLASH_COMMANDS}; pub use events::{DialogInfo, SessionEvent};
pub use session::{SLASH_COMMANDS, Session, SessionManager, SlashCommand};
pub use session_id::UnifiedSessionId; pub use session_id::UnifiedSessionId;

File diff suppressed because it is too large Load Diff

View File

@ -8,7 +8,6 @@
/// ///
/// For simple cases where only one dialog exists per chat: /// For simple cases where only one dialog exists per chat:
/// - `dialog_id` defaults to `"default"` /// - `dialog_id` defaults to `"default"`
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
pub const DEFAULT_DIALOG_ID: &str = "default"; pub const DEFAULT_DIALOG_ID: &str = "default";
@ -22,7 +21,11 @@ pub struct UnifiedSessionId {
impl UnifiedSessionId { impl UnifiedSessionId {
/// Create a new UnifiedSessionId /// Create a new UnifiedSessionId
pub fn new(channel: impl Into<String>, chat_id: impl Into<String>, dialog_id: impl Into<String>) -> Self { pub fn new(
channel: impl Into<String>,
chat_id: impl Into<String>,
dialog_id: impl Into<String>,
) -> Self {
Self { Self {
channel: channel.into(), channel: channel.into(),
chat_id: chat_id.into(), chat_id: chat_id.into(),

View File

@ -1,6 +1,6 @@
use std::path::Path; use std::path::Path;
use super::embedded::{EmbeddedSkill, EMBEDDED_SKILLS}; use super::embedded::{EMBEDDED_SKILLS, EmbeddedSkill};
pub fn install_builtin_skills(target_dir: &Path) { pub fn install_builtin_skills(target_dir: &Path) {
for skill in EMBEDDED_SKILLS { for skill in EMBEDDED_SKILLS {
@ -22,8 +22,7 @@ pub fn install_builtin_skills(target_dir: &Path) {
} }
fn install_one(skill: &EmbeddedSkill, target_dir: &Path) -> Result<(), String> { fn install_one(skill: &EmbeddedSkill, target_dir: &Path) -> Result<(), String> {
let decompressed = zstd::decode_all(skill.data) let decompressed = zstd::decode_all(skill.data).map_err(|e| format!("zstd decode: {}", e))?;
.map_err(|e| format!("zstd decode: {}", e))?;
let mut archive = tar::Archive::new(decompressed.as_slice()); let mut archive = tar::Archive::new(decompressed.as_slice());
archive archive

View File

@ -120,7 +120,11 @@ impl SkillsLoader {
let count = loaded.len(); let count = loaded.len();
let mut replaced = 0usize; let mut replaced = 0usize;
for skill in loaded { for skill in loaded {
if let Some(existing) = state.loaded_skills.iter_mut().find(|s| s.name == skill.name) { if let Some(existing) = state
.loaded_skills
.iter_mut()
.find(|s| s.name == skill.name)
{
*existing = skill; *existing = skill;
replaced += 1; replaced += 1;
} else { } else {
@ -138,33 +142,42 @@ impl SkillsLoader {
// Load from workspace skills dir (highest priority) — replace same-name skills // Load from workspace skills dir (highest priority) — replace same-name skills
if let Some(ref ws_dir) = self.workspace_skills_dir if let Some(ref ws_dir) = self.workspace_skills_dir
&& ws_dir.exists() { && ws_dir.exists()
let loaded = self.load_skills_from_dir(ws_dir); {
let count = loaded.len(); let loaded = self.load_skills_from_dir(ws_dir);
let mut replaced = 0usize; let count = loaded.len();
for skill in loaded { let mut replaced = 0usize;
if let Some(existing) = state.loaded_skills.iter_mut().find(|s| s.name == skill.name) { for skill in loaded {
*existing = skill; if let Some(existing) = state
replaced += 1; .loaded_skills
} else { .iter_mut()
state.loaded_skills.push(skill); .find(|s| s.name == skill.name)
} {
*existing = skill;
replaced += 1;
} else {
state.loaded_skills.push(skill);
} }
tracing::debug!(
dir = %ws_dir.display(),
count = count,
replaced = replaced,
"Loaded skills from workspace directory"
);
state.last_workspace_mtime = Self::get_dir_mtime(ws_dir);
} }
tracing::debug!(
dir = %ws_dir.display(),
count = count,
replaced = replaced,
"Loaded skills from workspace directory"
);
state.last_workspace_mtime = Self::get_dir_mtime(ws_dir);
}
state.last_load_time = SystemTime::now(); state.last_load_time = SystemTime::now();
if state.loaded_skills.is_empty() { if state.loaded_skills.is_empty() {
tracing::debug!("No skills found in any skills directory"); tracing::debug!("No skills found in any skills directory");
} else { } else {
tracing::info!(count = state.loaded_skills.len(), "Loaded {} skills total", state.loaded_skills.len()); tracing::info!(
count = state.loaded_skills.len(),
"Loaded {} skills total",
state.loaded_skills.len()
);
} }
} }
@ -215,18 +228,20 @@ impl SkillsLoader {
let mut max_mtime = None; let mut max_mtime = None;
if let Ok(metadata) = std::fs::metadata(dir) if let Ok(metadata) = std::fs::metadata(dir)
&& let Ok(mtime) = metadata.modified() { && let Ok(mtime) = metadata.modified()
max_mtime = Some(mtime); {
} max_mtime = Some(mtime);
}
if let Ok(entries) = std::fs::read_dir(dir) { if let Ok(entries) = std::fs::read_dir(dir) {
for entry in entries.flatten() { for entry in entries.flatten() {
let path = entry.path(); let path = entry.path();
if let Ok(metadata) = std::fs::metadata(&path) if let Ok(metadata) = std::fs::metadata(&path)
&& let Ok(mtime) = metadata.modified() && let Ok(mtime) = metadata.modified()
&& max_mtime.is_none_or(|current| mtime > current) { && max_mtime.is_none_or(|current| mtime > current)
max_mtime = Some(mtime); {
} max_mtime = Some(mtime);
}
} }
} }
@ -244,7 +259,12 @@ impl SkillsLoader {
pub fn get_always_skills(&self) -> Vec<Skill> { pub fn get_always_skills(&self) -> Vec<Skill> {
self.reload_if_changed(); self.reload_if_changed();
let state = self.state.lock().unwrap(); let state = self.state.lock().unwrap();
state.loaded_skills.iter().filter(|s| s.always).cloned().collect() state
.loaded_skills
.iter()
.filter(|s| s.always)
.cloned()
.collect()
} }
/// Get a specific skill by name (checks for changes first) /// Get a specific skill by name (checks for changes first)
@ -258,7 +278,8 @@ impl SkillsLoader {
pub fn list_skills(&self) -> Vec<(String, String)> { pub fn list_skills(&self) -> Vec<(String, String)> {
self.reload_if_changed(); self.reload_if_changed();
let state = self.state.lock().unwrap(); let state = self.state.lock().unwrap();
state.loaded_skills state
.loaded_skills
.iter() .iter()
.map(|s| (s.name.clone(), s.description.clone())) .map(|s| (s.name.clone(), s.description.clone()))
.collect() .collect()
@ -279,15 +300,21 @@ impl SkillsLoader {
prompt.push_str("### 目录说明\n\n"); prompt.push_str("### 目录说明\n\n");
prompt.push_str("- `~/.agents/skills/` — 外部共享 skill 目录(第三方、系统级 skill\n"); prompt.push_str("- `~/.agents/skills/` — 外部共享 skill 目录(第三方、系统级 skill\n");
prompt.push_str("- `~/.picobot/skills/` — 安装 skill 的默认目录\n"); prompt.push_str("- `~/.picobot/skills/` — 安装 skill 的默认目录\n");
prompt.push_str("- `{workspace}/skills/` — 工作目录下的 skillpicobot 自行创建的 skill 存放于此\n\n"); prompt.push_str(
prompt.push_str("安装或创建 skill 时请按上述目录规范存放创建skill时不要和已有skill同名。\n\n"); "- `{workspace}/skills/` — 工作目录下的 skillpicobot 自行创建的 skill 存放于此\n\n",
);
prompt.push_str(
"安装或创建 skill 时请按上述目录规范存放创建skill时不要和已有skill同名。\n\n",
);
// Always skills summary // Always skills summary
let always_skills: Vec<_> = state.loaded_skills.iter().filter(|s| s.always).collect(); let always_skills: Vec<_> = state.loaded_skills.iter().filter(|s| s.always).collect();
if !always_skills.is_empty() { if !always_skills.is_empty() {
prompt.push_str("### 常用技能\n\n"); prompt.push_str("### 常用技能\n\n");
for skill in &always_skills { for skill in &always_skills {
let path_str = skill.path.as_ref() let path_str = skill
.path
.as_ref()
.map(|p| p.to_string_lossy().to_string()) .map(|p| p.to_string_lossy().to_string())
.unwrap_or_else(|| "".to_string()); .unwrap_or_else(|| "".to_string());
prompt.push_str(&format!( prompt.push_str(&format!(
@ -300,8 +327,12 @@ impl SkillsLoader {
// Usage instructions // Usage instructions
prompt.push_str("### 使用方法\n\n"); prompt.push_str("### 使用方法\n\n");
prompt.push_str("- 使用 `get_skill` 工具 action=\"list\" 列出所有可用 skill 及其名称、简介、路径\n"); prompt.push_str(
prompt.push_str("- 使用 `get_skill` 工具 action=\"get\" 并提供 `skill_name` 获取指定 skill 完整内容\n"); "- 使用 `get_skill` 工具 action=\"list\" 列出所有可用 skill 及其名称、简介、路径\n",
);
prompt.push_str(
"- 使用 `get_skill` 工具 action=\"get\" 并提供 `skill_name` 获取指定 skill 完整内容\n",
);
prompt.push_str("- 当用户询问已安装的 skill 相关信息时,需重新调用 get_skill 工具查询最新内容,避免 skill 已变更导致信息过时\n"); prompt.push_str("- 当用户询问已安装的 skill 相关信息时,需重新调用 get_skill 工具查询最新内容,避免 skill 已变更导致信息过时\n");
// Always skills full content // Always skills full content
@ -338,25 +369,23 @@ impl SkillsLoader {
} }
match std::fs::read_to_string(&skill_file) { match std::fs::read_to_string(&skill_file) {
Ok(content) => { Ok(content) => match self.parse_skill(&path, &content) {
match self.parse_skill(&path, &content) { Some(skill) => {
Some(skill) => { tracing::debug!(
tracing::debug!( skill = %skill.name,
skill = %skill.name, path = %skill_file.display(),
path = %skill_file.display(), always = skill.always,
always = skill.always, "Loaded skill"
"Loaded skill" );
); skills.push(skill);
skills.push(skill);
}
None => {
tracing::warn!(
path = %skill_file.display(),
"Failed to parse skill"
);
}
} }
} None => {
tracing::warn!(
path = %skill_file.display(),
"Failed to parse skill"
);
}
},
Err(e) => { Err(e) => {
tracing::warn!( tracing::warn!(
path = %skill_file.display(), path = %skill_file.display(),
@ -447,7 +476,6 @@ impl Default for SkillsLoader {
} }
} }
/// Extract first non-empty, non-heading line as description /// Extract first non-empty, non-heading line as description
fn extract_description(content: &str) -> String { fn extract_description(content: &str) -> String {
content content

View File

@ -0,0 +1,19 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BackgroundTask {
pub id: String,
pub session_id: String,
pub channel: String,
pub chat_id: String,
pub prompt: String,
pub allowed_tools: Option<String>,
pub status: String,
pub result: Option<String>,
pub error: Option<String>,
pub tool_calls_count: i64,
pub iterations: i64,
pub started_at: Option<i64>,
pub finished_at: Option<i64>,
pub created_at: i64,
}

View File

@ -241,12 +241,11 @@ impl super::Storage {
let cutoff = chrono::Utc::now() - chrono::Duration::days(retention_days as i64); let cutoff = chrono::Utc::now() - chrono::Duration::days(retention_days as i64);
let cutoff_str = cutoff.to_rfc3339(); let cutoff_str = cutoff.to_rfc3339();
let result = sqlx::query( let result =
"DELETE FROM memories WHERE category = 'timeline' AND created_at < ?", sqlx::query("DELETE FROM memories WHERE category = 'timeline' AND created_at < ?")
) .bind(&cutoff_str)
.bind(&cutoff_str) .execute(self.pool())
.execute(self.pool()) .await?;
.await?;
Ok(result.rows_affected()) Ok(result.rows_affected())
} }
@ -276,9 +275,7 @@ impl super::Storage {
} }
} }
fn parse_memory_rows( fn parse_memory_rows(rows: &[sqlx::sqlite::SqliteRow]) -> Result<Vec<MemoryEntry>, StorageError> {
rows: &[sqlx::sqlite::SqliteRow],
) -> Result<Vec<MemoryEntry>, StorageError> {
rows.iter() rows.iter()
.map(|row| { .map(|row| {
Ok(MemoryEntry { Ok(MemoryEntry {

View File

@ -1,15 +1,17 @@
pub mod background_task;
pub mod error; pub mod error;
pub mod memory; pub mod memory;
pub mod message; pub mod message;
pub mod scheduler; pub mod scheduler;
pub mod session; pub mod session;
pub use background_task::BackgroundTask;
pub use error::StorageError; pub use error::StorageError;
pub use scheduler::{JobRun, ScheduledJob}; pub use scheduler::{JobRun, ScheduledJob};
use sqlx::{Pool, Row, Sqlite, SqlitePool}; use sqlx::{Pool, Row, Sqlite, SqlitePool};
use tokio::time::{sleep, Duration};
use std::path::Path; use std::path::Path;
use tokio::time::{Duration, sleep};
pub struct Storage { pub struct Storage {
pub(crate) pool: Pool<Sqlite>, pub(crate) pool: Pool<Sqlite>,
@ -40,6 +42,7 @@ impl Storage {
last_active_at INTEGER NOT NULL, last_active_at INTEGER NOT NULL,
message_count INTEGER DEFAULT 0, message_count INTEGER DEFAULT 0,
routing_info TEXT, routing_info TEXT,
archived_at INTEGER,
deleted_at INTEGER, deleted_at INTEGER,
last_consolidated_at INTEGER, last_consolidated_at INTEGER,
last_compressed_message_at INTEGER, last_compressed_message_at INTEGER,
@ -90,20 +93,58 @@ impl Storage {
.await?; .await?;
// Migration: add source column if upgrading from older schema // Migration: add source column if upgrading from older schema
sqlx::query( sqlx::query(r#"ALTER TABLE messages ADD COLUMN source TEXT"#)
r#"ALTER TABLE messages ADD COLUMN source TEXT"#, .execute(&self.pool)
) .await
.execute(&self.pool) .ok();
.await
.ok();
// Migration: add reasoning_content column if upgrading from older schema // Migration: add reasoning_content column if upgrading from older schema
sqlx::query(r#"ALTER TABLE messages ADD COLUMN reasoning_content TEXT"#)
.execute(&self.pool)
.await
.ok();
// Background tasks table — for async sub-agent tasks.
// Note: No FOREIGN KEY on session_id because sessions use soft delete (deleted_at IS NULL).
// Session and task association is maintained at the application level.
sqlx::query( sqlx::query(
r#"ALTER TABLE messages ADD COLUMN reasoning_content TEXT"#, r#"
CREATE TABLE IF NOT EXISTS background_tasks (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
channel TEXT NOT NULL,
chat_id TEXT NOT NULL,
prompt TEXT NOT NULL,
allowed_tools TEXT,
status TEXT NOT NULL DEFAULT 'pending',
result TEXT,
error TEXT,
tool_calls_count INTEGER DEFAULT 0,
iterations INTEGER DEFAULT 0,
started_at INTEGER,
finished_at INTEGER,
created_at INTEGER NOT NULL
)
"#,
) )
.execute(&self.pool) .execute(&self.pool)
.await .await?;
.ok();
sqlx::query(
r#"
CREATE INDEX IF NOT EXISTS idx_bg_tasks_session ON background_tasks(session_id)
"#,
)
.execute(&self.pool)
.await?;
sqlx::query(
r#"
CREATE INDEX IF NOT EXISTS idx_bg_tasks_status ON background_tasks(status)
"#,
)
.execute(&self.pool)
.await?;
sqlx::query( sqlx::query(
r#" r#"
@ -172,11 +213,19 @@ impl Storage {
.await?; .await?;
// Rebuild FTS5 index for any existing records // Rebuild FTS5 index for any existing records
sqlx::query("INSERT INTO memory_fts(memory_fts) VALUES ('rebuild')")
.execute(&self.pool)
.await?;
// Migration: add last_consolidated_at column if not exists
sqlx::query( sqlx::query(
"INSERT INTO memory_fts(memory_fts) VALUES ('rebuild')", r#"
ALTER TABLE sessions ADD COLUMN archived_at INTEGER
"#,
) )
.execute(&self.pool) .execute(&self.pool)
.await?; .await
.ok();
// Migration: add last_consolidated_at column if not exists // Migration: add last_consolidated_at column if not exists
sqlx::query( sqlx::query(
@ -216,7 +265,10 @@ impl Storage {
.await?; .await?;
if let Err(e) = Self::init_scheduler_schema(&self.pool).await { if let Err(e) = Self::init_scheduler_schema(&self.pool).await {
tracing::warn!("Failed to init scheduler schema (tables may already exist): {}", e); tracing::warn!(
"Failed to init scheduler schema (tables may already exist): {}",
e
);
} }
Ok(()) Ok(())
@ -330,16 +382,20 @@ impl Storage {
&self.pool &self.pool
} }
pub async fn upsert_session(&self, meta: &crate::storage::session::SessionMeta) -> Result<(), StorageError> { pub async fn upsert_session(
&self,
meta: &crate::storage::session::SessionMeta,
) -> Result<(), StorageError> {
sqlx::query( sqlx::query(
r#" r#"
INSERT INTO sessions (id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at) INSERT INTO sessions (id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET ON CONFLICT(id) DO UPDATE SET
title = excluded.title, title = excluded.title,
last_active_at = excluded.last_active_at, last_active_at = excluded.last_active_at,
message_count = excluded.message_count, message_count = excluded.message_count,
routing_info = excluded.routing_info, routing_info = excluded.routing_info,
archived_at = excluded.archived_at,
deleted_at = excluded.deleted_at, deleted_at = excluded.deleted_at,
last_consolidated_at = excluded.last_consolidated_at, last_consolidated_at = excluded.last_consolidated_at,
last_compressed_message_at = excluded.last_compressed_message_at last_compressed_message_at = excluded.last_compressed_message_at
@ -354,6 +410,7 @@ impl Storage {
.bind(meta.last_active_at) .bind(meta.last_active_at)
.bind(meta.message_count) .bind(meta.message_count)
.bind(&meta.routing_info) .bind(&meta.routing_info)
.bind(meta.archived_at)
.bind(meta.deleted_at) .bind(meta.deleted_at)
.bind(meta.last_consolidated_at) .bind(meta.last_consolidated_at)
.bind(meta.last_compressed_message_at) .bind(meta.last_compressed_message_at)
@ -363,10 +420,13 @@ impl Storage {
Ok(()) Ok(())
} }
pub async fn get_session(&self, id: &str) -> Result<crate::storage::session::SessionMeta, StorageError> { pub async fn get_session(
&self,
id: &str,
) -> Result<crate::storage::session::SessionMeta, StorageError> {
let row = sqlx::query( let row = sqlx::query(
r#" r#"
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at
FROM sessions WHERE id = ? AND deleted_at IS NULL FROM sessions WHERE id = ? AND deleted_at IS NULL
"#, "#,
) )
@ -385,6 +445,7 @@ impl Storage {
last_active_at: row.get("last_active_at"), last_active_at: row.get("last_active_at"),
message_count: row.get("message_count"), message_count: row.get("message_count"),
routing_info: row.get("routing_info"), routing_info: row.get("routing_info"),
archived_at: row.get("archived_at"),
deleted_at: row.get("deleted_at"), deleted_at: row.get("deleted_at"),
last_consolidated_at: row.get("last_consolidated_at"), last_consolidated_at: row.get("last_consolidated_at"),
last_compressed_message_at: row.get("last_compressed_message_at"), last_compressed_message_at: row.get("last_compressed_message_at"),
@ -396,18 +457,21 @@ impl Storage {
channel: &str, channel: &str,
chat_id: &str, chat_id: &str,
limit: i64, limit: i64,
include_archived: bool,
) -> Result<Vec<crate::storage::session::SessionMeta>, StorageError> { ) -> Result<Vec<crate::storage::session::SessionMeta>, StorageError> {
let rows = sqlx::query( let rows = sqlx::query(
r#" r#"
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at
FROM sessions FROM sessions
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL
AND (? OR archived_at IS NULL)
ORDER BY last_active_at DESC ORDER BY last_active_at DESC
LIMIT ? LIMIT ?
"#, "#,
) )
.bind(channel) .bind(channel)
.bind(chat_id) .bind(chat_id)
.bind(include_archived)
.bind(limit) .bind(limit)
.fetch_all(self.pool()) .fetch_all(self.pool())
.await?; .await?;
@ -424,6 +488,7 @@ impl Storage {
last_active_at: row.get("last_active_at"), last_active_at: row.get("last_active_at"),
message_count: row.get("message_count"), message_count: row.get("message_count"),
routing_info: row.get("routing_info"), routing_info: row.get("routing_info"),
archived_at: row.get("archived_at"),
deleted_at: row.get("deleted_at"), deleted_at: row.get("deleted_at"),
last_consolidated_at: row.get("last_consolidated_at"), last_consolidated_at: row.get("last_consolidated_at"),
last_compressed_message_at: row.get("last_compressed_message_at"), last_compressed_message_at: row.get("last_compressed_message_at"),
@ -454,13 +519,22 @@ impl Storage {
pub async fn soft_delete_session(&self, id: &str) -> Result<(), StorageError> { pub async fn soft_delete_session(&self, id: &str) -> Result<(), StorageError> {
let now = chrono::Utc::now().timestamp_millis(); let now = chrono::Utc::now().timestamp_millis();
sqlx::query( sqlx::query(r#"UPDATE sessions SET deleted_at = ? WHERE id = ?"#)
r#"UPDATE sessions SET deleted_at = ? WHERE id = ?"#, .bind(now)
) .bind(id)
.bind(now) .execute(self.pool())
.bind(id) .await?;
.execute(self.pool())
.await?; Ok(())
}
pub async fn archive_session(&self, id: &str) -> Result<(), StorageError> {
let now = chrono::Utc::now().timestamp_millis();
sqlx::query(r#"UPDATE sessions SET archived_at = ? WHERE id = ? AND deleted_at IS NULL"#)
.bind(now)
.bind(id)
.execute(self.pool())
.await?;
Ok(()) Ok(())
} }
@ -472,9 +546,9 @@ impl Storage {
) -> Result<Option<crate::storage::session::SessionMeta>, StorageError> { ) -> Result<Option<crate::storage::session::SessionMeta>, StorageError> {
let row = sqlx::query( let row = sqlx::query(
r#" r#"
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at
FROM sessions FROM sessions
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL AND archived_at IS NULL
ORDER BY last_active_at DESC ORDER BY last_active_at DESC
LIMIT 1 LIMIT 1
"#, "#,
@ -495,6 +569,7 @@ impl Storage {
last_active_at: row.get("last_active_at"), last_active_at: row.get("last_active_at"),
message_count: row.get("message_count"), message_count: row.get("message_count"),
routing_info: row.get("routing_info"), routing_info: row.get("routing_info"),
archived_at: row.get("archived_at"),
deleted_at: row.get("deleted_at"), deleted_at: row.get("deleted_at"),
last_consolidated_at: row.get("last_consolidated_at"), last_consolidated_at: row.get("last_consolidated_at"),
last_compressed_message_at: row.get("last_compressed_message_at"), last_compressed_message_at: row.get("last_compressed_message_at"),
@ -503,7 +578,11 @@ impl Storage {
} }
} }
pub async fn append_message(&self, session_id: &str, msg: &crate::storage::message::MessageMeta) -> Result<i64, StorageError> { pub async fn append_message(
&self,
session_id: &str,
msg: &crate::storage::message::MessageMeta,
) -> Result<i64, StorageError> {
sqlx::query( sqlx::query(
r#" r#"
INSERT INTO messages (id, session_id, seq, role, content, reasoning_content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at) INSERT INTO messages (id, session_id, seq, role, content, reasoning_content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at)
@ -630,16 +709,15 @@ impl Storage {
offset: i64, offset: i64,
limit: i64, limit: i64,
) -> Result<(Vec<crate::storage::session::SessionMeta>, i64), StorageError> { ) -> Result<(Vec<crate::storage::session::SessionMeta>, i64), StorageError> {
let count_row = sqlx::query( let count_row =
"SELECT COUNT(*) as total FROM sessions WHERE deleted_at IS NULL", sqlx::query("SELECT COUNT(*) as total FROM sessions WHERE deleted_at IS NULL")
) .fetch_one(self.pool())
.fetch_one(self.pool()) .await?;
.await?;
let total: i64 = count_row.get("total"); let total: i64 = count_row.get("total");
let rows = sqlx::query( let rows = sqlx::query(
r#" r#"
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at
FROM sessions FROM sessions
WHERE deleted_at IS NULL WHERE deleted_at IS NULL
ORDER BY last_active_at DESC ORDER BY last_active_at DESC
@ -663,6 +741,7 @@ impl Storage {
last_active_at: row.get("last_active_at"), last_active_at: row.get("last_active_at"),
message_count: row.get("message_count"), message_count: row.get("message_count"),
routing_info: row.get("routing_info"), routing_info: row.get("routing_info"),
archived_at: row.get("archived_at"),
deleted_at: row.get("deleted_at"), deleted_at: row.get("deleted_at"),
last_consolidated_at: row.get("last_consolidated_at"), last_consolidated_at: row.get("last_consolidated_at"),
last_compressed_message_at: row.get("last_compressed_message_at"), last_compressed_message_at: row.get("last_compressed_message_at"),
@ -728,7 +807,10 @@ impl Storage {
where_extra.push_str(" AND created_at > ?"); where_extra.push_str(" AND created_at > ?");
} }
let count_sql = format!("SELECT COUNT(*) as total FROM messages WHERE session_id = ?{}", where_extra); let count_sql = format!(
"SELECT COUNT(*) as total FROM messages WHERE session_id = ?{}",
where_extra
);
let select_sql = format!( let select_sql = format!(
r#" r#"
SELECT id, session_id, seq, role, content, reasoning_content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at SELECT id, session_id, seq, role, content, reasoning_content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at
@ -816,6 +898,148 @@ impl Storage {
} }
unreachable!() unreachable!()
} }
// ── Background Task CRUD ──
pub async fn create_background_task(
&self,
task: &crate::storage::background_task::BackgroundTask,
) -> Result<(), StorageError> {
sqlx::query(
r#"
INSERT INTO background_tasks (id, session_id, channel, chat_id, prompt, allowed_tools, status, result, error, tool_calls_count, iterations, started_at, finished_at, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"#,
)
.bind(&task.id)
.bind(&task.session_id)
.bind(&task.channel)
.bind(&task.chat_id)
.bind(&task.prompt)
.bind(&task.allowed_tools)
.bind(&task.status)
.bind(&task.result)
.bind(&task.error)
.bind(task.tool_calls_count)
.bind(task.iterations)
.bind(task.started_at)
.bind(task.finished_at)
.bind(task.created_at)
.execute(self.pool())
.await?;
Ok(())
}
pub async fn update_background_task_status(
&self,
id: &str,
status: &str,
result: Option<&str>,
error: Option<&str>,
started_at: Option<i64>,
finished_at: Option<i64>,
) -> Result<(), StorageError> {
sqlx::query(
r#"
UPDATE background_tasks
SET status = ?, result = COALESCE(?, result), error = COALESCE(?, error),
started_at = COALESCE(?, started_at), finished_at = COALESCE(?, finished_at)
WHERE id = ?
"#,
)
.bind(status)
.bind(result)
.bind(error)
.bind(started_at)
.bind(finished_at)
.bind(id)
.execute(self.pool())
.await?;
Ok(())
}
pub async fn get_background_task(
&self,
id: &str,
) -> Result<crate::storage::background_task::BackgroundTask, StorageError> {
let row = sqlx::query(
r#"
SELECT id, session_id, channel, chat_id, prompt, allowed_tools, status, result, error,
tool_calls_count, iterations, started_at, finished_at, created_at
FROM background_tasks WHERE id = ?
"#,
)
.bind(id)
.fetch_optional(self.pool())
.await?
.ok_or_else(|| StorageError::NotFound(id.to_string()))?;
Ok(crate::storage::background_task::BackgroundTask {
id: row.get("id"),
session_id: row.get("session_id"),
channel: row.get("channel"),
chat_id: row.get("chat_id"),
prompt: row.get("prompt"),
allowed_tools: row.get("allowed_tools"),
status: row.get("status"),
result: row.get("result"),
error: row.get("error"),
tool_calls_count: row.get("tool_calls_count"),
iterations: row.get("iterations"),
started_at: row.get("started_at"),
finished_at: row.get("finished_at"),
created_at: row.get("created_at"),
})
}
pub async fn list_background_tasks(
&self,
session_id: &str,
) -> Result<Vec<crate::storage::background_task::BackgroundTask>, StorageError> {
let rows = sqlx::query(
r#"
SELECT id, session_id, channel, chat_id, prompt, allowed_tools, status, result, error,
tool_calls_count, iterations, started_at, finished_at, created_at
FROM background_tasks
WHERE session_id = ?
ORDER BY created_at DESC
"#,
)
.bind(session_id)
.fetch_all(self.pool())
.await?;
Ok(rows
.into_iter()
.map(|row| crate::storage::background_task::BackgroundTask {
id: row.get("id"),
session_id: row.get("session_id"),
channel: row.get("channel"),
chat_id: row.get("chat_id"),
prompt: row.get("prompt"),
allowed_tools: row.get("allowed_tools"),
status: row.get("status"),
result: row.get("result"),
error: row.get("error"),
tool_calls_count: row.get("tool_calls_count"),
iterations: row.get("iterations"),
started_at: row.get("started_at"),
finished_at: row.get("finished_at"),
created_at: row.get("created_at"),
})
.collect())
}
pub async fn cleanup_old_tasks(&self, ttl_ms: i64) -> Result<usize, StorageError> {
let cutoff = chrono::Utc::now().timestamp_millis() - ttl_ms;
let result = sqlx::query(
"DELETE FROM background_tasks WHERE status IN ('completed', 'failed', 'cancelled') AND finished_at IS NOT NULL AND finished_at < ?",
)
.bind(cutoff)
.execute(self.pool())
.await?;
Ok(result.rows_affected() as usize)
}
} }
#[cfg(test)] #[cfg(test)]
@ -844,6 +1068,7 @@ mod tests {
last_active_at: 1000, last_active_at: 1000,
message_count: 0, message_count: 0,
routing_info: Some(r#"{"type":"cli"}"#.to_string()), routing_info: Some(r#"{"type":"cli"}"#.to_string()),
archived_at: None,
deleted_at: None, deleted_at: None,
last_consolidated_at: None, last_consolidated_at: None,
last_compressed_message_at: None, last_compressed_message_at: None,
@ -880,14 +1105,18 @@ mod tests {
last_active_at: i as i64 * 1000, last_active_at: i as i64 * 1000,
message_count: i, message_count: i,
routing_info: None, routing_info: None,
archived_at: None,
deleted_at: None, deleted_at: None,
last_consolidated_at: None, last_consolidated_at: None,
last_compressed_message_at: None, last_compressed_message_at: None,
}; };
storage.upsert_session(&meta).await.unwrap(); storage.upsert_session(&meta).await.unwrap();
} }
let sessions = storage.list_sessions("cli_chat", "sid123", 10).await.unwrap(); let sessions = storage
.list_sessions("cli_chat", "sid123", 10, false)
.await
.unwrap();
assert_eq!(sessions.len(), 5); assert_eq!(sessions.len(), 5);
// 按 last_active_at DESC 排序 // 按 last_active_at DESC 排序
assert_eq!(sessions[0].dialog_id, "dialog4"); assert_eq!(sessions[0].dialog_id, "dialog4");
@ -907,6 +1136,7 @@ mod tests {
last_active_at: 1000, last_active_at: 1000,
message_count: 0, message_count: 0,
routing_info: None, routing_info: None,
archived_at: None,
deleted_at: None, deleted_at: None,
last_consolidated_at: None, last_consolidated_at: None,
last_compressed_message_at: None, last_compressed_message_at: None,
@ -934,6 +1164,7 @@ mod tests {
last_active_at: 1000, last_active_at: 1000,
message_count: 0, message_count: 0,
routing_info: None, routing_info: None,
archived_at: None,
deleted_at: None, deleted_at: None,
last_consolidated_at: None, last_consolidated_at: None,
last_compressed_message_at: None, last_compressed_message_at: None,
@ -955,7 +1186,10 @@ mod tests {
created_at: 1000, created_at: 1000,
}; };
let seq = storage.append_message(&session_meta.id, &msg).await.unwrap(); let seq = storage
.append_message(&session_meta.id, &msg)
.await
.unwrap();
assert_eq!(seq, 1); assert_eq!(seq, 1);
let loaded = storage.load_messages(&session_meta.id, 0).await.unwrap(); let loaded = storage.load_messages(&session_meta.id, 0).await.unwrap();
@ -977,6 +1211,7 @@ mod tests {
last_active_at: 1000, last_active_at: 1000,
message_count: 0, message_count: 0,
routing_info: None, routing_info: None,
archived_at: None,
deleted_at: None, deleted_at: None,
last_consolidated_at: None, last_consolidated_at: None,
last_compressed_message_at: None, last_compressed_message_at: None,

View File

@ -165,7 +165,11 @@ impl crate::storage::Storage {
} }
/// Update next_run_at and last_run_at for a job. /// Update next_run_at and last_run_at for a job.
pub async fn set_scheduled_job_next_run(&self, id: &str, next_run_at: i64) -> anyhow::Result<()> { pub async fn set_scheduled_job_next_run(
&self,
id: &str,
next_run_at: i64,
) -> anyhow::Result<()> {
let now = now_ms(); let now = now_ms();
sqlx::query( sqlx::query(
"UPDATE scheduled_jobs SET next_run_at = ?, last_run_at = ?, updated_at = ? WHERE id = ?", "UPDATE scheduled_jobs SET next_run_at = ?, last_run_at = ?, updated_at = ? WHERE id = ?",
@ -331,7 +335,9 @@ mod tests {
async fn setup_storage() -> Storage { async fn setup_storage() -> Storage {
let pool = SqlitePool::connect("sqlite::memory:").await.unwrap(); let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
let storage = Storage { pool }; let storage = Storage { pool };
Storage::init_scheduler_schema(storage.pool()).await.unwrap(); Storage::init_scheduler_schema(storage.pool())
.await
.unwrap();
storage storage
} }
@ -450,7 +456,10 @@ mod tests {
updated_at: t, updated_at: t,
}; };
storage.add_scheduled_job(&job).await.unwrap(); storage.add_scheduled_job(&job).await.unwrap();
storage.set_scheduled_job_enabled("job-toggle", false).await.unwrap(); storage
.set_scheduled_job_enabled("job-toggle", false)
.await
.unwrap();
let got = storage.get_scheduled_job("job-toggle").await.unwrap(); let got = storage.get_scheduled_job("job-toggle").await.unwrap();
assert!(!got.enabled); assert!(!got.enabled);
} }
@ -461,31 +470,55 @@ mod tests {
let t = now(); let t = now();
let jobs = vec![ let jobs = vec![
ScheduledJob { ScheduledJob {
id: "due".into(), name: "due".into(), id: "due".into(),
schedule: Schedule::At { at: t }, prompt: "1".into(), name: "due".into(),
channel: "cli_chat".into(), chat_id: "c".into(), schedule: Schedule::At { at: t },
model: None, enabled: true, delete_after_run: false, prompt: "1".into(),
next_run_at: t - 1000, last_run_at: None, channel: "cli_chat".into(),
last_status: None, last_error: None, chat_id: "c".into(),
created_at: t, updated_at: t, model: None,
enabled: true,
delete_after_run: false,
next_run_at: t - 1000,
last_run_at: None,
last_status: None,
last_error: None,
created_at: t,
updated_at: t,
}, },
ScheduledJob { ScheduledJob {
id: "future".into(), name: "future".into(), id: "future".into(),
schedule: Schedule::At { at: t + 99999999 }, prompt: "2".into(), name: "future".into(),
channel: "cli_chat".into(), chat_id: "c".into(), schedule: Schedule::At { at: t + 99999999 },
model: None, enabled: true, delete_after_run: false, prompt: "2".into(),
next_run_at: t + 99999999, last_run_at: None, channel: "cli_chat".into(),
last_status: None, last_error: None, chat_id: "c".into(),
created_at: t, updated_at: t, model: None,
enabled: true,
delete_after_run: false,
next_run_at: t + 99999999,
last_run_at: None,
last_status: None,
last_error: None,
created_at: t,
updated_at: t,
}, },
ScheduledJob { ScheduledJob {
id: "disabled-due".into(), name: "disabled due".into(), id: "disabled-due".into(),
schedule: Schedule::At { at: t }, prompt: "3".into(), name: "disabled due".into(),
channel: "cli_chat".into(), chat_id: "c".into(), schedule: Schedule::At { at: t },
model: None, enabled: false, delete_after_run: false, prompt: "3".into(),
next_run_at: t - 1000, last_run_at: None, channel: "cli_chat".into(),
last_status: None, last_error: None, chat_id: "c".into(),
created_at: t, updated_at: t, model: None,
enabled: false,
delete_after_run: false,
next_run_at: t - 1000,
last_run_at: None,
last_status: None,
last_error: None,
created_at: t,
updated_at: t,
}, },
]; ];
for j in &jobs { for j in &jobs {
@ -501,24 +534,39 @@ mod tests {
let storage = setup_storage().await; let storage = setup_storage().await;
let t = now(); let t = now();
let job = ScheduledJob { let job = ScheduledJob {
id: "job-run".into(), name: "run test".into(), id: "job-run".into(),
name: "run test".into(),
schedule: Schedule::Every { every_ms: 1000 }, schedule: Schedule::Every { every_ms: 1000 },
prompt: "hi".into(), channel: "cli_chat".into(), chat_id: "c".into(), prompt: "hi".into(),
model: None, enabled: true, delete_after_run: false, channel: "cli_chat".into(),
next_run_at: t, last_run_at: None, chat_id: "c".into(),
last_status: None, last_error: None, model: None,
created_at: t, updated_at: t, enabled: true,
delete_after_run: false,
next_run_at: t,
last_run_at: None,
last_status: None,
last_error: None,
created_at: t,
updated_at: t,
}; };
storage.add_scheduled_job(&job).await.unwrap(); storage.add_scheduled_job(&job).await.unwrap();
let run = super::JobRun { let run = super::JobRun {
id: 0, job_id: "job-run".into(), id: 0,
started_at: t, finished_at: t + 500, job_id: "job-run".into(),
status: "ok".into(), output: Some("hello".into()), started_at: t,
error: None, duration_ms: 500, finished_at: t + 500,
status: "ok".into(),
output: Some("hello".into()),
error: None,
duration_ms: 500,
}; };
storage.record_scheduled_job_run(&run).await.unwrap(); storage.record_scheduled_job_run(&run).await.unwrap();
let runs = storage.list_scheduled_job_runs("job-run", 10).await.unwrap(); let runs = storage
.list_scheduled_job_runs("job-run", 10)
.await
.unwrap();
assert_eq!(runs.len(), 1); assert_eq!(runs.len(), 1);
assert_eq!(runs[0].status, "ok"); assert_eq!(runs[0].status, "ok");
assert_eq!(runs[0].output.as_deref(), Some("hello")); assert_eq!(runs[0].output.as_deref(), Some("hello"));
@ -529,22 +577,34 @@ mod tests {
let storage = setup_storage().await; let storage = setup_storage().await;
let t = now(); let t = now();
let job = ScheduledJob { let job = ScheduledJob {
id: "job-update".into(), name: "old name".into(), id: "job-update".into(),
name: "old name".into(),
schedule: Schedule::Every { every_ms: 1000 }, schedule: Schedule::Every { every_ms: 1000 },
prompt: "old prompt".into(), channel: "feishu".into(), prompt: "old prompt".into(),
chat_id: "oc_1".into(), model: None, channel: "feishu".into(),
enabled: true, delete_after_run: false, chat_id: "oc_1".into(),
next_run_at: t, last_run_at: None, model: None,
last_status: None, last_error: None, enabled: true,
created_at: t, updated_at: t, delete_after_run: false,
next_run_at: t,
last_run_at: None,
last_status: None,
last_error: None,
created_at: t,
updated_at: t,
}; };
storage.add_scheduled_job(&job).await.unwrap(); storage.add_scheduled_job(&job).await.unwrap();
storage.update_scheduled_job( storage
"job-update", .update_scheduled_job(
Some("new prompt".into()), "job-update",
Some(Schedule::Every { every_ms: 60000 }), Some("new prompt".into()),
None, None, None, Some(Schedule::Every { every_ms: 60000 }),
).await.unwrap(); None,
None,
None,
)
.await
.unwrap();
let got = storage.get_scheduled_job("job-update").await.unwrap(); let got = storage.get_scheduled_job("job-update").await.unwrap();
assert_eq!(got.prompt, "new prompt"); assert_eq!(got.prompt, "new prompt");
} }

View File

@ -11,6 +11,7 @@ pub struct SessionMeta {
pub last_active_at: i64, pub last_active_at: i64,
pub message_count: i64, pub message_count: i64,
pub routing_info: Option<String>, pub routing_info: Option<String>,
pub archived_at: Option<i64>,
pub deleted_at: Option<i64>, pub deleted_at: Option<i64>,
pub last_consolidated_at: Option<i64>, pub last_consolidated_at: Option<i64>,
pub last_compressed_message_at: Option<i64>, pub last_compressed_message_at: Option<i64>,

View File

@ -167,10 +167,7 @@ impl Tool for BashTool {
Err(_) => Ok(ToolResult { Err(_) => Ok(ToolResult {
success: false, success: false,
output: String::new(), output: String::new(),
error: Some(format!( error: Some(format!("Command timed out after {} seconds", timeout_secs)),
"Command timed out after {} seconds",
timeout_secs
)),
}), }),
} }
} }
@ -249,10 +246,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_pwd_command() { async fn test_pwd_command() {
let tool = BashTool::new(); let tool = BashTool::new();
let result = tool let result = tool.execute(json!({ "command": "pwd" })).await.unwrap();
.execute(json!({ "command": "pwd" }))
.await
.unwrap();
assert!(result.success); assert!(result.success);
} }
@ -260,7 +254,10 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_ls_command() { async fn test_ls_command() {
let tool = BashTool::new(); let tool = BashTool::new();
let result = tool.execute(json!({ "command": "ls -la /tmp" })).await.unwrap(); let result = tool
.execute(json!({ "command": "ls -la /tmp" }))
.await
.unwrap();
assert!(result.success); assert!(result.success);
} }

View File

@ -5,7 +5,7 @@ use std::time::Duration;
use anyhow::Context; use anyhow::Context;
use async_trait::async_trait; use async_trait::async_trait;
use base64::Engine; use base64::Engine;
use fantoccini::actions::{InputSource, MouseActions, PointerAction, MOUSE_BUTTON_LEFT}; use fantoccini::actions::{InputSource, MOUSE_BUTTON_LEFT, MouseActions, PointerAction};
use fantoccini::key::Key; use fantoccini::key::Key;
use fantoccini::{Client, ClientBuilder, Locator}; use fantoccini::{Client, ClientBuilder, Locator};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -63,7 +63,9 @@ impl BrowserTool {
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum BrowserAction { pub enum BrowserAction {
Open { url: String }, Open {
url: String,
},
Snapshot { Snapshot {
#[serde(default)] #[serde(default)]
interactive_only: bool, interactive_only: bool,
@ -72,10 +74,20 @@ pub enum BrowserAction {
#[serde(default)] #[serde(default)]
depth: Option<i64>, depth: Option<i64>,
}, },
Click { selector: String }, Click {
Fill { selector: String, value: String }, selector: String,
Type { selector: Option<String>, text: String }, },
GetText { selector: String }, Fill {
selector: String,
value: String,
},
Type {
selector: Option<String>,
text: String,
},
GetText {
selector: String,
},
GetTitle, GetTitle,
GetUrl, GetUrl,
Screenshot { Screenshot {
@ -84,7 +96,9 @@ pub enum BrowserAction {
#[serde(default)] #[serde(default)]
return_base64: bool, return_base64: bool,
}, },
Focus { selector: String }, Focus {
selector: String,
},
Wait { Wait {
#[serde(default)] #[serde(default)]
selector: Option<String>, selector: Option<String>,
@ -93,9 +107,16 @@ pub enum BrowserAction {
#[serde(default)] #[serde(default)]
text: Option<String>, text: Option<String>,
}, },
Press { key: String }, Press {
Hover { selector: String }, key: String,
ClickAt { x: u32, y: u32 }, },
Hover {
selector: String,
},
ClickAt {
x: u32,
y: u32,
},
Scroll { Scroll {
direction: String, direction: String,
#[serde(default)] #[serde(default)]
@ -120,13 +141,8 @@ fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result<Browse
.get("interactive_only") .get("interactive_only")
.and_then(Value::as_bool) .and_then(Value::as_bool)
.unwrap_or(true), .unwrap_or(true),
compact: args compact: args.get("compact").and_then(Value::as_bool).unwrap_or(true),
.get("compact") depth: args.get("depth").and_then(|v| v.as_i64()),
.and_then(Value::as_bool)
.unwrap_or(true),
depth: args
.get("depth")
.and_then(|v| v.as_i64()),
}), }),
"click" => { "click" => {
let selector = args let selector = args
@ -198,10 +214,7 @@ fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result<Browse
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.map(String::from), .map(String::from),
ms: args.get("ms").and_then(|v| v.as_u64()), ms: args.get("ms").and_then(|v| v.as_u64()),
text: args text: args.get("text").and_then(|v| v.as_str()).map(String::from),
.get("text")
.and_then(|v| v.as_str())
.map(String::from),
}), }),
"press" => { "press" => {
let key = args let key = args
@ -239,11 +252,13 @@ fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result<Browse
let x = args let x = args
.get("x") .get("x")
.and_then(|v| v.as_u64()) .and_then(|v| v.as_u64())
.ok_or_else(|| anyhow::anyhow!("Missing 'x' for click_at"))? as u32; .ok_or_else(|| anyhow::anyhow!("Missing 'x' for click_at"))?
as u32;
let y = args let y = args
.get("y") .get("y")
.and_then(|v| v.as_u64()) .and_then(|v| v.as_u64())
.ok_or_else(|| anyhow::anyhow!("Missing 'y' for click_at"))? as u32; .ok_or_else(|| anyhow::anyhow!("Missing 'y' for click_at"))?
as u32;
Ok(BrowserAction::ClickAt { x, y }) Ok(BrowserAction::ClickAt { x, y })
} }
other => anyhow::bail!("Unsupported browser action: {}", other), other => anyhow::bail!("Unsupported browser action: {}", other),
@ -488,7 +503,11 @@ impl BrowserState {
} }
Err(e) => return Err(e.into()), Err(e) => return Err(e.into()),
} }
tracing::debug!(action = "fill", output_len = value.len(), "Browser action completed"); tracing::debug!(
action = "fill",
output_len = value.len(),
"Browser action completed"
);
Ok(ToolResult { Ok(ToolResult {
success: true, success: true,
output: format!("Filled {} with {}", selector, value), output: format!("Filled {} with {}", selector, value),
@ -573,7 +592,10 @@ impl BrowserState {
error: None, error: None,
}) })
} }
BrowserAction::Screenshot { path, return_base64 } => { BrowserAction::Screenshot {
path,
return_base64,
} => {
let client = self.active_client()?; let client = self.active_client()?;
let png = client.screenshot().await?; let png = client.screenshot().await?;
let save_path = path.unwrap_or_else(|| { let save_path = path.unwrap_or_else(|| {
@ -588,14 +610,25 @@ impl BrowserState {
tokio::fs::write(&save_path, &png).await?; tokio::fs::write(&save_path, &png).await?;
if return_base64 { if return_base64 {
let b64 = base64::engine::general_purpose::STANDARD.encode(&png); let b64 = base64::engine::general_purpose::STANDARD.encode(&png);
tracing::debug!(action = "screenshot", output_len = b64.len(), "Browser action completed"); tracing::debug!(
action = "screenshot",
output_len = b64.len(),
"Browser action completed"
);
return Ok(ToolResult { return Ok(ToolResult {
success: true, success: true,
output: format!("Screenshot saved to {}. Base64: data:image/png;base64,{}", save_path, b64), output: format!(
"Screenshot saved to {}. Base64: data:image/png;base64,{}",
save_path, b64
),
error: None, error: None,
}); });
} }
tracing::debug!(action = "screenshot", output_len = save_path.len(), "Browser action completed"); tracing::debug!(
action = "screenshot",
output_len = save_path.len(),
"Browser action completed"
);
Ok(ToolResult { Ok(ToolResult {
success: true, success: true,
output: format!("Screenshot saved to {}", save_path), output: format!("Screenshot saved to {}", save_path),
@ -611,18 +644,18 @@ impl BrowserState {
vec![serde_json::to_value(el)?], vec![serde_json::to_value(el)?],
) )
.await?; .await?;
tracing::debug!(action = "focus", output_len = selector.len(), "Browser action completed"); tracing::debug!(
action = "focus",
output_len = selector.len(),
"Browser action completed"
);
Ok(ToolResult { Ok(ToolResult {
success: true, success: true,
output: format!("Focused {}", selector), output: format!("Focused {}", selector),
error: None, error: None,
}) })
} }
BrowserAction::Wait { BrowserAction::Wait { selector, ms, text } => {
selector,
ms,
text,
} => {
if let Some(sel) = selector { if let Some(sel) = selector {
let client = self.active_client()?; let client = self.active_client()?;
wait_for_selector(client, &sel).await?; wait_for_selector(client, &sel).await?;
@ -719,9 +752,21 @@ impl BrowserState {
let id = info.get("id").and_then(|v| v.as_str()).unwrap_or(""); let id = info.get("id").and_then(|v| v.as_str()).unwrap_or("");
let el_type = info.get("type").and_then(|v| v.as_str()).unwrap_or(""); let el_type = info.get("type").and_then(|v| v.as_str()).unwrap_or("");
let text = info.get("text").and_then(|v| v.as_str()).unwrap_or(""); let text = info.get("text").and_then(|v| v.as_str()).unwrap_or("");
let id_str = if id.is_empty() { String::new() } else { format!("#{id}") }; let id_str = if id.is_empty() {
let type_str = if el_type.is_empty() { String::new() } else { format!("[type={el_type}]") }; String::new()
let text_str = if text.is_empty() { String::new() } else { format!(" ({text})") }; } else {
format!("#{id}")
};
let type_str = if el_type.is_empty() {
String::new()
} else {
format!("[type={el_type}]")
};
let text_str = if text.is_empty() {
String::new()
} else {
format!(" ({text})")
};
format!("Clicked at ({x},{y}) on <{tag}{id_str}{type_str}>{text_str}") format!("Clicked at ({x},{y}) on <{tag}{id_str}{type_str}>{text_str}")
} }
None => format!("Clicked at ({}, {})", x, y), None => format!("Clicked at ({}, {})", x, y),
@ -1090,10 +1135,7 @@ fn css_attr_escape(input: &str) -> String {
} }
fn xpath_contains_text(text: &str) -> String { fn xpath_contains_text(text: &str) -> String {
format!( format!("//*[contains(normalize-space(.), {})]", xpath_literal(text))
"//*[contains(normalize-space(.), {})]",
xpath_literal(text)
)
} }
fn xpath_literal(input: &str) -> String { fn xpath_literal(input: &str) -> String {
@ -1140,7 +1182,10 @@ fn webdriver_key(key: &str) -> String {
"pagedown" => Key::PageDown.to_string(), "pagedown" => Key::PageDown.to_string(),
"space" => " ".to_string(), "space" => " ".to_string(),
other => { other => {
tracing::warn!("Unrecognized key '{}', this will have no effect (press only supports single named keys)", other); tracing::warn!(
"Unrecognized key '{}', this will have no effect (press only supports single named keys)",
other
);
other.to_string() other.to_string()
} }
} }

View File

@ -659,10 +659,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_evaluate_missing_expression() { async fn test_evaluate_missing_expression() {
let tool = CalculatorTool::new(); let tool = CalculatorTool::new();
let result = tool let result = tool.execute(json!({"function": "evaluate"})).await.unwrap();
.execute(json!({"function": "evaluate"}))
.await
.unwrap();
assert!(!result.success); assert!(!result.success);
} }

View File

@ -126,7 +126,10 @@ impl ChatManagerTool {
let start_num = offset + 1; let start_num = offset + 1;
let end_num = offset + sessions.len() as i64; let end_num = offset + sessions.len() as i64;
let mut output = format!("全部会话 (共 {} 个,第 {}-{} 个):\n", total, start_num, end_num); let mut output = format!(
"全部会话 (共 {} 个,第 {}-{} 个):\n",
total, start_num, end_num
);
for s in &sessions { for s in &sessions {
let ago = format_duration_ago(now_ms - s.last_active_at); let ago = format_duration_ago(now_ms - s.last_active_at);
@ -300,9 +303,10 @@ mod tests {
last_active_at: now - i * 3600_000, last_active_at: now - i * 3600_000,
message_count: i * 5, message_count: i * 5,
routing_info: None, routing_info: None,
archived_at: None,
deleted_at: None, deleted_at: None,
last_consolidated_at: None, last_consolidated_at: None,
last_compressed_message_at: None, last_compressed_message_at: None,
}; };
storage.upsert_session(&meta).await.unwrap(); storage.upsert_session(&meta).await.unwrap();
} }
@ -335,6 +339,7 @@ mod tests {
last_active_at: now, last_active_at: now,
message_count: 3, message_count: 3,
routing_info: None, routing_info: None,
archived_at: None,
deleted_at: None, deleted_at: None,
last_consolidated_at: None, last_consolidated_at: None,
last_compressed_message_at: None, last_compressed_message_at: None,
@ -346,7 +351,11 @@ mod tests {
id: format!("msg{}", i), id: format!("msg{}", i),
session_id: session_id.to_string(), session_id: session_id.to_string(),
seq: i as i64 + 1, seq: i as i64 + 1,
role: if i == 0 { "user".to_string() } else { "assistant".to_string() }, role: if i == 0 {
"user".to_string()
} else {
"assistant".to_string()
},
content: format!("消息内容 {}", i), content: format!("消息内容 {}", i),
reasoning_content: None, reasoning_content: None,
media_refs: None, media_refs: None,
@ -392,6 +401,7 @@ mod tests {
last_active_at: now, last_active_at: now,
message_count: 5, message_count: 5,
routing_info: None, routing_info: None,
archived_at: None,
deleted_at: None, deleted_at: None,
last_consolidated_at: None, last_consolidated_at: None,
last_compressed_message_at: None, last_compressed_message_at: None,
@ -403,7 +413,11 @@ mod tests {
id: format!("msg{}", i), id: format!("msg{}", i),
session_id: session_id.to_string(), session_id: session_id.to_string(),
seq: i as i64 + 1, seq: i as i64 + 1,
role: if i % 2 == 0 { "user".to_string() } else { "assistant".to_string() }, role: if i % 2 == 0 {
"user".to_string()
} else {
"assistant".to_string()
},
content: format!("消息内容 {}", i), content: format!("消息内容 {}", i),
reasoning_content: None, reasoning_content: None,
media_refs: None, media_refs: None,
@ -447,6 +461,7 @@ mod tests {
last_active_at: now, last_active_at: now,
message_count: 5, message_count: 5,
routing_info: None, routing_info: None,
archived_at: None,
deleted_at: None, deleted_at: None,
last_consolidated_at: None, last_consolidated_at: None,
last_compressed_message_at: None, last_compressed_message_at: None,
@ -492,10 +507,7 @@ mod tests {
let (storage, _dir) = create_test_storage().await; let (storage, _dir) = create_test_storage().await;
let tool = ChatManagerTool::new(storage, vec![]); let tool = ChatManagerTool::new(storage, vec![]);
let result = tool let result = tool.execute(json!({ "action": "unknown" })).await.unwrap();
.execute(json!({ "action": "unknown" }))
.await
.unwrap();
assert!(!result.success); assert!(!result.success);
assert!(result.error.unwrap().contains("Unknown action")); assert!(result.error.unwrap().contains("Unknown action"));
} }

View File

@ -31,10 +31,7 @@ impl ContentSearchTool {
for (i, line) in lines.iter().enumerate() { for (i, line) in lines.iter().enumerate() {
if output.len() + line.len() + 1 > MAX_OUTPUT_CHARS { if output.len() + line.len() + 1 > MAX_OUTPUT_CHARS {
let omitted = lines.len() - i; let omitted = lines.len() - i;
output.push_str(&format!( output.push_str(&format!("\n... ({} matches omitted) ...", omitted));
"\n... ({} matches omitted) ...",
omitted
));
break; break;
} }
if !output.is_empty() { if !output.is_empty() {
@ -113,18 +110,40 @@ impl Tool for ContentSearchTool {
let dir = self.resolve_dir(args.get("dir").and_then(|v| v.as_str())); let dir = self.resolve_dir(args.get("dir").and_then(|v| v.as_str()));
let file_pattern = args.get("file_pattern").and_then(|v| v.as_str()); let file_pattern = args.get("file_pattern").and_then(|v| v.as_str());
let case_sensitive = args.get("case_sensitive").and_then(|v| v.as_bool()).unwrap_or(false); let case_sensitive = args
let context_lines = args.get("context_lines").and_then(|v| v.as_u64()).unwrap_or(0) as usize; .get("case_sensitive")
let max_results = args.get("max_results").and_then(|v| v.as_u64()).unwrap_or(MAX_RESULTS as u64) as usize; .and_then(|v| v.as_bool())
.unwrap_or(false);
let context_lines = args
.get("context_lines")
.and_then(|v| v.as_u64())
.unwrap_or(0) as usize;
let max_results = args
.get("max_results")
.and_then(|v| v.as_u64())
.unwrap_or(MAX_RESULTS as u64) as usize;
let result = self.run_search(pattern, &dir, file_pattern, case_sensitive, context_lines, max_results).await; let result = self
.run_search(
pattern,
&dir,
file_pattern,
case_sensitive,
context_lines,
max_results,
)
.await;
match result { match result {
Ok(lines) => { Ok(lines) => {
let count = lines.len(); let count = lines.len();
let mut output = self.truncate_output(&lines); let mut output = self.truncate_output(&lines);
output.push_str(&format!("\n\n---\n{} 条匹配", count)); output.push_str(&format!("\n\n---\n{} 条匹配", count));
Ok(ToolResult { success: true, output, error: None }) Ok(ToolResult {
success: true,
output,
error: None,
})
} }
Err(e) => Ok(ToolResult { Err(e) => Ok(ToolResult {
success: false, success: false,
@ -146,22 +165,52 @@ impl ContentSearchTool {
max_results: usize, max_results: usize,
) -> anyhow::Result<Vec<String>> { ) -> anyhow::Result<Vec<String>> {
if which::which("rg").is_ok() { if which::which("rg").is_ok() {
match self.search_with_rg(pattern, dir, file_pattern, case_sensitive, context_lines, max_results).await { match self
.search_with_rg(
pattern,
dir,
file_pattern,
case_sensitive,
context_lines,
max_results,
)
.await
{
Ok(lines) => return Ok(lines), Ok(lines) => return Ok(lines),
Err(e) => tracing::warn!("rg failed: {}, falling back", e), Err(e) => tracing::warn!("rg failed: {}, falling back", e),
} }
} }
if which::which("grep").is_ok() { if which::which("grep").is_ok() {
match self.search_with_grep(pattern, dir, file_pattern, case_sensitive, context_lines, max_results).await { match self
.search_with_grep(
pattern,
dir,
file_pattern,
case_sensitive,
context_lines,
max_results,
)
.await
{
Ok(lines) if !lines.is_empty() => return Ok(lines), Ok(lines) if !lines.is_empty() => return Ok(lines),
Ok(_) => {}, Ok(_) => {}
Err(e) => tracing::warn!("grep failed: {}, falling back", e), Err(e) => tracing::warn!("grep failed: {}, falling back", e),
} }
} }
tracing::warn!("No rg/grep available, using built-in content search (much slower). Install ripgrep for better performance."); tracing::warn!(
self.search_with_rust(pattern, dir, file_pattern, case_sensitive, context_lines, max_results).await "No rg/grep available, using built-in content search (much slower). Install ripgrep for better performance."
);
self.search_with_rust(
pattern,
dir,
file_pattern,
case_sensitive,
context_lines,
max_results,
)
.await
} }
async fn search_with_rg( async fn search_with_rg(
@ -176,8 +225,10 @@ impl ContentSearchTool {
let mut cmd = Command::new("rg"); let mut cmd = Command::new("rg");
cmd.arg("-n") cmd.arg("-n")
.arg("--no-heading") .arg("--no-heading")
.arg("--color").arg("never") .arg("--color")
.arg("--max-count").arg(max_results.to_string()) .arg("never")
.arg("--max-count")
.arg(max_results.to_string())
.arg(pattern) .arg(pattern)
.arg(dir) .arg(dir)
.stdout(Stdio::piped()) .stdout(Stdio::piped())
@ -193,12 +244,9 @@ impl ContentSearchTool {
cmd.arg("--glob").arg(fp); cmd.arg("--glob").arg(fp);
} }
let output = timeout( let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
std::time::Duration::from_secs(TIMEOUT_SECS), .await
cmd.output(), .map_err(|_| anyhow::anyhow!("rg timed out after {}s", TIMEOUT_SECS))??;
)
.await
.map_err(|_| anyhow::anyhow!("rg timed out after {}s", TIMEOUT_SECS))??;
if !output.status.success() && output.status.code() != Some(1) { if !output.status.success() && output.status.code() != Some(1) {
let stderr = String::from_utf8_lossy(&output.stderr); let stderr = String::from_utf8_lossy(&output.stderr);
@ -206,7 +254,8 @@ impl ContentSearchTool {
} }
let text = String::from_utf8_lossy(&output.stdout); let text = String::from_utf8_lossy(&output.stdout);
let lines: Vec<String> = text.lines() let lines: Vec<String> = text
.lines()
.take(max_results) .take(max_results)
.map(|l| l.to_string()) .map(|l| l.to_string())
.collect(); .collect();
@ -242,15 +291,13 @@ impl ContentSearchTool {
cmd.arg("--include").arg(fp); cmd.arg("--include").arg(fp);
} }
let output = timeout( let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
std::time::Duration::from_secs(TIMEOUT_SECS), .await
cmd.output(), .map_err(|_| anyhow::anyhow!("grep timed out after {}s", TIMEOUT_SECS))??;
)
.await
.map_err(|_| anyhow::anyhow!("grep timed out after {}s", TIMEOUT_SECS))??;
let text = String::from_utf8_lossy(&output.stdout); let text = String::from_utf8_lossy(&output.stdout);
let lines: Vec<String> = text.lines() let lines: Vec<String> = text
.lines()
.take(max_results) .take(max_results)
.map(|l| l.to_string()) .map(|l| l.to_string())
.collect(); .collect();
@ -280,7 +327,9 @@ impl ContentSearchTool {
if case_sensitive { if case_sensitive {
regex::Regex::new(&re_str) regex::Regex::new(&re_str)
} else { } else {
regex::RegexBuilder::new(&re_str).case_insensitive(true).build() regex::RegexBuilder::new(&re_str)
.case_insensitive(true)
.build()
} }
}); });
@ -291,7 +340,14 @@ impl ContentSearchTool {
}; };
let mut results = Vec::new(); let mut results = Vec::new();
grep_dir(Path::new(dir), Path::new(dir), &re, file_re.as_ref(), &mut results, max_results)?; grep_dir(
Path::new(dir),
Path::new(dir),
&re,
file_re.as_ref(),
&mut results,
max_results,
)?;
Ok(results) Ok(results)
} }
@ -350,16 +406,19 @@ fn grep_dir(
if path.is_dir() { if path.is_dir() {
if let Some(name) = rel.file_name().and_then(|n| n.to_str()) if let Some(name) = rel.file_name().and_then(|n| n.to_str())
&& name.starts_with('.') && name.len() > 1 { && name.starts_with('.')
continue; && name.len() > 1
} {
continue;
}
grep_dir(base, &path, re, file_re, results, max)?; grep_dir(base, &path, re, file_re, results, max)?;
} else if path.is_file() { } else if path.is_file() {
if let Some(file_re) = file_re if let Some(file_re) = file_re
&& let Some(name) = rel.file_name().and_then(|n| n.to_str()) && let Some(name) = rel.file_name().and_then(|n| n.to_str())
&& !file_re.is_match(name) { && !file_re.is_match(name)
continue; {
} continue;
}
if let Ok(content) = std::fs::read_to_string(&path) { if let Ok(content) = std::fs::read_to_string(&path) {
for (line_num, line) in content.lines().enumerate() { for (line_num, line) in content.lines().enumerate() {
@ -391,8 +450,16 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_content_search_rust_fallback() { async fn test_content_search_rust_fallback() {
let dir = TempDir::new().unwrap(); let dir = TempDir::new().unwrap();
fs::write(dir.path().join("main.rs"), "fn main() {\n let x = 42;\n println!(\"hello\");\n}").unwrap(); fs::write(
fs::write(dir.path().join("lib.rs"), "pub fn foo() -> u32 {\n let y = 42;\n y\n}").unwrap(); dir.path().join("main.rs"),
"fn main() {\n let x = 42;\n println!(\"hello\");\n}",
)
.unwrap();
fs::write(
dir.path().join("lib.rs"),
"pub fn foo() -> u32 {\n let y = 42;\n y\n}",
)
.unwrap();
fs::write(dir.path().join("README.md"), "# Project\nHello world").unwrap(); fs::write(dir.path().join("README.md"), "# Project\nHello world").unwrap();
let tool = ContentSearchTool::new(); let tool = ContentSearchTool::new();

View File

@ -1,10 +1,10 @@
use std::sync::Arc; use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use serde_json::{json, Value}; use serde_json::{Value, json};
use uuid::Uuid; use uuid::Uuid;
use crate::scheduler::{next_run_for_schedule, Schedule}; use crate::scheduler::{Schedule, next_run_for_schedule};
use crate::storage::{ScheduledJob, Storage}; use crate::storage::{ScheduledJob, Storage};
use crate::tools::traits::{Tool, ToolResult}; use crate::tools::traits::{Tool, ToolResult};
@ -229,10 +229,7 @@ impl Tool for CronListTool {
} }
async fn execute(&self, args: Value) -> anyhow::Result<ToolResult> { async fn execute(&self, args: Value) -> anyhow::Result<ToolResult> {
let filter = args let filter = args.get("status").and_then(|v| v.as_str()).unwrap_or("all");
.get("status")
.and_then(|v| v.as_str())
.unwrap_or("all");
let jobs = self.storage.list_scheduled_jobs().await?; let jobs = self.storage.list_scheduled_jobs().await?;
let filtered: Vec<&ScheduledJob> = match filter { let filtered: Vec<&ScheduledJob> = match filter {
@ -397,7 +394,9 @@ impl Tool for CronEnableTool {
.map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?; .map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?;
let next = next_run_for_schedule(&job.schedule, now_ms()); let next = next_run_for_schedule(&job.schedule, now_ms());
self.storage.set_scheduled_job_enabled(&job_id, true).await?; self.storage
.set_scheduled_job_enabled(&job_id, true)
.await?;
if let Some(n) = next { if let Some(n) = next {
self.storage.set_scheduled_job_next_run(&job_id, n).await?; self.storage.set_scheduled_job_next_run(&job_id, n).await?;
} }
@ -464,7 +463,9 @@ impl Tool for CronDisableTool {
.get_scheduled_job(&job_id) .get_scheduled_job(&job_id)
.await .await
.map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?; .map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?;
self.storage.set_scheduled_job_enabled(&job_id, false).await?; self.storage
.set_scheduled_job_enabled(&job_id, false)
.await?;
Ok(ToolResult { Ok(ToolResult {
success: true, success: true,
@ -580,7 +581,9 @@ impl Tool for CronUpdateTool {
if args.get("schedule").is_some() { if args.get("schedule").is_some() {
let job = self.storage.get_scheduled_job(&job_id).await?; let job = self.storage.get_scheduled_job(&job_id).await?;
if let Some(next) = next_run_for_schedule(&job.schedule, now_ms()) { if let Some(next) = next_run_for_schedule(&job.schedule, now_ms()) {
self.storage.set_scheduled_job_next_run(&job_id, next).await?; self.storage
.set_scheduled_job_next_run(&job_id, next)
.await?;
} }
} }
@ -765,9 +768,7 @@ mod tests {
let job = ScheduledJob { let job = ScheduledJob {
id: "job-update-tool".into(), id: "job-update-tool".into(),
name: "old".into(), name: "old".into(),
schedule: Schedule::Every { schedule: Schedule::Every { every_ms: 3600000 },
every_ms: 3600000,
},
prompt: "old prompt".into(), prompt: "old prompt".into(),
channel: "feishu".into(), channel: "feishu".into(),
chat_id: "oc_1".into(), chat_id: "oc_1".into(),

390
src/tools/delegate.rs Normal file
View File

@ -0,0 +1,390 @@
use std::sync::Arc;
use async_trait::async_trait;
use serde_json::json;
use crate::agent::{ExecutionMode, SubAgentConfig, SubAgentManager, TaskStatus};
use crate::tools::traits::{Tool, ToolResult};
pub struct DelegateTool {
sub_agent_manager: Arc<SubAgentManager>,
}
impl DelegateTool {
pub fn new(sub_agent_manager: Arc<SubAgentManager>) -> Self {
Self { sub_agent_manager }
}
}
#[async_trait]
impl Tool for DelegateTool {
fn name(&self) -> &str {
"delegate"
}
fn description(&self) -> &str {
"子任务委托工具。创建子 Agent 处理独立任务,支持三种模式:\
inline ()background ()\
parallel ( Agent )\
"
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": ["run", "check_task", "cancel_task", "list_tasks"],
"description": "操作类型: run 创建子Agent执行任务, check_task 查询后台任务, cancel_task 取消后台任务, list_tasks 列出后台任务"
},
"prompt": {
"type": "string",
"description": "子任务描述(可内含额外约束,如:跳过 .tmp 文件。action=run 时必填"
},
"mode": {
"type": "string",
"enum": ["inline", "background", "parallel"],
"description": "执行模式: inline=阻塞返回结果, background=异步执行+通知, parallel=多子Agent并发。默认 inline"
},
"allowed_tools": {
"type": "array",
"items": { "type": "string" },
"description": "允许子Agent使用的工具列表。不填使用默认只读集: file_read,file_search,content_search,web_fetch,http_request,calculator"
},
"max_iterations": {
"type": "integer",
"description": "最大迭代次数,默认 99"
},
"timeout_secs": {
"type": "integer",
"description": "超时秒数,默认 36001小时"
},
"tasks": {
"type": "array",
"description": "并行模式下的多个子任务(仅 mode=parallel 时使用)",
"items": {
"type": "object",
"properties": {
"prompt": { "type": "string", "description": "子任务描述" },
"allowed_tools": {
"type": "array",
"items": { "type": "string" },
"description": "该子任务的工具列表"
}
},
"required": ["prompt"]
}
},
"task_id": {
"type": "string",
"description": "后台任务IDaction=check_task/cancel_task 时必填)"
}
},
"required": ["action"]
})
}
fn read_only(&self) -> bool {
false
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let action = args["action"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("missing required parameter: action"))?;
match action {
"run" => self.handle_run(&args).await,
"check_task" => self.handle_check_task(&args).await,
"cancel_task" => self.handle_cancel_task(&args).await,
"list_tasks" => self.handle_list_tasks(&args).await,
_ => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Unknown action: {}. Supported: run, check_task, cancel_task, list_tasks",
action
)),
}),
}
}
}
impl DelegateTool {
fn parse_config_from_args(&self, args: &serde_json::Value) -> anyhow::Result<SubAgentConfig> {
let prompt = args["prompt"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("missing required parameter: prompt"))?
.to_string();
let allowed_tools: Option<Vec<String>> = args["allowed_tools"].as_array().map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
});
let max_iterations = args["max_iterations"].as_u64().map(|v| v as usize);
let timeout_secs = args["timeout_secs"].as_u64();
Ok(SubAgentConfig {
prompt,
mode: ExecutionMode::Inline,
allowed_tools,
max_iterations,
timeout_secs,
})
}
async fn handle_run(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
let mode_str = args["mode"].as_str().unwrap_or("inline");
let mode = match mode_str {
"inline" => ExecutionMode::Inline,
"background" => ExecutionMode::Background,
"parallel" => ExecutionMode::Parallel,
_ => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"unknown mode: {}. Supported: inline, background, parallel",
mode_str
)),
});
}
};
match mode {
ExecutionMode::Inline => {
let config = self.parse_config_from_args(args)?;
let result = self
.sub_agent_manager
.run_inline(config)
.await
.map_err(|e| anyhow::anyhow!("{}", e))?;
match result.status {
TaskStatus::Completed => Ok(ToolResult {
success: true,
output: result.content,
error: None,
}),
TaskStatus::Failed(err) => Ok(ToolResult {
success: false,
output: result.content,
error: Some(err),
}),
TaskStatus::TimedOut => Ok(ToolResult {
success: false,
output: result.content,
error: Some("sub-agent timed out".into()),
}),
TaskStatus::Cancelled => Ok(ToolResult {
success: false,
output: result.content,
error: Some("sub-agent cancelled".into()),
}),
}
}
ExecutionMode::Background => {
let config = self.parse_config_from_args(args)?;
let ctx = crate::agent::sub_agent::get_delegate_context().map_err(|_| {
anyhow::anyhow!("delegate context not available: not in an agent worker")
})?;
let task_id = self
.sub_agent_manager
.run_background(config, ctx)
.await
.map_err(|e| anyhow::anyhow!("{}", e))?;
Ok(ToolResult {
success: true,
output: format!("后台任务已启动。\ntask_id: {}", task_id),
error: None,
})
}
ExecutionMode::Parallel => {
let tasks = args["tasks"]
.as_array()
.ok_or_else(|| anyhow::anyhow!("parallel mode requires 'tasks' array"))?;
let mut configs = Vec::new();
for task in tasks {
let prompt = task["prompt"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("each parallel task requires 'prompt'"))?
.to_string();
let allowed_tools: Option<Vec<String>> =
task["allowed_tools"].as_array().map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
});
configs.push(SubAgentConfig {
prompt,
mode: ExecutionMode::Inline,
allowed_tools,
max_iterations: args["max_iterations"].as_u64().map(|v| v as usize),
timeout_secs: args["timeout_secs"].as_u64(),
});
}
let has_args_allowed = args["allowed_tools"].as_array().is_some();
for c in &mut configs {
if c.allowed_tools.is_none() && has_args_allowed {
c.allowed_tools = args["allowed_tools"].as_array().map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
});
}
}
let results = self
.sub_agent_manager
.run_parallel(configs)
.await
.map_err(|e| anyhow::anyhow!("{}", e))?;
let mut output = String::new();
for (i, r) in results.iter().enumerate() {
let status_icon = match r.status {
TaskStatus::Completed => "",
TaskStatus::Failed(_) => "",
TaskStatus::TimedOut => "⏱️ 超时",
TaskStatus::Cancelled => "🚫 已取消",
};
output.push_str(&format!("[task_{}] {}\n", i + 1, status_icon));
if !r.content.is_empty() {
output.push_str(&r.content);
output.push_str("\n\n");
}
if let TaskStatus::Failed(ref err) = r.status {
output.push_str(&format!("错误: {}\n\n", err));
}
}
let all_success = results
.iter()
.all(|r| matches!(r.status, TaskStatus::Completed));
Ok(ToolResult {
success: all_success,
output: output.trim().to_string(),
error: None,
})
}
}
}
async fn handle_check_task(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
let task_id = args["task_id"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("missing required parameter: task_id"))?;
match self.sub_agent_manager.check_task(task_id).await {
Some(task) => {
let status_icon = match task.status.as_str() {
"completed" => "✅ 已完成",
"failed" => "❌ 失败",
"cancelled" => "🚫 已取消",
"running" => "🔄 运行中",
"pending" => "⏳ 等待中",
_ => task.status.as_str(),
};
let mut output = format!(
"任务 ID: {}\n状态: {}\n任务: {}",
task.id, status_icon, task.prompt
);
if let Some(ref result) = task.result {
output.push_str(&format!("\n\n结果:\n{}", result));
}
if let Some(ref error) = task.error {
output.push_str(&format!("\n错误: {}", error));
}
if let Some(started) = task.started_at {
if let Some(finished) = task.finished_at {
let duration = (finished - started) as f64 / 1000.0;
output.push_str(&format!("\n耗时: {:.1}s", duration));
}
}
Ok(ToolResult {
success: true,
output,
error: None,
})
}
None => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("task not found: {}", task_id)),
}),
}
}
async fn handle_cancel_task(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
let task_id = args["task_id"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("missing required parameter: task_id"))?;
match self.sub_agent_manager.cancel_task(task_id).await {
Ok(true) => Ok(ToolResult {
success: true,
output: format!("后台任务 {} 已取消", task_id),
error: None,
}),
Ok(false) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("无法取消任务 {}(可能已完成或不存在)", task_id)),
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("取消失败: {}", e)),
}),
}
}
async fn handle_list_tasks(&self, _args: &serde_json::Value) -> anyhow::Result<ToolResult> {
let ctx = crate::agent::sub_agent::get_delegate_context()
.map_err(|_| anyhow::anyhow!("delegate context not available"))?;
let tasks = self.sub_agent_manager.list_tasks(&ctx.session_id).await;
if tasks.is_empty() {
return Ok(ToolResult {
success: true,
output: "没有后台任务".to_string(),
error: None,
});
}
let mut output = String::from("后台任务列表:\n\n");
for task in &tasks {
let status_icon = match task.status.as_str() {
"completed" => "",
"failed" => "",
"cancelled" => "🚫",
"running" => "🔄",
"pending" => "",
_ => "",
};
output.push_str(&format!(
"{} {} - {} - {} (created: {})\n",
status_icon,
&task.id[..std::cmp::min(8, task.id.len())],
task.prompt.chars().take(60).collect::<String>(),
task.status,
task.created_at,
));
}
output.push_str(&format!("\n{} 个任务", tasks.len()));
Ok(ToolResult {
success: true,
output,
error: None,
})
}
}

View File

@ -243,8 +243,8 @@ impl Tool for FileEditTool {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use tempfile::NamedTempFile;
use std::io::Write; use std::io::Write;
use tempfile::NamedTempFile;
#[tokio::test] #[tokio::test]
async fn test_edit_simple() { async fn test_edit_simple() {

View File

@ -181,10 +181,7 @@ impl Tool for FileReadTool {
} }
result = lines[..end_idx].join("\n"); result = lines[..end_idx].join("\n");
let truncated = original_len - result.len(); let truncated = original_len - result.len();
result.push_str(&format!( result.push_str(&format!("\n\n... ({} chars truncated) ...", truncated));
"\n\n... ({} chars truncated) ...",
truncated
));
} }
if end < total { if end < total {
@ -196,10 +193,7 @@ impl Tool for FileReadTool {
end + 1 end + 1
)); ));
} else { } else {
result.push_str(&format!( result.push_str(&format!("\n\n(End of file — {} lines total)", total));
"\n\n(End of file — {} lines total)",
total
));
} }
if let Some(label) = encoding_label { if let Some(label) = encoding_label {
@ -214,7 +208,7 @@ impl Tool for FileReadTool {
} }
None => { None => {
// Truly binary file — base64 encode // Truly binary file — base64 encode
use base64::{engine::general_purpose::STANDARD, Engine}; use base64::{Engine, engine::general_purpose::STANDARD};
let encoded = STANDARD.encode(&bytes); let encoded = STANDARD.encode(&bytes);
let mime = mime_guess::from_path(&resolved) let mime = mime_guess::from_path(&resolved)
.first_or_octet_stream() .first_or_octet_stream()
@ -278,8 +272,8 @@ fn decode_text(bytes: &[u8]) -> (Option<String>, Option<&'static str>) {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use tempfile::NamedTempFile;
use std::io::Write; use std::io::Write;
use tempfile::NamedTempFile;
#[tokio::test] #[tokio::test]
async fn test_read_simple_file() { async fn test_read_simple_file() {
@ -338,10 +332,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_is_directory() { async fn test_is_directory() {
let tool = FileReadTool::new(); let tool = FileReadTool::new();
let result = tool let result = tool.execute(json!({ "path": "." })).await.unwrap();
.execute(json!({ "path": "." }))
.await
.unwrap();
assert!(!result.success); assert!(!result.success);
assert!(result.error.unwrap().contains("Not a file")); assert!(result.error.unwrap().contains("Not a file"));

View File

@ -101,17 +101,29 @@ impl Tool for FileSearchTool {
}; };
let dir = self.resolve_dir(args.get("dir").and_then(|v| v.as_str())); let dir = self.resolve_dir(args.get("dir").and_then(|v| v.as_str()));
let case_sensitive = args.get("case_sensitive").and_then(|v| v.as_bool()).unwrap_or(true); let case_sensitive = args
let max_results = args.get("max_results").and_then(|v| v.as_u64()).unwrap_or(MAX_RESULTS as u64) as usize; .get("case_sensitive")
.and_then(|v| v.as_bool())
.unwrap_or(true);
let max_results = args
.get("max_results")
.and_then(|v| v.as_u64())
.unwrap_or(MAX_RESULTS as u64) as usize;
let result = self.run_search(pattern, &dir, case_sensitive, max_results).await; let result = self
.run_search(pattern, &dir, case_sensitive, max_results)
.await;
match result { match result {
Ok(lines) => { Ok(lines) => {
let count = lines.len(); let count = lines.len();
let mut output = self.truncate_output(&lines); let mut output = self.truncate_output(&lines);
output.push_str(&format!("\n\n---\n{} 个文件", count)); output.push_str(&format!("\n\n---\n{} 个文件", count));
Ok(ToolResult { success: true, output, error: None }) Ok(ToolResult {
success: true,
output,
error: None,
})
} }
Err(e) => Ok(ToolResult { Err(e) => Ok(ToolResult {
success: false, success: false,
@ -139,9 +151,12 @@ impl FileSearchTool {
}; };
if !fd_cmd.is_empty() { if !fd_cmd.is_empty() {
match self.search_with_fd(pattern, dir, case_sensitive, max_results, fd_cmd).await { match self
.search_with_fd(pattern, dir, case_sensitive, max_results, fd_cmd)
.await
{
Ok(lines) if !lines.is_empty() => return Ok(lines), Ok(lines) if !lines.is_empty() => return Ok(lines),
Ok(_) => {}, Ok(_) => {}
Err(e) => tracing::warn!("{} failed: {}, falling back", fd_cmd, e), Err(e) => tracing::warn!("{} failed: {}, falling back", fd_cmd, e),
} }
} }
@ -149,13 +164,14 @@ impl FileSearchTool {
if which::which("find").is_ok() { if which::which("find").is_ok() {
match self.search_with_find(pattern, dir, max_results).await { match self.search_with_find(pattern, dir, max_results).await {
Ok(lines) if !lines.is_empty() => return Ok(lines), Ok(lines) if !lines.is_empty() => return Ok(lines),
Ok(_) => {}, Ok(_) => {}
Err(e) => tracing::warn!("find failed: {}, falling back", e), Err(e) => tracing::warn!("find failed: {}, falling back", e),
} }
} }
tracing::warn!("No fd/find available, using built-in file search (slower)"); tracing::warn!("No fd/find available, using built-in file search (slower)");
self.search_with_rust(pattern, dir, case_sensitive, max_results).await self.search_with_rust(pattern, dir, case_sensitive, max_results)
.await
} }
async fn search_with_fd( async fn search_with_fd(
@ -167,11 +183,15 @@ impl FileSearchTool {
fd_cmd: &str, fd_cmd: &str,
) -> anyhow::Result<Vec<String>> { ) -> anyhow::Result<Vec<String>> {
let mut cmd = Command::new(fd_cmd); let mut cmd = Command::new(fd_cmd);
cmd.arg("--search-path").arg(dir) cmd.arg("--search-path")
.arg("--glob").arg(pattern) .arg(dir)
.arg("--color").arg("never") .arg("--glob")
.arg(pattern)
.arg("--color")
.arg("never")
.arg("--strip-cwd-prefix") .arg("--strip-cwd-prefix")
.arg("--max-results").arg(max_results.to_string()) .arg("--max-results")
.arg(max_results.to_string())
.stdout(Stdio::piped()) .stdout(Stdio::piped())
.stderr(Stdio::piped()); .stderr(Stdio::piped());
@ -179,12 +199,9 @@ impl FileSearchTool {
cmd.arg("--ignore-case"); cmd.arg("--ignore-case");
} }
let output = timeout( let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
std::time::Duration::from_secs(TIMEOUT_SECS), .await
cmd.output(), .map_err(|_| anyhow::anyhow!("fd timed out after {}s", TIMEOUT_SECS))??;
)
.await
.map_err(|_| anyhow::anyhow!("fd timed out after {}s", TIMEOUT_SECS))??;
if !output.status.success() { if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr); let stderr = String::from_utf8_lossy(&output.stderr);
@ -192,7 +209,8 @@ impl FileSearchTool {
} }
let text = String::from_utf8_lossy(&output.stdout); let text = String::from_utf8_lossy(&output.stdout);
let lines: Vec<String> = text.lines() let lines: Vec<String> = text
.lines()
.filter(|l| !l.is_empty()) .filter(|l| !l.is_empty())
.map(|l| l.to_string()) .map(|l| l.to_string())
.collect(); .collect();
@ -215,15 +233,13 @@ impl FileSearchTool {
.stdout(Stdio::piped()) .stdout(Stdio::piped())
.stderr(Stdio::null()); .stderr(Stdio::null());
let output = timeout( let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
std::time::Duration::from_secs(TIMEOUT_SECS), .await
cmd.output(), .map_err(|_| anyhow::anyhow!("find timed out after {}s", TIMEOUT_SECS))??;
)
.await
.map_err(|_| anyhow::anyhow!("find timed out after {}s", TIMEOUT_SECS))??;
let text = String::from_utf8_lossy(&output.stdout); let text = String::from_utf8_lossy(&output.stdout);
let mut lines: Vec<String> = text.lines() let mut lines: Vec<String> = text
.lines()
.filter(|l| !l.is_empty()) .filter(|l| !l.is_empty())
.map(|l| { .map(|l| {
let p = Path::new(l); let p = Path::new(l);
@ -254,7 +270,13 @@ impl FileSearchTool {
.map_err(|e| anyhow::anyhow!("Invalid glob pattern '{}': {}", pattern, e))?; .map_err(|e| anyhow::anyhow!("Invalid glob pattern '{}': {}", pattern, e))?;
let mut results = Vec::new(); let mut results = Vec::new();
walk_dir(Path::new(dir), Path::new(dir), &re, &mut results, max_results)?; walk_dir(
Path::new(dir),
Path::new(dir),
&re,
&mut results,
max_results,
)?;
Ok(results) Ok(results)
} }
} }
@ -311,15 +333,18 @@ fn walk_dir(
if path.is_dir() { if path.is_dir() {
if let Some(name) = rel.file_name().and_then(|n| n.to_str()) if let Some(name) = rel.file_name().and_then(|n| n.to_str())
&& name.starts_with('.') && name.len() > 1 { && name.starts_with('.')
continue; && name.len() > 1
} {
continue;
}
walk_dir(base, &path, re, results, max)?; walk_dir(base, &path, re, results, max)?;
} else if path.is_file() { } else if path.is_file() {
if let Some(name) = rel.file_name().and_then(|n| n.to_str()) if let Some(name) = rel.file_name().and_then(|n| n.to_str())
&& re.is_match(name) { && re.is_match(name)
results.push(rel.to_string_lossy().to_string()); {
} results.push(rel.to_string_lossy().to_string());
}
if results.len() >= max { if results.len() >= max {
return Ok(()); return Ok(());
} }

View File

@ -90,13 +90,14 @@ impl Tool for FileWriteTool {
// Create parent directories if needed // Create parent directories if needed
if let Some(parent) = resolved.parent() if let Some(parent) = resolved.parent()
&& !parent.exists() && !parent.exists()
&& let Err(e) = std::fs::create_dir_all(parent) { && let Err(e) = std::fs::create_dir_all(parent)
return Ok(ToolResult { {
success: false, return Ok(ToolResult {
output: String::new(), success: false,
error: Some(format!("Failed to create parent directory: {}", e)), output: String::new(),
}); error: Some(format!("Failed to create parent directory: {}", e)),
} });
}
match std::fs::write(&resolved, content) { match std::fs::write(&resolved, content) {
Ok(_) => Ok(ToolResult { Ok(_) => Ok(ToolResult {
@ -168,10 +169,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_write_missing_path() { async fn test_write_missing_path() {
let tool = FileWriteTool::new(); let tool = FileWriteTool::new();
let result = tool let result = tool.execute(json!({ "content": "Hello" })).await.unwrap();
.execute(json!({ "content": "Hello" }))
.await
.unwrap();
assert!(!result.success); assert!(!result.success);
assert!(result.error.unwrap().contains("path")); assert!(result.error.unwrap().contains("path"));

View File

@ -129,7 +129,9 @@ impl GetSkillTool {
let mut output = format!("可用 skill (共 {} 个):\n", skills.len()); let mut output = format!("可用 skill (共 {} 个):\n", skills.len());
for s in &skills { for s in &skills {
let always_mark = if s.always { " [常驻]" } else { "" }; let always_mark = if s.always { " [常驻]" } else { "" };
let path_str = s.path.as_ref() let path_str = s
.path
.as_ref()
.map(|p| p.to_string_lossy().to_string()) .map(|p| p.to_string_lossy().to_string())
.unwrap_or_else(|| "".to_string()); .unwrap_or_else(|| "".to_string());
output.push_str(&format!( output.push_str(&format!(
@ -148,10 +150,10 @@ impl GetSkillTool {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use tempfile::tempdir;
use std::fs::File; use std::fs::File;
use std::io::Write; use std::io::Write;
use std::path::PathBuf; use std::path::PathBuf;
use tempfile::tempdir;
#[tokio::test] #[tokio::test]
async fn test_get_existing_skill() { async fn test_get_existing_skill() {

View File

@ -50,10 +50,7 @@ impl HttpRequestTool {
} }
if !host_matches_allowlist(&host, &self.allowed_domains) { if !host_matches_allowlist(&host, &self.allowed_domains) {
return Err(format!( return Err(format!("Host '{}' is not in allowed_domains", host));
"Host '{}' is not in allowed_domains",
host
));
} }
Ok(url.to_string()) Ok(url.to_string())
@ -80,11 +77,10 @@ impl HttpRequestTool {
for (key, value) in obj { for (key, value) in obj {
if let Some(str_val) = value.as_str() if let Some(str_val) = value.as_str()
&& let Ok(name) = reqwest::header::HeaderName::from_bytes(key.as_bytes()) && let Ok(name) = reqwest::header::HeaderName::from_bytes(key.as_bytes())
&& let Ok(val) = && let Ok(val) = reqwest::header::HeaderValue::from_str(str_val)
reqwest::header::HeaderValue::from_str(str_val) {
{ header_map.insert(name, val);
header_map.insert(name, val); }
}
} }
} }
@ -191,7 +187,9 @@ fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool {
allowed_domains.iter().any(|domain| { allowed_domains.iter().any(|domain| {
host == domain host == domain
|| host.strip_suffix(domain).is_some_and(|prefix| prefix.ends_with('.')) || host
.strip_suffix(domain)
.is_some_and(|prefix| prefix.ends_with('.'))
}) })
} }
@ -202,7 +200,11 @@ fn is_private_host(host: &str) -> bool {
} }
// Check .local TLD // Check .local TLD
if host.rsplit('.').next().is_some_and(|label| label == "local") { if host
.rsplit('.')
.next()
.is_some_and(|label| label == "local")
{
return true; return true;
} }
@ -224,9 +226,7 @@ fn is_private_ip(ip: &std::net::IpAddr) -> bool {
|| v4.is_broadcast() || v4.is_broadcast()
|| v4.is_multicast() || v4.is_multicast()
} }
std::net::IpAddr::V6(v6) => { std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(),
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
}
} }
} }
@ -278,10 +278,7 @@ impl Tool for HttpRequestTool {
} }
}; };
let method_str = args let method_str = args.get("method").and_then(|v| v.as_str()).unwrap_or("GET");
.get("method")
.and_then(|v| v.as_str())
.unwrap_or("GET");
let headers_val = args.get("headers").cloned().unwrap_or(json!({})); let headers_val = args.get("headers").cloned().unwrap_or(json!({}));
let body = args.get("body").and_then(|v| v.as_str()); let body = args.get("body").and_then(|v| v.as_str());

View File

@ -151,10 +151,19 @@ impl Tool for MemoryRecallTool {
.and_then(|v| v.as_i64()) .and_then(|v| v.as_i64())
.unwrap_or(chrono::Utc::now().timestamp_millis()); .unwrap_or(chrono::Utc::now().timestamp_millis());
self.memory self.memory
.recall_by_time(since, until, Some(query), limit, Some(MemoryCategory::Knowledge), None) .recall_by_time(
since,
until,
Some(query),
limit,
Some(MemoryCategory::Knowledge),
None,
)
.await? .await?
} else { } else {
self.memory.recall(query, limit, Some(MemoryCategory::Knowledge), None).await? self.memory
.recall(query, limit, Some(MemoryCategory::Knowledge), None)
.await?
}; };
if entries.is_empty() { if entries.is_empty() {
@ -168,7 +177,11 @@ impl Tool for MemoryRecallTool {
let formatted = entries let formatted = entries
.iter() .iter()
.map(|e| { .map(|e| {
let session = e.session_id.as_deref().map(|s| format!(" [session: {}]", s)).unwrap_or_default(); let session = e
.session_id
.as_deref()
.map(|s| format!(" [session: {}]", s))
.unwrap_or_default();
format!( format!(
"- {} [{}]{} [importance: {:.1}]: {}", "- {} [{}]{} [importance: {:.1}]: {}",
e.key, e.key,
@ -264,10 +277,19 @@ impl Tool for TimelineRecallTool {
.and_then(|v| v.as_i64()) .and_then(|v| v.as_i64())
.unwrap_or(chrono::Utc::now().timestamp_millis()); .unwrap_or(chrono::Utc::now().timestamp_millis());
self.memory self.memory
.recall_by_time(since, until, Some(query), limit, Some(MemoryCategory::Timeline), session_id) .recall_by_time(
since,
until,
Some(query),
limit,
Some(MemoryCategory::Timeline),
session_id,
)
.await? .await?
} else { } else {
self.memory.recall(query, limit, Some(MemoryCategory::Timeline), session_id).await? self.memory
.recall(query, limit, Some(MemoryCategory::Timeline), session_id)
.await?
}; };
if entries.is_empty() { if entries.is_empty() {
@ -281,7 +303,11 @@ impl Tool for TimelineRecallTool {
let formatted = entries let formatted = entries
.iter() .iter()
.map(|e| { .map(|e| {
let session = e.session_id.as_deref().map(|s| format!(" [session: {}]", s)).unwrap_or_default(); let session = e
.session_id
.as_deref()
.map(|s| format!(" [session: {}]", s))
.unwrap_or_default();
format!( format!(
"- {} [{}]{} [importance: {:.1}]: {}", "- {} [{}]{} [importance: {:.1}]: {}",
e.key, e.key,

View File

@ -4,6 +4,7 @@ pub mod calculator;
pub mod chat_manager; pub mod chat_manager;
pub mod content_search; pub mod content_search;
pub mod cron; pub mod cron;
pub mod delegate;
pub mod file_edit; pub mod file_edit;
pub mod file_read; pub mod file_read;
pub mod file_search; pub mod file_search;
@ -23,6 +24,7 @@ pub use browser::BrowserTool;
pub use calculator::CalculatorTool; pub use calculator::CalculatorTool;
pub use chat_manager::ChatManagerTool; pub use chat_manager::ChatManagerTool;
pub use content_search::ContentSearchTool; pub use content_search::ContentSearchTool;
pub use delegate::DelegateTool;
pub use file_edit::FileEditTool; pub use file_edit::FileEditTool;
pub use file_read::FileReadTool; pub use file_read::FileReadTool;
pub use file_search::FileSearchTool; pub use file_search::FileSearchTool;
@ -35,10 +37,11 @@ pub use send_message::SendMessageTool;
pub use traits::{OutboundMessenger, Tool, ToolResult}; pub use traits::{OutboundMessenger, Tool, ToolResult};
pub use web_fetch::WebFetchTool; pub use web_fetch::WebFetchTool;
use std::sync::Arc; use crate::agent::SubAgentManager;
use crate::config::BrowserConfig; use crate::config::BrowserConfig;
use crate::memory::MemoryManager; use crate::memory::MemoryManager;
use crate::skills::SkillsLoader; use crate::skills::SkillsLoader;
use std::sync::Arc;
/// Create the base tool registry (without send_message). /// Create the base tool registry (without send_message).
/// `send_message` tool is registered later via `SessionManager::register_outbound_tool()` /// `send_message` tool is registered later via `SessionManager::register_outbound_tool()`
@ -46,6 +49,7 @@ use crate::skills::SkillsLoader;
pub fn create_default_tools( pub fn create_default_tools(
skills_loader: Arc<SkillsLoader>, skills_loader: Arc<SkillsLoader>,
memory: Arc<MemoryManager>, memory: Arc<MemoryManager>,
sub_agent_manager: Option<Arc<SubAgentManager>>,
browser_config: Option<&BrowserConfig>, browser_config: Option<&BrowserConfig>,
) -> ToolRegistry { ) -> ToolRegistry {
let registry = ToolRegistry::new(); let registry = ToolRegistry::new();
@ -76,5 +80,9 @@ pub fn create_default_tools(
} }
} }
if let Some(mgr) = sub_agent_manager {
registry.register(DelegateTool::new(mgr));
}
registry registry
} }

608
src/tools/pty.rs Normal file
View File

@ -0,0 +1,608 @@
use std::collections::{HashMap, VecDeque};
use std::io::Write;
use std::sync::{Arc, Mutex};
use std::time::Instant;
use async_trait::async_trait;
use serde_json::json;
use crate::tools::traits::{Tool, ToolResult};
const MAX_OUTPUT_LINES: usize = 50_000;
const MAX_CHARS_PER_LINE: usize = 2_000;
const MAX_SESSIONS: usize = 10;
fn guard_command(command: &str) -> Option<String> {
let deny_patterns: &[&str] = &[
r"\brm\s+-[rf]{1,2}\b",
r"\bdel\s+/[fq]\b",
r"\brmdir\s+/s\b",
r":\(\)\s*\{.*\};\s*:",
];
let lower = command.to_lowercase();
for pattern in deny_patterns {
if regex::Regex::new(pattern)
.ok()
.map(|re| re.is_match(&lower))
.unwrap_or(false)
{
return Some(format!(
"Command blocked by safety guard (dangerous pattern: {})",
pattern
));
}
}
None
}
fn unescape_control_chars(s: &str) -> String {
let mut result = String::with_capacity(s.len());
let mut chars = s.chars().peekable();
while let Some(c) = chars.next() {
if c == '\\' {
match chars.next() {
Some('n') => result.push('\n'),
Some('r') => result.push('\r'),
Some('t') => result.push('\t'),
Some('x') => {
let hex: String = chars.by_ref().take(2).collect();
if hex.len() == 2 {
if let Ok(byte) = u8::from_str_radix(&hex, 16) {
result.push(byte as char);
} else {
result.push_str(&format!("\\x{}", hex));
}
} else {
result.push_str(&format!("\\x{}", hex));
}
}
Some('\\') => result.push('\\'),
Some(other) => {
result.push('\\');
result.push(other);
}
None => result.push('\\'),
}
} else {
result.push(c);
}
}
result
}
fn truncate_command(cmd: &str, max_len: usize) -> String {
let cmd = cmd.trim();
let first_arg = cmd.split_whitespace().next().unwrap_or(cmd);
if first_arg.len() > max_len {
format!("{}...", &first_arg[..max_len])
} else {
first_arg.to_string()
}
}
fn status_str(status: &SessionStatus) -> &str {
match status {
SessionStatus::Running => "running",
SessionStatus::Exited(_) => "exited",
SessionStatus::Killed => "killed",
}
}
#[derive(Debug, Clone, PartialEq)]
enum SessionStatus {
Running,
Exited(i32),
Killed,
}
struct PtySession {
#[allow(dead_code)]
id: String,
#[allow(dead_code)]
command: String,
#[allow(dead_code)]
started_at: Instant,
status: SessionStatus,
child: Arc<Mutex<Option<Box<dyn portable_pty::Child + Send + Sync>>>>,
writer: Arc<Mutex<Option<Box<dyn Write + Send>>>>,
output_buffer: VecDeque<String>,
output_total_lines: usize,
}
impl PtySession {
fn new(
id: String,
command: String,
child: Box<dyn portable_pty::Child + Send + Sync>,
writer: Box<dyn Write + Send>,
) -> Self {
Self {
id,
command,
started_at: Instant::now(),
status: SessionStatus::Running,
child: Arc::new(Mutex::new(Some(child))),
writer: Arc::new(Mutex::new(Some(writer))),
output_buffer: VecDeque::new(),
output_total_lines: 0,
}
}
fn push_line(&mut self, line: String) {
let line = if line.len() > MAX_CHARS_PER_LINE {
format!("{}...<truncated>", &line[..MAX_CHARS_PER_LINE])
} else {
line
};
self.output_total_lines += 1;
if self.output_buffer.len() >= MAX_OUTPUT_LINES {
self.output_buffer.pop_front();
}
self.output_buffer.push_back(line);
}
}
pub struct PtyManager {
sessions: Mutex<HashMap<String, Arc<Mutex<PtySession>>>>,
}
impl PtyManager {
pub fn new() -> Self {
Self {
sessions: Mutex::new(HashMap::new()),
}
}
pub fn cleanup_all(&self) {
let sessions: Vec<Arc<Mutex<PtySession>>> = {
let mut guard = self.sessions.lock().unwrap();
guard.drain().map(|(_, s)| s).collect()
};
for session in sessions {
let mut guard = session.lock().unwrap();
let mut child_guard = guard.child.lock().unwrap();
if let Some(ref mut child) = *child_guard {
let _ = child.kill();
}
*child_guard = None;
guard.status = SessionStatus::Killed;
}
}
fn spawn(&self, command: &str) -> Result<String, String> {
if let Some(reason) = guard_command(command) {
return Err(reason);
}
let mut sessions = self.sessions.lock().unwrap();
if sessions.len() >= MAX_SESSIONS {
return Err(format!(
"Max sessions ({}) reached, kill some sessions first",
MAX_SESSIONS
));
}
let session_id = format!("pty_{}", crate::util::short_id());
let cwd = std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from("."));
let pty_system = portable_pty::native_pty_system();
let pty_pair = pty_system
.openpty(portable_pty::PtySize {
rows: 24,
cols: 80,
pixel_width: 0,
pixel_height: 0,
})
.map_err(|e| format!("Failed to open PTY: {}", e))?;
let mut cmd = portable_pty::CommandBuilder::new("bash");
cmd.args(&["-c", command]);
cmd.cwd(cwd);
let child = pty_pair
.slave
.spawn_command(cmd)
.map_err(|e| format!("Failed to spawn: {}", e))?;
let pid = child.process_id().unwrap_or(0);
let writer = pty_pair
.master
.take_writer()
.map_err(|e| format!("Failed to take writer: {}", e))?;
let reader = pty_pair
.master
.try_clone_reader()
.map_err(|e| format!("Failed to clone reader: {}", e))?;
let session_id_clone = session_id.clone();
let session = PtySession::new(session_id_clone, command.to_string(), child, writer);
let session = Arc::new(Mutex::new(session));
sessions.insert(session_id.clone(), session.clone());
let session_for_reader = session.clone();
let child_for_reader = session_for_reader.lock().unwrap().child.clone();
tokio::task::spawn_blocking(move || {
let mut reader = reader;
let mut buf = [0u8; 4096];
let mut partial = String::new();
loop {
use std::io::Read;
match reader.read(&mut buf) {
Ok(0) => break,
Ok(n) => {
let s = String::from_utf8_lossy(&buf[..n]);
partial.push_str(&s);
let lines: Vec<&str> = partial.split('\n').collect();
if lines.len() <= 1 {
continue;
}
let complete = lines.len() - 1;
let mut guard = session_for_reader.lock().unwrap();
for line in lines[..complete].iter() {
guard.push_line(line.to_string());
}
partial = lines[complete].to_string();
}
Err(_) => break,
}
}
let mut guard = session_for_reader.lock().unwrap();
if !partial.is_empty() {
guard.push_line(partial);
}
let exit_code = {
let mut cg = child_for_reader.lock().unwrap();
if let Some(ref mut c) = *cg {
c.wait().map(|s| s.exit_code() as i32).ok()
} else {
None
}
};
guard.status = SessionStatus::Exited(exit_code.unwrap_or(-1));
});
Ok(format!("session_id: {}, pid: {}", session_id, pid))
}
fn write(&self, session_id: &str, data: &str) -> Result<String, String> {
let unescaped = unescape_control_chars(data);
let byte_count = unescaped.len();
let sessions = self.sessions.lock().unwrap();
let session = sessions
.get(session_id)
.ok_or_else(|| format!("Session not found: {}", session_id))?;
let mut guard = session.lock().unwrap();
if guard.status != SessionStatus::Running {
return Err("Session is not running".to_string());
}
let writer = guard.writer.clone();
drop(guard);
drop(sessions);
let mut writer_guard = writer.lock().unwrap();
match *writer_guard {
Some(ref mut w) => {
w.write_all(unescaped.as_bytes())
.map_err(|e| format!("Write error: {}", e))?;
w.flush().map_err(|e| format!("Flush error: {}", e))?;
}
None => return Err("Writer not available".to_string()),
}
Ok(format!("OK, wrote {} bytes", byte_count))
}
fn read(
&self,
session_id: &str,
offset: usize,
limit: usize,
) -> Result<String, String> {
let sessions = self.sessions.lock().unwrap();
let session = sessions
.get(session_id)
.ok_or_else(|| format!("Session not found: {}", session_id))?;
let guard = session.lock().unwrap();
let total = guard.output_total_lines;
let buffer_len = guard.output_buffer.len();
let start = 0_usize.max(offset);
let skip_old = total.saturating_sub(buffer_len);
let view_start = start.saturating_sub(skip_old);
let lines: Vec<String> = guard
.output_buffer
.iter()
.skip(view_start)
.take(limit)
.enumerate()
.map(|(i, line)| format!("{}: {}", start + i, line))
.collect();
let displayed = start + lines.len();
let has_more = displayed < total;
let mut output = format!(
"# Lines {}-{} (共 {} 行{})\n",
start,
displayed.saturating_sub(1),
total,
if has_more { ",还有更多" } else { "" }
);
output.push_str(&lines.join("\n"));
if has_more {
output.push_str(&format!(
"\n[还有 {} 行未显示,用 offset={} 继续读取]",
total.saturating_sub(displayed),
displayed
));
}
Ok(output)
}
fn kill(&self, session_id: &str) -> Result<String, String> {
let mut sessions = self.sessions.lock().unwrap();
let session = sessions
.get(session_id)
.ok_or_else(|| format!("Session not found: {}", session_id))?;
let mut guard = session.lock().unwrap();
let mut child_guard = guard.child.lock().unwrap();
if let Some(ref mut child) = *child_guard {
let _ = child.kill();
let _ = child.wait();
}
*child_guard = None;
guard.status = SessionStatus::Killed;
drop(child_guard);
drop(guard);
sessions.remove(session_id);
Ok(format!("Session {} killed", session_id))
}
fn list(&self) -> String {
let sessions = self.sessions.lock().unwrap();
if sessions.is_empty() {
return "No active PTY sessions".to_string();
}
let mut lines: Vec<String> = sessions
.iter()
.map(|(id, session)| {
let guard = session.lock().unwrap();
let age = Instant::now().duration_since(guard.started_at);
let age_str = if age.as_secs() < 60 {
format!("{}s ago", age.as_secs())
} else {
format!("{}m ago", age.as_secs() / 60)
};
format!(
"{:<14} {:<12} {:<10} {:<6} lines {}",
id,
truncate_command(&guard.command, 10),
status_str(&guard.status),
guard.output_total_lines,
age_str,
)
})
.collect();
lines.sort();
lines.join("\n")
}
}
impl Drop for PtyManager {
fn drop(&mut self) {
self.cleanup_all();
}
}
pub struct PtyTool {
pty_manager: Arc<PtyManager>,
}
impl PtyTool {
pub fn new(pty_manager: Arc<PtyManager>) -> Self {
Self { pty_manager }
}
}
#[async_trait]
impl Tool for PtyTool {
fn name(&self) -> &str {
"pty"
}
fn description(&self) -> &str {
"管理持久伪终端(PTY)会话。用于交互式程序、长运行服务、多步骤命令等需要保持终端状态的场景。支持操作: spawn(创建)/write(写入输入)/read(读取输出)/kill(终止)/list(列出所有会话)。"
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": ["spawn", "write", "read", "kill", "list"],
"description": "操作类型: spawn=创建会话, write=写入输入, read=读取输出, kill=终止会话, list=列出所有会话"
},
"session_id": {
"type": "string",
"description": "会话ID (write/read/kill 需要)"
},
"command": {
"type": "string",
"description": "要执行的命令 (spawn 需要)"
},
"data": {
"type": "string",
"description": "写入终端的数据,支持转义序列: \\n(换行) \\x03(Ctrl+C) \\x04(Ctrl+D) \\x1a(Ctrl+Z) (write 需要)"
},
"offset": {
"type": "integer",
"description": "输出读取起始行号,从 0 开始 (read 可选,默认 0)",
"minimum": 0
},
"limit": {
"type": "integer",
"description": "读取的最大行数 (read 可选,默认 500)",
"minimum": 1
}
},
"required": ["action"]
})
}
fn exclusive(&self) -> bool {
true
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let action = match args.get("action").and_then(|v| v.as_str()) {
Some(a) => a,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: action".to_string()),
});
}
};
match action {
"spawn" => {
let command = match args.get("command").and_then(|v| v.as_str()) {
Some(c) => c,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: command".to_string()),
});
}
};
match self.pty_manager.spawn(command) {
Ok(output) => Ok(ToolResult {
success: true,
output,
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
}),
}
}
"write" => {
let session_id = match args.get("session_id").and_then(|v| v.as_str()) {
Some(id) => id,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: session_id".to_string()),
});
}
};
let data = match args.get("data").and_then(|v| v.as_str()) {
Some(d) => d,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: data".to_string()),
});
}
};
match self.pty_manager.write(session_id, data) {
Ok(output) => Ok(ToolResult {
success: true,
output,
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
}),
}
}
"read" => {
let session_id = match args.get("session_id").and_then(|v| v.as_str()) {
Some(id) => id,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: session_id".to_string()),
});
}
};
let offset = args
.get("offset")
.and_then(|v| v.as_u64())
.unwrap_or(0) as usize;
let limit = args
.get("limit")
.and_then(|v| v.as_u64())
.unwrap_or(500) as usize;
match self.pty_manager.read(session_id, offset, limit) {
Ok(output) => Ok(ToolResult {
success: true,
output,
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
}),
}
}
"kill" => {
let session_id = match args.get("session_id").and_then(|v| v.as_str()) {
Some(id) => id,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: session_id".to_string()),
});
}
};
match self.pty_manager.kill(session_id) {
Ok(output) => Ok(ToolResult {
success: true,
output,
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
}),
}
}
"list" => {
let output = self.pty_manager.list();
Ok(ToolResult {
success: true,
output,
error: None,
})
}
_ => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Unknown action: {}", action)),
}),
}
}
}

View File

@ -17,7 +17,15 @@ impl ToolRegistry {
} }
pub fn register<T: ToolTrait + 'static>(&self, tool: T) { pub fn register<T: ToolTrait + 'static>(&self, tool: T) {
self.tools.lock().unwrap().insert(tool.name().to_string(), Arc::new(tool)); self.tools
.lock()
.unwrap()
.insert(tool.name().to_string(), Arc::new(tool));
}
/// Register an existing Arc-wrapped tool by name
pub fn register_raw(&self, name: String, tool: Arc<dyn ToolTrait>) {
self.tools.lock().unwrap().insert(name, tool);
} }
pub fn get(&self, name: &str) -> Option<Arc<dyn ToolTrait>> { pub fn get(&self, name: &str) -> Option<Arc<dyn ToolTrait>> {
@ -62,6 +70,17 @@ impl ToolRegistry {
.map(|(k, v)| (k.clone(), v.clone())) .map(|(k, v)| (k.clone(), v.clone()))
.collect() .collect()
} }
/// 生成工具列表描述,用于子 Agent 系统提示词
pub fn describe_for_prompt(&self) -> String {
let mut entries: Vec<String> = self
.iter()
.into_iter()
.map(|(name, tool)| format!("- {}: {}", name, tool.description()))
.collect();
entries.sort();
entries.join("\n")
}
} }
impl Default for ToolRegistry { impl Default for ToolRegistry {

View File

@ -115,9 +115,11 @@ impl SchemaCleanr {
} }
if let Some(Value::String(t)) = obj.get("type") if let Some(Value::String(t)) = obj.get("type")
&& t == "object" && !obj.contains_key("properties") { && t == "object"
tracing::warn!("Object schema without 'properties' field may cause issues"); && !obj.contains_key("properties")
} {
tracing::warn!("Object schema without 'properties' field may cause issues");
}
Ok(()) Ok(())
} }
@ -173,9 +175,10 @@ impl SchemaCleanr {
// Handle anyOf/oneOf simplification // Handle anyOf/oneOf simplification
if (obj.contains_key("anyOf") || obj.contains_key("oneOf")) if (obj.contains_key("anyOf") || obj.contains_key("oneOf"))
&& let Some(simplified) = Self::try_simplify_union(&obj, defs, strategy, ref_stack) { && let Some(simplified) = Self::try_simplify_union(&obj, defs, strategy, ref_stack)
return simplified; {
} return simplified;
}
// Build cleaned object // Build cleaned object
let mut cleaned = Map::new(); let mut cleaned = Map::new();
@ -243,12 +246,13 @@ impl SchemaCleanr {
} }
if let Some(def_name) = Self::parse_local_ref(ref_value) if let Some(def_name) = Self::parse_local_ref(ref_value)
&& let Some(definition) = defs.get(def_name.as_str()) { && let Some(definition) = defs.get(def_name.as_str())
ref_stack.insert(ref_value.to_string()); {
let cleaned = Self::clean_with_defs(definition.clone(), defs, strategy, ref_stack); ref_stack.insert(ref_value.to_string());
ref_stack.remove(ref_value); let cleaned = Self::clean_with_defs(definition.clone(), defs, strategy, ref_stack);
return Self::preserve_meta(obj, cleaned); ref_stack.remove(ref_value);
} return Self::preserve_meta(obj, cleaned);
}
tracing::warn!("Cannot resolve $ref: {}", ref_value); tracing::warn!("Cannot resolve $ref: {}", ref_value);
Self::preserve_meta(obj, Value::Object(Map::new())) Self::preserve_meta(obj, Value::Object(Map::new()))
@ -340,13 +344,16 @@ impl SchemaCleanr {
return true; return true;
} }
if let Some(Value::Array(arr)) = obj.get("enum") if let Some(Value::Array(arr)) = obj.get("enum")
&& arr.len() == 1 && matches!(arr[0], Value::Null) { && arr.len() == 1
return true; && matches!(arr[0], Value::Null)
} {
return true;
}
if let Some(Value::String(t)) = obj.get("type") if let Some(Value::String(t)) = obj.get("type")
&& t == "null" { && t == "null"
return true; {
} return true;
}
} }
false false
} }
@ -403,7 +410,10 @@ impl SchemaCleanr {
match non_null.len() { match non_null.len() {
0 => Value::String("null".to_string()), 0 => Value::String("null".to_string()),
1 => non_null.into_iter().next().unwrap_or(Value::String("null".to_string())), 1 => non_null
.into_iter()
.next()
.unwrap_or(Value::String("null".to_string())),
_ => Value::Array(non_null), _ => Value::Array(non_null),
} }
} else { } else {

View File

@ -1,5 +1,5 @@
use std::sync::Arc;
use std::collections::HashSet; use std::collections::HashSet;
use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use mime_guess::mime; use mime_guess::mime;
@ -31,14 +31,20 @@ fn parse_target_chat_id(raw: &str) -> Result<(&str, &str, Option<&str>), String>
match parts.len() { match parts.len() {
2 => { 2 => {
if parts[0].is_empty() || parts[1].is_empty() { if parts[0].is_empty() || parts[1].is_empty() {
Err(format!("Invalid target_chat_id format '{}': channel and chat_id must not be empty", raw)) Err(format!(
"Invalid target_chat_id format '{}': channel and chat_id must not be empty",
raw
))
} else { } else {
Ok((parts[0], parts[1], None)) Ok((parts[0], parts[1], None))
} }
} }
3 => { 3 => {
if parts[0].is_empty() || parts[1].is_empty() || parts[2].is_empty() { if parts[0].is_empty() || parts[1].is_empty() || parts[2].is_empty() {
Err(format!("Invalid target_chat_id format '{}': all three parts must not be empty", raw)) Err(format!(
"Invalid target_chat_id format '{}': all three parts must not be empty",
raw
))
} else { } else {
Ok((parts[0], parts[1], Some(parts[2]))) Ok((parts[0], parts[1], Some(parts[2])))
} }
@ -98,8 +104,8 @@ target_chat_id 支持两种格式:<channel>:<chat_id>(发送到该聊天下
.ok_or_else(|| anyhow::anyhow!("missing content"))?; .ok_or_else(|| anyhow::anyhow!("missing content"))?;
// 1. Parse target_chat_id // 1. Parse target_chat_id
let (channel, chat_id, dialog_id) = parse_target_chat_id(raw_id) let (channel, chat_id, dialog_id) =
.map_err(|e| anyhow::anyhow!(e))?; parse_target_chat_id(raw_id).map_err(|e| anyhow::anyhow!(e))?;
// 2. Validate channel // 2. Validate channel
if !self.available_channels.contains(channel) { if !self.available_channels.contains(channel) {
@ -109,7 +115,11 @@ target_chat_id 支持两种格式:<channel>:<chat_id>(发送到该聊天下
error: Some(format!( error: Some(format!(
"Channel '{}' is not available. Available channels: {}", "Channel '{}' is not available. Available channels: {}",
channel, channel,
self.available_channels.iter().cloned().collect::<Vec<_>>().join(", ") self.available_channels
.iter()
.cloned()
.collect::<Vec<_>>()
.join(", ")
)), )),
}); });
} }
@ -129,7 +139,8 @@ target_chat_id 支持两种格式:<channel>:<chat_id>(发送到该聊天下
let media = parse_files_arg(&args); let media = parse_files_arg(&args);
// 4. Send via messenger // 4. Send via messenger
match self.messenger match self
.messenger
.send_message(channel, chat_id, dialog_id, content, source, media) .send_message(channel, chat_id, dialog_id, content, source, media)
.await .await
{ {

View File

@ -1,5 +1,5 @@
use async_trait::async_trait;
use crate::bus::{MediaItem, MessageSource}; use crate::bus::{MediaItem, MessageSource};
use async_trait::async_trait;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ToolResult { pub struct ToolResult {

View File

@ -239,7 +239,11 @@ fn is_private_host(host: &str) -> bool {
return true; return true;
} }
if host.rsplit('.').next().is_some_and(|label| label == "local") { if host
.rsplit('.')
.next()
.is_some_and(|label| label == "local")
{
return true; return true;
} }
@ -248,7 +252,9 @@ fn is_private_host(host: &str) -> bool {
std::net::IpAddr::V4(v4) => { std::net::IpAddr::V4(v4) => {
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified() v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified()
} }
std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(), std::net::IpAddr::V6(v6) => {
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
}
}; };
} }

View File

@ -1,6 +1,6 @@
use std::collections::HashMap;
use picobot::providers::{create_provider, ChatCompletionRequest, Message};
use picobot::config::{Config, LLMProviderConfig}; use picobot::config::{Config, LLMProviderConfig};
use picobot::providers::{ChatCompletionRequest, Message, create_provider};
use std::collections::HashMap;
fn load_config() -> Option<LLMProviderConfig> { fn load_config() -> Option<LLMProviderConfig> {
dotenv::from_filename("tests/test.env").ok()?; dotenv::from_filename("tests/test.env").ok()?;
@ -42,8 +42,7 @@ fn create_request(content: &str) -> ChatCompletionRequest {
#[tokio::test] #[tokio::test]
#[ignore] #[ignore]
async fn test_openai_simple_completion() { async fn test_openai_simple_completion() {
let config = load_config() let config = load_config().expect("Please configure tests/test.env with valid API keys");
.expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider"); let provider = create_provider(config).expect("Failed to create provider");
let response = provider.chat(create_request("Say 'ok'")).await.unwrap(); let response = provider.chat(create_request("Say 'ok'")).await.unwrap();
@ -57,8 +56,7 @@ async fn test_openai_simple_completion() {
#[tokio::test] #[tokio::test]
#[ignore] #[ignore]
async fn test_openai_conversation() { async fn test_openai_conversation() {
let config = load_config() let config = load_config().expect("Please configure tests/test.env with valid API keys");
.expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider"); let provider = create_provider(config).expect("Failed to create provider");
@ -82,7 +80,9 @@ async fn test_openai_conversation() {
async fn test_config_load() { async fn test_config_load() {
// Test that config.json can be loaded and provider config created // Test that config.json can be loaded and provider config created
let config = Config::load("config.json").expect("Failed to load config.json"); let config = Config::load("config.json").expect("Failed to load config.json");
let provider_config = config.get_provider_config("default").expect("Failed to get provider config"); let provider_config = config
.get_provider_config("default")
.expect("Failed to get provider config");
assert_eq!(provider_config.provider_type, "openai"); assert_eq!(provider_config.provider_type, "openai");
assert_eq!(provider_config.name, "aliyun"); assert_eq!(provider_config.name, "aliyun");

View File

@ -1,5 +1,5 @@
use picobot::providers::{ChatCompletionRequest, Message};
use picobot::protocol::{SessionSummary, WsInbound, WsOutbound}; use picobot::protocol::{SessionSummary, WsInbound, WsOutbound};
use picobot::providers::{ChatCompletionRequest, Message};
/// Test that message with special characters is properly escaped /// Test that message with special characters is properly escaped
#[test] #[test]
@ -19,7 +19,9 @@ fn test_message_special_characters() {
#[test] #[test]
fn test_multiline_system_prompt() { fn test_multiline_system_prompt() {
let messages = vec![ let messages = vec![
Message::system("You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate"), Message::system(
"You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate",
),
Message::user("Hi"), Message::user("Hi"),
]; ];
@ -33,10 +35,7 @@ fn test_multiline_system_prompt() {
#[test] #[test]
fn test_chat_request_serialization() { fn test_chat_request_serialization() {
let request = ChatCompletionRequest { let request = ChatCompletionRequest {
messages: vec![ messages: vec![Message::system("You are helpful"), Message::user("Hello")],
Message::system("You are helpful"),
Message::user("Hello"),
],
temperature: Some(0.7), temperature: Some(0.7),
max_tokens: Some(100), max_tokens: Some(100),
tools: None, tools: None,

View File

@ -41,7 +41,7 @@ async fn test_scheduler_types_roundtrip() {
/// Verify that next_run_for_schedule produces valid future timestamps. /// Verify that next_run_for_schedule produces valid future timestamps.
#[test] #[test]
fn test_next_run_always_future() { fn test_next_run_always_future() {
use picobot::scheduler::{next_run_for_schedule, Schedule}; use picobot::scheduler::{Schedule, next_run_for_schedule};
let now = 1700000000000_i64; let now = 1700000000000_i64;
@ -56,6 +56,10 @@ fn test_next_run_always_future() {
for s in &schedules { for s in &schedules {
let next = next_run_for_schedule(s, now); let next = next_run_for_schedule(s, now);
assert!(next.is_some(), "expected next run for {:?}", s); assert!(next.is_some(), "expected next run for {:?}", s);
assert!(next.unwrap() > now, "next run should be after now for {:?}", s); assert!(
next.unwrap() > now,
"next run should be after now for {:?}",
s
);
} }
} }

View File

@ -1,6 +1,6 @@
use std::collections::HashMap;
use picobot::providers::{create_provider, ChatCompletionRequest, Message, Tool, ToolFunction};
use picobot::config::LLMProviderConfig; use picobot::config::LLMProviderConfig;
use picobot::providers::{ChatCompletionRequest, Message, Tool, ToolFunction, create_provider};
use std::collections::HashMap;
fn load_openai_config() -> Option<LLMProviderConfig> { fn load_openai_config() -> Option<LLMProviderConfig> {
dotenv::from_filename("tests/test.env").ok()?; dotenv::from_filename("tests/test.env").ok()?;
@ -53,8 +53,7 @@ fn make_weather_tool() -> Tool {
#[tokio::test] #[tokio::test]
#[ignore] #[ignore]
async fn test_openai_tool_call() { async fn test_openai_tool_call() {
let config = load_openai_config() let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
.expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider"); let provider = create_provider(config).expect("Failed to create provider");
@ -68,7 +67,11 @@ async fn test_openai_tool_call() {
let response = provider.chat(request).await.unwrap(); let response = provider.chat(request).await.unwrap();
// Should have tool calls // Should have tool calls
assert!(!response.tool_calls.is_empty(), "Expected tool call, got: {}", response.content); assert!(
!response.tool_calls.is_empty(),
"Expected tool call, got: {}",
response.content
);
let tool_call = &response.tool_calls[0]; let tool_call = &response.tool_calls[0];
assert_eq!(tool_call.name, "get_weather"); assert_eq!(tool_call.name, "get_weather");
@ -78,8 +81,7 @@ async fn test_openai_tool_call() {
#[tokio::test] #[tokio::test]
#[ignore] #[ignore]
async fn test_openai_tool_call_with_manual_execution() { async fn test_openai_tool_call_with_manual_execution() {
let config = load_openai_config() let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
.expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider"); let provider = create_provider(config).expect("Failed to create provider");
@ -92,8 +94,7 @@ async fn test_openai_tool_call_with_manual_execution() {
}; };
let response1 = provider.chat(request1).await.unwrap(); let response1 = provider.chat(request1).await.unwrap();
let tool_call = response1.tool_calls.first() let tool_call = response1.tool_calls.first().expect("Expected tool call");
.expect("Expected tool call");
assert_eq!(tool_call.name, "get_weather"); assert_eq!(tool_call.name, "get_weather");
// Second request with tool result // Second request with tool result
@ -116,8 +117,7 @@ async fn test_openai_tool_call_with_manual_execution() {
#[tokio::test] #[tokio::test]
#[ignore] #[ignore]
async fn test_openai_no_tool_when_not_provided() { async fn test_openai_no_tool_when_not_provided() {
let config = load_openai_config() let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
.expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider"); let provider = create_provider(config).expect("Failed to create provider");