Compare commits

..

No commits in common. "8f4ee79d8d6f880824df1770c001a6554fe9d719" and "a77c02682605b6a901f821694293d18867868534" have entirely different histories.

97 changed files with 7937 additions and 5926 deletions

View File

@ -1,28 +0,0 @@
# Git
.git
.gitignore
# Build artifacts
target/
!target/release/picobot
# IDE
.vscode/
.idea/
*.swp
*.swo
# Docs and references
docs/
reference/
# Test files
tests/
# Misc
*.md
*.txt
.opencode/
CLAUDE.md
AGENTS.md
ARCHITECTURE_REVIEW.md

2
.gitignore vendored
View File

@ -1,8 +1,6 @@
/target
docker_build/
reference/**
.env
*.env
Cargo.lock
.worktrees/
design

128
ARCHITECTURE_REVIEW.md Normal file
View File

@ -0,0 +1,128 @@
# 架构审查报告
> 生成时间: 2026-04-26
> 更新时间: 2026-04-26
## 审查摘要
本报告识别了当前代码库中的架构不合理、冗余和无效代码的问题。
---
## 问题清单
### 已修复
#### ✅ #1 OutboundDispatcher 重复维护 Channel 注册表
**修复方案**: `OutboundDispatcher` 现在从 `ChannelManager` 获取 channels而不是自己维护一份注册表。
**修改文件**:
- `src/bus/dispatcher.rs` - 移除 `channels` 字段,改用 `ChannelManager`
- `src/channels/manager.rs` - 添加 `register_channel` 方法
- `src/gateway/mod.rs` - 简化 dispatcher 初始化
---
#### ✅ #2 CliChatChannel 持有独立的 SessionStore
**修复方案**: `CliChatChannel``SessionStore` 通过依赖注入从 `ChannelManager` 获取,而不是独立持有。
**修改文件**:
- `src/channels/cli_chat.rs` - 添加 `set_store()` 方法
- `src/channels/manager.rs` - 添加 `cli_chat_channel` 字段
- `src/gateway/mod.rs` - 重构 channel 初始化流程
---
#### ✅ #3 MessageBus 被创建两次引用
**修复方案**: 移除 `GatewayState.bus` 字段,直接使用 `channel_manager.bus()`
**修改文件**:
- `src/gateway/mod.rs` - 移除冗余的 `bus` 字段
---
#### ✅ #4 GatewayState 同时持有 channel_manager 和 cli_chat_channel
**修复方案**: `cli_chat_channel` 只通过 `ChannelManager` 管理,`GatewayState` 不再单独持有。
**修改文件**:
- `src/gateway/mod.rs` - 移除 `cli_chat_channel` 字段,添加 `cli_chat_channel()` getter 方法
---
### 高优先级(待修复)
#### ❌ Session 每次重建都创建新的 LLM Provider
**文件**: `src/gateway/session.rs:349-361`
**问题**: 每当 session TTL 过期默认4小时就会销毁并重建 session同时创建新的 LLM provider 连接。
**建议**: Provider 应该池化复用,不随 session 销毁而重建。
---
#### ❌ CliChatChannel::send 广播给所有客户端
**文件**: `src/channels/cli_chat.rs:279-289`
**问题**: `OutboundMessage``chat_id` 字段用于路由,但实现广播给所有客户端,而不是只发给对应 chat_id 的客户端。
**建议**: 根据 `chat_id` 过滤客户端,只发送给对应的客户端。
---
### 中优先级(待修复)
#### ❌ default_tools() 每次调用创建新 ToolRegistry
**文件**: `src/gateway/session.rs:212-227`
**建议**: 如果工具列表是只读的,直接 clone Arc如果需要修改需要澄清设计意图。
---
### 低优先级(待修复)
#### ❌ FeishuChannel::new 接收未使用的 provider_config
**文件**: `src/channels/feishu.rs:175-178`
---
#### ❌ OutboundDispatcher::send_with_retry 永不执行的 unreachable
**文件**: `src/bus/dispatcher.rs:81`
---
#### ❌ Channel trait 的 `is_running` 使用 std::sync::Mutex
**文件**: `src/channels/base.rs:38` vs `src/channels/cli_chat.rs:265-267`
---
#### ❌ LoopDetector 硬编码在 AgentLoop 中
**文件**: `src/agent/agent_loop.rs:88-172`
---
#### ❌ InboundMessage 和 OutboundMessage 结构重复
**文件**: `src/bus/message.rs`
---
## 问题统计
| 状态 | 优先级 | 数量 |
|------|--------|------|
| ✅ 已修复 | - | 4 |
| ❌ 待修复 | 高 | 2 |
| ❌ 待修复 | 中 | 1 |
| ❌ 待修复 | 低 | 5 |
| **总计** | - | **12** |

View File

@ -1,6 +1,6 @@
[package]
name = "picobot"
version = "1.1.0"
version = "0.1.0"
edition = "2024"
[dependencies]
@ -12,8 +12,6 @@ serde_json = "1.0"
async-trait = "0.1"
thiserror = "2.0.18"
tokio = { version = "1.52", features = ["full"] }
tokio-util = { version = "0.7", features = ["rt"] }
dashmap = "6.1"
uuid = { version = "1.23", features = ["v4"] }
axum = { version = "0.8", features = ["ws"] }
tokio-tungstenite = { version = "0.29.0", features = ["rustls-tls-webpki-roots", "rustls"] }
@ -51,7 +49,6 @@ encoding_rs = "0.8"
zstd = "0.13"
tar = "0.4"
fantoccini = { version = "0.22", default-features = false, features = ["rustls-tls"] }
portable-pty = "0.9"
[build-dependencies]
zstd = "0.13"

View File

@ -1,110 +0,0 @@
# =============================================================================
# PicoBot Docker Image
# =============================================================================
# Build binary on host:
# cargo build --release
#
# Build image:
# docker build -t picobot .
#
# Run gateway: docker run -d -v ~/.picobot:/app/.picobot -p 19876:19876 picobot gateway
# Run chat: docker run -it -v ~/.picobot:/app/.picobot picobot chat
# =============================================================================
FROM debian:trixie-slim
LABEL org.opencontainers.image.title="PicoBot"
LABEL org.opencontainers.image.description="AI agent gateway and chat client"
LABEL org.opencontainers.image.source="https://github.com/your-repo/picobot"
# Avoid interactive prompts
ENV DEBIAN_FRONTEND=noninteractive
# Configure domestic mirrors for pip, uv, npm (China)
ENV PIP_INDEX_URL=https://mirrors.aliyun.com/pypi/simple/
ENV UV_INDEX_URL=https://mirrors.aliyun.com/pypi/simple/
# Install base tools, Python, and uv in one layer to reduce duplication
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
tini \
curl \
gnupg \
git \
jq \
tree \
zip \
unzip \
sqlite3 \
openssh-client \
sshpass \
dnsutils \
poppler-utils \
fonts-wqy-zenhei \
fonts-wqy-microhei \
python3 \
python3-pip \
python3-venv \
&& rm -rf /var/lib/apt/lists/* \
&& pip3 install --no-cache-dir --break-system-packages uv
# Install Node.js and npx
RUN curl -fsSL https://deb.nodesource.com/setup_22.x | bash - \
&& apt-get install -y --no-install-recommends nodejs \
&& npm config set registry https://registry.npmmirror.com \
&& npm cache clean --force \
&& rm -rf /var/lib/apt/lists/*
# Install himalaya (CLI email client) from local file
COPY docker_build/himalaya.x86_64-linux.tgz /tmp/himalaya.tgz
RUN tar -xzf /tmp/himalaya.tgz -C /usr/local/bin \
&& chmod +x /usr/local/bin/himalaya \
&& rm -f /tmp/himalaya.tgz
# Install fd (alternative to find)
RUN curl -fsSL https://github.com/sharkdp/fd/releases/download/v9.0.0/fd-v9.0.0-x86_64-unknown-linux-gnu.tar.gz | \
tar -xz --strip-components=1 -C /usr/local/bin \
&& chmod +x /usr/local/bin/fd
# Install ripgrep (rg)
RUN curl -fsSL https://github.com/BurntSushi/ripgrep/releases/download/14.1.0/ripgrep-14.1.0-x86_64-unknown-linux-musl.tar.gz | \
tar -xz --strip-components=1 -C /usr/local/bin \
&& chmod +x /usr/local/bin/rg
# Install Chromium and chromedriver for browser automation
# Debian's chromium package is real (not a snap shim like Ubuntu 24.04)
RUN apt-get update && apt-get install -y --no-install-recommends \
chromium \
chromium-driver \
&& ln -sf /usr/bin/chromium /usr/local/bin/chrome \
&& ln -sf /usr/bin/chromedriver /usr/local/bin/chromedriver \
&& rm -rf /var/lib/apt/lists/*
# Create non-root user
RUN useradd -m -s /bin/bash app
WORKDIR /app
# Copy pre-built binary from host
COPY target/release/picobot /app/picobot
# Copy config template
COPY resources/templates/config.example.json /app/config.json.example
# Create required directories
RUN mkdir -p /app/.picobot/workspace /app/.picobot/media /app/.picobot/tmp && \
chown -R app:app /app
USER app
ENV HOME=/app
# Environment variables for Chromium in containers
ENV CHROME_BIN=/usr/bin/chromium
ENV TMPDIR=/app/.picobot/tmp
ENTRYPOINT ["/app/picobot"]
CMD ["gateway"]
EXPOSE 19876
ENV RUST_LOG=info

537
README.md
View File

@ -1,102 +1,143 @@
# PicoBot
PicoBot is a Rust-based personal AI assistant runtime. It runs a local gateway, connects chat channels such as the terminal TUI and Feishu/Lark, persists sessions in SQLite, and gives the agent a tool system for files, shell commands, web access, memory, scheduling, skills, MCP tools, and delegated sub-agents.
A multi-channel AI agent framework with a WebSocket gateway and TUI client, supporting OpenAI-compatible and Anthropic LLM providers, tool calling, session persistence, and cron-based scheduling.
## What It Does
## System Architecture
- Runs as a gateway server on `127.0.0.1:19876` by default.
- Provides a Ratatui terminal client over WebSocket.
- Supports Feishu/Lark messages, reactions, file upload/download, and media references.
- Calls OpenAI-compatible providers and Anthropic Messages API providers.
- Persists conversations, messages, memories, scheduled jobs, LLM call metadata, and background sub-agent tasks in SQLite.
- Loads skills from workspace, user, and shared skill directories, with built-in skills installed on first use.
- Compresses long contexts and stores timeline summaries for later recall.
- Can register tools discovered from configured MCP servers.
```mermaid
graph TB
subgraph Clients
TUI["🖥️ CLI Chat (TUI)"]
FS["📱 Feishu/Lark"]
end
## Architecture
subgraph Gateway["Gateway Server (127.0.0.1:19876)"]
HTTP["HTTP Endpoints<br/>GET /health<br/>GET /ws (WebSocket upgrade)"]
WS["WebSocket Handler"]
CD["ChannelManager"]
SP["SessionManager"]
AL["AgentLoop"]
end
```text
Channel -> MessageBus -> SessionManager -> AgentLoop -> LLM Provider
| |
| v
| Tools
v
SQLite
subgraph Bus["MessageBus"]
IB["Inbound Channel"]
OB["Outbound Channel"]
CC["Control Channel"]
end
Control messages -> SessionManager -> MessageBus -> OutboundDispatcher -> Channel
subgraph Storage
SQLite[("SQLite<br/>picobot.db")]
end
subgraph AI["AI Providers"]
OpenAI["OpenAI / DashScope"]
Anthropic["Anthropic Claude"]
end
TUI <-->|WebSocket| WS
FS <-->|Webhook| HTTP
CD -->|InboundMessage| IB
IB -->|DialogEvent| SP
CC -->|ControlMessage| SP
SP <--> AL
AL -->|API Call| OpenAI
AL -->|API Call| Anthropic
AL -->|Tool Call| Tools
SP -->|OutboundMessage| OB
OB --> CD
SP --> SQLite
Tools --> SQLite
subgraph Tools
Bash["Bash"]
FileIO["File Read/Write/Edit"]
Web["HTTP Request / Web Fetch"]
Calc["Calculator"]
Skill["Get Skill"]
Msg["Send Message"]
Cron["Cron Jobs"]
end
```
The main runtime boundary is:
### Core Data Flow
- `channels` only receive and send external messages.
- `bus` is an async queue, not a router.
- `session` owns dialog lifecycle, persistence, memory recall, prompt assembly, compression, and task cancellation.
- `agent` runs the stateless LLM/tool loop.
- `providers` are HTTP clients for model APIs.
- `tools` execute agent actions and return string results.
- `storage` owns SQLite schema and CRUD.
- `scheduler` polls due jobs and feeds prompts back into sessions.
```mermaid
sequenceDiagram
participant Channel as Channel<br/>(CLI/Feishu)
participant Bus as MessageBus
participant SM as SessionManager
participant AL as AgentLoop
participant LLM as LLM Provider
participant Tool as Tools
Channel->>Bus: InboundMessage (user input)
Bus->>SM: DialogEvent
SM->>SM: Load/Resolve Session
SM->>AL: Process (session state)
AL->>LLM: ChatCompletionRequest
LLM-->>AL: response / tool_calls
alt Tool Calls
AL->>Tool: execute tool
Tool-->>AL: result
AL->>LLM: continue with tool result
end
AL-->>SM: AgentProcessResult (text + token count)
SM->>SM: Persist to SQLite
SM->>Bus: OutboundMessage
Bus->>Channel: response to user
```
## Features
### Channels
### Multi-Channel Support
- **CLI Chat Client** — Full TUI with session management, Markdown rendering, slash commands
- **Feishu (Lark)** — Webhook-based integration with typing indicators and media support
- `cli_chat`: terminal TUI client connected through `/ws`.
- `feishu`: Feishu/Lark channel with configurable allow list, media directory, and reaction emoji.
### Multi-Provider LLM
- OpenAI-compatible API (GPT-4, DashScope, Volcengine, etc.)
- Anthropic Messages API (Claude)
- Cross-provider JSON Schema normalization for tool calling compatibility
### LLM Providers
### Session Management
- Multi-session conversations per channel/chat
- Create, switch, rename, archive, delete dialogs via slash commands or WebSocket
- SQLite-persisted session history with automatic TTL-based cleanup
- Context compression for long conversations approaching token limits
- OpenAI-compatible chat completions, including DashScope, Volcengine, and similar APIs.
- Anthropic Messages API.
- Model-specific `input_type` metadata for text/image capability checks.
- JSON Schema cleanup for cross-provider tool compatibility.
### Tool System
| Tool | Description |
|------|-------------|
| `bash` | Execute shell commands in workspace |
| `file_read` | Read file contents |
| `file_write` | Create/overwrite files |
| `file_edit` | Precise string substitution in files |
| `http_request` | Make HTTP API requests |
| `web_fetch` | Fetch and parse web pages |
| `calculator` | Evaluate mathematical expressions |
| `get_skill` | Load agent skills from local skill files |
| `send_message` | Send messages to other channels |
| `cron_add/list/remove/enable/disable/update` | Manage scheduled jobs |
### Sessions And Memory
### Scheduling
- Cron-based recurring jobs with optional timezone support
- One-shot (`at`) and interval (`every`) schedules
- Jobs trigger agent processing via specified channel/chat
- Session IDs use `<channel>:<chat_id>:<dialog_id>`.
- Each channel/chat can have multiple dialogs.
- Dialog operations include create, list, switch, rename, delete, compact, dump, info, and stop.
- Session history is persisted to SQLite and can be incrementally restored after compression.
- Knowledge memories are recalled into the system prompt each turn.
- Timeline memories are produced by context compression and can be searched later.
### Skills System
- Load Markdown skill files from `~/.picobot/skills` and `~/.agents/skills`
- Skills inject specialized system prompts for specific tasks
- Automatic hot-reload on file changes
### Tools
Base tools registered for the agent:
| Tool | Purpose |
|------|---------|
| `calculator` | Math expressions and statistics |
| `file_read` / `file_write` / `file_edit` | Workspace file operations |
| `file_search` / `content_search` | File and content search |
| `bash` | Run shell commands in the workspace |
| `http_request` | HTTP API requests |
| `web_fetch` | Fetch and extract web page text |
| `get_skill` | List or load local skills |
| `memory_store` / `memory_recall` / `timeline_recall` / `memory_forget` | Long-term memory operations |
| `delegate` | Run inline, background, or parallel sub-agents |
| `send_message` | Send outbound messages to configured channels |
| `chat_manager` | Inspect sessions, channels, and stored messages |
| `cron_add/list/remove/enable/disable/update` | Manage scheduled jobs when scheduler is enabled |
| `browser` | Optional WebDriver browser automation when enabled |
| MCP tools | Dynamically registered from configured MCP servers |
### Skills
Skills are directories containing `SKILL.md`. Load priority is:
1. `{workspace}/skills`
2. `~/.picobot/skills`
3. `~/.agents/skills`
Same-name skills in higher-priority locations override lower-priority ones. Built-in skills from `resources/skills` are embedded into the binary and installed into `~/.picobot/skills` if missing.
### Observability
- Observer pattern for agent and tool telemetry
- Events: `AgentStart`, `AgentEnd`, `ToolCallStart`, `ToolCall`
- Structured JSON logging with file rotation
## Quick Start
### Prerequisites
- Rust toolchain with edition 2024 support.
- A configured LLM provider API key.
- Rust nightly (edition 2024) — use `rustup` to install
### Build
@ -106,186 +147,276 @@ cargo build
### Configure
PicoBot loads `~/.picobot/config.json` first, then falls back to `./config.json`. On gateway startup, a template is released to `~/.picobot/config.example.json` if it does not exist. The source template is [resources/templates/config.example.json](/home/xiaoxixi/code/PicoBot/resources/templates/config.example.json).
Minimal example:
1. Create `config.json` (or `~/.picobot/config.json`):
```json
{
"providers": {
"openai": {
"type": "openai",
"base_url": "https://api.openai.com/v1",
"api_key": "<OPENAI_API_KEY>",
"extra_headers": {}
"providers": {
"openai": {
"type": "openai",
"base_url": "https://api.openai.com/v1",
"api_key": "<OPENAI_API_KEY>"
}
},
"models": {
"gpt-4o": {
"model_id": "gpt-4o",
"temperature": 0.7,
"max_tokens": 4096
}
},
"agents": {
"default": {
"provider": "openai",
"model": "gpt-4o",
"max_tool_iterations": 99,
"token_limit": 128000
}
}
},
"models": {
"gpt-4o": {
"model_id": "gpt-4o",
"temperature": 0.7,
"max_tokens": 4096,
"input_type": ["text", "image"]
}
},
"agents": {
"default": {
"provider": "openai",
"model": "gpt-4o",
"max_tool_iterations": 99,
"token_limit": 128000
}
},
"workspace_dir": "~/.picobot/workspace"
}
```
The `.env` file in the current directory is parsed by PicoBot itself. Values like `<OPENAI_API_KEY>` in JSON are replaced from the process environment after `.env` is loaded.
2. Set API keys via `.env` file (one `KEY=VALUE` per line):
```env
OPENAI_API_KEY=sk-xxxxx
```
### Run
**Start gateway server:**
```bash
cargo run -- gateway
```
The gateway switches the process working directory to `workspace_dir` and stores `picobot.db` there by default.
Binds `127.0.0.1:19876` by default. Override with `--host` and `--port`.
In another terminal:
**Connect CLI client:**
```bash
cargo run -- chat
```
The client connects to `ws://127.0.0.1:19876/ws` by default. Override with `--gateway-url`.
Connects to `ws://127.0.0.1:19876/ws`. Override with `--gateway-url`.
## Configuration
## Configuration Reference
Top-level config fields:
Config load order: `~/.picobot/config.json``./config.json` (fallback).
| Field | Purpose |
|-------|---------|
| `providers` | Named LLM provider configs |
| `models` | Named model configs |
| `agents` | Agent-to-provider/model binding |
| `gateway` | Bind address, session DB path, cleanup, scheduler, background task limits |
| `client` | Default WebSocket URL for the TUI client |
| `channels` | Channel configs, currently Feishu/Lark |
| `memory` | Recall and consolidation settings |
| `mcp` | MCP server configs |
| `browser` | Optional WebDriver browser tool config |
| `workspace_dir` | Workspace used for file tools, shell commands, DB default, and workspace skills |
### Full Config Structure
Important defaults:
```mermaid
graph LR
Config["config.json"]
Config --> Providers["providers<br/>ProviderConfig{}"]
Config --> Models["models<br/>ModelConfig{}"]
Config --> Agents["agents<br/>AgentConfig{}"]
Config --> Gateway["gateway<br/>GatewayConfig"]
Config --> Client["client<br/>ClientConfig"]
Config --> Channels["channels<br/>ChannelConfig{}"]
Config --> Workspace["workspace_dir"]
| Key | Default |
|-----|---------|
| `gateway.host` | `127.0.0.1` |
| `gateway.port` | `19876` |
| `gateway.max_concurrent_background_tasks` | `10` |
| `gateway.scheduler.enabled` | `true` if `scheduler` is omitted and defaulted |
| `client.gateway_url` | `ws://127.0.0.1:19876/ws` |
| `memory.recall_limit` | `5` |
| `memory.timeline_retention_days` | `90` |
| `mcp.tool_timeout_secs` | `180` |
| `browser.enabled` | `false` |
Providers --> PT["type (openai / anthropic)<br/>base_url<br/>api_key<br/>extra_headers"]
Models --> MT["model_id<br/>temperature<br/>max_tokens"]
Agents --> AT["provider (ref)<br/>model (ref)<br/>max_tool_iterations<br/>token_limit"]
Gateway --> GT["host / port<br/>session_db_path<br/>scheduler"]
Channels --> CT["feishu: app_id, app_secret<br/>allow_from, agent, media_dir"]
```
MCP servers support `stdio`, `sse`, and `streamable-http` transports. Browser automation requires a compatible Chrome/Chromium and chromedriver/WebDriver endpoint.
### Environment Variables
The `.env` file in the working directory is loaded manually (not via dotenv crate). Placeholders in `config.json` written as `<VAR_NAME>` are substituted at load time.
### Gateway Config
| Key | Type | Default | Description |
|-----|------|---------|-------------|
| `host` | string | `127.0.0.1` | Bind address |
| `port` | u16 | `19876` | Listen port |
| `session_db_path` | string | workspace `picobot.db` | SQLite database path |
| `scheduler.enabled` | bool | `false` | Enable cron scheduler |
### Agent Config
| Key | Type | Default | Description |
|-----|------|---------|-------------|
| `provider` | string | — | Provider name (key in `providers`) |
| `model` | string | — | Model name (key in `models`) |
| `max_tool_iterations` | number | `99` | Max tool call iterations per turn |
| `token_limit` | number | `128000` | Context window token limit |
## Slash Commands
Available from CLI chat and channel text messages:
Available in CLI chat and Feishu:
| Command | Description |
|---------|-------------|
| `/new` | Create a new dialog |
| `/sessions` | List recent dialogs |
| `/switch <dialog_id>` | Switch dialog |
| `/rename <title>` | Rename current dialog |
| `/delete` | Delete current dialog |
| `/compact` | Manually trigger context compression |
| `/info` | Show current dialog information |
| `/dump` | Save current dialog as Markdown |
| `/?`, `/help` | Show help |
| `/mcp` | Show MCP server and tool status |
| `/stop` | Stop active tasks and clear queued messages |
| Command | Alias | Description |
|---------|-------|-------------|
| `/new` | `/刷新` | Create a new dialog |
| `/list` | `/对话列表` | List all dialogs |
| `/switch <id>` | — | Switch to a dialog |
| `/rename <title>` | — | Rename current dialog |
| `/archive` | — | Archive current dialog |
| `/delete` | — | Delete current dialog |
| `/clear` | `/清空` | Clear current dialog history |
## WebSocket API
## WebSocket Protocol
The gateway exposes:
The gateway exposes a WebSocket endpoint at `/ws`. Messages use typed JSON with a `type` discriminator field.
### Client → Server (WsInbound)
| Type | Fields |
|------|--------|
| `user_input` | `content`, `channel?`, `chat_id?`, `sender_id?` |
| `create_session` | `title?` |
| `list_sessions` | `include_archived` |
| `load_session` | `session_id` |
| `rename_session` | `session_id?`, `title` |
| `archive_session` | `session_id?` |
| `delete_session` | `session_id?` |
| `clear_history` | `chat_id?`, `session_id?` |
| `get_slash_commands` | — |
| `ping` | — |
### Server → Client (WsOutbound)
| Type | Fields |
|------|--------|
| `assistant_response` | `session_id`, `response`, `tokens_used?`, `tool_calls?` |
| `session_list` | `sessions[]` |
| `session_loaded` | `session_id`, `messages[]` |
| `session_created` | `session_id`, `title` |
| `session_renamed` | `session_id`, `title` |
| `session_archived` | `session_id` |
| `session_deleted` | `session_id` |
| `slash_commands` | `commands[]` |
| `error` | `message` |
| `pong` | — |
## HTTP Endpoints
| Method | Path | Description |
|--------|------|-------------|
| `GET` | `/health` | Returns service health and version |
| `GET` | `/health` | Health check — returns `{"status":"ok","version":"x.y.z"}` |
| `GET` | `/ws` | WebSocket upgrade for chat clients |
Inbound WebSocket message types:
| Type | Main fields |
|------|-------------|
| `user_input` | `content`, optional `channel`, `chat_id`, `sender_id` |
| `clear_history` | optional `chat_id`, `session_id` |
| `create_session` | optional `title` |
| `list_sessions` | `include_archived` |
| `load_session` | `session_id` |
| `rename_session` | optional `session_id`, `title` |
| `archive_session` | optional `session_id` |
| `delete_session` | optional `session_id` |
| `get_slash_commands` | none |
| `ping` | none |
Outbound WebSocket message types include `assistant_response`, `error`, `session_established`, `session_created`, `session_list`, `session_loaded`, `session_renamed`, `session_archived`, `session_deleted`, `history_cleared`, `slash_commands_list`, `pong`, `command_executed`, and `system_notification`.
## Testing
```bash
# Unit tests
# Unit tests (no external dependencies)
cargo test --lib
# Integration tests require real API keys in tests/test.env
# Integration tests (require API keys)
cp tests/test.env.example tests/test.env
# Fill in your API keys in tests/test.env
cargo test --test test_integration -- --ignored
cargo test --test test_tool_calling -- --ignored
cargo test --test test_request_format -- --ignored
# Run all tests
cargo test -- --ignored
```
Integration tests are ignored by default because they make real provider calls.
Integration tests are `#[ignore]` by default because they make real API calls.
## Project Layout
## Project Structure
```text
src/
agent/ LLM loop, context compression, system prompts, media handling, sub-agents
bus/ Inbound, outbound, and control message queues
channels/ CLI chat and Feishu/Lark integrations
client/ Ratatui terminal UI
config/ Config loading, env substitution, path expansion
gateway/ Axum HTTP/WebSocket server and GatewayState wiring
mcp/ MCP client connections and tool wrappers
memory/ Memory manager and memory types
observability/ Agent/tool telemetry observer interfaces
providers/ OpenAI-compatible and Anthropic clients
scheduler/ Scheduled job runtime
session/ Session lifecycle, dialog commands, persistence integration
skills/ Skill loading and embedded built-in skill installation
storage/ SQLite schema and CRUD
tools/ Agent tool implementations
resources/
skills/ Built-in skills embedded at build time
templates/ Config, AGENTS.md, and USER.md templates released on first run
tests/ Unit and ignored integration tests
reference/ Third-party reference code; do not modify as project source
```
├── src/
│ ├── main.rs # CLI entrypoint (clap-based subcommands)
│ ├── lib.rs # Module declarations
│ ├── gateway/ # HTTP/WS server, GatewayState initialization
│ │ ├── mod.rs
│ │ ├── http.rs # Health endpoint
│ │ └── ws.rs # WebSocket handler
│ ├── client/ # TUI chat client
│ │ ├── mod.rs
│ │ └── tui/ # Ratatui-based terminal UI
│ ├── channels/ # Channel integrations
│ │ ├── base.rs # Channel trait
│ │ ├── cli_chat.rs # CLI WebSocket channel
│ │ ├── feishu.rs # Feishu/Lark webhook channel
│ │ ├── manager.rs # ChannelManager
│ │ └── slash_command.rs # Slash command parser
│ ├── bus/ # Async message bus
│ │ ├── mod.rs # MessageBus (tokio mpsc channels)
│ │ ├── message.rs # Message types
│ │ └── dispatcher.rs # OutboundDispatcher
│ ├── session/ # Session & dialog management
│ │ ├── mod.rs
│ │ ├── session.rs # Session, SessionManager
│ │ ├── session_id.rs # UnifiedSessionId
│ │ ├── commands.rs # SessionCommand enum
│ │ └── events.rs # SessionEvent, DialogInfo
│ ├── agent/ # LLM interaction loop
│ │ ├── mod.rs
│ │ ├── agent_loop.rs # AgentLoop (stateless)
│ │ ├── context_compressor.rs # Token estimation & summarization
│ │ └── system_prompt.rs # System prompt builder
│ ├── providers/ # LLM API clients
│ │ ├── mod.rs # Factory: create_provider()
│ │ ├── traits.rs # LLMProvider trait
│ │ ├── openai.rs # OpenAI-compatible client
│ │ └── anthropic.rs # Anthropic Messages API client
│ ├── tools/ # Agent tools
│ │ ├── mod.rs # create_default_tools()
│ │ ├── registry.rs # ToolRegistry
│ │ ├── traits.rs # Tool trait, ToolResult
│ │ ├── schema.rs # Cross-provider JSON Schema cleaner
│ │ ├── bash.rs # Shell command execution
│ │ ├── calculator.rs # Math expression evaluator
│ │ ├── chat_manager.rs # Session management tool
│ │ ├── cron.rs # Cron job management tools
│ │ ├── file_read.rs # File reader
│ │ ├── file_write.rs # File writer
│ │ ├── file_edit.rs # File editor (string substitution)
│ │ ├── get_skill.rs # Skill loader tool
│ │ ├── http_request.rs # HTTP request tool
│ │ ├── send_message.rs # Cross-channel messaging
│ │ └── web_fetch.rs # Web page fetcher
│ ├── skills/ # Skills loading from markdown files
│ │ └── mod.rs # SkillsLoader, Skill
│ ├── storage/ # SQLite persistence
│ │ ├── mod.rs # Storage, schema init
│ │ ├── session.rs # Session CRUD operations
│ │ ├── message.rs # Message persistence
│ │ ├── scheduler.rs # ScheduledJob, JobRun storage
│ │ └── error.rs # StorageError
│ ├── scheduler/ # Cron scheduler runtime
│ │ ├── mod.rs # Scheduler, next_run_for_schedule()
│ │ └── types.rs # Schedule enum (At/Every/Cron)
│ ├── observability/ # Telemetry observer pattern
│ │ └── mod.rs # Observer trait, ObserverEvent, MultiObserver
│ ├── protocol.rs # WebSocket message types (WsInbound/WsOutbound)
│ ├── config/ # Config loading & env substitution
│ │ └── mod.rs # Config, LLMProviderConfig, load_env_file()
│ └── logging.rs # Tracing subscriber init with file rotation
├── tests/
│ ├── test_integration.rs # LLM provider integration tests
│ ├── test_tool_calling.rs # Tool calling integration tests
│ ├── test_request_format.rs # Request format tests
│ ├── test_scheduler.rs # Scheduler unit tests
│ ├── test.env.example # Test environment template
│ └── test.env # Actual test keys (gitignored)
├── reference/ # Third-party reference code (do not modify)
├── resources/ # Assets embedded in binary
│ └── templates/ # Templates released to ~/.picobot/ on first run
├── config.example.json # Full config example
└── Cargo.toml
```
## Key Dependencies
| Crate | Purpose |
|-------|---------|
| `axum`, `tokio`, `tokio-tungstenite` | Gateway and WebSocket runtime |
| `sqlx` | SQLite persistence |
| `reqwest` | LLM and HTTP clients |
| `ratatui`, `crossterm`, `termimad` | Terminal UI |
| `rmcp` | MCP client support |
| `fantoccini` | Optional browser automation |
| `cron`, `chrono-tz` | Scheduling |
| `jieba-rs` | Chinese tokenization for memory search |
| `zstd`, `tar` | Embedded built-in skill packaging |
| `axum` + `tokio-tungstenite` | HTTP server & WebSocket |
| `sqlx` (SQLite) | Session/Message/Job persistence |
| `reqwest` (rustls) | LLM API & external HTTP calls |
| `ratatui` + `crossterm` | Terminal UI |
| `clap` | CLI argument parsing |
| `tracing` + `tracing-subscriber` | Structured logging |
| `cron` + `chrono-tz` | Cron schedule parsing |
| `meval` | Mathematical expression evaluation |
| `uuid` | Session/Dialog ID generation |
| `dirs` | Platform config directory resolution |

View File

@ -1,16 +0,0 @@
services:
picobot:
image: picobot:latest
container_name: picobot
restart: unless-stopped
ports:
- "19876:19876"
volumes:
- ~/.picobot/config.json:/app/.picobot/config.json:ro
- picobot_data:/app/.picobot
environment:
- RUST_LOG=info
command: gateway
volumes:
picobot_data:

View File

@ -1,359 +0,0 @@
# PicoBot 代码质量分析报告
审查日期2026-06-15
## 结论摘要
PicoBot 的总体架构方向是清晰的Gateway 负责装配Channel 只做收发MessageBus 解耦输入输出SessionManager 管理会话AgentLoop 保持无状态并执行工具Storage 统一持久化。这条主线是成立的,也已经具备较完整的 AI 助手运行时能力。
当前主要质量风险集中在三类:
1. 会话/CLI 路由语义不一致,导致多客户端隔离、加载会话、当前会话追踪不可靠。
2. 若干公开控制接口是空实现或弱实现,协议层暴露的能力和后端实际行为不匹配。
3. 工具和后台任务的资源边界偏弱文件、shell、HTTP、长期任务在异常情况下容易突破预期的安全或稳定性边界。
如果只安排一轮修复,优先处理会话路由和控制接口。这些问题会直接影响用户看到的行为;工具安全和大模块拆分可以作为第二阶段。
## 修复状态
- 已修复CLI 会话路由现在按每个 WebSocket client 的稳定 `chat_id` 隔离,普通输入、创建、列表、加载和 outbound 投递不再混用完整 `session_id``chat_id`
- 已修复Dialog 控制接口已补齐当前会话查询、列表 current 标记、归档、清空历史和 `/delete` 删除当前会话后新建的行为;`include_archived` 现在由 Storage 查询生效。
- 待处理工具文件边界、Session 锁粒度、Bash 超时进程清理等仍是后续质量风险。
## 主要发现
### 已修复CLI 会话路由会破坏会话连续性和多客户端隔离
位置:
- `src/channels/cli_chat.rs:113-126`
- `src/channels/cli_chat.rs:160-164`
- `src/channels/cli_chat.rs:225-249`
- `src/channels/cli_chat.rs:479-494`
- `src/session/session.rs:1305-1310`
问题:
`Client.current_session_id` 存的是完整 session id但 CLI channel 在多个地方把它当作 `chat_id` 使用。普通用户输入如果没有显式传 `chat_id`,会在 `src/channels/cli_chat.rs:119` 生成新的短 ID而不是复用当前 client 的 chat scope。`CreateSession` 又把当前完整 session id 当成新会话的 chat_id。`LoadSession` 解析了传入 session id但随后调用 `GetCurrentDialog`,而后端 `get_current_dialog()` 固定返回 `None`
同时,`send()` 会把所有 `OutboundMessage` 广播给所有 CLI WebSocket client没有按 `msg.chat_id` 或 client 当前会话过滤。这意味着一个客户端的回复可能出现在另一个客户端里。
影响:
- CLI 多轮对话可能落入不同 chat scope。
- 创建/列出/加载会话得到的结果可能不符合 UI 预期。
- 多个 CLI 客户端同时连接时存在串话。
建议:
- 将 client 状态拆成 `chat_id``current_session_id`,不要混用。
- 注册 client 时生成稳定 `chat_id`,后续 `UserInput` 默认复用它。
- `send()``OutboundMessage.chat_id` 精确投递;必要时维护 `chat_id -> clients` 映射。
- `LoadSession` 应直接切换到指定 session或通过 `SwitchDialog` 使用其中的 `dialog_id`
- 为 CLI WebSocket 增加多客户端路由测试。
### 已修复Dialog 控制接口与协议承诺不一致
位置:
- `src/session/session.rs:996-997`
- `src/session/session.rs:1305-1310`
- `src/session/session.rs:1329-1349`
- `src/session/session.rs:1378-1384`
- `src/channels/cli_chat.rs:128-158`
问题:
后端暴露了 create/list/load/rename/archive/delete/clear 等 dialog 操作,但部分行为是空实现或语义错位:
- `/delete` 只创建新 session并没有删除当前 session。
- `get_current_dialog()` 固定返回 `Ok(None)`
- `list_dialogs()` 忽略 `include_archived`,且总是返回 `current_dialog_id = None`
- `archive_dialog()` 是空操作。
- `clear_dialog_history()` 直接返回不可用,但 WebSocket 协议仍暴露 `clear_history`
影响:
用户通过 slash command 和 WebSocket 调用同一类能力时,会得到不一致结果。前端难以基于协议实现可靠状态同步。
建议:
- 明确“archive/clear 是否支持”。不支持就从协议和命令列表移除;支持就实现到底。
- `/delete` 应调用 `delete_dialog(current_session_id)`,再创建一个新的 current session。
- `get_current_dialog()` 应读取 `current_sessions[channel:chat_id]` 并解析为 `UnifiedSessionId`
- `list_dialogs()` 返回真实 current dialog并补上 archived 模型或移除 archived 参数。
### 高优先级:工具文件边界不符合“工作目录内工具”的架构约束
位置:
- `src/tools/mod.rs:56-62`
- `src/tools/path_utils.rs:3-23`
- `src/tools/bash.rs:146-185`
问题:
文件工具默认通过 `FileReadTool::new()``FileWriteTool::new()` 等注册,没有传入 workspace allowlist。`resolve_path()` 对绝对路径直接放行;即使传入 allowlist也只是做 `Path::starts_with()` 的词法判断,没有 canonicalize不能防御 `..`、符号链接等路径逃逸。
`bash` 默认工作目录是 `"."`Gateway 启动时切到 workspace这对相对路径有效但 shell 命令仍然可以访问绝对路径。当前 denylist 只挡少数危险模式,不构成权限边界。
影响:
Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“工作目录内操作”不一致。对于个人助手这可能是有意设计,但如果未来接入外部渠道、多用户或 MCP风险会放大。
建议:
- 工具注册时传入 `workspace_dir`,默认所有文件工具限制在 workspace。
- `resolve_path()` 使用 `std::fs::canonicalize``path_absolutize` 风格逻辑,并处理目标文件不存在时的父目录 canonicalize。
- 写工具禁止跟随危险符号链接,或至少在文档中明确该能力是全文件系统权限。
- shell 工具如果保留,应在配置中显式开关,并区分本地可信模式和渠道暴露模式。
### 中高优先级Session 锁内执行过多异步操作
位置:
- `src/session/session.rs:1001-1018`
- `src/session/session.rs:1604-1711`
问题:
`/compact` 在持有 session mutex 时执行压缩和持久化。agent worker 的 Phase 1 也在持有 session mutex 时执行用户消息落库、memory recall、上下文压缩、session meta 持久化和 agent 创建。其中 `compress_if_needed()` 可能触发 LLM 摘要,属于慢操作。
影响:
- 同一 session 的 slash command、stop、消息排队、状态查询会被慢操作阻塞。
- 当压缩或存储出现抖动时,用户感觉像“卡死”。
- 后续如果在这些慢操作里间接需要 session 状态,容易形成锁顺序问题。
建议:
- 锁内只做内存状态快照和必要的状态标记。
- 将 memory recall、压缩、LLM 摘要放到锁外执行。
- 锁外完成后重新加锁提交结果,并用 generation/version 检测期间是否被 `/stop` 或新任务替换。
### 中优先级Bash 超时不会显式终止子进程
位置:
- `src/tools/bash.rs:150-174`
- `src/tools/bash.rs:180-207`
问题:
`timeout()` 包裹的是 `run_command()` future。超时后 future 被取消,但代码没有持有 child 句柄并显式 `kill()` / `wait()`。对于已经启动的长运行命令或子进程树,可能留下后台进程。
影响:
长任务、服务进程或卡住的 shell 命令会泄漏进程和资源,后续工具调用的行为也会变得不可预测。
建议:
- 使用 `tokio::process::Child``kill_on_drop(true)`
- 超时分支显式 kill child 并 wait。
- 对 shell 子进程树使用进程组隔离,必要时杀整个进程组。
- 对需要持久进程的场景使用 PTY 工具,不混用 bash 的一次性语义。
### 中优先级:文件读取对大二进制文件没有输出上限
位置:
- `src/tools/file_read.rs:121-131`
- `src/tools/file_read.rs:214-229`
问题:
`file_read``std::fs::read()` 读取整个文件。文本路径有 `MAX_CHARS` 截断,但二进制路径会完整 base64 编码后返回,没有大小限制。
影响:
读取大文件会造成内存膨胀、响应膨胀、上下文污染,甚至拖垮进程。
建议:
- 先检查 metadata size超过阈值直接返回提示。
- 二进制文件默认只返回 mime、大小和建议操作需要内容时提供显式 `max_bytes` 参数。
- 对文本读取也改成流式按行读取,而不是整文件读入。
### 中优先级HTTP 私网防护只检查字面 host未做 DNS 解析校验
位置:
- `src/tools/http_request.rs:31-59`
问题:
`http_request` 阻止 localhost、私网 IP 字面量和 `.local`,但普通域名不会解析后检查最终 IP。DNS rebinding 或内网域名解析到私网地址时,当前校验拦不住。
影响:
如果该工具暴露给非完全可信输入,存在 SSRF 风险。
建议:
- 请求前解析域名拒绝私网、loopback、link-local、multicast、unspecified 地址。
- 禁止或限制重定向,重定向后的每个 URL 重新校验。
- 对 `http_request``web_fetch` 复用同一套 URL 安全策略。
### 中优先级:后台任务和主循环缺少监督与优雅关闭
位置:
- `src/bus/mod.rs:51-99`
- `src/gateway/mod.rs:187-244`
- `src/gateway/mod.rs:247-266`
问题:
Gateway 中多个长期任务通过 `tokio::spawn` 启动后没有保存 JoinHandle也没有统一 cancellation token。MessageBus 的 `consume_*()` 在 channel 关闭时使用 `expect()` panic。
影响:
- 某个后台 loop 异常退出后Gateway 不一定能发现。
- 关闭流程只能 stop channel无法系统性停止 scheduler、dispatcher、agent workers、notification publishers。
- bus channel 关闭时更像崩溃,而不是可恢复状态。
建议:
- 引入 runtime supervisor保存 JoinHandle 并集中处理退出原因。
- 用 `CancellationToken` 贯穿 Gateway 子任务。
- `consume_*()` 返回 `Result<Option<T>>`,由调用方决定退出或重启。
### 中低优先级Cron 计算函数没有按入参 `from` 计算 cron 下一次时间
位置:
- `src/scheduler/mod.rs:18-40`
问题:
`next_run_for_schedule(schedule, from)` 的注释说基于 `from` 计算,但 cron 分支创建了 `from_dt` 后没有传给 `cron_schedule`,实际使用的是 `upcoming(Utc)``upcoming(tz)` 的当前时间。
影响:
单元测试或补偿调度传入历史/未来时间时,结果不符合函数契约。线上 reschedule 当前使用 now影响较小但函数语义是错的。
建议:
- 使用 `cron_schedule.after(&from_dt).next()` 或等价 API。
- timezone 分支用 `from_dt.with_timezone(&tz)` 作为 after 起点。
- 增加固定时间输入的单元测试,避免受系统时间影响。
### 中低优先级:存在未接入或半接入代码,增加维护噪音
位置:
- `src/tools/pty.rs`
- `src/tools/mod.rs:1-20`
- `src/tools/mod.rs:49-88`
问题:
仓库里有完整 `pty.rs`,但 `tools/mod.rs` 没有声明 `pub mod pty``create_default_tools()` 也没有注册 PTY 工具。类似情况会让文档、计划和实现状态难以判断。
影响:
维护者会误以为功能已上线。未来改动容易遗漏测试和注册路径。
建议:
- 若 PTY 是要发布的功能:接入模块导出、注册、配置开关、测试和文档。
- 若暂不发布:移动到设计文档或 feature branch避免主干保留死代码。
## 架构评价
### 做得好的地方
- 模块分层方向清楚Channel、Bus、Session、Agent、Provider、Tool、Storage 边界基本可理解。
- AgentLoop 设计为无状态,历史由 SessionManager 管理,这一点利于恢复、压缩和测试。
- Provider 抽象简单直接OpenAI-compatible 与 Anthropic 的差异被限制在 provider 层。
- Storage 集中初始化 schema便于部署单二进制应用。
- Skill、memory、MCP、delegate 这几条扩展线已经形成统一的 ToolRegistry 接入点。
### 主要架构债务
- SessionManager 承担过多职责会话生命周期、命令解析、memory recall、压缩、agent worker、任务取消、send_message 目标解析都在一个 2000 行文件内。
- Channel 和 Session 对 chat_id/session_id/dialog_id 的边界没有类型保护,导致 CLI 层混用字符串。
- Tool 权限模型不够显式:工具是否能访问全文件系统、是否能联网、是否能修改状态主要靠工具自身约定。
- 后台任务生命周期分散gateway loop、agent worker、notification publisher、scheduler、sub-agent task 各自 spawn缺少统一管理。
## 模块级分析
### gateway
`GatewayState::new()` 是清晰的装配中心配置、workspace、storage、memory、bus、session manager、channels、MCP、scheduler 都在这里接线。问题是启动后任务监督不足,且 scheduler 默认 `unwrap_or_default()` 会在省略 `gateway.scheduler` 时启用调度器,这和“省略配置是否代表开启”需要产品层确认。
### channels
Feishu channel 功能较厚,单文件接近 2000 行,建议后续按 API client、message parsing、media handling、outbound rendering 拆分。CLI channel 目前是质量风险最高的 channel核心问题是会话身份混用和广播投递。
### bus
MessageBus 简洁,但当前消费者 API 通过 mutex 包住 receiver 并 `expect()`,更像“单消费者内部队列”。这没问题,但应该把“只能有一个 consumer”写进类型/文档,并把关闭作为正常状态处理。
### session
这是系统核心,也是债务最集中的模块。建议把 `session.rs` 拆成:
- `manager.rs`SessionManager 状态和 dialog 生命周期
- `worker.rs`per-session agent worker 和 cancellation
- `commands.rs`slash command 执行
- `outbound.rs`OutboundMessenger 实现
- `restore.rs`storage 恢复与 tool call chain repair
拆分之前,先补行为测试,尤其是 CLI/WS session lifecycle。
### agent
AgentLoop 的职责相对聚焦:请求模型、执行工具、回填 tool result、循环直到 final response。需要关注的是工具并发的语义`read_only()` 目前是工具自己声明副作用工具不能错标。LoopDetector 有帮助,但属于 runtime guard不应替代工具层的资源限制。
### providers
Provider 层整体可维护。OpenAI/Anthropic 的请求构造逻辑可以继续保留在 provider 内。建议补充请求脱敏策略:当前 debug log 和 `llm_calls` 会持久化完整 request/response可能包含用户隐私、API 返回内容和文件内容。
### tools
工具体系覆盖面很强,但需要明确权限模型。建议新增统一的 `ToolExecutionContext`,包含 workspace、channel、session_id、权限策略、网络策略、输出预算。现在很多策略散落在各工具构造函数里默认值容易失控。
### storage
Storage schema 初始化实用但迁移方式是“CREATE IF NOT EXISTS + ALTER IGNORE”适合早期迭代不适合长期演进。建议引入 schema version 表或 sqlx migrations至少把每次迁移记录下来。
### skills
Skill 加载优先级清晰,内置 skill 打包也实用。需要注意 `SkillsLoader` 使用同步文件系统扫描和 `std::sync::Mutex`,在请求路径频繁 `reload_if_changed()` 时可能造成阻塞。短期可以接受,长期建议缓存刷新放到后台 watcher。
## 建议修复路线
### P0先修会话正确性
1. 修正 CLI `chat_id/current_session_id` 数据模型。
2. 修正 CLI 出站按 client/chat_id 投递。
3. 实现 `get_current_dialog()``list_dialogs()` current 返回。
4. 修正 `/delete``clear_history``archive` 的真实行为或从协议移除。
5. 增加 WebSocket session lifecycle 测试。
### P1收紧工具和资源边界
1. 文件工具默认限制 workspace路径 canonicalize。
2. bash 超时杀进程,必要时引入进程组。
3. file_read 增加文件大小上限和二进制输出上限。
4. HTTP/web 工具增加 DNS 解析后的私网校验和重定向校验。
5. 明确高危工具的配置开关。
### P2降低架构复杂度
1. 拆分 `session.rs``feishu.rs``storage/mod.rs``browser.rs`
2. 引入任务 supervisor 和统一 shutdown token。
3. 引入正式数据库迁移。
4. 增加工具注册快照测试,避免死代码和文档漂移。
## 建议测试补充
- CLI 多客户端并发:两个 WebSocket client 同时发消息,互不串话。
- CLI 不传 chat_id 的连续对话:所有消息应进入同一 session。
- Load/switch/list/delete/clear 的完整 WebSocket 流程。
- `/delete` 后旧 session 软删除、新 session 成为 current。
- 文件路径逃逸:`../`、绝对路径、符号链接、workspace 前缀欺骗。
- bash timeout 后检查子进程不存在。
- cron `next_run_for_schedule()` 使用固定 `from` 的 deterministic 测试。
- HTTP 工具对 DNS 解析到 `127.0.0.1` / `10.0.0.0/8` 的域名拒绝测试。

View File

@ -0,0 +1,40 @@
# 客户端代码整合设计
## 目标
将分散在 `src/cli/``src/client/` 的客户端代码整合到 `src/client/` 目录。
## 变更
### 目录结构
```
src/
├── client/ # 整合后的客户端模块
│ ├── mod.rs # 主程序入口 (run 函数)
│ ├── input.rs # InputHandler + InputCommand (从 cli/input.rs 合并)
│ └── channel.rs # CliChannel (从 cli/channel.rs 合并)
├── cli/ # 删除
└── protocol.rs # 保留
```
### 关键变更
| 变更 | 说明 |
|------|------|
| `InputEvent::Message(String)` | 简化为只携带文本内容,不再使用 `ChatMessage` |
| `cli` 模块删除 | 代码合并到 `client` |
| 解耦 | `client` 不再依赖 `bus::ChatMessage` |
## 实施步骤
1. 创建 `src/client/input.rs` - 从 `cli/input.rs` 合并,修改 `InputEvent::Message``String`
2. 创建 `src/client/channel.rs` - 从 `cli/channel.rs` 直接复制
3. 更新 `src/client/mod.rs` - 更新 import
4. 更新 `src/lib.rs` - 删除 `pub mod cli;`
5. 删除 `src/cli/` 目录
## 验证
- `cargo build` 通过
- 功能保持不变

View File

@ -0,0 +1,877 @@
# Phase 1: Storage 基础 实现计划
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
**Goal:** 创建 `src/storage/` 模块,实现 SQLite 持久化层,为后续 Session 扩展提供 Storage 基础设施。
**Architecture:** 使用 `sqlx` + `sqlite`,通过 `SqlitePool` 实现异步连接池,所有 Storage 操作均为 async`Storage` 内部管理连接池的生命周期。
**Tech Stack:** `sqlx` (sqlite, tokio), `serde`, `chrono` (时间戳), `tokio::time::sleep` (重试退避)
---
## Task 1: 添加依赖
**Files:**
- Modify: `Cargo.toml:36` (在 `[dependencies]` 末尾添加)
**Step 1: 添加 sqlx + sqlite 依赖**
`Cargo.toml` 末尾添加:
```toml
sqlx = { version = "0.8", features = ["sqlite", "tokio", "macros", "chrono"] }
```
**Step 2: 运行 cargo check 验证依赖**
Run: `cargo check 2>&1`
Expected: 无报错,依赖解析成功
**Step 3: Commit**
```bash
git add Cargo.toml
git commit -m "deps: 添加 sqlx + sqlite 依赖"
```
---
## Task 2: 创建 Storage Error 类型
**Files:**
- Create: `src/storage/error.rs`
**Step 1: 编写 StorageError 枚举和测试**
```rust
use thiserror::Error;
#[derive(Error, Debug)]
pub enum StorageError {
#[error("session not found: {0}")]
NotFound(String),
#[error("session already exists: {0}")]
AlreadyExists(String),
#[error("database error: {0}")]
Database(#[from] sqlx::Error),
#[error("serialization error: {0}")]
Serialization(String),
}
```
**Step 2: 验证编译**
Run: `cargo build --lib 2>&1 | head -30`
Expected: 报错 "cannot find module `storage`"(因为模块未创建),这是预期的
**Step 3: Commit**
```bash
git add src/storage/error.rs
git commit -m "feat(storage): 添加 StorageError 类型"
```
---
## Task 3: 创建 Storage 模块骨架
**Files:**
- Create: `src/storage/mod.rs`
- Create: `src/storage/session.rs`
- Create: `src/storage/message.rs`
**Step 1: 创建 `src/storage/mod.rs`**
```rust
pub mod error;
pub mod session;
pub mod message;
pub use error::StorageError;
```
**Step 2: 创建 `src/storage/session.rs`(空壳)**
```rust
// Session CRUD 操作占位符
```
**Step 3: 创建 `src/storage/message.rs`(空壳)**
```rust
// Message CRUD 操作占位符
```
**Step 4: 在 `src/lib.rs` 中添加 storage 模块**
`src/lib.rs` 末尾添加:
```rust
pub mod storage;
```
**Step 5: 验证编译**
Run: `cargo build --lib 2>&1`
Expected: 编译成功(空壳模块)
**Step 6: Commit**
```bash
git add src/storage/ src/lib.rs
git commit -m "feat(storage): 创建 storage 模块骨架"
```
---
## Task 4: 实现 Storage 主结构
**Files:**
- Modify: `src/storage/mod.rs`
**Step 1: 编写 Storage 结构和初始化逻辑**
```rust
use sqlx::{Pool, Sqlite, SqlitePool};
use std::path::Path;
pub struct Storage {
pool: Pool<Sqlite>,
}
impl Storage {
/// 打开或创建数据库
pub async fn new(db_path: &Path) -> Result<Self, StorageError> {
let database_url = format!("sqlite:{}?mode=rwc", db_path.display());
let pool = SqlitePool::connect(&database_url).await?;
let storage = Self { pool };
storage.init_schema().await?;
Ok(storage)
}
/// 初始化数据库 schema
async fn init_schema(&self) -> Result<(), StorageError> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
channel TEXT NOT NULL,
chat_id TEXT NOT NULL,
dialog_id TEXT NOT NULL,
title TEXT NOT NULL DEFAULT '新对话',
created_at INTEGER NOT NULL,
last_active_at INTEGER NOT NULL,
message_count INTEGER DEFAULT 0,
routing_info TEXT,
deleted_at INTEGER,
UNIQUE(channel, chat_id, dialog_id)
)
"#,
)
.execute(&self.pool)
.await?;
sqlx::query(
r#"
CREATE INDEX IF NOT EXISTS idx_sessions_chat
ON sessions(channel, chat_id, deleted_at)
"#,
)
.execute(&self.pool)
.await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS messages (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
seq INTEGER NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
media_refs TEXT,
tool_call_id TEXT,
tool_name TEXT,
tool_calls TEXT,
created_at INTEGER NOT NULL,
FOREIGN KEY (session_id) REFERENCES sessions(id) ON DELETE CASCADE
)
"#,
)
.execute(&self.pool)
.await?;
sqlx::query(
r#"
CREATE INDEX IF NOT EXISTS idx_messages_session_seq
ON messages(session_id, seq)
"#,
)
.execute(&self.pool)
.await?;
Ok(())
}
/// 获取连接池引用(供内部 CRUD 使用)
pub(crate) fn pool(&self) -> &Pool<Sqlite> {
&self.pool
}
}
```
**Step 2: 验证编译**
Run: `cargo build --lib 2>&1`
Expected: 编译成功
**Step 3: Commit**
```bash
git add src/storage/mod.rs
git commit -m "feat(storage): 实现 Storage 主结构和初始化"
```
---
## Task 5: 定义 SessionMeta 和 MessageMeta 数据结构
**Files:**
- Modify: `src/storage/session.rs`
- Modify: `src/storage/message.rs`
**Step 1: 在 `session.rs` 中定义 SessionMeta**
```rust
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionMeta {
pub id: String,
pub channel: String,
pub chat_id: String,
pub dialog_id: String,
pub title: String,
pub created_at: i64,
pub last_active_at: i64,
pub message_count: i64,
pub routing_info: Option<String>,
pub deleted_at: Option<i64>,
}
```
**Step 2: 在 `message.rs` 中定义 MessageMeta**
```rust
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageMeta {
pub id: String,
pub session_id: String,
pub seq: i64,
pub role: String,
pub content: String,
pub media_refs: Option<String>,
pub tool_call_id: Option<String>,
pub tool_name: Option<String>,
pub tool_calls: Option<String>,
pub created_at: i64,
}
```
**Step 3: 验证编译**
Run: `cargo build --lib 2>&1`
Expected: 编译成功
**Step 4: Commit**
```bash
git add src/storage/session.rs src/storage/message.rs
git commit -m "feat(storage): 定义 SessionMeta 和 MessageMeta 数据结构"
```
---
## Task 6: 实现 Session CRUD 操作
**Files:**
- Modify: `src/storage/session.rs`
**Step 1: 编写 upsert_session**
```rust
use sqlx::Row;
use super::SessionMeta;
impl Storage {
pub async fn upsert_session(&self, meta: &SessionMeta) -> Result<(), StorageError> {
sqlx::query(
r#"
INSERT INTO sessions (id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
title = excluded.title,
last_active_at = excluded.last_active_at,
message_count = excluded.message_count,
routing_info = excluded.routing_info,
deleted_at = excluded.deleted_at
"#,
)
.bind(&meta.id)
.bind(&meta.channel)
.bind(&meta.chat_id)
.bind(&meta.dialog_id)
.bind(&meta.title)
.bind(meta.created_at)
.bind(meta.last_active_at)
.bind(meta.message_count)
.bind(&meta.routing_info)
.bind(meta.deleted_at)
.execute(self.pool())
.await?;
Ok(())
}
pub async fn get_session(&self, id: &str) -> Result<SessionMeta, StorageError> {
let row = sqlx::query(
r#"
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at
FROM sessions WHERE id = ? AND deleted_at IS NULL
"#,
)
.bind(id)
.fetch_optional(self.pool())
.await?
.ok_or_else(|| StorageError::NotFound(id.to_string()))?;
Ok(SessionMeta {
id: row.get("id"),
channel: row.get("channel"),
chat_id: row.get("chat_id"),
dialog_id: row.get("dialog_id"),
title: row.get("title"),
created_at: row.get("created_at"),
last_active_at: row.get("last_active_at"),
message_count: row.get("message_count"),
routing_info: row.get("routing_info"),
deleted_at: row.get("deleted_at"),
})
}
pub async fn list_sessions(
&self,
channel: &str,
chat_id: &str,
limit: i64,
) -> Result<Vec<SessionMeta>, StorageError> {
let rows = sqlx::query(
r#"
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at
FROM sessions
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL
ORDER BY last_active_at DESC
LIMIT ?
"#,
)
.bind(channel)
.bind(chat_id)
.bind(limit)
.fetch_all(self.pool())
.await?;
Ok(rows
.into_iter()
.map(|row| SessionMeta {
id: row.get("id"),
channel: row.get("channel"),
chat_id: row.get("chat_id"),
dialog_id: row.get("dialog_id"),
title: row.get("title"),
created_at: row.get("created_at"),
last_active_at: row.get("last_active_at"),
message_count: row.get("message_count"),
routing_info: row.get("routing_info"),
deleted_at: row.get("deleted_at"),
})
.collect())
}
pub async fn touch_session(
&self,
id: &str,
message_count: i64,
last_active_at: i64,
) -> Result<(), StorageError> {
sqlx::query(
r#"
UPDATE sessions SET message_count = ?, last_active_at = ?
WHERE id = ?
"#,
)
.bind(message_count)
.bind(last_active_at)
.bind(id)
.execute(self.pool())
.await?;
Ok(())
}
pub async fn soft_delete_session(&self, id: &str) -> Result<(), StorageError> {
let now = chrono::Utc::now().timestamp_millis();
sqlx::query(
r#"UPDATE sessions SET deleted_at = ? WHERE id = ?"#,
)
.bind(now)
.bind(id)
.execute(self.pool())
.await?;
Ok(())
}
/// 查找 channel:chat_id 下最近活跃且未过期的 session
pub async fn find_active_session(
&self,
channel: &str,
chat_id: &str,
ttl_millis: i64,
) -> Result<Option<SessionMeta>, StorageError> {
let cutoff = chrono::Utc::now().timestamp_millis() - ttl_millis;
let row = sqlx::query(
r#"
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at
FROM sessions
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL AND last_active_at > ?
ORDER BY last_active_at DESC
LIMIT 1
"#,
)
.bind(channel)
.bind(chat_id)
.bind(cutoff)
.fetch_optional(self.pool())
.await?;
match row {
Some(row) => Ok(Some(SessionMeta {
id: row.get("id"),
channel: row.get("channel"),
chat_id: row.get("chat_id"),
dialog_id: row.get("dialog_id"),
title: row.get("title"),
created_at: row.get("created_at"),
last_active_at: row.get("last_active_at"),
message_count: row.get("message_count"),
routing_info: row.get("routing_info"),
deleted_at: row.get("deleted_at"),
})),
None => Ok(None),
}
}
}
```
> 注意:`Storage` 的 CRUD 方法需要能访问 `pool()`,但目前 `pool()``pub(crate)`。在 `mod.rs` 中为 `session.rs` 实现 `Storage` 的 CRUD所以同模块内可访问。
**Step 2: 验证编译**
Run: `cargo build --lib 2>&1`
Expected: 编译成功
**Step 3: Commit**
```bash
git add src/storage/session.rs
git commit -m "feat(storage): 实现 Session CRUD 操作"
```
---
## Task 7: 实现 Message CRUD 操作
**Files:**
- Modify: `src/storage/message.rs`
**Step 1: 编写 Message CRUD**
```rust
use sqlx::Row;
use super::MessageMeta;
impl Storage {
pub async fn append_message(&self, session_id: &str, msg: &MessageMeta) -> Result<i64, StorageError> {
sqlx::query(
r#"
INSERT INTO messages (id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"#,
)
.bind(&msg.id)
.bind(session_id)
.bind(msg.seq)
.bind(&msg.role)
.bind(&msg.content)
.bind(&msg.media_refs)
.bind(&msg.tool_call_id)
.bind(&msg.tool_name)
.bind(&msg.tool_calls)
.bind(msg.created_at)
.execute(self.pool())
.await?;
Ok(msg.seq)
}
pub async fn append_messages(
&self,
session_id: &str,
msgs: &[MessageMeta],
) -> Result<Vec<i64>, StorageError> {
let mut seqs = Vec::with_capacity(msgs.len());
for msg in msgs {
let seq = self.append_message(session_id, msg).await?;
seqs.push(seq);
}
Ok(seqs)
}
pub async fn load_messages(
&self,
session_id: &str,
from_seq: i64,
) -> Result<Vec<MessageMeta>, StorageError> {
let rows = sqlx::query(
r#"
SELECT id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, created_at
FROM messages
WHERE session_id = ? AND seq >= ?
ORDER BY seq ASC
"#,
)
.bind(session_id)
.bind(from_seq)
.fetch_all(self.pool())
.await?;
Ok(rows
.into_iter()
.map(|row| MessageMeta {
id: row.get("id"),
session_id: row.get("session_id"),
seq: row.get("seq"),
role: row.get("role"),
content: row.get("content"),
media_refs: row.get("media_refs"),
tool_call_id: row.get("tool_call_id"),
tool_name: row.get("tool_name"),
tool_calls: row.get("tool_calls"),
created_at: row.get("created_at"),
})
.collect())
}
pub async fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> {
sqlx::query(r#"DELETE FROM messages WHERE session_id = ?"#)
.bind(session_id)
.execute(self.pool())
.await?;
Ok(())
}
}
```
> 注意:同样在 `mod.rs` 中实现,这样 `Storage` 的方法对 `message.rs` 中的 impl 可见。
**Step 2: 验证编译**
Run: `cargo build --lib 2>&1`
Expected: 编译成功
**Step 3: Commit**
```bash
git add src/storage/message.rs
git commit -m "feat(storage): 实现 Message CRUD 操作"
```
---
## Task 8: 实现写入重试逻辑
**Files:**
- Modify: `src/storage/mod.rs`
**Step 1: 在 Storage 中添加带重试的 append_message**
`mod.rs``Storage` impl 块中添加:
```rust
use tokio::time::{sleep, Duration};
impl Storage {
/// 追加消息,带重试逻辑
/// 重试 3 次100/200/300ms 退避),仍失败返回错误
pub async fn append_message_with_retry(
&self,
session_id: &str,
msg: &MessageMeta,
) -> Result<i64, StorageError> {
let delays = [100, 200, 300];
for (i, delay) in delays.iter().enumerate() {
match self.append_message(session_id, msg).await {
Ok(seq) => return Ok(seq),
Err(e) if i < delays.len() - 1 => {
sleep(Duration::from_millis(*delay)).await;
tracing::warn!("Storage write failed, retrying: {}", e);
}
Err(e) => {
tracing::error!("Storage write failed after retries: {}", e);
return Err(e);
}
}
}
unreachable!()
}
}
```
> 注意:需要 `use sqlx::Row;``mod.rs` 中。
**Step 2: 验证编译**
Run: `cargo build --lib 2>&1`
Expected: 编译成功
**Step 3: Commit**
```bash
git add src/storage/mod.rs
git commit -m "feat(storage): 添加写入重试逻辑"
```
---
## Task 9: 编写 Storage 单元测试
**Files:**
- Modify: `src/storage/mod.rs`(添加测试模块)
**Step 1: 编写 Storage 集成测试**
`src/storage/mod.rs` 末尾添加:
```rust
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
use std::path::Path;
async fn create_test_storage() -> (Storage, impl Fn()) {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let storage = Storage::new(&db_path).await.unwrap();
let cleanup = || {
drop(dir);
};
(storage, cleanup)
}
#[tokio::test]
async fn test_upsert_and_get_session() {
let (storage, cleanup) = create_test_storage().await;
defer { cleanup(); }
let meta = SessionMeta {
id: "cli_chat:sid123:dialog1".to_string(),
channel: "cli_chat".to_string(),
chat_id: "sid123".to_string(),
dialog_id: "dialog1".to_string(),
title: "测试会话".to_string(),
created_at: 1000,
last_active_at: 1000,
message_count: 0,
routing_info: Some(r#"{"type":"cli"}"#.to_string()),
deleted_at: None,
};
storage.upsert_session(&meta).await.unwrap();
let loaded = storage.get_session(&meta.id).await.unwrap();
assert_eq!(loaded.title, "测试会话");
assert_eq!(loaded.channel, "cli_chat");
}
#[tokio::test]
async fn test_get_nonexistent_session() {
let (storage, cleanup) = create_test_storage().await;
defer { cleanup(); }
let result = storage.get_session("nonexistent").await;
assert!(result.is_err());
matches!(result.unwrap_err(), StorageError::NotFound(_));
}
#[tokio::test]
async fn test_list_sessions() {
let (storage, cleanup) = create_test_storage().await;
defer { cleanup(); }
for i in 0..5 {
let meta = SessionMeta {
id: format!("cli_chat:sid123:dialog{}", i),
channel: "cli_chat".to_string(),
chat_id: "sid123".to_string(),
dialog_id: format!("dialog{}", i),
title: format!("会话{}", i),
created_at: i as i64 * 1000,
last_active_at: i as i64 * 1000,
message_count: i,
routing_info: None,
deleted_at: None,
};
storage.upsert_session(&meta).await.unwrap();
}
let sessions = storage.list_sessions("cli_chat", "sid123", 10).await.unwrap();
assert_eq!(sessions.len(), 5);
// 按 last_active_at DESC 排序
assert_eq!(sessions[0].dialog_id, "dialog4");
}
#[tokio::test]
async fn test_soft_delete() {
let (storage, cleanup) = create_test_storage().await;
defer { cleanup(); }
let meta = SessionMeta {
id: "cli_chat:sid123:dialog1".to_string(),
channel: "cli_chat".to_string(),
chat_id: "sid123".to_string(),
dialog_id: "dialog1".to_string(),
title: "测试".to_string(),
created_at: 1000,
last_active_at: 1000,
message_count: 0,
routing_info: None,
deleted_at: None,
};
storage.upsert_session(&meta).await.unwrap();
storage.soft_delete_session(&meta.id).await.unwrap();
let result = storage.get_session(&meta.id).await;
assert!(result.is_err());
matches!(result.unwrap_err(), StorageError::NotFound(_));
}
#[tokio::test]
async fn test_append_and_load_messages() {
let (storage, cleanup) = create_test_storage().await;
defer { cleanup(); }
let session_meta = SessionMeta {
id: "cli_chat:sid123:dialog1".to_string(),
channel: "cli_chat".to_string(),
chat_id: "sid123".to_string(),
dialog_id: "dialog1".to_string(),
title: "测试".to_string(),
created_at: 1000,
last_active_at: 1000,
message_count: 0,
routing_info: None,
deleted_at: None,
};
storage.upsert_session(&session_meta).await.unwrap();
let msg = MessageMeta {
id: "msg1".to_string(),
session_id: session_meta.id.clone(),
seq: 1,
role: "user".to_string(),
content: "你好".to_string(),
media_refs: None,
tool_call_id: None,
tool_name: None,
tool_calls: None,
created_at: 1000,
};
let seq = storage.append_message(&session_meta.id, &msg).await.unwrap();
assert_eq!(seq, 1);
let loaded = storage.load_messages(&session_meta.id, 0).await.unwrap();
assert_eq!(loaded.len(), 1);
assert_eq!(loaded[0].content, "你好");
}
#[tokio::test]
async fn test_touch_session() {
let (storage, cleanup) = create_test_storage().await;
defer { cleanup(); }
let meta = SessionMeta {
id: "cli_chat:sid123:dialog1".to_string(),
channel: "cli_chat".to_string(),
chat_id: "sid123".to_string(),
dialog_id: "dialog1".to_string(),
title: "测试".to_string(),
created_at: 1000,
last_active_at: 1000,
message_count: 0,
routing_info: None,
deleted_at: None,
};
storage.upsert_session(&meta).await.unwrap();
storage.touch_session(&meta.id, 5, 2000).await.unwrap();
let loaded = storage.get_session(&meta.id).await.unwrap();
assert_eq!(loaded.message_count, 5);
assert_eq!(loaded.last_active_at, 2000);
}
}
```
> 需要在 `Cargo.toml` 中添加 `tempfile` 依赖(已存在)。`defer` 宏可自己实现一个简单的:`fn defer<F: FnOnce()>(f: F) { f() }`
**Step 2: 运行测试**
Run: `cargo test storage::tests --lib 2>&1`
Expected: 所有 7 个测试 PASS
**Step 3: Commit**
```bash
git add src/storage/mod.rs
git commit -m "test(storage): 编写 Storage 单元测试"
```
---
## 汇总
| Task | 改动文件 | 关键交付物 |
|------|----------|-----------|
| 1 | `Cargo.toml` | sqlx 依赖 |
| 2 | `src/storage/error.rs` | StorageError |
| 3 | `src/storage/{mod.rs,session.rs,message.rs}`, `src/lib.rs` | 模块骨架 |
| 4 | `src/storage/mod.rs` | Storage 结构 + init_schema |
| 5 | `src/storage/session.rs`, `message.rs` | SessionMeta, MessageMeta |
| 6 | `src/storage/session.rs` | Session CRUD |
| 7 | `src/storage/message.rs` | Message CRUD |
| 8 | `src/storage/mod.rs` | append_message_with_retry |
| 9 | `src/storage/mod.rs` | 7 个单元测试 |
**Phase 1 完成后:** Storage 模块可独立使用,具备完整的持久化能力,可安全地集成到 Session 和 SessionManager 中。

View File

@ -0,0 +1,278 @@
# Session 持久化设计方案
## 概述
为 PicoBot 添加 SQLite 持久化层,实现 Session 数据的持久化、完整 Dialog 生命周期管理、消息实时落盘、以及基于 TTL 的自动内存清理。
## 核心概念
```
UnifiedSessionId = {channel}:{chat_id}:{dialog_id}
Session = Dialog两者等价不再分层
```
每个 Session 独立管理自己的消息历史、LLM 配置和路由信息。
## 数据库 Schema
### sessions 表
```sql
CREATE TABLE sessions (
id TEXT PRIMARY KEY,
channel TEXT NOT NULL,
chat_id TEXT NOT NULL,
dialog_id TEXT NOT NULL,
title TEXT NOT NULL DEFAULT '新对话',
created_at INTEGER NOT NULL,
last_active_at INTEGER NOT NULL,
message_count INTEGER DEFAULT 0,
routing_info TEXT,
deleted_at INTEGER,
UNIQUE(channel, chat_id, dialog_id)
);
CREATE INDEX idx_sessions_chat ON sessions(channel, chat_id, deleted_at);
```
### messages 表
```sql
CREATE TABLE messages (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
seq INTEGER NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
media_refs TEXT,
tool_call_id TEXT,
tool_name TEXT,
tool_calls TEXT,
created_at INTEGER NOT NULL,
FOREIGN KEY (session_id) REFERENCES sessions(id) ON DELETE CASCADE
);
CREATE INDEX idx_messages_session_seq ON messages(session_id, seq);
```
## Storage API
### Session 操作
| 方法 | 说明 |
|------|------|
| `new(db_path) -> Storage` | 打开/创建数据库 |
| `upsert_session(meta) -> Result<(), StorageError>` | 插入或更新 session 元数据 |
| `get_session(id) -> Result<SessionMeta, StorageError>` | 获取单个 session |
| `list_sessions(channel, chat_id, limit) -> Result<Vec<SessionMeta>>` | 最近 N 条 |
| `touch_session(id, message_count, last_active_at)` | 更新计数和最后活跃时间 |
| `soft_delete_session(id) -> Result<(), StorageError>` | 软删除 |
### Message 操作
| 方法 | 说明 |
|------|------|
| `append_message(session_id, msg) -> Result<i64, StorageError>` | 追加单条消息,返回 seq |
| `append_messages(session_id, msgs) -> Result<Vec<i64>, StorageError>` | 批量追加 |
| `load_messages(session_id, from_seq) -> Result<Vec<MessageMeta>>` | 从指定 seq 加载 |
| `clear_messages(session_id) -> Result<(), StorageError>` | 清除消息(保留 session |
### 写入失败处理
重试 3 次100/200/300ms 退避),仍失败则发送系统通知告警。
## Session 结构
```rust
pub struct Session {
pub id: UnifiedSessionId,
pub title: String,
pub created_at: i64,
pub last_active_at: i64,
pub message_count: i64, // 用户消息计数
pub total_message_count: i64, // 含系统消息
messages: Vec<ChatMessage>, // 内存消息历史
seq_counter: i64, // 下一个消息的 seq
provider_config: LLMProviderConfig,
provider: Arc<dyn LLMProvider>,
tools: Arc<ToolRegistry>,
compressor: ContextCompressor,
user_tx: mpsc::Sender<WsOutbound>,
storage: Arc<Storage>, // 持久化 sink
routing_info: String, // JSON 路由信息
}
```
### 初始化流程
```
new() 或 from_storage()
注入 storage 引用
创建 provider, tools, compressor
从 Storage 加载 messagesfrom_seq = 0
设置 seq_counter = messages.len() + 1
返回 Session 实例
```
## handle_message 流程
```
handle_message(channel, chat_id, sender_id, content, media)
├── 1. 确定 dialog_id
│ │
│ ├── 显式传入 dialog_id → 使用
│ └── 无 dialog_id
│ ├── 查找 channel:chat_id 下最近活跃且未过期的 session
│ ├── 找到 → 使用该 session
│ └── 未找到 → 创建新 sessiondialog_id = 新随机 ID
├── 2. 获取或创建 Session
│ 有 → 更新 session_timestamps
│ 无 → 从 Storage 恢复 或 创建新 Session
├── 3. 追加用户消息并持久化
│ seq = seq_counter; seq_counter += 1
│ Storage.append_message()(失败重试 → 告警)
│ messages.push(user_msg)
│ message_count += 1
├── 4. 检查 title 自动生成
│ message_count == 10 且 title == 默认值 → LLM 生成 → 更新 title → 写回 Storage
├── 5. 注入 skills_prompt
├── 6. 新 session 注入欢迎消息(系统消息,不计入 message_count
├── 7. 上下文压缩(如需要)
├── 8. 调用 AgentLoop
├── 9. 持久化 Agent 响应
└── 10. 返回响应
```
## Dialog 生命周期命令
| 命令 | 行为 |
|------|------|
| `/new [标题]` | 创建新 dialog新随机 dialog_id新建 Session |
| `/sessions` | 列出 channel:chat_id 下最近 10 条 session按 last_active_at 倒序) |
| `/switch <dialog_id>` | 切换到指定 session从 Storage 恢复或内存命中) |
| `/rename <新标题>` | 重命名当前 session |
| `/delete` | 软删除当前 session内存移除 + Storage 标记 deleted_at |
| `/info` | 显示当前 session 信息 |
| `/compact` | 手动触发上下文压缩 |
## 路由信息
每种 Channel 在创建 Session 时注入路由信息:
```rust
// CLI
routing_info = json!({"type": "cli", "ws_sender_id": "xxx"})
// Feishu
routing_info = json!({"type": "feishu", "open_conversation_id": "oc_xxx", "tenant_key": "xxx"})
```
## Title 自动生成
调用时机:
1. Session 首次创建时(初始 title = "新对话"
2. `message_count` 达到 10 且 title 仍为默认值时,自动更新
生成 Prompt
```
给定以下对话历史生成一个简短的会话标题5-15 个中文字符),
概括这个对话的核心内容或用户的主要需求。只返回一个标题,不要解释。
历史:
{messages}
```
## TTL 清理
- 内存 session 超时 → 释放内存Storage 记录保留
- 用户切换回该 session → 从 Storage 重新加载到内存
- Storage 中的 session 记录通过 `deleted_at` 软删除,不会物理删除
## 文件结构
```
src/
├── storage/
│ ├── mod.rs # Storage 主模块
│ ├── session.rs # Session CRUD
│ ├── message.rs # Message CRUD
│ └── error.rs # StorageError
└── session/
├── mod.rs # 导出 Session, SessionManager
├── session.rs # Session, SessionManager 实现
├── session_id.rs # UnifiedSessionId
├── commands.rs # SessionCommand
├── events.rs # SessionEvent, DialogInfo
└── error.rs # SessionError
```
## 实现顺序
### Phase 1: Storage 基础
1. 添加 `sqlx` + `sqlite` 依赖
2. 实现 `Storage` 结构(连接池、初始化)
3. Session CRUD + Message CRUD
4. 写入重试逻辑
5. 单元测试
### Phase 2: Session 扩展
1. 扩展 `Session` 结构(添加 storage、routing_info、计数字段、seq_counter
2. `from_storage()` 恢复逻辑
3. `add_message` 持久化集成
4. `send_system_notification` 接口
5. Title 自动生成
### Phase 3: SessionManager 完善
1. 注入 `Arc<Storage>`
2. 实现 `list_dialogs()`
3. 实现 `switch_dialog()`
4. 实现 `delete_dialog()` / `rename_dialog()`
5. 后台 TTL 清理任务
6. 集成测试
### Phase 4: 斜杠命令
1. 实现 `/sessions`
2. 实现 `/switch`
3. 实现 `/rename`
4. 实现 `/delete`
5. 端到端测试
## 配置项
```json
{
"session": {
"ttl_hours": 24,
"cleanup_interval_minutes": 60,
"auto_title_after_n_messages": 10,
"storage_retry_delays_ms": [100, 200, 300]
}
}
```
## 与现有代码的冲突点
| 冲突 | 处理方式 |
|------|----------|
| `DialogInfo``archived_at` | 删除该字段,改用 `deleted_at` |
| `SessionCommand::ArchiveDialog` | 删除 |
| `/new` 现有行为 | 改为创建新 session新 dialog_id |
| 现有 `Session` 无 storage/routing_info | 扩展结构,新增 `from_storage()` |
| `SessionManager` 需注入 `Arc<Storage>` | 扩展构造方法 |
| stub 方法 | 实现 |

View File

@ -0,0 +1,226 @@
# PicoBot Memory System Design
Date: 2026-05-07
## 1. Overview
Introduce a memory system that allows PicoBot agents to remember user preferences, project context, facts, and conversation history across sessions. The memory system is **unified with the existing context compression pipeline**: compression automatically produces `timeline` memory entries and advances a `last_consolidated_at` pointer to avoid redundant reprocessing.
### Design Principles
- **Compression is memory** (inspired by nanobot): when old messages are compressed, the summary is persisted — not discarded
- **FTS5 only** (no vector embeddings): keyword search via SQLite FTS5, sufficient for current scale
- **Extend existing infrastructure**: reuse `Storage` connection pool, `ContextCompressor`, `SystemPromptBuilder`
- **YAGNI**: no knowledge graph, no response cache, no namespace isolation, no audit trail
## 2. Core Architecture
```
ContextCompressor (existing) MemoryManager (new)
│ │
│ compress_if_needed() │ store / recall / forget
│ ├─ LLM summary → inject │
│ └─ store(timeline entry) ──────┘
│ └─ advance last_consolidated_at
SystemPromptBuilder ── recall(knowledge, limit=5) ──→ inject into system prompt
AgentLoop ── after_turn ──→ memory_store / memory_recall / memory_forget tools
```
## 3. Memory Categories
| Category | Purpose | Written By | Retrieved By |
|----------|---------|-----------|--------------|
| `knowledge` | Long-term facts, preferences, patterns, insights | Agent via `memory_store` tool | FTS5 → injected into system prompt every turn |
| `timeline` | Compressed conversation summaries | ContextCompressor automatically | FTS5 + time-range queries |
## 4. Storage Schema
### New table: `memories`
Added to the existing `Storage` initialization in `src/storage/mod.rs`:
```sql
CREATE TABLE IF NOT EXISTS memories (
id TEXT PRIMARY KEY,
key TEXT NOT NULL UNIQUE,
content TEXT NOT NULL,
category TEXT NOT NULL DEFAULT 'knowledge',
importance REAL NOT NULL DEFAULT 0.5,
session_id TEXT,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);
CREATE VIRTUAL TABLE IF NOT EXISTS memory_fts USING fts5(
key,
content,
content=memories,
content_rowid=rowid
);
```
### Modified table: `sessions`
```sql
ALTER TABLE sessions ADD COLUMN last_consolidated_at INTEGER;
```
## 5. Unified Compression-Memory Pipeline
### Trigger Conditions
Compression/consolidation fires when **any** of these conditions is met:
| Condition | Value | Rationale |
|-----------|-------|-----------|
| Token budget exceeds 50% threshold | `context_window / 2` | Primary trigger — context is getting full |
| Accumulated N turns without consolidation | 3 (configurable) | Catch-up for short messages that don't hit token threshold |
| Session idle | 10 minutes (configurable) | Important for async channels like Feishu |
### Flow
```
compress_if_needed(history, session_id):
1. Read last_consolidated_at from session
→ Only compress messages after that timestamp
2. If no messages to compress → return history unchanged
3. FTS5 recall(user_input, limit=recall_limit, category=knowledge)
→ Inject relevant facts into system prompt
4. LLM summarization of old messages → [Context Summary]
→ Inject into current conversation
5. Store summary as timeline entry:
key: "ctx_{session_id}_{uuid}"
content: "[YYYY-MM-DD HH:MM] summary text..."
category: timeline
6. UPDATE sessions.last_consolidated_at = now()
7. Return compressed history
```
### timeline Entry Format
Each timeline entry follows nanobot's convention:
```
[2026-05-07 14:30] User asked about Rust async patterns. Discussed tokio::select!,
semaphore-based rate limiting, and backpressure strategies. No code was written.
```
This format is grep-friendly and human-readable.
## 6. Retrieval Strategy
### Automatic Retrieval (every turn)
`SystemPromptBuilder.build_system_prompt()` calls:
```rust
memory.recall(query=user_message, limit=recall_limit, category=knowledge)
```
Results sorted by FTS5 BM25 score, injected as:
```
## Memory Context
- user_prefers_rust: User prefers Rust for all backend projects
- project_picobot_stack: PicoBot uses Rust, axum, sqlx, ratatui, tokio
- user_workflow: User prefers TDD workflow with cargo test --lib
```
### Agent-Initiated Retrieval
Agent uses `memory_recall` tool with optional `category`, `since`, `until` parameters.
### Fallback
If FTS5 returns empty results, fallback to `LIKE '%keyword%'` on `key` and `content` columns.
## 7. Agent Tools
| Tool | Parameters | Description |
|------|-----------|-------------|
| `memory_store` | `key: str`, `content: str`, `category: str`, `importance?: f64` | Write or update a memory entry. Key is semantic identifier (e.g., "user_language_pref") |
| `memory_recall` | `query: str`, `category?: str`, `since?: i64`, `until?: i64`, `limit?: usize` | Search memories by keyword and optional filters |
| `memory_forget` | `key: str` | Delete a memory entry by key |
## 8. Error Handling & Degradation
| Scenario | Strategy |
|----------|----------|
| Consolidation LLM call fails | Log warning, increment failure counter, do NOT block main flow |
| Consecutive failures >= 3 | Degrade: append raw message dump to timeline with `[RAW]` prefix, reset counter |
| FTS5 recall returns empty | Fallback to `LIKE '%keyword%'` query |
| `memory.enabled = false` | ContextCompressor works normally, no memory writes |
| MemoryManager uninitialized | ContextCompressor works with feature-gated memory write path |
## 9. Configuration
```json
{
"memory": {
"enabled": true,
"consolidation_provider": "openai",
"consolidation_model": "gpt-4o-mini",
"recall_limit": 5,
"consolidation_turn_threshold": 3,
"idle_consolidation_minutes": 10,
"timeline_retention_days": 90,
"max_failures_before_degrade": 3
}
}
```
| Key | Type | Default | Description |
|-----|------|---------|-------------|
| `enabled` | bool | `false` | Master switch for memory system |
| `consolidation_provider` | string | — | Provider name for consolidation LLM calls |
| `consolidation_model` | string | — | Model name for consolidation |
| `recall_limit` | usize | `5` | Max knowledge entries injected into system prompt |
| `consolidation_turn_threshold` | usize | `3` | Turns before forced consolidation |
| `idle_consolidation_minutes` | u64 | `10` | Idle time before consolidation trigger |
| `timeline_retention_days` | u64 | `90` | Auto-cleanup age for timeline entries |
| `max_failures_before_degrade` | usize | `3` | Consecutive failures before raw archive fallback |
## 10. New Module Structure
```
src/
├── memory/
│ ├── mod.rs # MemoryManager, MemoryConfig
│ ├── types.rs # MemoryEntry, MemoryCategory, ConsolidationResult
│ └── consolidation.rs # Consolidation prompt + LLM call logic
├── storage/
│ └── memory.rs # SQLite CRUD for memories table + FTS5
├── tools/
│ ├── memory_store.rs # memory_store tool
│ ├── memory_recall.rs # memory_recall tool
│ └── memory_forget.rs # memory_forget tool
```
## 11. Integration Points (Existing Files Modified)
| File | Change |
|------|--------|
| `src/lib.rs` | Add `pub mod memory;` |
| `src/config/mod.rs` | Add `MemoryConfig` struct and deserialization |
| `src/storage/mod.rs` | Add `pub mod memory;`, init `memories` table and FTS5 in `init_schema()` |
| `src/storage/session.rs` | Add `last_consolidated_at` column read/write |
| `src/session/session.rs` | Add `last_consolidated_at: Option<i64>` field to Session |
| `src/agent/context_compressor.rs` | Add `memory: Option<Arc<MemoryManager>>` field, write timeline on compress |
| `src/agent/system_prompt.rs` | Add `memory_context` section via `MemoryManager::recall()` |
| `src/agent/agent_loop.rs` | No changes (tools registered via ToolRegistry) |
| `src/tools/mod.rs` | Register `memory_store`, `memory_recall`, `memory_forget` in `create_default_tools()` |
| `src/gateway/mod.rs` | Initialize `MemoryManager` in `GatewayState::new()`, pass to ContextCompressor |
## 12. Implementation Order
| # | Task | Dependencies |
|---|------|-------------|
| 1 | Types: `MemoryEntry`, `MemoryCategory`, `ConsolidationResult` | — |
| 2 | Config: `MemoryConfig` + deserialization | — |
| 3 | Storage: `memories` table + FTS5 + CRUD + search | #1 |
| 4 | `MemoryManager` API | #1, #2, #3 |
| 5 | Session: `last_consolidated_at` field | — |
| 6 | `ContextCompressor` memory integration | #4, #5 |
| 7 | `SystemPromptBuilder` memory context injection | #4 |
| 8 | Agent tools: `memory_store`, `memory_recall`, `memory_forget` | #4 |
| 9 | `GatewayState` initialization wiring | #4, #5, #6 |
| 10 | Unit tests | #1-#9 |

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,90 @@
# 启动增量恢复设计
## 问题
PicoBot 重启后,`Session::from_storage()` 全量加载 `messages` 表,恢复的 history 可能直接超出上下文窗口,首次 LLM 调用即触发压缩,浪费 token。
## 设计
### 核心思路
`last_compressed_message_at` 标记最后压缩时刻。恢复时:
- 加载该标记之后的原始消息
- 用该 session 的 Timeline 条目替代已压缩部分
- `seq_counter` 统一从 SQLite 查 `MAX(seq) + 1`
```
messages 表 memories(timeline)
┌──────────────────────────┐ ┌───────────────────────────┐
│ created_at = T1..T5 │ ← 跳过 │ session = feishu:oc:dialog │
│ (压缩已覆盖用Timeline替代)│ │ created_at 降序 │
├──────────────────────────┤ ├───────────────────────────┤
│ created_at > T6 │ ← 加载 │ 只取最近 3 条 │
└──────────────────────────┘ └───────────────────────────┘
```
### 数据变更
**`sessions` 表加列:**
```sql
last_compressed_message_at INTEGER
```
**`SessionMeta` / `Session` 加字段:** `last_compressed_message_at: Option<i64>`
### Storage 层新增方法
| 方法 | SQL |
|------|-----|
| `get_max_message_seq(session_id)` | `SELECT MAX(seq) FROM messages WHERE session_id = ?` |
| `load_messages_after_timestamp(session_id, after_ts)` | `WHERE created_at > ?` |
| `load_session_timelines(session_id, limit)` | `WHERE session_id = ? AND category = 'timeline' ORDER BY created_at DESC LIMIT ?` |
### 压缩跟踪
`compress_if_needed()` 返回值改为 `CompressionResult { history, created_timelines: bool }`
`compress_once()` 中 LLM 摘要路径才置 `true`Tier 2Tier 1/3 不产生 Timeline。
**记录时机**`handle_message` 正常流、溢出重试流、`/compact` 统一):
```rust
if result.created_timelines {
session.last_compressed_message_at = Some(now());
session.persist_session_meta().await;
}
```
### Session::from_storage() 恢复逻辑
有压缩标记时:
1. `load_session_timelines(limit=4)` → 取 3 条给 LLM第 4 条判"有更多"
2. 有更多 → 插入提示 user 消息
3. 逐条插入 Timeline 为 `[Previous Context]` user 消息
4. `load_messages_after_timestamp(after_ts)` → 原始尾消息
5. `repair_tool_call_chains`
无压缩标记 → 全量加载(现有行为)。
统一:`seq_counter = MAX(seq) + 1`
### 系统提示词
`Session.last_compressed_message_at` 非空时追加:
```
## 历史会话
之前的对话摘要已归档。如需回顾历史上下文,使用 `timeline_recall` 工具搜索。
```
## 改动清单
| # | 文件 | 改动 |
|---|------|------|
| 1 | `storage/session.rs` | `SessionMeta``last_compressed_message_at` |
| 2 | `storage/mod.rs` | DDL migration + upsert/get_session 加列 |
| 3 | `storage/mod.rs` | 新增 `get_max_message_seq`, `load_messages_after_timestamp` |
| 4 | `storage/memory.rs` | 新增 `load_session_timelines` |
| 5 | `agent/context_compressor.rs` | 返回值改为 `CompressionResult``created_timelines` |
| 6 | `session/session.rs` | `Session` 加字段,`persist_session_meta` 加字段 |
| 7 | `session/session.rs` | `from_storage()` 重写恢复逻辑 |
| 8 | `session/session.rs` | `handle_message()` 压缩后记录标记 |
| 9 | `session/session.rs` | `/compact` 命令压缩后记录标记 |
| 10 | `session/session.rs` | `build_system_prompt()` 注入 `last_compressed_message_at` |

View File

@ -0,0 +1,674 @@
# 启动增量恢复 Implementation Plan
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
**Goal:** PicoBot 重启后不再全量加载 messages 表,而是基于 `last_compressed_message_at` 标记增量恢复,用 Timeline 替代已压缩部分。
**Architecture:** 在 `sessions` 表加 `last_compressed_message_at` 列,`compress_if_needed` 返回值增加 `created_timelines` 标志,恢复时按时间戳增量加载消息并用近 3 条 Timeline 前置,`seq_counter` 统一从 SQLite 查 MAX(seq)。
**Tech Stack:** Rust, sqlx (SQLite), tokio
---
### Task 1: SessionMeta 和数据库 DDL 加列
**Files:**
- Modify: `src/storage/session.rs:15`
- Modify: `src/storage/mod.rs:44-45` (DDL), `:172-180` (migration)
- Modify: `src/storage/mod.rs:317-326` (upsert_session SQL + ON CONFLICT)
- Modify: `src/storage/mod.rs:345-369` (get_session SELECT + struct)
- Modify: `src/storage/mod.rs:380-406`, `:454-479`, `:564-588`, `:728`, `:754` (list_sessions 及测试 mock)
**Step 1: 在 `src/storage/session.rs` SessionMeta 加字段**
`last_consolidated_at: Option<i64>` 后加一行:
```rust
pub last_compressed_message_at: Option<i64>,
```
**Step 2: DDL schema 加列**
`src/storage/mod.rs` 的 CREATE TABLE sessions 中 (line 44)`last_consolidated_at INTEGER` 后加逗号和:
```sql
last_compressed_message_at INTEGER
```
**Step 3: migration 加列**
`src/storage/mod.rs` line 182 之后(现有 migration 的 `); .ok();` 之后),添加新 migration
```rust
// Migration: add last_compressed_message_at column if not exists
sqlx::query(
r#"ALTER TABLE sessions ADD COLUMN last_compressed_message_at INTEGER"#,
)
.execute(&self.pool)
.await
.ok();
```
**Step 4: upsert_session SQL 加列**
`src/storage/mod.rs` line 317: INSERT 列列表加 `last_compressed_message_at`VALUES 加 `?`ON CONFLICT DO UPDATE SET 加 `last_compressed_message_at = excluded.last_compressed_message_at`。line 338 后加 `.bind(meta.last_compressed_message_at)`
**Step 5: get_session SELECT 加列**
`src/storage/mod.rs` line 348: SELECT 列加 `last_compressed_message_at`。line 368 后加:
```rust
last_compressed_message_at: row.get("last_compressed_message_at"),
```
**Step 6: 其他 SELECT sessions 的地方list_sessions 多个变体)**
同样补 `last_compressed_message_at` 到 SELECT 列和 struct 构造。以及测试中的 mock SessionMeta 构造line 728, 754 等)。
**Step 7: 编译检查**
```bash
cargo check 2>&1
```
**Step 8: Commit**
```bash
git add src/storage/session.rs src/storage/mod.rs
git commit -m "feat(storage): add last_compressed_message_at column to sessions"
```
---
### Task 2: Storage 新增加载方法
**Files:**
- Modify: `src/storage/mod.rs` (在 load_messages 之后)
- Modify: `src/storage/memory.rs` (在 cleanup_old_timelines 之后)
**Step 1: `get_max_message_seq`**
`src/storage/mod.rs``load_messages` 函数后面添加:
```rust
pub async fn get_max_message_seq(&self, session_id: &str) -> Result<i64, StorageError> {
let row = sqlx::query(
"SELECT COALESCE(MAX(seq), 0) as max_seq FROM messages WHERE session_id = ?",
)
.bind(session_id)
.fetch_one(self.pool())
.await?;
Ok(row.get::<i64, _>("max_seq"))
}
```
**Step 2: `load_messages_after_timestamp`**
```rust
pub async fn load_messages_after_timestamp(
&self,
session_id: &str,
after_ts: i64,
) -> Result<Vec<crate::storage::message::MessageMeta>, StorageError> {
let rows = sqlx::query(
r#"
SELECT id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at
FROM messages
WHERE session_id = ? AND created_at > ?
ORDER BY seq ASC
"#,
)
.bind(session_id)
.bind(after_ts)
.fetch_all(self.pool())
.await?;
Ok(rows.into_iter().map(|row| crate::storage::message::MessageMeta {
id: row.get("id"),
session_id: row.get("session_id"),
seq: row.get("seq"),
role: row.get("role"),
content: row.get("content"),
media_refs: row.get("media_refs"),
tool_call_id: row.get("tool_call_id"),
tool_name: row.get("tool_name"),
tool_calls: row.get("tool_calls"),
source: row.get("source"),
created_at: row.get("created_at"),
}).collect())
}
```
**Step 3: `load_session_timelines`**
`src/storage/memory.rs``cleanup_old_timelines` 之后line 252 的 `}` 之前)添加:
```rust
pub async fn load_session_timelines(
&self,
session_id: &str,
limit: usize,
) -> Result<Vec<MemoryEntry>, StorageError> {
let rows = sqlx::query(
r#"
SELECT id, key, content, category, importance,
session_id, created_at, updated_at
FROM memories
WHERE session_id = ? AND category = 'timeline'
ORDER BY created_at DESC
LIMIT ?
"#,
)
.bind(session_id)
.bind(limit as i64)
.fetch_all(self.pool())
.await?;
parse_memory_rows(&rows)
}
```
**Step 4: 编译检查**
```bash
cargo check 2>&1
```
**Step 5: Commit**
```bash
git add src/storage/mod.rs src/storage/memory.rs
git commit -m "feat(storage): add load_messages_after_timestamp, load_session_timelines, get_max_message_seq"
```
---
### Task 3: ContextCompressor 引入 CompressionResult
**Files:**
- Modify: `src/agent/context_compressor.rs:172-274` (compress_if_needed)
- Modify: `src/agent/context_compressor.rs:320-402` (compress_once)
**Step 1: 定义 CompressionResult**
在 context_compressor.rs 中 `ContextCompressor` struct 定义之后添加:
```rust
pub struct CompressionResult {
pub history: Vec<ChatMessage>,
pub created_timelines: bool,
}
```
**Step 2: 修改 compress_if_needed 签名和返回**
`pub async fn compress_if_needed(&self, mut history: Vec<ChatMessage>) -> Result<Vec<ChatMessage>, AgentError>` 改为:
```rust
pub async fn compress_if_needed(
&self,
mut history: Vec<ChatMessage>,
) -> Result<CompressionResult, AgentError> {
```
内部的 `return Ok(history)` 改为 `return Ok(CompressionResult { history, created_timelines: false })`Tier 1 fast trim 和不需要压缩时)。
**Step 3: 修改 LLM summarization pass 部分**
在压缩循环中维护一个 `created_timelines` 标志:
```rust
let mut created_timelines = false;
for pass in 0..self.config.max_passes {
// ...
match self.compress_once(...).await {
Ok(Some(compressed)) => {
current_history = compressed;
created_timelines = true;
}
// ...
}
}
```
最后返回:
```rust
Ok(CompressionResult { history: current_history, created_timelines })
```
**Step 4: 更新所有 compress_if_needed 调用方**
所有 `compress_if_needed(history)` 改为 `compress_if_needed(history).await?.history`。在 `handle_message``/compact` 中还需要用到 `created_timelines`
**Step 5: 编译检查**
```bash
cargo check 2>&1
```
**Step 6: Commit**
```bash
git add src/agent/context_compressor.rs src/session/session.rs
git commit -m "feat(compressor): return CompressionResult with created_timelines flag"
```
---
### Task 4: Session 结构体和持久化
**Files:**
- Modify: `src/session/session.rs:52-74` (Session struct)
- Modify: `src/session/session.rs:76-121` (Session::new)
- Modify: `src/session/session.rs:298-320` (persist_session_meta)
**Step 1: Session struct 加字段**
`pub last_consolidated_at: Option<i64>` 后加:
```rust
pub last_compressed_message_at: Option<i64>,
```
**Step 2: Session::new 初始化**
`last_consolidated_at: None` 后加:
```rust
last_compressed_message_at: None,
```
**Step 3: persist_session_meta 加字段**
`last_consolidated_at: self.last_consolidated_at` 后加:
```rust
last_compressed_message_at: self.last_compressed_message_at,
```
**Step 4: 编译检查**
```bash
cargo check 2>&1
```
**Step 5: Commit**
```bash
git add src/session/session.rs
git commit -m "feat(session): add last_compressed_message_at field to Session and persist"
```
---
### Task 5: Session::from_storage() 增量恢复
**Files:**
- Modify: `src/session/session.rs:124-189` (from_storage)
**Step 1: 重写 from_storage**
替换现有实现为:
```rust
pub async fn from_storage(
id: UnifiedSessionId,
provider_config: LLMProviderConfig,
tools: Arc<ToolRegistry>,
storage: StdArc<Storage>,
memory_manager: Arc<crate::memory::MemoryManager>,
) -> Result<Self, AgentError> {
let session_meta = storage.get_session(&id.to_string()).await
.map_err(|e| AgentError::Other(format!("failed to load session from storage: {}", e)))?;
let mut provider_box = create_provider(provider_config.clone())
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
provider_box.set_storage(storage.clone());
let provider: Arc<dyn LLMProvider> = Arc::from(provider_box);
let compressor_config = ContextCompressionConfig {
protect_first_n: 2,
..Default::default()
};
let mut compressor = ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config, memory_manager.clone());
compressor.set_session_id(Some(id.to_string()));
// Determine recovery strategy
let mut chat_messages: Vec<ChatMessage> = Vec::new();
if let Some(after_ts) = session_meta.last_compressed_message_at {
// Load last 4 timelines to determine if there are > 3
let timelines = storage
.load_session_timelines(&id.to_string(), 4)
.await
.unwrap_or_default();
let has_more_timelines = timelines.len() > 3;
// Insert hint if more timelines exist
if has_more_timelines {
chat_messages.push(ChatMessage::user(
"[Earlier conversation summaries exist. \
Use `timeline_recall` to search if needed.]"
));
}
// Insert latest 3 timelines as context (reversed: oldest first)
for tl in timelines.iter().take(3).rev() {
chat_messages.push(ChatMessage::user(format!(
"[Previous Context]\n{}", tl.content
)));
}
// Load raw messages after compressed timestamp
let tail = storage
.load_messages_after_timestamp(&id.to_string(), after_ts)
.await
.unwrap_or_default();
let mut tail_msgs: Vec<ChatMessage> = tail.into_iter().map(|m| {
ChatMessage {
id: m.id,
role: m.role,
content: m.content,
media_refs: m.media_refs.map(|refs| serde_json::from_str(&refs).unwrap_or_default()).unwrap_or_default(),
timestamp: m.created_at,
tool_call_id: m.tool_call_id,
tool_name: m.tool_name,
tool_calls: m.tool_calls
.and_then(|tc| serde_json::from_str::<Vec<crate::providers::ToolCall>>(&tc).ok())
.filter(|v| !v.is_empty()),
source: m.source.and_then(|s| serde_json::from_str(&s).ok()),
}
}).collect();
repair_tool_call_chains(&mut tail_msgs);
chat_messages.extend(tail_msgs);
} else {
// No prior compression — load all messages (existing behavior)
let messages = storage.load_messages(&id.to_string(), 0).await
.map_err(|e| AgentError::Other(format!("failed to load messages from storage: {}", e)))?;
chat_messages = messages.into_iter().map(|m| {
ChatMessage {
id: m.id,
role: m.role,
content: m.content,
media_refs: m.media_refs.map(|refs| serde_json::from_str(&refs).unwrap_or_default()).unwrap_or_default(),
timestamp: m.created_at,
tool_call_id: m.tool_call_id,
tool_name: m.tool_name,
tool_calls: m.tool_calls
.and_then(|tc| serde_json::from_str::<Vec<crate::providers::ToolCall>>(&tc).ok())
.filter(|v| !v.is_empty()),
source: m.source.and_then(|s| serde_json::from_str(&s).ok()),
}
}).collect();
repair_tool_call_chains(&mut chat_messages);
}
// seq_counter from actual DB max
let max_seq = storage
.get_max_message_seq(&id.to_string())
.await
.unwrap_or(0);
let seq_counter = max_seq + 1;
let total_message_count = session_meta.message_count;
Ok(Self {
id: id.clone(),
title: session_meta.title,
created_at: session_meta.created_at,
last_active_at: session_meta.last_active_at,
message_count: session_meta.message_count,
total_message_count,
messages: chat_messages,
seq_counter,
provider_config: provider_config.clone(),
provider: provider.clone(),
tools,
compressor,
storage: Some(storage),
routing_info: session_meta.routing_info.unwrap_or_default(),
last_consolidated_at: session_meta.last_consolidated_at,
last_compressed_message_at: session_meta.last_compressed_message_at,
memory_manager,
})
}
```
**Step 2: 编译检查**
```bash
cargo check 2>&1
```
**Step 3: Commit**
```bash
git add src/session/session.rs
git commit -m "feat(session): incremental recovery from storage using compressed timeline"
```
---
### Task 6: 系统提示词加历史会话提示
**Files:**
- Modify: `src/agent/system_prompt.rs:289-304` (MemorySection)
- Modify: `src/agent/system_prompt.rs:16-23` (PromptContext)
- Modify: `src/agent/system_prompt.rs:343-358` (build_system_prompt free function)
- Modify: `src/session/session.rs:411-426` (build_system_prompt)
**Step 1: PromptContext 加 has_compressed_history 字段**
```rust
pub struct PromptContext<'a> {
pub workspace_dir: &'a Path,
pub model_name: &'a str,
pub tools: &'a ToolRegistry,
pub session_id: Option<&'a str>,
pub memory_context: Option<&'a str>,
pub has_compressed_history: bool,
}
```
**Step 2: 加 HistorySection**
在 MemorySection 后面添加:
```rust
pub struct HistorySection;
impl PromptSection for HistorySection {
fn name(&self) -> &str {
"history"
}
fn build(&self, ctx: &PromptContext<'_>) -> String {
if ctx.has_compressed_history {
"## 历史会话\n之前的对话摘要已归档。如需回顾历史上下文使用 `timeline_recall` 工具搜索。".to_string()
} else {
String::new()
}
}
}
```
**Step 3: 注册到 SystemPromptBuilder::with_defaults**
`with_defaults()` 的 sections vec 中 `Box::new(MemorySection)` 后加 `Box::new(HistorySection)`
**Step 4: 更新 build_system_prompt 签名和调用**
```rust
pub fn build_system_prompt(
workspace_dir: &Path,
model_name: &str,
tools: &ToolRegistry,
session_id: Option<&str>,
memory_context: Option<&str>,
has_compressed_history: bool,
) -> String {
let ctx = PromptContext {
workspace_dir,
model_name,
tools,
session_id,
memory_context,
has_compressed_history,
};
SystemPromptBuilder::with_defaults().build(&ctx)
}
```
**Step 5: 更新 Session::build_system_prompt**
```rust
pub fn build_system_prompt(&self, skills_prompt: &str, memory_context: Option<&str>) -> String {
let base_prompt = build_system_prompt(
&self.provider_config.workspace_dir,
&self.provider_config.model_id,
&self.tools,
Some(&self.id.to_string()),
memory_context,
self.last_compressed_message_at.is_some(),
);
if skills_prompt.trim().is_empty() {
base_prompt
} else {
format!("{}\n\n## Skills\n\n{}\n\nUse the `get_skill` tool to load a skill's full content when needed.", base_prompt, skills_prompt)
}
}
```
**Step 6: 更新所有其他 build_system_prompt 调用方**
搜索 `build_system_prompt(` 的所有调用位置,每个都要加 `false` 参数。主要有 `agent/agent_loop.rs` 中的两个调用。
**Step 7: 编译检查**
```bash
cargo check 2>&1
```
**Step 8: Commit**
```bash
git add src/agent/system_prompt.rs src/session/session.rs src/agent/agent_loop.rs
git commit -m "feat(system-prompt): add history section for archived conversation context"
```
---
### Task 7: handle_message 和 /compact 记录压缩标记
**Files:**
- Modify: `src/session/session.rs:1339-1355` (handle_message 压缩后)
- Modify: `src/session/session.rs:1372-1376` (handle_message 溢出重试)
- Modify: `src/session/session.rs:851-872` (/compact 命令)
**Step 1: handle_message 正常流**
`compress_if_needed(history).await?` 之后line 1346改为
```rust
let result = session_guard.compressor
.compress_if_needed(history)
.await?;
if result.created_timelines {
session_guard.last_compressed_message_at = Some(chrono::Utc::now().timestamp_millis());
if let Err(e) = session_guard.persist_session_meta().await {
tracing::warn!(error = %e, "Failed to persist compressed message marker");
}
}
let mut history = result.history;
```
同时删除后面line 1350-1355单独的 `persist_session_meta` 调用(现在已合入上面的逻辑)。
**Step 2: handle_message 溢出重试流**
```rust
let raw = session_guard.get_history().to_vec();
let result = session_guard.compressor.compress_if_needed(raw).await?;
if result.created_timelines {
session_guard.last_compressed_message_at = Some(chrono::Utc::now().timestamp_millis());
let _ = session_guard.persist_session_meta().await;
}
let mut retry = result.history;
retry.insert(0, ChatMessage::system(system_prompt));
agent.process(retry).await?
```
**Step 3: /compact 命令**
```rust
let result = session_guard.compressor
.compress_if_needed(history)
.await?;
let compressed_count = result.history.len();
if result.created_timelines {
session_guard.last_compressed_message_at = Some(chrono::Utc::now().timestamp_millis());
let _ = session_guard.persist_session_meta().await;
}
session_guard.clear_history();
for msg in result.history {
session_guard.add_message(msg, false).await
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
}
```
同时确认 `compress_if_needed` 的 import 正常(已在 scope 中)。
**Step 4: 编译检查**
```bash
cargo check 2>&1
```
**Step 5: Commit**
```bash
git add src/session/session.rs
git commit -m "feat(session): record last_compressed_message_at after compression"
```
---
### Task 8: 全局编译和测试
**Step 1: 全局编译**
```bash
cargo check 2>&1
```
修复所有编译错误,确保全部文件一致。
**Step 2: 运行单元测试**
```bash
cargo test --lib 2>&1
```
**Step 3: 测试通过后 commit**
```bash
git add -A
git commit -m "chore: fix remaining compilation and test issues for incremental recovery"
```
**Step 4: 运行 lint**
```bash
cargo clippy --lib 2>&1 | head -50
```
修复任何 warning。
---
### Task 9: 验证 & 提交设计文档
**Step 1: 最终验证**
```bash
cargo test --lib 2>&1
```
**Step 2: Commit 设计文档**
```bash
git add docs/plans/2026-05-10-incremental-session-recovery-design.md
git commit -m "docs: add incremental session recovery design doc"
```

File diff suppressed because it is too large Load Diff

View File

@ -5,7 +5,7 @@ always: true
---
# About PicoBot
PicoBot 是一个基于 Rust 的个人 AI 助手运行时,包含本地 Gateway、CLI TUI 客户端、飞书渠道、SQLite 会话持久化、长期记忆、定时任务、Skill 系统、MCP 工具接入和子 Agent 委托能力
PicoBot 是一个基于 Rust 的个人 AI 助手支持多渠道飞书、CLI、长记忆、定时任务、Skill 系统等
## 目录索引
@ -13,10 +13,10 @@ PicoBot 是一个基于 Rust 的个人 AI 助手运行时,包含本地 Gateway
| 文件 | 内容 |
|------|------|
| `references/config.md` | 配置字段详解providers、models、agents、gateway、client、channels、memory、mcp、browser |
| `references/db-schema.md` | 数据库表结构sessions、messages、memories、scheduled_jobs、llm_calls、background_tasks |
| `references/architecture.md` | 核心架构数据流、会话系统、上下文压缩、记忆系统、Skill 优先级、MCP、子 Agent |
| `references/faq.md` | 常见问题模型切换、渠道添加、Skill 安装、历史查询、定时任务、MCP 等 |
| `references/config.md` | 配置字段详解providers、models、agents、gateway、memory、channels、mcp |
| `references/db-schema.md` | 数据库表结构sessions、messages、memories、scheduled_jobs、llm_calls |
| `references/architecture.md` | 核心架构数据流、会话系统、上下文压缩、记忆系统、Skill 优先级机制 |
| `references/faq.md` | 常见问题模型切换、渠道添加、Skill 安装、历史查询、定时任务等 |
| `references/commands.md` | 常用命令:编译、启动网关、启动客户端、运行测试 |
| `assets/config.example.json` | config.json 完整示例 |

View File

@ -72,15 +72,5 @@
"timeline_retention_days": 90,
"max_failures_before_degrade": 3
},
"mcp": {
"servers": [],
"tool_timeout_secs": 180
},
"browser": {
"enabled": false,
"webdriver_url": "http://127.0.0.1:9515",
"headless": true,
"chrome_path": null
},
"workspace_dir": "~/.picobot/workspace"
}

View File

@ -17,9 +17,9 @@ Channel → MessageBus → SessionManager → AgentLoop → (tools) → SessionM
| `channels` | 外部集成飞书、CLI仅收发消息 |
| `bus` | 异步消息队列,纯队列不路由 |
| `session` | 会话生命周期管理、dialog 操作 |
| `agent` | LLM 调用循环、工具执行、上下文压缩、媒体处理、子 Agent |
| `agent` | LLM 调用循环、工具执行、上下文压缩 |
| `providers` | LLM API 客户端OpenAI 兼容、Anthropic |
| `tools` | Agent 工具bash、文件操作、搜索、HTTP、web、browser、memory、delegate 等) |
| `tools` | Agent 工具bash、文件操作、HTTP、web、get_skill 等) |
| `skills` | Skill 加载、管理和 prompt 构建 |
| `storage` | SQLite 持久化 |
| `scheduler` | Cron 作业调度 |
@ -37,8 +37,6 @@ Channel → MessageBus → SessionManager → AgentLoop → (tools) → SessionM
- AgentLoop 无状态,接收 dialog 事件调用 LLM、执行工具
- Providers 是纯 HTTP 客户端,无 bus/session/channel 感知
- Tools 接收原始参数,返回字符串结果
- MCP 工具在 Gateway 初始化时连接服务器、发现工具,并包装成普通 Tool 注册到 ToolRegistry
- 子 Agent 由 `delegate` 工具创建,复用 provider 配置和按需过滤后的工具集;后台任务结果通过 MessageBus 发回原会话
## 关键约束
@ -47,7 +45,6 @@ Channel → MessageBus → SessionManager → AgentLoop → (tools) → SessionM
- ChannelManager 持有 MessageBus 和所有 channel
- OutboundDispatcher 通过 ChannelManager 路由出站消息
- Config `.env` 加载使用 `unsafe { env::set_var(...) }`
- `browser` 工具只有在 `browser.enabled=true` 时注册,依赖 Chrome/Chromium 与 WebDriver
## 上下文压缩
@ -195,48 +192,3 @@ LLM 对话上下文接近 token 限制 (默认 128K × 70%) 时自动触发压
| 有压缩历史时 | `HistorySection` 提示 LLM 使用 `timeline_recall` |
| 压缩完成后 | 摘要自动存储为 Timeline 记忆 |
| 空闲时 | 可配置自动 consolidation`idle_consolidation_minutes` |
---
## MCP 工具集成
Gateway 初始化时读取 `config.mcp.servers`
1. 按服务器配置连接 `stdio``sse``streamable-http` 传输
2. 调用 MCP `list_tools`
3. 将每个 MCP tool 包装为 `McpToolWrapper`
4. 注册到当前 session 的 `ToolRegistry`
`/mcp` 斜杠命令会显示 MCP 服务器连接状态和工具列表。
---
## 子 Agent / delegate
`delegate` 工具用于把独立任务交给子 Agent
| 模式 | 行为 |
|------|------|
| `inline` | 当前轮阻塞等待子 Agent 返回 |
| `background` | 后台运行,完成后通过原 channel/chat 通知 |
| `parallel` | 多个子 Agent 并发执行并聚合结果 |
默认工具集是只读工具:`file_read``file_search``content_search``web_fetch``http_request``calculator`。调用时可通过 `allowed_tools` 显式放开其他工具。后台任务会写入 `background_tasks` 表,默认 24 小时后清理。
---
## 当前斜杠命令
| 命令 | 说明 |
|------|------|
| `/new` | 创建新对话 |
| `/sessions` | 列出最近对话 |
| `/switch <dialog_id>` | 切换到指定对话 |
| `/rename <title>` | 重命名当前对话 |
| `/delete` | 删除当前对话 |
| `/compact` | 手动触发上下文压缩 |
| `/info` | 显示当前对话信息 |
| `/dump` | 保存当前对话为 markdown |
| `/?`, `/help` | 显示帮助 |
| `/mcp` | 显示 MCP 状态 |
| `/stop` | 停止当前任务并清空消息队列 |

View File

@ -14,9 +14,8 @@
"client": {}, // 客户端配置
"channels": {}, // 渠道配置
"memory": {}, // 记忆系统配置
"workspace_dir": "", // 工作目录,默认 ~/.picobot/workspace
"mcp": {}, // MCP 服务器配置
"browser": {} // 可选浏览器自动化配置
"workspace_dir": // 工作目录,默认 ~/.picobot/workspace
"mcp": {} // MCP 服务器配置
}
```
@ -58,17 +57,8 @@
| `session_ttl_hours` | int | - | 会话过期小时数 |
| `session_db_path` | string | - | SQLite 数据库路径,默认在 workspace 下 |
| `cleanup_interval_minutes` | int | - | 清理间隔 |
| `max_concurrent_background_tasks` | int | 10 | delegate 后台子任务最大并发数 |
| `scheduler` | object | - | 调度器配置 |
### gateway.scheduler 字段
| 字段 | 类型 | 默认 | 说明 |
|------|------|------|------|
| `enabled` | bool | true | 是否启动调度器并注册 cron 工具 |
| `poll_interval_secs` | int | 60 | 检查到期任务的轮询间隔 |
| `max_concurrent` | int | 1 | 最大并发任务数,当前实现预留 |
## memory 字段
| 字段 | 类型 | 默认 | 说明 |
@ -104,21 +94,8 @@ MCP 服务器单条配置:
| 字段 | 说明 |
|------|------|
| `name` | 服务器名称 |
| `transport` | 传输方式: `stdio`、`sse`、`streamable-http` |
| `command` | 启动命令(stdio 模式) |
| `transport` | 传输方式: `Stdio`、`Sse`、`streamable-http` |
| `command` | 启动命令(Stdio 模式) |
| `args` | 命令参数 |
| `env` | 子进程环境变量 |
| `url` | URLsse / streamable-http 模式) |
| `headers` | HTTP 传输额外请求头 |
| `url` | URLSse / streamable-http 模式) |
| `tool_timeout_secs` | 单独的超时设置 |
## browser 字段
浏览器工具默认关闭,开启后注册 `browser` 工具。依赖 Chrome/Chromium 与 chromedriver/WebDriver。
| 字段 | 类型 | 默认 | 说明 |
|------|------|------|------|
| `enabled` | bool | false | 是否启用浏览器工具 |
| `webdriver_url` | string | http://127.0.0.1:9515 | WebDriver 服务地址 |
| `headless` | bool | true | 是否无头运行 |
| `chrome_path` | string | - | 自定义 Chrome/Chromium 路径 |

View File

@ -36,28 +36,6 @@
| `tool_calls` | TEXT | 工具调用参数 JSON |
| `source` | TEXT | 消息来源(跨会话消息时标记来源 session_id |
| `created_at` | INTEGER | 创建时间unix 秒) |
| `reasoning_content` | TEXT | provider 返回的推理内容(如有) |
## background_tasks 表
delegate 后台子任务表。`session_id` 不使用数据库外键,因为 session 使用软删除,关联关系由应用层维护。
| 字段 | 类型 | 说明 |
|------|------|------|
| `id` | TEXT PK | 后台任务 ID |
| `session_id` | TEXT | 所属会话 |
| `channel` | TEXT | 回传渠道 |
| `chat_id` | TEXT | 回传目标对话 |
| `prompt` | TEXT | 子任务提示 |
| `allowed_tools` | TEXT | 允许工具 JSON |
| `status` | TEXT | pending / running / completed / failed / cancelled |
| `result` | TEXT | 执行结果 |
| `error` | TEXT | 错误信息 |
| `tool_calls_count` | INTEGER | 工具调用次数 |
| `iterations` | INTEGER | Agent 迭代次数 |
| `started_at` | INTEGER | 开始时间 |
| `finished_at` | INTEGER | 结束时间 |
| `created_at` | INTEGER | 创建时间 |
## memories 表

View File

@ -124,51 +124,9 @@
---
## delegate — 子 Agent 委托
## file_read / file_write / file_edit / file_search — 文件操作
创建子 Agent 处理独立任务。
| 参数 | 必填 | 说明 |
|------|------|------|
| `action` | 是 | `run`, `check_task`, `cancel_task`, `list_tasks` |
| `prompt` | run 必填 | 子任务描述 |
| `mode` | 否 | `inline`, `background`, `parallel`,默认 `inline` |
| `allowed_tools` | 否 | 子 Agent 可用工具列表;默认只读工具集 |
| `max_iterations` | 否 | 最大迭代次数,默认 99 |
| `timeout_secs` | 否 | 超时秒数,默认 3600 |
| `tasks` | parallel 必填 | 并行子任务数组 |
| `task_id` | 查询/取消必填 | 后台任务 ID |
默认只读工具集:`file_read``file_search``content_search``web_fetch``http_request``calculator`
---
## browser — 浏览器自动化
仅在 `browser.enabled=true` 时注册。底层使用 WebDriver/Chrome。
| action | 说明 |
|--------|------|
| `open` | 打开 URL |
| `snapshot` | 获取页面结构快照 |
| `click`, `click_at` | 点击元素或坐标 |
| `fill`, `type`, `press` | 输入文本或按键 |
| `get_text`, `get_title`, `get_url` | 读取页面信息 |
| `screenshot` | 截图,可写入文件或返回 base64 |
| `focus`, `hover`, `scroll`, `wait` | 常见交互和等待 |
| `close` | 关闭浏览器会话 |
---
## MCP 工具
如果 `config.mcp.servers` 配置了 MCP 服务器Gateway 启动时会连接服务器、发现工具,并把 MCP 工具包装后注册到 ToolRegistry。使用 `/mcp` 查看当前连接状态和工具列表。
---
## file_read / file_write / file_edit / file_search / content_search — 文件操作和搜索
工作目录内的文件读写编辑、文件名搜索和内容搜索。详细的参数定义见各工具的 parameters_schema。
工作目录内的文件读写编辑和搜索。详细的参数定义见各工具的 parameters_schema。
## bash — 执行命令

View File

@ -72,15 +72,5 @@
"timeline_retention_days": 90,
"max_failures_before_degrade": 3
},
"mcp": {
"servers": [],
"tool_timeout_secs": 180
},
"browser": {
"enabled": false,
"webdriver_url": "http://127.0.0.1:9515",
"headless": true,
"chrome_path": null
},
"workspace_dir": "~/.picobot/workspace"
}

View File

@ -4,8 +4,10 @@ use crate::agent::system_prompt::build_system_prompt;
use crate::bus::message::ContentBlock;
use crate::bus::{ChatMessage, MediaRef};
use crate::config::LLMProviderConfig;
use crate::observability::{Observer, ObserverEvent, ToolExecutionOutcome, truncate_args};
use crate::providers::{ChatCompletionRequest, LLMProvider, Message, ToolCall, create_provider};
use crate::observability::{
truncate_args, Observer, ObserverEvent, ToolExecutionOutcome,
};
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall};
use crate::tools::ToolRegistry;
use std::collections::VecDeque;
use std::hash::{Hash, Hasher};
@ -226,7 +228,6 @@ pub struct AgentLoop {
pub struct AgentProcessResult {
pub final_response: ChatMessage,
pub emitted_messages: Vec<ChatMessage>,
pub total_tokens: Option<u32>,
}
impl AgentLoop {
@ -254,10 +255,7 @@ impl AgentLoop {
}
/// Create a new AgentLoop with provider created from config and given tools.
pub fn with_tools(
provider_config: LLMProviderConfig,
tools: Arc<ToolRegistry>,
) -> Result<Self, AgentError> {
pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>) -> Result<Self, AgentError> {
let max_iterations = provider_config.max_tool_iterations;
let model_name = provider_config.model_id.clone();
let workspace_dir = provider_config.workspace_dir.clone();
@ -280,13 +278,7 @@ impl AgentLoop {
}
/// Create a new AgentLoop with an existing shared provider.
pub fn with_provider(
provider: Arc<dyn LLMProvider>,
max_iterations: usize,
model_name: String,
workspace_dir: PathBuf,
input_types: Vec<String>,
) -> Self {
pub fn with_provider(provider: Arc<dyn LLMProvider>, max_iterations: usize, model_name: String, workspace_dir: PathBuf, input_types: Vec<String>) -> Self {
Self {
provider,
tools: Arc::new(ToolRegistry::new()),
@ -348,9 +340,8 @@ impl AgentLoop {
}
/// Preemptive trim: truncate old tool results in-place when history is
/// approaching the context window limit. Old results (outside of `keep_recent`
/// zone) are replaced with a short placeholder; recent results are truncated
/// to `max_chars`.
/// approaching the context window limit. Only trims tool messages with
/// content > TRIM_CHARS, preserving the most recent KEEP messages.
fn preemptive_trim_old_tool_results(
&self,
messages: &mut [ChatMessage],
@ -367,11 +358,11 @@ impl AgentLoop {
if messages[i].content.len() <= max_chars {
continue;
}
let tool_name = messages[i].tool_name.as_deref().unwrap_or("unknown");
let chars = messages[i].content.len();
let removed = messages[i].content.len() - max_chars;
messages[i].content = format!(
"[Tool output ({}) — {} chars, omitted from context]",
tool_name, chars
"{}...\n\n[Output truncated - {} characters removed]",
&messages[i].content[..messages[i].content.ceil_char_boundary(max_chars)],
removed
);
modified += 1;
}
@ -386,12 +377,7 @@ impl AgentLoop {
let content = if m.media_refs.is_empty() {
vec![ContentBlock::text(&m.content)]
} else {
build_content_blocks(
&m.content,
&m.media_refs,
&self.input_types,
&self.media_registry,
)
build_content_blocks(&m.content, &m.media_refs, &self.input_types, &self.media_registry)
};
Message {
@ -411,28 +397,14 @@ impl AgentLoop {
/// it loops back to the LLM with the tool results until either:
/// - The LLM returns no more tool calls (final response)
/// - Maximum iterations are reached
pub async fn process(
&self,
mut messages: Vec<ChatMessage>,
) -> Result<AgentProcessResult, AgentError> {
pub async fn process(&self, mut messages: Vec<ChatMessage>) -> Result<AgentProcessResult, AgentError> {
#[cfg(debug_assertions)]
tracing::debug!(
history_len = messages.len(),
max_iterations = self.max_iterations,
"Starting agent process"
);
tracing::debug!(history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process");
// Build and inject system prompt if not present
let has_system = messages.first().is_some_and(|m| m.role == "system");
if !has_system {
let system_prompt = build_system_prompt(
&self.workspace_dir,
&self.model_name,
&self.tools,
None,
None,
false,
);
let system_prompt = build_system_prompt(&self.workspace_dir, &self.model_name, &self.tools, None, None, false);
#[cfg(debug_assertions)]
tracing::debug!("System prompt injected:\n{}", system_prompt);
messages.insert(0, ChatMessage::system(system_prompt));
@ -441,7 +413,6 @@ impl AgentLoop {
// Track tool calls for loop detection
let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default());
let mut emitted_messages = Vec::new();
let mut accumulated_tokens: u32 = 0;
for iteration in 0..self.max_iterations {
#[cfg(debug_assertions)]
@ -453,7 +424,9 @@ impl AgentLoop {
let estimated = estimate_tokens(&messages);
let danger = (self.context_window as f64 * 0.8) as usize;
if estimated > danger {
let trimmed = self.preemptive_trim_old_tool_results(&mut messages, 2000, 4);
let trimmed = self.preemptive_trim_old_tool_results(
&mut messages, 2000, 4,
);
if trimmed > 0 {
#[cfg(debug_assertions)]
tracing::debug!(
@ -487,12 +460,11 @@ impl AgentLoop {
};
// Call LLM
let response = (*self.provider).chat(request).await.map_err(|e| {
tracing::error!(error = %e, "LLM request failed");
AgentError::LlmError(e.to_string())
})?;
accumulated_tokens += response.usage.total_tokens;
let response = (*self.provider).chat(request).await
.map_err(|e| {
tracing::error!(error = %e, "LLM request failed");
AgentError::LlmError(e.to_string())
})?;
#[cfg(debug_assertions)]
tracing::debug!(
@ -510,15 +482,12 @@ impl AgentLoop {
return Ok(AgentProcessResult {
final_response: assistant_message,
emitted_messages,
total_tokens: Some(accumulated_tokens),
});
}
// Execute tool calls — log and notify immediately
{
let tools_info: Vec<String> = response
.tool_calls
.iter()
let tools_info: Vec<String> = response.tool_calls.iter()
.map(|tc| {
let args = serde_json::to_string(&tc.arguments).unwrap_or_default();
let s = format!("{}:{}", tc.name, args);
@ -547,9 +516,7 @@ impl AgentLoop {
// Log function call with name and arguments
let args_str = match &tool_call.arguments {
serde_json::Value::Object(obj) if obj.is_empty() => "{}".to_string(),
other => {
serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string())
}
other => serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()),
};
tracing::info!(tool = %tool_call.name, args = %args_str, "Calling tool");
@ -589,11 +556,7 @@ impl AgentLoop {
// Loop continues to next iteration with updated messages
#[cfg(debug_assertions)]
tracing::debug!(
iteration,
message_count = messages.len(),
"Tool execution complete, continuing to next iteration"
);
tracing::debug!(iteration, message_count = messages.len(), "Tool execution complete, continuing to next iteration");
}
// Max iterations reached - ask LLM for a summary based on completed work
@ -602,7 +565,7 @@ impl AgentLoop {
// Add a message asking for summary
let summary_request = ChatMessage::user(
"You have reached the maximum number of tool call iterations. \
Please provide your best answer based on the work completed so far.",
Please provide your best answer based on the work completed so far."
);
messages.push(summary_request);
@ -621,32 +584,24 @@ impl AgentLoop {
match (*self.provider).chat(request).await {
Ok(response) => {
accumulated_tokens += response.usage.total_tokens;
let mut assistant_message = ChatMessage::assistant(response.content);
assistant_message.reasoning_content = response.reasoning_content;
emitted_messages.push(assistant_message.clone());
Ok(AgentProcessResult {
final_response: assistant_message,
emitted_messages,
total_tokens: Some(accumulated_tokens),
})
}
Err(e) => {
// Fallback if summary call fails
tracing::error!(error = %e, "Failed to get summary from LLM");
let final_message = ChatMessage::assistant(format!(
"I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.",
self.max_iterations
));
let final_message = ChatMessage::assistant(
format!("I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.", self.max_iterations)
);
emitted_messages.push(final_message.clone());
Ok(AgentProcessResult {
final_response: final_message,
emitted_messages,
total_tokens: if accumulated_tokens > 0 {
Some(accumulated_tokens)
} else {
None
},
})
}
}
@ -734,7 +689,10 @@ impl AgentLoop {
}
// Apply duration
ToolExecutionOutcome { duration, ..result }
ToolExecutionOutcome {
duration,
..result
}
}
/// Internal tool execution without event tracking.
@ -756,12 +714,18 @@ impl AgentLoop {
ToolExecutionOutcome::success(result.output)
} else {
let error = result.error.unwrap_or_default();
ToolExecutionOutcome::failure(format!("Error: {}", error), Some(error))
ToolExecutionOutcome::failure(
format!("Error: {}", error),
Some(error),
)
}
}
Err(e) => {
tracing::error!(tool = %tool_call.name, error = %e, "Tool execution failed");
ToolExecutionOutcome::failure(format!("Error: {}", e), Some(e.to_string()))
ToolExecutionOutcome::failure(
format!("Error: {}", e),
Some(e.to_string()),
)
}
}
}
@ -849,14 +813,8 @@ mod tests {
assert_eq!(provider_message.role, "assistant");
assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1);
assert_eq!(
provider_message.tool_calls.as_ref().unwrap()[0].id,
"call_1"
);
assert_eq!(
provider_message.tool_calls.as_ref().unwrap()[0].name,
"calculator"
);
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].id, "call_1");
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].name, "calculator");
}
}

View File

@ -68,10 +68,6 @@ pub struct ContextCompressor {
memory: Arc<MemoryManager>,
/// Current session ID for timeline memory writes.
session_id: Option<String>,
/// Message count sent in the last LLM call (used to split known/new history).
last_sent_message_count: Option<usize>,
/// Real total_tokens from the last API response.
last_api_total_tokens: Option<u32>,
}
/// Result of context compression.
@ -80,15 +76,6 @@ pub struct CompressionResult {
pub created_timelines: bool,
}
/// Token budget state snapshot for diagnostics.
pub struct TokenInfo {
pub context_window: usize,
pub threshold: usize,
pub estimated_tokens: usize,
pub last_api_tokens: Option<u32>,
pub cache_active: bool,
}
impl ContextCompressor {
/// Create a new compressor with the given provider, context window size, and memory manager.
pub fn new(
@ -103,8 +90,6 @@ impl ContextCompressor {
provider,
memory,
session_id: None,
last_sent_message_count: None,
last_api_total_tokens: None,
}
}
@ -122,8 +107,6 @@ impl ContextCompressor {
provider,
memory,
session_id: None,
last_sent_message_count: None,
last_api_total_tokens: None,
}
}
@ -137,91 +120,39 @@ impl ContextCompressor {
self.context_window = window;
}
/// Record the API's reported token usage from the last completed turn.
/// `msg_count`: number of messages sent to LLM in that call.
/// `tokens`: `total_tokens` from the API response.
pub fn set_last_api_info(&mut self, msg_count: usize, tokens: Option<u32>) {
self.last_sent_message_count = Some(msg_count);
self.last_api_total_tokens = tokens;
}
/// Invalidate the cached API token info — called after compression modifies messages.
fn invalidate_token_cache(&mut self) {
self.last_sent_message_count = None;
self.last_api_total_tokens = None;
}
/// Hybrid token estimation: API-reported tokens for known history +
/// char/4 estimate for new messages since last API call.
fn token_estimate_with_history(&self, messages: &[ChatMessage]) -> usize {
match (self.last_api_total_tokens, self.last_sent_message_count) {
(Some(known), Some(known_count)) if messages.len() > known_count => {
let delta = &messages[known_count..];
known as usize + estimate_tokens(delta)
}
(Some(known), _) => known as usize,
_ => estimate_tokens(messages),
}
}
/// Always true — memory is always available (memory system is always on).
pub fn has_memory(&self) -> bool {
true
}
/// Get a snapshot of the current token budget state for diagnostics.
pub fn token_info(&self, messages: &[ChatMessage]) -> TokenInfo {
TokenInfo {
context_window: self.context_window,
threshold: self.threshold(),
estimated_tokens: self.token_estimate_with_history(messages),
last_api_tokens: self.last_api_total_tokens,
cache_active: self.last_api_total_tokens.is_some(),
}
}
/// Get the compression threshold in tokens.
pub fn threshold(&self) -> usize {
(self.context_window as f64 * self.threshold_ratio) as usize
}
/// Fast-path: trim oversized tool results without LLM call.
/// Old tool results (outside of `protect_tail` zone) are replaced with a
/// concise placeholder; recent results are truncated to `tool_result_trim_chars`.
/// Returns the number of messages modified.
fn fast_trim_tool_results(&self, messages: &mut [ChatMessage], protect_tail: usize) -> usize {
fn fast_trim_tool_results(&self, messages: &mut [ChatMessage]) -> usize {
let limit = self.config.tool_result_trim_chars;
let tail_start = messages.len().saturating_sub(protect_tail);
let mut modified = 0;
for (i, msg) in messages.iter_mut().enumerate() {
if msg.role != "tool" || msg.content.len() <= limit {
continue;
}
if i < tail_start {
let tool_name = msg.tool_name.as_deref().unwrap_or("unknown");
let chars = msg.content.len();
msg.content = format!(
"[Tool output ({}) — {} chars, omitted from context]",
tool_name, chars
);
} else {
for msg in messages.iter_mut() {
if msg.role == "tool" && msg.content.len() > limit {
let removed = msg.content.len() - limit;
msg.content = format!(
"{}...\n\n[Output truncated - {} characters removed]",
&msg.content[..msg.content.ceil_char_boundary(limit)],
removed
);
modified += 1;
}
modified += 1;
}
modified
}
/// Repair tool call chains after compression.
/// Phase 1: remove orphan tool results whose declaring tool_calls are missing.
/// Phase 2: strip tool_calls from assistants whose results are missing.
/// Remove orphan tool results whose declaring tool_calls have been compressed away.
/// Scans for tool messages with no preceding assistant tool_call, and removes them.
pub fn repair_tool_pairs(messages: &mut Vec<ChatMessage>) {
let mut declared: std::collections::HashSet<String> = std::collections::HashSet::new();
let mut i = 0;
@ -234,58 +165,23 @@ impl ContextCompressor {
}
} else if messages[i].role == "tool"
&& let Some(ref tid) = messages[i].tool_call_id
&& !declared.contains(tid.as_str())
{
messages.remove(i);
continue;
}
&& !declared.contains(tid.as_str()) {
messages.remove(i);
continue;
}
i += 1;
}
let broken: Vec<usize> = messages
.iter()
.enumerate()
.filter_map(|(idx, msg)| {
if msg.role == "assistant"
&& let Some(ref tcs) = msg.tool_calls
&& !tcs.is_empty()
{
let all_present = tcs.iter().all(|tc| {
messages.iter().any(|m| {
m.role == "tool" && m.tool_call_id.as_deref() == Some(tc.id.as_str())
})
});
if !all_present { Some(idx) } else { None }
} else {
None
}
})
.collect();
for idx in broken {
let msg = &mut messages[idx];
let tcs = msg.tool_calls.take().unwrap_or_default();
let names: Vec<&str> = tcs.iter().map(|tc| tc.name.as_str()).collect();
msg.content = format!(
"{}\n\n[Tool calls ({}) — results are no longer available]",
msg.content,
names.join(", ")
);
}
}
/// Main entry point - compresses history if over threshold.
pub async fn compress_if_needed(
&mut self,
&self,
mut history: Vec<ChatMessage>,
) -> Result<CompressionResult, AgentError> {
// Check if compression is needed
let tokens = self.token_estimate_with_history(&history);
let tokens = estimate_tokens(&history);
if tokens <= self.threshold() {
return Ok(CompressionResult {
history,
created_timelines: false,
});
return Ok(CompressionResult { history, created_timelines: false });
}
#[cfg(debug_assertions)]
@ -297,8 +193,8 @@ impl ContextCompressor {
);
// Fast trim pass first — modify history in place
let trimmed = self.fast_trim_tool_results(&mut history, self.config.protect_last_n);
let tokens_after = self.token_estimate_with_history(&history);
let trimmed = self.fast_trim_tool_results(&mut history);
let tokens_after = estimate_tokens(&history);
if trimmed > 0 {
#[cfg(debug_assertions)]
tracing::debug!(
@ -308,24 +204,24 @@ impl ContextCompressor {
);
}
if tokens_after <= self.threshold() {
self.invalidate_token_cache();
return Ok(CompressionResult {
history,
created_timelines: false,
});
return Ok(CompressionResult { history, created_timelines: false });
}
// LLM summarization pass
let mut current_history = history;
let mut created_timelines = false;
for pass in 0..self.config.max_passes {
let tokens = self.token_estimate_with_history(&current_history);
let tokens = estimate_tokens(&current_history);
if tokens <= self.threshold() {
break;
}
#[cfg(debug_assertions)]
tracing::debug!(pass = pass + 1, tokens = tokens, "Compression pass");
tracing::debug!(
pass = pass + 1,
tokens = tokens,
"Compression pass"
);
match self.compress_once(&current_history).await {
Ok(Some(compressed)) => {
@ -345,52 +241,15 @@ impl ContextCompressor {
// Hard safety net: if still dangerously high after all passes,
// fall back to head+tail truncation so the LLM call doesn't overflow.
let final_tokens = self.token_estimate_with_history(&current_history);
let final_tokens = estimate_tokens(&current_history);
let danger_threshold = (self.context_window as f64 * 0.9) as usize;
if final_tokens > danger_threshold
&& current_history.len() > self.config.protect_first_n + self.config.protect_last_n
{
let mut tail_start = current_history.len() - self.config.protect_last_n;
// Align tail_start backwards to preserve tool chain boundaries:
// if an assistant with tool_calls has results spanning the cut,
// include the assistant in the tail.
if tail_start > 0 && tail_start < current_history.len() {
let mut scan = tail_start.saturating_sub(1);
loop {
let m = &current_history[scan];
if m.role == "assistant" {
if let Some(tcs) = &m.tool_calls
&& !tcs.is_empty()
{
let has_post = current_history[scan + 1..]
.iter()
.filter(|r| r.role == "tool")
.any(|r| {
tcs.iter()
.any(|tc| r.tool_call_id.as_deref() == Some(tc.id.as_str()))
});
if has_post {
tail_start = scan;
}
}
break;
}
if scan == 0 {
break;
}
scan -= 1;
}
}
// Skip orphan tool messages at the new head-tail boundary
while tail_start < current_history.len() && current_history[tail_start].role == "tool" {
tail_start += 1;
}
let head: Vec<_> = current_history[..self.config.protect_first_n].to_vec();
let tail_start = current_history.len() - self.config.protect_last_n;
let tail: Vec<_> = current_history[tail_start..].to_vec();
let dropped = current_history.len() - self.config.protect_first_n - tail.len();
let dropped = current_history.len() - self.config.protect_first_n - self.config.protect_last_n;
let mut truncated = head;
truncated.push(ChatMessage::user(format!(
@ -400,26 +259,6 @@ impl ContextCompressor {
)));
truncated.extend(tail);
// Strip tool_calls from any assistant in the head whose results
// were dropped (previously in the middle section).
for msg in &mut truncated[..self.config.protect_first_n] {
if msg.role == "assistant" {
if let Some(ref tcs) = msg.tool_calls
&& !tcs.is_empty()
{
let names: Vec<&str> = tcs.iter().map(|tc| tc.name.as_str()).collect();
msg.content = format!(
"{}\n\n[Tool calls ({}) — results dropped during truncation]",
msg.content,
names.join(", ")
);
msg.tool_calls = None;
}
}
}
Self::repair_tool_pairs(&mut truncated);
tracing::warn!(
final_tokens = final_tokens,
danger = danger_threshold,
@ -430,21 +269,14 @@ impl ContextCompressor {
current_history = truncated;
}
if created_timelines {
self.invalidate_token_cache();
}
#[cfg(debug_assertions)]
tracing::debug!(
final_tokens = self.token_estimate_with_history(&current_history),
final_tokens = estimate_tokens(&current_history),
final_msg_count = current_history.len(),
"Context compression completed"
);
Ok(CompressionResult {
history: current_history,
created_timelines,
})
Ok(CompressionResult { history: current_history, created_timelines })
}
/// Try to extract the actual context token limit from an LLM error message.
@ -467,21 +299,20 @@ impl ContextCompressor {
// Look for a number in the vicinity (up to 10 chars after marker)
if let Some(num_str) = find_number_nearby(after, 50)
&& let Ok(n) = num_str.parse::<usize>()
&& (1024..=10_000_000).contains(&n)
{
return Some(n);
}
&& (1024..=10_000_000).contains(&n) {
return Some(n);
}
}
}
// Also try: "XXXX token context" or "XXXX limit"
if let Some(num_str) = find_number_nearby(&lower, lower.len())
&& let Ok(n) = num_str.parse::<usize>()
&& (1024..=10_000_000).contains(&n)
&& (lower.contains("token") || lower.contains("context") || lower.contains("limit"))
{
return Some(n);
}
&& (1024..=10_000_000).contains(&n)
&& (lower.contains("token") || lower.contains("context") || lower.contains("limit"))
{
return Some(n);
}
None
}
@ -530,26 +361,19 @@ impl ContextCompressor {
// Persist compressed summary as timeline memory entry
let ts = chrono::Utc::now().format("%Y-%m-%d %H:%M").to_string();
let timeline_content = format!(
"[{}] Compressed {} conversation segments:\n{}",
ts,
between.len(),
summary
);
let timeline_content = format!("[{}] Compressed {} conversation segments:\n{}",
ts, between.len(), summary);
let key = format!("ctx_compressed_{}", uuid::Uuid::new_v4());
let mm = self.memory.clone();
let sid = self.session_id.clone();
tokio::spawn(async move {
if let Err(e) = mm
.store(
&key,
&timeline_content,
crate::memory::MemoryCategory::Timeline,
sid.as_deref(),
Some(0.3),
)
.await
{
if let Err(e) = mm.store(
&key,
&timeline_content,
crate::memory::MemoryCategory::Timeline,
sid.as_deref(),
Some(0.3),
).await {
tracing::warn!(error = %e, "Failed to store compressed context as timeline");
}
});
@ -580,7 +404,10 @@ impl ContextCompressor {
}
/// Summarize a segment of messages using LLM.
async fn summarize_segment(&self, messages: &[ChatMessage]) -> Result<String, AgentError> {
async fn summarize_segment(
&self,
messages: &[ChatMessage],
) -> Result<String, AgentError> {
if messages.is_empty() {
return Ok(String::new());
}
@ -594,8 +421,7 @@ impl ContextCompressor {
"tool" => "Tool",
_ => m.role.as_str(),
};
let name = m
.tool_name
let name = m.tool_name
.as_ref()
.map(|n| format!(" ({})", n))
.unwrap_or_default();
@ -640,10 +466,7 @@ Be concise, aim for {} characters or less.
);
let request = ChatCompletionRequest {
messages: vec![
Message::system("You are a helpful assistant."),
Message::user(&prompt),
],
messages: vec![Message::system("You are a helpful assistant."), Message::user(&prompt)],
temperature: Some(0.3),
max_tokens: Some(1000),
tools: None,
@ -715,23 +538,13 @@ mod tests {
content: "[summarized]".into(),
reasoning_content: None,
tool_calls: vec![],
usage: Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
},
usage: Usage { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 },
})
}
fn ptype(&self) -> &str {
"mock"
}
fn name(&self) -> &str {
"mock"
}
fn model_id(&self) -> &str {
"mock"
}
fn ptype(&self) -> &str { "mock" }
fn name(&self) -> &str { "mock" }
fn model_id(&self) -> &str { "mock" }
}
fn mock_summarizer() -> Arc<dyn LLMProvider> {
@ -743,13 +556,11 @@ mod tests {
MM.get_or_init(|| {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let tmp = std::env::temp_dir()
.join(format!("picobot_ctx_test_{}.db", std::process::id()));
let tmp = std::env::temp_dir().join(format!("picobot_ctx_test_{}.db", std::process::id()));
let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap());
Arc::new(MemoryManager::new(storage, "test".into(), "test".into()))
})
})
.clone()
}).clone()
}
#[test]
@ -765,11 +576,7 @@ mod tests {
// "Hi there!" (8) -> ceil(8/4)+4 = 2+4 = 6
// "How are you?" (11) -> ceil(11/4)+4 = 3+4 = 7
// raw = 19, with 1.2x = ~23
assert!(
tokens > 18 && tokens < 30,
"Expected ~23 tokens, got {}",
tokens
);
assert!(tokens > 18 && tokens < 30, "Expected ~23 tokens, got {}", tokens);
}
#[test]
@ -778,15 +585,14 @@ mod tests {
tool_result_trim_chars: 50,
..Default::default()
};
let compressor =
ContextCompressor::with_config(mock_provider(), 100_000, config, test_memory_manager());
let compressor = ContextCompressor::with_config(mock_provider(), 100_000, config, test_memory_manager());
let mut messages = vec![
ChatMessage::user("Hello"),
ChatMessage::tool("call1", "bash", &"x".repeat(200)),
];
let modified = compressor.fast_trim_tool_results(&mut messages, 2);
let modified = compressor.fast_trim_tool_results(&mut messages);
assert_eq!(modified, 1);
assert!(messages[1].content.len() < 100);
}
@ -813,18 +619,14 @@ mod tests {
max_passes: 0,
..Default::default()
};
let mut compressor = ContextCompressor::with_config(mock_provider(), 200, config, mm);
let compressor = ContextCompressor::with_config(mock_provider(), 200, config, mm);
let messages = vec![
ChatMessage::user("Hi"),
ChatMessage::tool("call1", "bash", &"x".repeat(3000)),
];
let result = compressor
.compress_if_needed(messages)
.await
.unwrap()
.history;
let result = compressor.compress_if_needed(messages).await.unwrap().history;
let tool_msg = result.iter().find(|m| m.role == "tool").unwrap();
assert!(
@ -848,19 +650,18 @@ mod tests {
// - B2B (L275): last user message lost when it is the final history message
//
// context_window=200 → threshold=100. Large tool outputs force LLM summarization.
let tmp =
std::env::temp_dir().join(format!("picobot_ctx_boundary_{}.db", std::process::id()));
let tmp = std::env::temp_dir().join(format!("picobot_ctx_boundary_{}.db", std::process::id()));
let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap());
let mm = Arc::new(MemoryManager::new(storage, "test".into(), "test".into()));
let config = ContextCompressionConfig {
tool_result_trim_chars: 2000,
protect_first_n: 1, // system/protected → B2A: first user (after skip) duplicated
protect_first_n: 1, // system/protected → B2A: first user (after skip) duplicated
protect_last_n: 2,
max_passes: 1,
..Default::default()
};
let mut compressor = ContextCompressor::with_config(mock_summarizer(), 200, config, mm);
let compressor = ContextCompressor::with_config(mock_summarizer(), 200, config, mm);
// History: 9 messages, last message is user Q4.
// user_indices (skip 1) = [1, 3, 6, 8]
@ -869,43 +670,25 @@ mod tests {
let big = "x".repeat(3000);
let messages = vec![
ChatMessage::system("You are a helper."), // 0: protected
ChatMessage::user("Q1"), // 1: first user
ChatMessage::tool("t1", "bash", &big), // 2
ChatMessage::user("Q2"), // 3
ChatMessage::assistant("thinking"), // 4
ChatMessage::tool("t2", "bash", &big), // 5
ChatMessage::user("Q3"), // 6
ChatMessage::assistant("thinking"), // 7
ChatMessage::user("Q4"), // 8: LAST, is user → B2B triggers
ChatMessage::user("Q1"), // 1: first user
ChatMessage::tool("t1", "bash", &big), // 2
ChatMessage::user("Q2"), // 3
ChatMessage::assistant("thinking"), // 4
ChatMessage::tool("t2", "bash", &big), // 5
ChatMessage::user("Q3"), // 6
ChatMessage::assistant("thinking"), // 7
ChatMessage::user("Q4"), // 8: LAST, is user → B2B triggers
];
let result = compressor
.compress_if_needed(messages)
.await
.unwrap()
.history;
let result = compressor.compress_if_needed(messages).await.unwrap().history;
// B2A: "Q1" must appear exactly once
let q1_count = result
.iter()
.filter(|m| m.role == "user" && m.content == "Q1")
.count();
assert_eq!(
q1_count, 1,
"Q1 should appear exactly once, got {}",
q1_count
);
let q1_count = result.iter().filter(|m| m.role == "user" && m.content == "Q1").count();
assert_eq!(q1_count, 1, "Q1 should appear exactly once, got {}", q1_count);
// B2B: "Q4" must NOT be lost
let q4_count = result
.iter()
.filter(|m| m.role == "user" && m.content == "Q4")
.count();
assert_eq!(
q4_count, 1,
"Q4 should appear exactly once (not lost), got {}",
q4_count
);
let q4_count = result.iter().filter(|m| m.role == "user" && m.content == "Q4").count();
assert_eq!(q4_count, 1, "Q4 should appear exactly once (not lost), got {}", q4_count);
let _ = std::fs::remove_file(&tmp);
}
@ -919,16 +702,16 @@ mod tests {
let mm = Arc::new(MemoryManager::new(storage, "test".into(), "test".into()));
let config = ContextCompressionConfig {
tool_result_trim_chars: 500, // trim reduces but not enough
tool_result_trim_chars: 500, // trim reduces but not enough
protect_first_n: 1,
protect_last_n: 2,
max_passes: 0, // no LLM summarization → will exceed danger
max_passes: 0, // no LLM summarization → will exceed danger
..Default::default()
};
// context_window=100, danger_threshold=90.
// Each trimmed tool (~500 chars): ceil(500/4)+4 = 129 raw. 3 tools = 387.
// Plus users (~5 each) + system (~15) = ~417 raw * 1.2 = 500 > 90.
let mut compressor = ContextCompressor::with_config(mock_provider(), 100, config, mm);
let compressor = ContextCompressor::with_config(mock_provider(), 100, config, mm);
let big = "x".repeat(3000);
let messages = vec![
@ -941,23 +724,13 @@ mod tests {
ChatMessage::tool("t3", "bash", &big),
];
let result = compressor
.compress_if_needed(messages)
.await
.unwrap()
.history;
let result = compressor.compress_if_needed(messages).await.unwrap().history;
// After hard truncation: head (1) + trunc_note (1) + tail (2) = 4 messages
assert!(
result.len() < 7,
"expected truncation reduction, got {} messages",
result.len()
);
assert!(result.len() < 7, "expected truncation reduction, got {} messages", result.len());
// Truncation notice should be present
let has_notice = result
.iter()
.any(|m| m.content.contains("Context truncation"));
let has_notice = result.iter().any(|m| m.content.contains("Context truncation"));
assert!(has_notice, "hard truncation notice missing");
let _ = std::fs::remove_file(&tmp);
@ -972,9 +745,9 @@ mod tests {
let mut messages = vec![
ChatMessage::user("Q1"),
ChatMessage::user("[Context Summary]\n\nsummary of previous turn"),
ChatMessage::tool("tc1", "bash", "orphan result"), // orphan — tc1 never declared
ChatMessage::assistant("done"), // declares tc2
ChatMessage::tool("tc2", "bash", "legitimate result"), // legit
ChatMessage::tool("tc1", "bash", "orphan result"), // orphan — tc1 never declared
ChatMessage::assistant("done"), // declares tc2
ChatMessage::tool("tc2", "bash", "legitimate result"), // legit
];
// Set tool_call_id on tool messages and tool_calls on assistant
messages[2].tool_call_id = Some("tc1".into());
@ -989,16 +762,8 @@ mod tests {
// orphan should be removed; legitimate should stay
assert_eq!(messages.len(), 4);
assert!(
messages
.iter()
.all(|m| m.tool_call_id != Some("tc1".into()))
);
assert!(
messages
.iter()
.any(|m| m.tool_call_id == Some("tc2".into()))
);
assert!(messages.iter().all(|m| m.tool_call_id != Some("tc1".into())));
assert!(messages.iter().any(|m| m.tool_call_id == Some("tc2".into())));
}
#[test]

View File

@ -49,7 +49,7 @@ impl MediaHandler for ImageHandler {
}
fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> {
use base64::{Engine as _, engine::general_purpose::STANDARD};
use base64::{engine::general_purpose::STANDARD, Engine as _};
let mut file = std::fs::File::open(path)?;
let mut buffer = Vec::new();

View File

@ -1,16 +1,8 @@
pub mod agent_loop;
pub mod context_compressor;
pub mod media_handler;
pub mod sub_agent;
pub mod system_prompt;
pub use agent_loop::{AgentError, AgentLoop, AgentProcessResult};
pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult};
pub use context_compressor::{ContextCompressor, estimate_tokens};
pub use sub_agent::{
DelegateContext, ExecutionMode, SubAgentConfig, SubAgentError, SubAgentManager, SubAgentResult,
TaskNotification, TaskStatus,
};
pub use system_prompt::{
PromptContext, PromptSection, SystemPromptBuilder, build_sub_agent_system_prompt,
build_system_prompt,
};
pub use system_prompt::{build_system_prompt, PromptContext, PromptSection, SystemPromptBuilder};

View File

@ -1,623 +0,0 @@
use std::collections::HashSet;
use std::sync::Arc;
use std::time::Instant;
use dashmap::DashMap;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use crate::agent::AgentError;
use crate::agent::AgentLoop;
use crate::agent::system_prompt::build_sub_agent_system_prompt;
use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig;
use crate::providers::{LLMProvider, create_provider};
use crate::skills::SkillsLoader;
use crate::tools::ToolRegistry;
tokio::task_local! {
pub(crate) static DELEGATE_CONTEXT: DelegateContext;
}
/// Read the delegate context from the current task. Returns an error if not set.
pub fn get_delegate_context() -> Result<DelegateContext, String> {
DELEGATE_CONTEXT
.try_with(|ctx| ctx.clone())
.map_err(|_| "DELEGATE_CONTEXT not set".to_string())
}
const DEFAULT_MAX_ITERATIONS: usize = 99;
const DEFAULT_TIMEOUT_SECS: u64 = 3600;
const MAX_INLINE_RESULT_CHARS: usize = 8000;
const DEFAULT_READONLY_TOOLS: &[&str] = &[
"file_read",
"file_search",
"content_search",
"web_fetch",
"http_request",
"calculator",
];
#[derive(Debug, Clone)]
pub struct SubAgentConfig {
pub prompt: String,
pub mode: ExecutionMode,
pub allowed_tools: Option<Vec<String>>,
pub max_iterations: Option<usize>,
pub timeout_secs: Option<u64>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ExecutionMode {
Inline,
Background,
Parallel,
}
#[derive(Debug, Clone)]
pub struct SubAgentResult {
pub task_id: String,
pub content: String,
pub content_truncated: bool,
pub status: TaskStatus,
pub tool_calls_count: usize,
pub iterations: usize,
pub duration_ms: u64,
}
#[derive(Debug, Clone)]
pub enum TaskStatus {
Completed,
Failed(String),
Cancelled,
TimedOut,
}
#[derive(Debug, Clone)]
pub struct TaskNotification {
pub task_id: String,
pub session_id: String,
pub channel: String,
pub chat_id: String,
pub status: TaskStatus,
pub result_summary: String,
}
#[derive(Debug, Clone)]
pub struct DelegateContext {
pub session_id: String,
pub channel: String,
pub chat_id: String,
}
#[derive(Debug)]
pub enum SubAgentError {
TooManyTasks(usize),
ProviderCreation(String),
Storage(String),
Other(String),
}
impl std::fmt::Display for SubAgentError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TooManyTasks(max) => write!(f, "后台任务已达上限({}),请稍后重试", max),
Self::ProviderCreation(e) => write!(f, "provider creation failed: {}", e),
Self::Storage(e) => write!(f, "storage error: {}", e),
Self::Other(e) => write!(f, "{}", e),
}
}
}
impl std::error::Error for SubAgentError {}
pub struct SubAgentManager {
provider_config: LLMProviderConfig,
full_tools: Arc<ToolRegistry>,
storage: Option<Arc<crate::storage::Storage>>,
active_tasks: Arc<DashMap<String, CancellationToken>>,
notify_tx: tokio::sync::mpsc::UnboundedSender<TaskNotification>,
max_concurrent_background_tasks: usize,
skills_loader: Option<Arc<SkillsLoader>>,
}
impl SubAgentManager {
pub fn new(
provider_config: LLMProviderConfig,
full_tools: Arc<ToolRegistry>,
storage: Option<Arc<crate::storage::Storage>>,
notify_tx: tokio::sync::mpsc::UnboundedSender<TaskNotification>,
max_concurrent_background_tasks: usize,
skills_loader: Option<Arc<SkillsLoader>>,
) -> Self {
Self {
provider_config,
full_tools,
storage,
active_tasks: Arc::new(DashMap::new()),
notify_tx,
max_concurrent_background_tasks,
skills_loader,
}
}
pub fn filter_tools(&self, allowed: &Option<Vec<String>>) -> Arc<ToolRegistry> {
let allowed_set: HashSet<&str> = match allowed {
Some(list) => list.iter().map(|s| s.as_str()).collect(),
None => DEFAULT_READONLY_TOOLS.iter().copied().collect(),
};
let filtered = ToolRegistry::new();
for (name, tool) in self.full_tools.iter() {
if allowed_set.contains(name.as_str()) && name != "delegate" {
filtered.register_raw(name, tool);
}
}
Arc::new(filtered)
}
fn get_skills_prompt(&self, tools: &ToolRegistry) -> Option<String> {
let has_get_skill = tools.iter().iter().any(|(name, _)| name == "get_skill");
if has_get_skill {
if let Some(ref loader) = self.skills_loader {
let prompt = loader.build_skills_prompt();
if !prompt.is_empty() {
return Some(prompt);
}
}
}
None
}
pub fn build_sub_agent(
&self,
config: &SubAgentConfig,
tools: Arc<ToolRegistry>,
) -> Result<AgentLoop, AgentError> {
let mut provider = create_provider(self.provider_config.clone())
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
if let Some(ref s) = self.storage {
provider.set_storage(s.clone());
}
let provider: Arc<dyn LLMProvider> = Arc::from(provider);
let max_iterations = config.max_iterations.unwrap_or(DEFAULT_MAX_ITERATIONS);
let workspace_dir = self.provider_config.workspace_dir.clone();
let model_name = self.provider_config.model_id.clone();
let input_types = self.provider_config.input_types.clone();
let agent = AgentLoop::with_provider_and_tools(
provider,
tools,
max_iterations,
model_name,
workspace_dir,
input_types,
)
.with_context_window(self.provider_config.token_limit);
Ok(agent)
}
pub async fn run_inline(
&self,
config: SubAgentConfig,
) -> Result<SubAgentResult, SubAgentError> {
let task_id = generate_task_id();
let tools = self.filter_tools(&config.allowed_tools);
let timeout_secs = config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
let timeout_human = format_duration(timeout_secs);
let http_get_only = config.allowed_tools.is_none()
|| config
.allowed_tools
.as_ref()
.is_some_and(|v| v.iter().any(|t| t == "http_request"));
let skills_prompt = self.get_skills_prompt(&tools);
let system_prompt = build_sub_agent_system_prompt(
&config.prompt,
&timeout_human,
&tools,
&self.provider_config.workspace_dir,
&self.provider_config.model_id,
skills_prompt,
http_get_only,
);
let agent = self
.build_sub_agent(&config, tools)
.map_err(|e| SubAgentError::ProviderCreation(e.to_string()))?;
let history = vec![
ChatMessage::system(system_prompt),
ChatMessage::user(&config.prompt),
];
let start = Instant::now();
let result = tokio::time::timeout(
std::time::Duration::from_secs(timeout_secs),
agent.process(history),
)
.await;
let duration_ms = start.elapsed().as_millis() as u64;
match result {
Ok(Ok(agent_result)) => {
let (content, truncated) =
truncate_sub_agent_result(&agent_result.final_response.content);
let tool_calls_count = agent_result
.emitted_messages
.iter()
.filter(|m| m.tool_calls.is_some())
.count();
let iterations = agent_result
.emitted_messages
.iter()
.filter(|m| m.role == "assistant" && m.tool_calls.is_some())
.count();
Ok(SubAgentResult {
task_id,
content,
content_truncated: truncated,
status: TaskStatus::Completed,
tool_calls_count,
iterations,
duration_ms,
})
}
Ok(Err(e)) => Ok(SubAgentResult {
task_id,
content: String::new(),
content_truncated: false,
status: TaskStatus::Failed(e.to_string()),
tool_calls_count: 0,
iterations: 0,
duration_ms,
}),
Err(_elapsed) => Ok(SubAgentResult {
task_id,
content: String::new(),
content_truncated: false,
status: TaskStatus::TimedOut,
tool_calls_count: 0,
iterations: 0,
duration_ms,
}),
}
}
pub async fn run_parallel(
&self,
configs: Vec<SubAgentConfig>,
) -> Result<Vec<SubAgentResult>, SubAgentError> {
let futures: Vec<_> = configs
.into_iter()
.map(|config| {
let mgr = self; // &self borrow, all tasks share the same manager
async move { mgr.run_inline(config).await }
})
.collect();
let results = futures_util::future::join_all(futures).await;
Ok(results.into_iter().collect::<Result<Vec<_>, _>>()?)
}
pub async fn run_background(
&self,
config: SubAgentConfig,
ctx: DelegateContext,
) -> Result<String, SubAgentError> {
if self.active_tasks.len() >= self.max_concurrent_background_tasks {
return Err(SubAgentError::TooManyTasks(
self.max_concurrent_background_tasks,
));
}
let task_id = generate_task_id();
let cancel_token = CancellationToken::new();
self.active_tasks
.insert(task_id.clone(), cancel_token.clone());
// Write DB: pending
if let Some(ref storage) = self.storage {
let allowed_tools_json = config
.allowed_tools
.as_ref()
.and_then(|v| serde_json::to_string(v).ok());
let record = crate::storage::BackgroundTask {
id: task_id.clone(),
session_id: ctx.session_id.clone(),
channel: ctx.channel.clone(),
chat_id: ctx.chat_id.clone(),
prompt: config.prompt.clone(),
allowed_tools: allowed_tools_json,
status: "pending".to_string(),
result: None,
error: None,
tool_calls_count: 0,
iterations: 0,
started_at: None,
finished_at: None,
created_at: chrono::Utc::now().timestamp_millis(),
};
storage
.create_background_task(&record)
.await
.map_err(|e| SubAgentError::Storage(e.to_string()))?;
}
let tools = self.filter_tools(&config.allowed_tools);
let timeout_secs = config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
let timeout_human = format_duration(timeout_secs);
let http_get_only = config.allowed_tools.is_none()
|| config
.allowed_tools
.as_ref()
.is_some_and(|v| v.iter().any(|t| t == "http_request"));
let skills_prompt = self.get_skills_prompt(&tools);
let system_prompt = build_sub_agent_system_prompt(
&config.prompt,
&timeout_human,
&tools,
&self.provider_config.workspace_dir,
&self.provider_config.model_id,
skills_prompt,
http_get_only,
);
let provider_config = self.provider_config.clone();
let storage = self.storage.clone();
let notify_tx = self.notify_tx.clone();
let active_tasks = Arc::clone(&self.active_tasks);
let tid = task_id.clone();
let sess_id = ctx.session_id.clone();
let ch = ctx.channel.clone();
let cid = ctx.chat_id.clone();
let prompt = config.prompt.clone();
tokio::spawn(async move {
let started_at = chrono::Utc::now().timestamp_millis();
// Update DB: running
if let Some(ref s) = storage {
let _ = s
.update_background_task_status(
&tid,
"running",
None,
None,
Some(started_at),
None,
)
.await;
}
let mut provider = create_provider(provider_config.clone()).ok();
if let Some(ref mut p) = provider {
if let Some(ref s) = storage {
p.set_storage(s.clone());
}
}
let provider_result: Option<Arc<dyn LLMProvider>> = provider.map(|p| Arc::from(p));
let result = match provider_result {
Some(provider) => {
let agent = AgentLoop::with_provider_and_tools(
provider,
tools,
DEFAULT_MAX_ITERATIONS,
provider_config.model_id.clone(),
provider_config.workspace_dir.clone(),
provider_config.input_types.clone(),
)
.with_context_window(provider_config.token_limit);
let history = vec![
ChatMessage::system(system_prompt),
ChatMessage::user(&prompt),
];
tokio::select! {
r = tokio::time::timeout(
std::time::Duration::from_secs(timeout_secs),
agent.process(history),
) => {
match r {
Ok(Ok(agent_result)) => SubAgentResult {
task_id: tid.clone(),
content: agent_result.final_response.content,
content_truncated: false,
status: TaskStatus::Completed,
tool_calls_count: 0,
iterations: 0,
duration_ms: 0,
},
Ok(Err(e)) => SubAgentResult {
task_id: tid.clone(),
content: String::new(),
content_truncated: false,
status: TaskStatus::Failed(e.to_string()),
tool_calls_count: 0,
iterations: 0,
duration_ms: 0,
},
Err(_) => SubAgentResult {
task_id: tid.clone(),
content: String::new(),
content_truncated: false,
status: TaskStatus::TimedOut,
tool_calls_count: 0,
iterations: 0,
duration_ms: 0,
},
}
}
_ = cancel_token.cancelled() => SubAgentResult {
task_id: tid.clone(),
content: String::new(),
content_truncated: false,
status: TaskStatus::Cancelled,
tool_calls_count: 0,
iterations: 0,
duration_ms: 0,
},
}
}
None => SubAgentResult {
task_id: tid.clone(),
content: String::new(),
content_truncated: false,
status: TaskStatus::Failed("provider creation failed".into()),
tool_calls_count: 0,
iterations: 0,
duration_ms: 0,
},
};
let finished_at = chrono::Utc::now().timestamp_millis();
let duration_ms = (finished_at - started_at) as u64;
let (status_str, error_val) = match &result.status {
TaskStatus::Completed => ("completed".to_string(), None),
TaskStatus::Failed(e) => ("failed".to_string(), Some(e.clone())),
TaskStatus::Cancelled => ("cancelled".to_string(), None),
TaskStatus::TimedOut => ("failed".to_string(), Some("timeout".to_string())),
};
if let Some(ref s) = storage {
let _ = s
.update_background_task_status(
&tid,
&status_str,
Some(&result.content),
error_val.as_deref(),
Some(started_at),
Some(finished_at),
)
.await;
}
let _ = notify_tx.send(TaskNotification {
task_id: tid.clone(),
session_id: sess_id,
channel: ch,
chat_id: cid,
status: result.status,
result_summary: summarize_for_notification(&result.content, duration_ms),
});
active_tasks.remove(&tid);
});
Ok(task_id)
}
pub async fn cancel_task(&self, task_id: &str) -> Result<bool, SubAgentError> {
if let Some((_, token)) = self.active_tasks.remove(task_id) {
token.cancel();
if let Some(ref s) = self.storage {
s.update_background_task_status(
task_id,
"cancelled",
None,
None,
None,
Some(chrono::Utc::now().timestamp_millis()),
)
.await
.map_err(|e| SubAgentError::Storage(e.to_string()))?;
}
Ok(true)
} else if let Some(ref s) = self.storage {
match s.get_background_task(task_id).await {
Ok(task) => match task.status.as_str() {
"pending" | "running" => {
tracing::warn!(task_id, "task in DB but not in active_tasks");
Ok(false)
}
_ => Ok(false),
},
Err(_) => Ok(false),
}
} else {
Ok(false)
}
}
pub async fn check_task(&self, task_id: &str) -> Option<crate::storage::BackgroundTask> {
if let Some(ref s) = self.storage {
s.get_background_task(task_id).await.ok()
} else {
None
}
}
pub async fn list_tasks(&self, session_id: &str) -> Vec<crate::storage::BackgroundTask> {
if let Some(ref s) = self.storage {
s.list_background_tasks(session_id)
.await
.unwrap_or_default()
} else {
vec![]
}
}
pub async fn cancel_by_session(&self, session_id: &str) {
// Cancel all running tasks for a session by checking DB
if let Some(ref s) = self.storage {
if let Ok(tasks) = s.list_background_tasks(session_id).await {
for task in &tasks {
if task.status == "pending" || task.status == "running" {
let _ = self.cancel_task(&task.id).await;
}
}
}
}
}
pub fn active_task_count(&self) -> usize {
self.active_tasks.len()
}
}
fn generate_task_id() -> String {
Uuid::new_v4().to_string()[..8].to_string()
}
fn format_duration(seconds: u64) -> String {
if seconds < 60 {
format!("{}s", seconds)
} else if seconds < 3600 {
format!("{}m", seconds / 60)
} else {
format!("{}h", seconds / 3600)
}
}
fn truncate_sub_agent_result(content: &str) -> (String, bool) {
if content.len() <= MAX_INLINE_RESULT_CHARS {
(content.to_string(), false)
} else {
let truncate_at = content.floor_char_boundary(MAX_INLINE_RESULT_CHARS);
(
format!(
"{}\n\n[... 结果已截断,共 {} 字符,完整结果请使用 check_task 查看 ...]",
&content[..truncate_at],
content.len()
),
true,
)
}
}
fn summarize_for_notification(content: &str, _duration_ms: u64) -> String {
const MAX_SUMMARY_BYTES: usize = 500;
if content.len() <= MAX_SUMMARY_BYTES {
content.to_string()
} else {
let truncate_at = content.floor_char_boundary(MAX_SUMMARY_BYTES);
format!("{}...", &content[..truncate_at])
}
}

View File

@ -3,7 +3,11 @@
//! This module provides a modular framework for building system prompts
//! using the SystemPromptBuilder pattern.
//!
//! Prompt section ordering: Identity → Environment → Tasks → Rules → Capabilities → Dynamic → Delegation
//! Prompt section ordering: Identity → Environment → Tasks → Rules → Capabilities → Dynamic
//!
//! Configuration files loaded from ~/.picobot/:
//! - AGENTS.md — agent identity and behavior
//! - USER.md — user preferences and profile
use crate::tools::ToolRegistry;
use std::path::Path;
@ -51,35 +55,10 @@ impl SystemPromptBuilder {
Box::new(CrossChannelSection),
Box::new(MemorySection),
Box::new(HistorySection),
Box::new(DelegationSection),
],
}
}
/// Create a builder with sub-agent specific sections.
pub fn with_sub_agent_defaults(
task: &str,
timeout: &str,
skills_prompt: Option<String>,
http_get_only: bool,
) -> Self {
let mut sections: Vec<Box<dyn PromptSection>> = vec![
Box::new(SubAgentIdentitySection {
task: task.to_string(),
timeout: timeout.to_string(),
}),
Box::new(ToolHonestySection),
Box::new(SafetySection),
Box::new(SubAgentToolsSection { http_get_only }),
Box::new(WorkspaceSection),
Box::new(DateTimeSection),
];
if let Some(sp) = skills_prompt {
sections.push(Box::new(SubAgentSkillsSection { skills_prompt: sp }));
}
Self { sections }
}
/// Add a custom section to the builder.
pub fn add_section(mut self, section: Box<dyn PromptSection>) -> Self {
self.sections.push(section);
@ -196,10 +175,10 @@ impl PromptSection for UserProfileSection {
if let Some(user_config_dir) = get_user_config_dir()
&& let Some(content) =
load_file_from_dir(&user_config_dir, "USER.md", BOOTSTRAP_MAX_CHARS)
{
output.push_str(&content);
return output;
}
{
output.push_str(&content);
return output;
}
// No USER.md found, return empty
String::new()
@ -220,10 +199,10 @@ impl PromptSection for AgentProfileSection {
if let Some(user_config_dir) = get_user_config_dir()
&& let Some(content) =
load_file_from_dir(&user_config_dir, "AGENTS.md", BOOTSTRAP_MAX_CHARS)
{
output.push_str(&content);
return output;
}
{
output.push_str(&content);
return output;
}
String::new()
}
@ -374,120 +353,6 @@ impl PromptSection for HistorySection {
}
}
/// Sub-agent delegation principles.
pub struct DelegationSection;
impl PromptSection for DelegationSection {
fn name(&self) -> &str {
"delegation"
}
fn build(&self, _ctx: &PromptContext<'_>) -> String {
"## 子 Agent 委托原则\n\n\
使 delegate Agent\n\
\n\
### \n\
- 使 mode=\"parallel\"\n\
- 使 mode=\"background\"\n\
- \n\
\n\
### \n\
- **** Agent \n\
- **** file_readfile_searchweb_fetch bashfile_writefile_edit\n\
- **** delegate Agent\n\
- **** Agent \n\
\n\
### Skill \n\
- skill allowed_tools get_skill\n\
- prompt Agent 使 get_skill \n\
- \"使用 get_skill action='get' skill_name='pdf' 加载 PDF 处理技能后完成任务\"\n\
\n\
### \n\
- prompt \n\
- prompt \"跳过 .tmp 文件\"\n\
- \n\
\n\
### \n\
- 使 mode=\"parallel\",任务定义在 tasks 数组中\n\
- \n\
- 5 \n\
\n\
### \n\
- 30s 使 mode=\"background\"\n\
- ".to_string()
}
}
// === Sub-Agent Prompt Sections ===
/// Sub-agent identity and task instructions.
pub struct SubAgentIdentitySection {
pub task: String,
pub timeout: String,
}
impl PromptSection for SubAgentIdentitySection {
fn name(&self) -> &str {
"sub_agent_identity"
}
fn build(&self, _ctx: &PromptContext<'_>) -> String {
format!(
"## 子 Agent\n\n\
Agent Agent Agent\n\
\n\
## \n\n\
{}\n\
\n\
## \n\
- \n\
- 使\n\
- 使 delegate \n\
- \n\
- \n\
- {}",
self.task, self.timeout,
)
}
}
/// Sub-agent available tools description.
pub struct SubAgentToolsSection {
pub http_get_only: bool,
}
impl PromptSection for SubAgentToolsSection {
fn name(&self) -> &str {
"sub_agent_tools"
}
fn build(&self, ctx: &PromptContext<'_>) -> String {
let mut s = String::from("## 可用工具\n\n");
s.push_str(&ctx.tools.describe_for_prompt());
if self.http_get_only {
s.push_str(
"\n\n**注意**:使用 http_request 时只允许 GET 方法,禁止 POST、PUT、DELETE 等。",
);
}
s
}
}
/// Sub-agent skills information, injected when get_skill tool is available.
pub struct SubAgentSkillsSection {
pub skills_prompt: String,
}
impl PromptSection for SubAgentSkillsSection {
fn name(&self) -> &str {
"sub_agent_skills"
}
fn build(&self, _ctx: &PromptContext<'_>) -> String {
self.skills_prompt.clone()
}
}
// === Helper Functions ===
/// Get user config directory (~/.picobot/).
@ -544,28 +409,6 @@ pub fn build_system_prompt(
SystemPromptBuilder::with_defaults().build(&ctx)
}
/// Build a system prompt for a sub-agent with all relevant operational sections.
pub fn build_sub_agent_system_prompt(
task: &str,
timeout_human: &str,
tools: &ToolRegistry,
workspace_dir: &Path,
model_name: &str,
skills_prompt: Option<String>,
http_get_only: bool,
) -> String {
let ctx = PromptContext {
workspace_dir,
model_name,
tools,
session_id: None,
memory_context: None,
has_compressed_history: false,
};
SystemPromptBuilder::with_sub_agent_defaults(task, timeout_human, skills_prompt, http_get_only)
.build(&ctx)
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -1,8 +1,8 @@
use std::sync::Arc;
use crate::bus::{MessageBus, OutboundMessage};
use crate::channels::ChannelManager;
use crate::channels::base::{Channel, ChannelError};
use crate::channels::ChannelManager;
/// OutboundDispatcher consumes outbound messages from the MessageBus
/// and dispatches them to the appropriate Channel

View File

@ -1,5 +1,5 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::providers::ToolCall;
@ -23,9 +23,7 @@ pub struct ImageUrlBlock {
impl ContentBlock {
pub fn text(content: impl Into<String>) -> Self {
Self::Text {
text: content.into(),
}
Self::Text { text: content.into() }
}
pub fn image_url(url: impl Into<String>) -> Self {
@ -51,10 +49,10 @@ pub struct MediaRef {
#[derive(Debug, Clone)]
pub struct MediaItem {
pub path: String, // Local file path
pub media_type: String, // "image", "audio", "file", "video"
pub path: String, // Local file path
pub media_type: String, // "image", "audio", "file", "video"
pub mime_type: Option<String>,
pub original_key: Option<String>, // Feishu file_key for download
pub original_key: Option<String>, // Feishu file_key for download
}
impl MediaItem {
@ -163,10 +161,7 @@ impl ChatMessage {
}
}
pub fn assistant_with_tool_calls(
content: impl Into<String>,
tool_calls: Vec<ToolCall>,
) -> Self {
pub fn assistant_with_tool_calls(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
role: "assistant".to_string(),
@ -211,11 +206,7 @@ impl ChatMessage {
}
}
pub fn tool(
tool_call_id: impl Into<String>,
tool_name: impl Into<String>,
content: impl Into<String>,
) -> Self {
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
role: "tool".to_string(),

View File

@ -2,13 +2,10 @@ pub mod dispatcher;
pub mod message;
pub use dispatcher::OutboundDispatcher;
pub use message::{
ChatMessage, ContentBlock, ControlMessage, InboundMessage, MediaItem, MediaRef, MessageSource,
OutboundMessage, SourceKind,
};
pub use message::{ChatMessage, ContentBlock, ControlMessage, InboundMessage, MediaItem, MediaRef, MessageSource, OutboundMessage, SourceKind};
use std::sync::Arc;
use tokio::sync::{Mutex, mpsc};
use tokio::sync::{mpsc, Mutex};
// ============================================================================
// MessageBus - Async message queue for Channel <-> Agent communication
@ -52,8 +49,7 @@ impl MessageBus {
/// Consume an inbound message (Agent -> Bus)
pub async fn consume_inbound(&self) -> InboundMessage {
let msg = self
.inbound_rx
let msg = self.inbound_rx
.lock()
.await
.recv()

View File

@ -1,10 +1,10 @@
use async_trait::async_trait;
use std::sync::Arc;
use tokio::sync::{Mutex, mpsc};
use async_trait::async_trait;
use tokio::sync::{mpsc, Mutex};
use crate::bus::{ControlMessage, InboundMessage, MessageBus, OutboundMessage};
use crate::protocol::{SlashCommandInfo, WsInbound, WsOutbound, parse_inbound};
use crate::session::{SessionCommand, SessionEvent, UnifiedSessionId};
use crate::protocol::{parse_inbound, WsInbound, WsOutbound, SlashCommandInfo};
use super::base::{Channel, ChannelError};
@ -14,7 +14,6 @@ use super::base::{Channel, ChannelError};
pub(crate) struct Client {
sender: mpsc::Sender<WsOutbound>,
chat_id: String,
current_session_id: Mutex<Option<String>>,
}
@ -42,28 +41,23 @@ impl CliChatChannel {
}
/// Register a new client connection, returns (session_id, client)
pub(crate) async fn register_client(
&self,
sender: mpsc::Sender<WsOutbound>,
) -> (String, Arc<Client>) {
// Each WebSocket connection gets a stable chat scope. All user input and
// dialog controls for this client stay inside that scope unless the
// protocol explicitly carries a full session id.
let chat_id = crate::util::short_id();
pub(crate) async fn register_client(&self, sender: mpsc::Sender<WsOutbound>) -> (String, Arc<Client>) {
// Generate connection ID (used as chat_id) - use short ID
let connection_id = crate::util::short_id();
let client = Arc::new(Client {
sender,
chat_id: chat_id.clone(),
current_session_id: Mutex::new(None),
});
self.clients.lock().await.push(client.clone());
// Create initial session via control message
let session_id = match self.create_session_via_control(&chat_id, None).await {
Ok((id, _title)) => id,
let session_id = match self.create_session_via_control(&connection_id, None).await {
Ok(id) => id,
Err(e) => {
tracing::error!(error = %e, "Failed to create initial session");
UnifiedSessionId::new("cli_chat", &chat_id, &crate::util::short_id()).to_string()
// Fall back to old format for backward compatibility
connection_id.clone()
}
};
@ -79,19 +73,21 @@ impl CliChatChannel {
/// Handle an inbound message from a client
pub(crate) async fn handle_inbound(&self, client: Arc<Client>, raw_msg: &str) {
match parse_inbound(raw_msg) {
Ok(inbound) => match self.handle_ws_inbound(client.clone(), inbound).await {
Ok(()) => {}
Err(e) => {
tracing::warn!(error = %e, "Failed to handle inbound message");
let _ = client
.sender
.send(WsOutbound::Error {
code: "INTERNAL_ERROR".to_string(),
message: e.to_string(),
})
.await;
Ok(inbound) => {
match self.handle_ws_inbound(client.clone(), inbound).await {
Ok(()) => {}
Err(e) => {
tracing::warn!(error = %e, "Failed to handle inbound message");
let _ = client
.sender
.send(WsOutbound::Error {
code: "INTERNAL_ERROR".to_string(),
message: e.to_string(),
})
.await;
}
}
},
}
Err(e) => {
tracing::warn!(error = %e, "Failed to parse inbound message");
let _ = client
@ -105,30 +101,22 @@ impl CliChatChannel {
}
}
async fn handle_ws_inbound(
&self,
client: Arc<Client>,
inbound: WsInbound,
) -> Result<(), ChannelError> {
async fn handle_ws_inbound(&self, client: Arc<Client>, inbound: WsInbound) -> Result<(), ChannelError> {
let bus = {
let guard = self.bus.lock().unwrap();
guard
.clone()
.ok_or_else(|| ChannelError::Other("Channel not started".to_string()))?
guard.clone().ok_or_else(|| ChannelError::Other("Channel not started".to_string()))?
};
let mut current_session_guard = client.current_session_id.lock().await;
match inbound {
WsInbound::UserInput {
content, chat_id, ..
} => {
WsInbound::UserInput { content, chat_id, .. } => {
// All messages (including slash commands) go through the normal inbound flow
// SessionManager handles session creation/reuse internally
let msg = InboundMessage {
channel: self.name().to_string(),
sender_id: "cli".to_string(),
chat_id: chat_id.unwrap_or_else(|| client.chat_id.clone()),
chat_id: chat_id.unwrap_or_else(crate::util::short_id),
content,
timestamp: crate::bus::message::current_timestamp(),
media: Vec::new(),
@ -137,56 +125,19 @@ impl CliChatChannel {
};
bus.publish_inbound(msg).await?;
}
WsInbound::ClearHistory {
chat_id,
session_id,
} => {
WsInbound::ClearHistory { chat_id, session_id } => {
let target = session_id
.or(chat_id)
.or(current_session_guard.clone())
.ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
let (reply_tx, mut reply_rx) = mpsc::channel(1);
let session_id = if let Some(session_id) = session_id {
UnifiedSessionId::parse(&session_id).ok_or_else(|| {
ChannelError::Other("Invalid session ID format".to_string())
})?
} else if let Some(chat_id) = chat_id {
let (current_tx, mut current_rx) = mpsc::channel(1);
bus.publish_control(ControlMessage {
op: SessionCommand::GetCurrentDialog {
channel: "cli_chat".to_string(),
chat_id,
},
reply_tx: current_tx,
})
.await?;
match current_rx.recv().await {
Some(Ok(SessionEvent::CurrentDialog {
session_id: Some(session_id),
})) => session_id,
Some(Ok(SessionEvent::CurrentDialog { session_id: None })) => {
return Err(ChannelError::Other("No active session".to_string()));
}
Some(Ok(_)) => {
return Err(ChannelError::Other(
"Unexpected response type".to_string(),
));
}
Some(Err(e)) => return Err(e),
None => {
return Err(ChannelError::Other("Control channel closed".to_string()));
}
}
} else {
let target = current_session_guard
.clone()
.ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
UnifiedSessionId::parse(&target).ok_or_else(|| {
ChannelError::Other("Invalid session ID format".to_string())
})?
};
let target = session_id.to_string();
let session_id = UnifiedSessionId::parse(&target)
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
bus.publish_control(ControlMessage {
op: SessionCommand::ClearHistory { session_id },
reply_tx,
})
.await?;
}).await?;
match reply_rx.recv().await {
Some(Ok(SessionEvent::HistoryCleared { .. })) => {
@ -207,21 +158,24 @@ impl CliChatChannel {
}
}
WsInbound::CreateSession { title } => {
let (new_id, created_title) = self
.create_session_via_control(&client.chat_id, title.as_deref())
.await?;
// Use current session's chat_id if available, otherwise generate new one
let chat_id = current_session_guard.clone()
.unwrap_or_else(crate::util::short_id);
let new_id = self.create_session_via_control(&chat_id, title.as_deref()).await?;
*current_session_guard = Some(new_id.clone());
let _ = client
.sender
.send(WsOutbound::SessionCreated {
session_id: new_id,
title: created_title,
title: title.unwrap_or_default(),
})
.await;
}
WsInbound::ListSessions { include_archived } => {
// List dialogs for the current chat
let chat_id = client.chat_id.clone();
let chat_id = current_session_guard.clone()
.unwrap_or_else(|| "".to_string());
let chat_id_for_response = chat_id.clone();
let (reply_tx, mut reply_rx) = mpsc::channel(1);
bus.publish_control(ControlMessage {
op: SessionCommand::ListDialogs {
@ -230,18 +184,13 @@ impl CliChatChannel {
include_archived,
},
reply_tx,
})
.await?;
}).await?;
match reply_rx.recv().await {
Some(Ok(SessionEvent::DialogList {
dialogs,
current_dialog_id,
})) => {
Some(Ok(SessionEvent::DialogList { dialogs, current_dialog_id })) => {
// Convert DialogInfo to SessionSummary for backward compatibility
let sessions: Vec<crate::protocol::SessionSummary> = dialogs
.into_iter()
.map(|d| crate::protocol::SessionSummary {
let sessions: Vec<crate::protocol::SessionSummary> = dialogs.into_iter().map(|d| {
crate::protocol::SessionSummary {
session_id: d.session_id.to_string(),
title: d.title,
channel_name: d.session_id.channel.clone(),
@ -249,14 +198,11 @@ impl CliChatChannel {
message_count: d.message_count,
last_active_at: d.last_active_at,
archived_at: d.archived_at,
})
.collect();
}
}).collect();
let current_session_id = current_dialog_id.map(|did| {
UnifiedSessionId::new("cli_chat", &client.chat_id, &did).to_string()
UnifiedSessionId::new("cli_chat", chat_id_for_response.clone(), did).to_string()
});
if let Some(ref session_id) = current_session_id {
*current_session_guard = Some(session_id.clone());
}
let _ = client
.sender
.send(WsOutbound::SessionList {
@ -277,35 +223,39 @@ impl CliChatChannel {
}
}
WsInbound::LoadSession { session_id } => {
// LoadSession: parse the session_id and get current dialog info
let (reply_tx, mut reply_rx) = mpsc::channel(1);
let unified_id = UnifiedSessionId::parse(&session_id)
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
if unified_id.channel != "cli_chat" || unified_id.chat_id != client.chat_id {
return Err(ChannelError::Other(
"Session does not belong to this client".to_string(),
));
}
bus.publish_control(ControlMessage {
op: SessionCommand::SwitchDialog {
op: SessionCommand::GetCurrentDialog {
channel: unified_id.channel.clone(),
chat_id: unified_id.chat_id.clone(),
dialog_id: unified_id.dialog_id.clone(),
},
reply_tx,
})
.await?;
}).await?;
match reply_rx.recv().await {
Some(Ok(SessionEvent::DialogSwitched { session_id })) => {
*current_session_guard = Some(session_id.to_string());
let _ = client
.sender
.send(WsOutbound::SessionLoaded {
session_id: session_id.to_string(),
title: "Session".to_string(),
message_count: 0,
})
.await;
Some(Ok(SessionEvent::CurrentDialog { session_id: current_session_id_opt })) => {
if let Some(current_session_id) = current_session_id_opt {
*current_session_guard = Some(current_session_id.to_string());
let _ = client
.sender
.send(WsOutbound::SessionLoaded {
session_id: current_session_id.to_string(),
title: "Session".to_string(), // TODO: get actual title
message_count: 0, // TODO: get actual count
})
.await;
} else {
let _ = client
.sender
.send(WsOutbound::Error {
code: "NO_CURRENT_DIALOG".to_string(),
message: "No current dialog".to_string(),
})
.await;
}
}
Some(Ok(_)) => {
// Unexpected response type
@ -325,30 +275,23 @@ impl CliChatChannel {
}
}
WsInbound::RenameSession { session_id, title } => {
let target = session_id
.or(current_session_guard.clone())
.ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
let target = session_id.or(current_session_guard.clone()).ok_or_else(|| {
ChannelError::Other("No active session".to_string())
})?;
let (reply_tx, mut reply_rx) = mpsc::channel(1);
let unified_id = UnifiedSessionId::parse(&target)
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
bus.publish_control(ControlMessage {
op: SessionCommand::RenameDialog {
session_id: unified_id,
title: title.clone(),
},
op: SessionCommand::RenameDialog { session_id: unified_id, title: title.clone() },
reply_tx,
})
.await?;
}).await?;
match reply_rx.recv().await {
Some(Ok(SessionEvent::DialogRenamed { session_id, title })) => {
let _ = client
.sender
.send(WsOutbound::SessionRenamed {
session_id: session_id.to_string(),
title,
})
.send(WsOutbound::SessionRenamed { session_id: session_id.to_string(), title })
.await;
}
Some(Ok(_)) => {
@ -363,43 +306,24 @@ impl CliChatChannel {
}
}
WsInbound::ArchiveSession { session_id } => {
let target = session_id
.or(current_session_guard.clone())
.ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
let was_current = current_session_guard.as_deref() == Some(&target);
let target = session_id.or(current_session_guard.clone()).ok_or_else(|| {
ChannelError::Other("No active session".to_string())
})?;
let (reply_tx, mut reply_rx) = mpsc::channel(1);
let unified_id = UnifiedSessionId::parse(&target)
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
bus.publish_control(ControlMessage {
op: SessionCommand::ArchiveDialog {
session_id: unified_id,
},
op: SessionCommand::ArchiveDialog { session_id: unified_id },
reply_tx,
})
.await?;
}).await?;
match reply_rx.recv().await {
Some(Ok(SessionEvent::DialogArchived { session_id })) => {
let _ = client
.sender
.send(WsOutbound::SessionArchived {
session_id: session_id.to_string(),
})
.send(WsOutbound::SessionArchived { session_id: session_id.to_string() })
.await;
if was_current {
let (new_id, title) = self
.create_session_via_control(&client.chat_id, None)
.await?;
*current_session_guard = Some(new_id.clone());
let _ = client
.sender
.send(WsOutbound::SessionCreated {
session_id: new_id,
title,
})
.await;
}
}
Some(Ok(_)) => {
// Unexpected response type
@ -413,42 +337,35 @@ impl CliChatChannel {
}
}
WsInbound::DeleteSession { session_id } => {
let target = session_id
.or(current_session_guard.clone())
.ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
let target = session_id.or(current_session_guard.clone()).ok_or_else(|| {
ChannelError::Other("No active session".to_string())
})?;
let (reply_tx, mut reply_rx) = mpsc::channel(1);
let unified_id = UnifiedSessionId::parse(&target)
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
bus.publish_control(ControlMessage {
op: SessionCommand::DeleteDialog {
session_id: unified_id,
},
op: SessionCommand::DeleteDialog { session_id: unified_id },
reply_tx,
})
.await?;
}).await?;
match reply_rx.recv().await {
Some(Ok(SessionEvent::DialogDeleted { session_id })) => {
let _ = client
.sender
.send(WsOutbound::SessionDeleted {
session_id: session_id.to_string(),
})
.send(WsOutbound::SessionDeleted { session_id: session_id.to_string() })
.await;
// If deleting current session, create a new one
if current_session_guard.as_deref() == Some(&target) {
drop(reply_rx);
if let Ok((new_id, title)) =
self.create_session_via_control(&client.chat_id, None).await
{
if let Ok(new_id) = self.create_session_via_control(&target, None).await {
*current_session_guard = Some(new_id.clone());
let _ = client
.sender
.send(WsOutbound::SessionCreated {
session_id: new_id,
title,
title: String::new(),
})
.await;
}
@ -471,45 +388,32 @@ impl CliChatChannel {
bus.publish_control(ControlMessage {
op: SessionCommand::GetSlashCommands {
channel: "cli_chat".to_string(),
chat_id: client.chat_id.clone(),
chat_id: "".to_string(),
},
reply_tx,
})
.await?;
}).await?;
if let Some(result) = reply_rx.recv().await {
match result {
Ok(SessionEvent::SlashCommandsList { commands }) => {
// Convert to SlashCommand to SlashCommandInfo
let command_infos: Vec<SlashCommandInfo> = commands
.into_iter()
.map(|cmd| SlashCommandInfo {
let command_infos: Vec<SlashCommandInfo> = commands.into_iter().map(|cmd| {
SlashCommandInfo {
name: cmd.name.to_string(),
description: cmd.description.to_string(),
aliases: cmd.aliases.iter().map(|&a| a.to_string()).collect(),
})
.collect();
let _ = client
.sender
.send(WsOutbound::SlashCommandsList {
commands: command_infos,
})
.await;
}
}).collect();
let _ = client.sender.send(WsOutbound::SlashCommandsList { commands: command_infos }).await;
}
Ok(SessionEvent::Error { code, message }) => {
let _ = client
.sender
.send(WsOutbound::Error { code, message })
.await;
let _ = client.sender.send(WsOutbound::Error { code, message }).await;
}
Err(e) => {
let _ = client
.sender
.send(WsOutbound::Error {
code: "GET_COMMANDS_ERROR".to_string(),
message: e.to_string(),
})
.await;
let _ = client.sender.send(WsOutbound::Error {
code: "GET_COMMANDS_ERROR".to_string(),
message: e.to_string()
}).await;
}
_ => {}
}
@ -523,34 +427,29 @@ impl CliChatChannel {
}
/// Create a session via control message and return the session_id
async fn create_session_via_control(
&self,
chat_id: &str,
title: Option<&str>,
) -> Result<(String, String), ChannelError> {
async fn create_session_via_control(&self, connection_id: &str, title: Option<&str>) -> Result<String, ChannelError> {
let bus = {
let guard = self.bus.lock().unwrap();
guard
.clone()
.ok_or_else(|| ChannelError::Other("Channel not started".to_string()))?
guard.clone().ok_or_else(|| ChannelError::Other("Channel not started".to_string()))?
};
let (reply_tx, mut reply_rx) = mpsc::channel(1);
bus.publish_control(ControlMessage {
op: SessionCommand::CreateDialog {
channel: "cli_chat".to_string(),
chat_id: chat_id.to_string(),
chat_id: connection_id.to_string(),
title: title.map(String::from),
},
reply_tx,
})
.await?;
}).await?;
match reply_rx.recv().await {
Some(Ok(SessionEvent::DialogCreated { session_id, title })) => {
Ok((session_id.to_string(), title))
Some(Ok(SessionEvent::DialogCreated { session_id, .. })) => {
Ok(session_id.to_string())
}
Some(Ok(_)) => {
Err(ChannelError::Other("Unexpected response type".to_string()))
}
Some(Ok(_)) => Err(ChannelError::Other("Unexpected response type".to_string())),
Some(Err(e)) => Err(e),
None => Err(ChannelError::Other("Control channel closed".to_string())),
}
@ -580,11 +479,7 @@ impl Channel for CliChatChannel {
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
let clients = self.clients.lock().await.clone();
for client in clients {
if client.chat_id != msg.chat_id {
continue;
}
let outbound = if msg.metadata.get("_type").map(|v| v.as_str()) == Some("notification")
{
let outbound = if msg.metadata.get("_type").map(|v| v.as_str()) == Some("notification") {
WsOutbound::SystemNotification {
content: msg.content.clone(),
}

File diff suppressed because it is too large Load Diff

View File

@ -24,10 +24,7 @@ impl ChannelManager {
}
}
pub fn with_bus(
cli_chat_channel: Arc<crate::channels::CliChatChannel>,
bus: Arc<MessageBus>,
) -> Self {
pub fn with_bus(cli_chat_channel: Arc<crate::channels::CliChatChannel>, bus: Arc<MessageBus>) -> Self {
Self {
channels: Arc::new(RwLock::new(HashMap::new())),
cli_chat_channel,
@ -42,10 +39,7 @@ impl ChannelManager {
/// Register a channel with the manager
pub async fn register_channel(&self, name: &str, channel: Arc<dyn Channel + Send + Sync>) {
self.channels
.write()
.await
.insert(name.to_string(), channel);
self.channels.write().await.insert(name.to_string(), channel);
}
/// Get CLI chat channel
@ -62,19 +56,14 @@ impl ChannelManager {
// Initialize Feishu channel if enabled
if let Some(feishu_config) = config.channels.get("feishu") {
if feishu_config.enabled {
let channel =
FeishuChannel::new(feishu_config.clone(), &workspace_dir).map_err(|e| {
ChannelError::Other(format!("Failed to create Feishu channel: {}", e))
})?;
let channel = FeishuChannel::new(feishu_config.clone(), &workspace_dir)
.map_err(|e| ChannelError::Other(format!("Failed to create Feishu channel: {}", e)))?;
self.channels
.write()
.await
.insert("feishu".to_string(), Arc::new(channel));
tracing::info!(
"Feishu channel registered (media_dir: {}/media/feishu)",
workspace_dir.display()
);
tracing::info!("Feishu channel registered (media_dir: {}/media/feishu)", workspace_dir.display());
} else {
tracing::info!("Feishu channel disabled in config");
}
@ -129,10 +118,7 @@ impl ChannelManager {
if let Some(channel) = self.get_channel(channel_name).await {
channel.send(msg).await
} else {
Err(ChannelError::Other(format!(
"Channel not found: {}",
channel_name
)))
Err(ChannelError::Other(format!("Channel not found: {}", channel_name)))
}
}
}

View File

@ -1,11 +1,11 @@
pub mod base;
pub mod cli_chat;
pub mod feishu;
pub mod cli_chat;
pub mod manager;
pub mod slash_command;
pub use base::{Channel, ChannelError};
pub use cli_chat::CliChatChannel;
pub use feishu::FeishuChannel;
pub use manager::ChannelManager;
pub use slash_command::{command_matches, parse_slash_command};
pub use feishu::FeishuChannel;
pub use cli_chat::CliChatChannel;
pub use slash_command::{parse_slash_command, command_matches};

View File

@ -16,9 +16,7 @@ pub fn parse_slash_command(content: &str) -> Option<(&str, &str)> {
/// 检查内容是否匹配指定命令
pub fn command_matches(content: &str, aliases: &[&str]) -> bool {
let trimmed = content.trim();
aliases
.iter()
.any(|&alias| trimmed == alias || trimmed.starts_with(&format!("{} ", alias)))
aliases.iter().any(|&alias| trimmed == alias || trimmed.starts_with(&format!("{} ", alias)))
}
#[cfg(test)]
@ -29,10 +27,7 @@ mod tests {
fn test_parse_slash_command() {
assert_eq!(parse_slash_command("/reset"), Some(("reset", "")));
assert_eq!(parse_slash_command("/reset arg"), Some(("reset", "arg")));
assert_eq!(
parse_slash_command("/new hello world"),
Some(("new", "hello world"))
);
assert_eq!(parse_slash_command("/new hello world"), Some(("new", "hello world")));
assert_eq!(parse_slash_command("/??"), Some(("??", "")));
assert_eq!(parse_slash_command("/? arg"), Some(("?", "arg")));
assert_eq!(parse_slash_command("/?"), Some(("?", "")));

View File

@ -8,10 +8,10 @@ use crate::client::tui::ui::render_ui;
use crossterm::{
event::{self, Event},
execute,
terminal::{EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, enable_raw_mode},
terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen},
};
use futures_util::{SinkExt, StreamExt};
use ratatui::{Terminal, prelude::CrosstermBackend};
use ratatui::{prelude::CrosstermBackend, Terminal};
use std::io;
use tokio_tungstenite::{connect_async, tungstenite::Message};
@ -104,10 +104,7 @@ async fn handle_ws_message(app: &mut App, outbound: WsOutbound) {
WsOutbound::SessionCreated { session_id, .. } => {
app.set_current_session(Some(session_id));
}
WsOutbound::SessionList {
sessions,
current_session_id,
} => {
WsOutbound::SessionList { sessions, current_session_id } => {
app.set_sessions(sessions);
if let Some(id) = current_session_id {
app.set_current_session(Some(id));

View File

@ -1,10 +1,10 @@
use crate::client::tui::app::{App, MessageRole};
use ratatui::{
Frame,
layout::Rect,
style::{Color, Modifier, Style},
text::Line,
widgets::{Block, Borders, List, ListItem},
Frame,
};
pub fn render(f: &mut Frame, area: Rect, app: &App) {

View File

@ -1,10 +1,10 @@
use crate::client::tui::app::App;
use ratatui::{
Frame,
layout::Rect,
style::{Color, Modifier, Style},
text::{Line, Span},
widgets::{Block, Borders, List, ListItem},
Frame,
};
pub fn render(f: &mut Frame, area: Rect, app: &App) {

View File

@ -1,8 +1,8 @@
use ratatui::{
Frame,
layout::Rect,
style::{Color, Modifier, Style},
widgets::{Block, Borders, Clear, List, ListItem},
Frame,
};
pub fn render(f: &mut Frame, area: Rect) {

View File

@ -1,9 +1,9 @@
use crate::client::tui::app::App;
use ratatui::{
Frame,
layout::Rect,
style::{Color, Style},
widgets::{Block, Borders, Paragraph},
Frame,
};
pub fn render(f: &mut Frame, area: Rect, app: &App) {

View File

@ -1,9 +1,9 @@
use crate::client::tui::app::App;
use ratatui::{
Frame,
layout::Rect,
style::{Color, Modifier, Style},
widgets::{Block, Borders, List, ListItem},
Frame,
};
pub fn render(f: &mut Frame, area: Rect, app: &App) {
@ -11,7 +11,9 @@ pub fn render(f: &mut Frame, area: Rect, app: &App) {
.sessions
.iter()
.map(|session| {
let is_current = app.current_session_id.as_ref() == Some(&session.session_id);
let is_current = app
.current_session_id
.as_ref() == Some(&session.session_id);
let archived = session.archived_at.is_some();
let mut content = if is_current {

View File

@ -1,18 +1,15 @@
use crate::client::tui::app::App;
use ratatui::{
Frame,
layout::Rect,
style::{Color, Modifier, Style},
widgets::{Block, Borders, Paragraph},
Frame,
};
pub fn render(f: &mut Frame, area: Rect, app: &App) {
let (title, style) = if app.pending_quit {
let msg = if let Some(session_id) = &app.current_session_id {
format!(
"PicoBot | Session: {} | Press Ctrl+C again to quit",
session_id
)
format!("PicoBot | Session: {} | Press Ctrl+C again to quit", session_id)
} else {
"PicoBot | Press Ctrl+C again to quit".to_string()
};

View File

@ -1,6 +1,6 @@
use crate::client::tui::app::{App, MessageRole};
use crate::protocol::WsInbound;
use crate::protocol::serialize_inbound;
use crate::protocol::WsInbound;
use crossterm::event::{KeyCode, KeyEvent};
use futures_util::SinkExt;
@ -48,10 +48,7 @@ pub async fn handle_key_event(app: &mut App, key: KeyEvent) {
async fn handle_normal_input(app: &mut App, key: KeyEvent) {
// Handle Ctrl+C for quit (double press to exit)
let is_ctrl_c = key.code == KeyCode::Char('c')
&& key
.modifiers
.contains(crossterm::event::KeyModifiers::CONTROL);
let is_ctrl_c = key.code == KeyCode::Char('c') && key.modifiers.contains(crossterm::event::KeyModifiers::CONTROL);
if is_ctrl_c {
if app.handle_ctrl_c_for_quit() {
return;
@ -66,11 +63,9 @@ async fn handle_normal_input(app: &mut App, key: KeyEvent) {
}
KeyCode::Char(c) => {
app.input_insert_char(c);
// Show command menu when input starts with /
if !app.show_command_menu
&& (app.input == "/" || (app.input.len() > 1 && app.input.starts_with('/')))
{
if !app.show_command_menu && (app.input == "/" || (app.input.len() > 1 && app.input.starts_with('/'))) {
app.show_command_menu = true;
app.selected_command_idx = 0;
} else if app.show_command_menu && !app.input.starts_with('/') {
@ -79,7 +74,7 @@ async fn handle_normal_input(app: &mut App, key: KeyEvent) {
}
KeyCode::Backspace => {
app.input_delete_char();
// Hide menu if input no longer starts with /
if app.show_command_menu && !app.input.starts_with('/') {
app.show_command_menu = false;
@ -126,9 +121,7 @@ async fn process_input(app: &mut App, input: String) {
sender_id: None,
};
if let Ok(text) = serialize_inbound(&inbound) {
let _ = sender
.send(tokio_tungstenite::tungstenite::Message::Text(text.into()))
.await;
let _ = sender.send(tokio_tungstenite::tungstenite::Message::Text(text.into())).await;
}
}
}

View File

@ -1,8 +1,8 @@
use crate::client::tui::app::App;
use crate::client::tui::components::*;
use ratatui::{
Frame,
layout::{Constraint, Direction, Layout, Rect},
Frame,
};
pub fn render_ui(f: &mut Frame, app: &App) {

View File

@ -152,26 +152,10 @@ pub struct GatewayConfig {
pub cleanup_interval_minutes: Option<u64>,
#[serde(default, rename = "session_db_path")]
pub session_db_path: Option<String>,
#[serde(default, rename = "max_concurrent_background_tasks")]
pub max_concurrent_background_tasks: usize,
#[serde(default)]
pub scheduler: Option<SchedulerConfig>,
}
impl Default for GatewayConfig {
fn default() -> Self {
Self {
host: default_gateway_host(),
port: default_gateway_port(),
session_ttl_hours: None,
cleanup_interval_minutes: None,
session_db_path: None,
max_concurrent_background_tasks: 10,
scheduler: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SchedulerConfig {
/// Whether the scheduler is enabled
@ -225,6 +209,19 @@ fn default_gateway_url() -> String {
"ws://127.0.0.1:19876/ws".to_string()
}
impl Default for GatewayConfig {
fn default() -> Self {
Self {
host: default_gateway_host(),
port: default_gateway_port(),
session_ttl_hours: None,
cleanup_interval_minutes: None,
session_db_path: None,
scheduler: None,
}
}
}
impl Default for ClientConfig {
fn default() -> Self {
Self {
@ -273,16 +270,12 @@ impl Default for MemoryConfig {
impl MemoryConfig {
/// Resolve consolidation provider name, falling back to the main agent's provider.
pub fn resolve_consolidation_provider(&self, default: &str) -> String {
self.consolidation_provider
.clone()
.unwrap_or_else(|| default.to_string())
self.consolidation_provider.clone().unwrap_or_else(|| default.to_string())
}
/// Resolve consolidation model name, falling back to the main agent's model.
pub fn resolve_consolidation_model(&self, default: &str) -> String {
self.consolidation_model
.clone()
.unwrap_or_else(|| default.to_string())
self.consolidation_model.clone().unwrap_or_else(|| default.to_string())
}
}
@ -370,18 +363,10 @@ impl Default for BrowserConfig {
}
}
fn default_recall_limit() -> usize {
5
}
fn default_idle_consolidation_minutes() -> u64 {
10
}
fn default_timeline_retention_days() -> u64 {
90
}
fn default_max_failures_before_degrade() -> usize {
3
}
fn default_recall_limit() -> usize { 5 }
fn default_idle_consolidation_minutes() -> u64 { 10 }
fn default_timeline_retention_days() -> u64 { 90 }
fn default_max_failures_before_degrade() -> usize { 3 }
#[derive(Debug, Clone)]
pub struct LLMProviderConfig {
@ -481,11 +466,7 @@ pub enum ConfigError {
impl std::fmt::Display for ConfigError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConfigError::ConfigNotFound(path) => write!(
f,
"Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json",
path
),
ConfigError::ConfigNotFound(path) => write!(f, "Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json", path),
ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name),
ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name),
ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name),

View File

@ -1,19 +1,19 @@
pub mod http;
pub mod ws;
use axum::{Router, routing};
use std::sync::Arc;
use axum::{routing, Router};
use tokio::net::TcpListener;
use crate::bus::{ControlMessage, MessageBus, OutboundDispatcher};
use crate::channels::base::{Channel, ChannelError};
use crate::channels::{ChannelManager, CliChatChannel};
use crate::config::{Config, ensure_workspace_dir, expand_path};
use crate::channels::base::{Channel, ChannelError};
use crate::config::{Config, expand_path, ensure_workspace_dir};
use crate::logging;
use crate::mcp;
use crate::memory::MemoryManager;
use crate::scheduler::Scheduler;
use crate::session::SessionManager;
use crate::scheduler::Scheduler;
pub struct GatewayState {
pub config: Config,
@ -32,13 +32,8 @@ impl GatewayState {
let workspace_path = ensure_workspace_dir(&workspace_path)?;
// Switch current working directory to workspace
std::env::set_current_dir(&workspace_path).map_err(|e| {
format!(
"Failed to switch to workspace directory {}: {}",
workspace_path.display(),
e
)
})?;
std::env::set_current_dir(&workspace_path)
.map_err(|e| format!("Failed to switch to workspace directory {}: {}", workspace_path.display(), e))?;
tracing::info!("Using workspace directory: {}", workspace_path.display());
@ -57,9 +52,8 @@ impl GatewayState {
workspace_path.join("picobot.db")
};
let storage = Arc::new(
crate::storage::Storage::new(&db_path)
.await
.map_err(|e| format!("failed to initialize session storage: {}", e))?,
crate::storage::Storage::new(&db_path).await
.map_err(|e| format!("failed to initialize session storage: {}", e))?
);
tracing::info!("Session storage: {}", db_path.display());
@ -97,16 +91,13 @@ impl GatewayState {
bus.clone(),
memory_manager,
browser_config,
config.gateway.max_concurrent_background_tasks,
)?;
let session_manager = Arc::new(session_manager);
// Create ChannelManager and init channels
let cli_chat_channel = Arc::new(CliChatChannel::new());
let channel_manager = ChannelManager::with_bus(cli_chat_channel, bus);
channel_manager
.init(&config, workspace_path.clone())
.await
channel_manager.init(&config, workspace_path.clone()).await
.map_err(|e| format!("Failed to init channels: {}", e))?;
// Register send_message tool with available channel names
@ -115,12 +106,9 @@ impl GatewayState {
session_manager.register_outbound_tool(available_channels);
// Register chat_manager tool
session_manager
.tools()
.register(crate::tools::ChatManagerTool::new(
storage.clone(),
valid_channels.clone(),
));
session_manager.tools().register(
crate::tools::ChatManagerTool::new(storage.clone(), valid_channels.clone()),
);
// Initialize MCP servers — connect and register discovered tools
if !config.mcp.servers.is_empty() {
@ -141,27 +129,24 @@ impl GatewayState {
let scheduler_config = config.gateway.scheduler.clone().unwrap_or_default();
if scheduler_config.enabled {
// Register cron tools
session_manager
.tools()
.register(crate::tools::cron::CronAddTool::new(
storage.clone(),
valid_channels,
));
session_manager
.tools()
.register(crate::tools::cron::CronListTool::new(storage.clone()));
session_manager
.tools()
.register(crate::tools::cron::CronRemoveTool::new(storage.clone()));
session_manager
.tools()
.register(crate::tools::cron::CronEnableTool::new(storage.clone()));
session_manager
.tools()
.register(crate::tools::cron::CronDisableTool::new(storage.clone()));
session_manager
.tools()
.register(crate::tools::cron::CronUpdateTool::new(storage.clone()));
session_manager.tools().register(
crate::tools::cron::CronAddTool::new(storage.clone(), valid_channels),
);
session_manager.tools().register(
crate::tools::cron::CronListTool::new(storage.clone()),
);
session_manager.tools().register(
crate::tools::cron::CronRemoveTool::new(storage.clone()),
);
session_manager.tools().register(
crate::tools::cron::CronEnableTool::new(storage.clone()),
);
session_manager.tools().register(
crate::tools::cron::CronDisableTool::new(storage.clone()),
);
session_manager.tools().register(
crate::tools::cron::CronUpdateTool::new(storage.clone()),
);
tracing::info!("Cron tools registered");
}
@ -282,103 +267,71 @@ impl GatewayState {
}
/// Handle control messages (session management operations)
async fn handle_control_message(session_manager: &SessionManager, msg: ControlMessage) {
async fn handle_control_message(
session_manager: &SessionManager,
msg: ControlMessage,
) {
use crate::session::{SessionCommand::*, SessionEvent};
let reply_tx = msg.reply_tx;
let result: Result<SessionEvent, ChannelError> = match msg.op {
CreateDialog {
channel,
chat_id,
title,
} => session_manager
.create_dialog(&channel, &chat_id, title.as_deref())
.await
.map(|(session_id, title)| SessionEvent::DialogCreated { session_id, title })
.map_err(|e| ChannelError::Other(e.to_string())),
ListDialogs {
channel,
chat_id,
include_archived,
} => session_manager
.list_dialogs(&channel, &chat_id, include_archived)
.await
.map(|(dialogs, current_dialog_id)| SessionEvent::DialogList {
dialogs,
current_dialog_id,
})
.map_err(|e| ChannelError::Other(e.to_string())),
GetCurrentDialog { channel, chat_id } => session_manager
.get_current_dialog(&channel, &chat_id)
.await
.map(|session_id| SessionEvent::CurrentDialog { session_id })
.map_err(|e| ChannelError::Other(e.to_string())),
SwitchDialog {
channel,
chat_id,
dialog_id,
} => session_manager
.switch_dialog(&channel, &chat_id, &dialog_id)
.await
.map(|session_id| SessionEvent::DialogSwitched { session_id })
.map_err(|e| ChannelError::Other(e.to_string())),
RenameDialog { session_id, title } => session_manager
.rename_dialog(&session_id, &title)
.await
.map(|()| SessionEvent::DialogRenamed { session_id, title })
.map_err(|e| ChannelError::Other(e.to_string())),
ArchiveDialog { session_id } => session_manager
.archive_dialog(&session_id)
.await
.map(|()| SessionEvent::DialogArchived { session_id })
.map_err(|e| ChannelError::Other(e.to_string())),
DeleteDialog { session_id } => session_manager
.delete_dialog(&session_id)
.await
.map(|()| SessionEvent::DialogDeleted { session_id })
.map_err(|e| ChannelError::Other(e.to_string())),
ClearHistory { session_id } => session_manager
.clear_dialog_history(&session_id)
.await
.map(|()| SessionEvent::HistoryCleared { session_id })
.map_err(|e| ChannelError::Other(e.to_string())),
GetSlashCommands {
channel: _,
chat_id: _,
} => {
CreateDialog { channel, chat_id, title } => {
session_manager.create_dialog(&channel, &chat_id, title.as_deref()).await
.map(|(session_id, title)| SessionEvent::DialogCreated { session_id, title })
.map_err(|e| ChannelError::Other(e.to_string()))
}
ListDialogs { channel, chat_id, include_archived } => {
session_manager.list_dialogs(&channel, &chat_id, include_archived).await
.map(|(dialogs, current_dialog_id)| SessionEvent::DialogList { dialogs, current_dialog_id })
.map_err(|e| ChannelError::Other(e.to_string()))
}
GetCurrentDialog { channel, chat_id } => {
session_manager.get_current_dialog(&channel, &chat_id).await
.map(|session_id| SessionEvent::CurrentDialog { session_id })
.map_err(|e| ChannelError::Other(e.to_string()))
}
SwitchDialog { channel, chat_id, dialog_id } => {
session_manager.switch_dialog(&channel, &chat_id, &dialog_id).await
.map(|session_id| SessionEvent::DialogSwitched { session_id })
.map_err(|e| ChannelError::Other(e.to_string()))
}
RenameDialog { session_id, title } => {
session_manager.rename_dialog(&session_id, &title).await
.map(|()| SessionEvent::DialogRenamed { session_id, title })
.map_err(|e| ChannelError::Other(e.to_string()))
}
ArchiveDialog { session_id } => {
session_manager.archive_dialog(&session_id)
.map(|()| SessionEvent::DialogArchived { session_id })
.map_err(|e| ChannelError::Other(e.to_string()))
}
DeleteDialog { session_id } => {
session_manager.delete_dialog(&session_id).await
.map(|()| SessionEvent::DialogDeleted { session_id })
.map_err(|e| ChannelError::Other(e.to_string()))
}
ClearHistory { session_id } => {
session_manager.clear_dialog_history(&session_id)
.map(|()| SessionEvent::HistoryCleared { session_id })
.map_err(|e| ChannelError::Other(e.to_string()))
}
GetSlashCommands { channel: _, chat_id: _ } => {
let commands = session_manager.get_slash_commands().to_vec();
Ok(SessionEvent::SlashCommandsList { commands })
}
ExecuteSlashCommand {
command,
args,
channel,
chat_id,
current_session_id,
} => session_manager
.execute_slash_command(
&command,
args.as_deref(),
&channel,
&chat_id,
current_session_id.as_ref(),
)
.await
.map(|(new_id, msg)| SessionEvent::SlashCommandExecuted {
new_session_id: new_id,
message: msg,
})
.map_err(|e| ChannelError::Other(e.to_string())),
ExecuteSlashCommand { command, args, channel, chat_id, current_session_id } => {
session_manager.execute_slash_command(&command, args.as_deref(), &channel, &chat_id, current_session_id.as_ref())
.await
.map(|(new_id, msg)| SessionEvent::SlashCommandExecuted { new_session_id: new_id, message: msg })
.map_err(|e| ChannelError::Other(e.to_string()))
}
};
let _ = reply_tx.send(result).await;
}
}
pub async fn run(
host: Option<String>,
port: Option<u16>,
) -> Result<(), Box<dyn std::error::Error>> {
pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn std::error::Error>> {
// Initialize logging
logging::init_logging();
tracing::info!("Starting PicoBot Gateway");

View File

@ -1,12 +1,12 @@
use super::GatewayState;
use crate::protocol::WsOutbound;
use crate::protocol::serialize_outbound;
use std::sync::Arc;
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage};
use axum::extract::State;
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
use axum::response::Response;
use futures_util::{SinkExt, StreamExt};
use std::sync::Arc;
use tokio::sync::mpsc;
use crate::protocol::serialize_outbound;
use crate::protocol::WsOutbound;
use super::GatewayState;
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
ws.on_upgrade(|socket| async move {
@ -25,11 +25,9 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
let (session_id, client) = cli_chat_channel.register_client(sender.clone()).await;
// Send session established message
let _ = sender
.send(WsOutbound::SessionEstablished {
session_id: session_id.clone(),
})
.await;
let _ = sender.send(WsOutbound::SessionEstablished {
session_id: session_id.clone(),
}).await;
tracing::info!(session_id = %session_id, "CLI session established");
@ -39,10 +37,9 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
tokio::spawn(async move {
while let Some(msg) = receiver.recv().await {
if let Ok(text) = serialize_outbound(&msg)
&& ws_sender.send(WsMessage::Text(text.into())).await.is_err()
{
break;
}
&& ws_sender.send(WsMessage::Text(text.into())).await.is_err() {
break;
}
}
});

View File

@ -1,17 +1,17 @@
pub mod agent;
pub mod bus;
pub mod channels;
pub mod client;
pub mod config;
pub mod providers;
pub mod bus;
pub mod agent;
pub mod gateway;
pub mod session;
pub mod client;
pub mod protocol;
pub mod channels;
pub mod logging;
pub mod mcp;
pub mod memory;
pub mod observability;
pub mod protocol;
pub mod providers;
pub mod scheduler;
pub mod session;
pub mod skills;
pub mod storage;
pub mod tools;

View File

@ -1,7 +1,11 @@
use std::path::PathBuf;
use tracing_appender::rolling::{RollingFileAppender, Rotation};
use tracing_subscriber::{
EnvFilter, fmt, fmt::time::LocalTime, layer::SubscriberExt, util::SubscriberInitExt,
fmt,
layer::SubscriberExt,
util::SubscriberInitExt,
fmt::time::LocalTime,
EnvFilter,
};
/// Get the default log directory path: ~/.picobot/logs
@ -23,20 +27,20 @@ pub fn init_logging() {
// Create log directory if it doesn't exist
if !log_dir.exists()
&& let Err(e) = std::fs::create_dir_all(&log_dir)
{
eprintln!(
"Warning: Failed to create log directory {}: {}",
log_dir.display(),
e
);
}
&& let Err(e) = std::fs::create_dir_all(&log_dir) {
eprintln!("Warning: Failed to create log directory {}: {}", log_dir.display(), e);
}
// Create file appender with daily rotation
let file_appender = RollingFileAppender::new(Rotation::DAILY, &log_dir, "picobot.log");
let file_appender = RollingFileAppender::new(
Rotation::DAILY,
&log_dir,
"picobot.log",
);
// Build subscriber with both console and file output
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
let env_filter = EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new("info"));
let file_layer = fmt::layer()
.with_writer(file_appender)
@ -62,7 +66,8 @@ pub fn init_logging() {
/// Initialize logging without file output (console only)
pub fn init_logging_console_only() {
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
let env_filter = EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new("info"));
let console_layer = fmt::layer()
.with_timer(LocalTime::rfc_3339())

View File

@ -1,9 +1,8 @@
use clap::{CommandFactory, Parser};
use clap::{Parser, CommandFactory};
#[derive(Parser)]
#[command(name = "picobot")]
#[command(about = "A CLI chatbot", long_about = None)]
#[command(version = "1.1.0")]
enum Command {
/// Connect to gateway
Chat {

View File

@ -92,19 +92,24 @@ fn extract_text(result: &rmcp::model::CallToolResult) -> String {
parts.push(text.text.clone());
}
RawContent::Image(image) => {
parts.push(format!("[image: {}]", image.mime_type,));
parts.push(format!(
"[image: {}]",
image.mime_type,
));
}
RawContent::Resource(resource) => match &resource.resource {
rmcp::model::ResourceContents::TextResourceContents { text, .. } => {
parts.push(format!(
"[resource text: {}]",
text.chars().take(200).collect::<String>(),
));
RawContent::Resource(resource) => {
match &resource.resource {
rmcp::model::ResourceContents::TextResourceContents { text, .. } => {
parts.push(format!(
"[resource text: {}]",
text.chars().take(200).collect::<String>(),
));
}
rmcp::model::ResourceContents::BlobResourceContents { uri, .. } => {
parts.push(format!("[resource blob: {}]", uri));
}
}
rmcp::model::ResourceContents::BlobResourceContents { uri, .. } => {
parts.push(format!("[resource blob: {}]", uri));
}
},
}
_ => {
parts.push("[unsupported content]".to_string());
}
@ -220,8 +225,8 @@ async fn connect_server(config: &McpServerConfig) -> anyhow::Result<McpConnectio
cmd.env(k, v);
}
let service =
().serve(
let service = ()
.serve(
TokioChildProcess::new(cmd).context("failed to create stdio MCP transport")?,
)
.await
@ -256,14 +261,14 @@ async fn connect_server(config: &McpServerConfig) -> anyhow::Result<McpConnectio
} else {
StreamableHttpClientTransport::from_config(
StreamableHttpClientTransportConfig::with_uri(url.to_string())
.custom_headers(headers_map),
.custom_headers(headers_map)
)
};
let service =
().serve(transport)
.await
.context("failed to connect to HTTP/SSE MCP server")?;
let service = ()
.serve(transport)
.await
.context("failed to connect to HTTP/SSE MCP server")?;
let peer = service.peer().clone();

View File

@ -102,11 +102,7 @@ mod tests {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let storage = Arc::new(Storage::new(&db_path).await.unwrap());
let mm = Arc::new(MemoryManager::new(
storage,
"default".into(),
"default".into(),
));
let mm = Arc::new(MemoryManager::new(storage, "default".into(), "default".into()));
(mm, dir)
}
@ -135,9 +131,15 @@ mod tests {
async fn test_upsert_overwrites() {
let (mm, _dir) = setup_memory_manager().await;
mm.store("dup_key", "original", MemoryCategory::Knowledge, None, None)
.await
.unwrap();
mm.store(
"dup_key",
"original",
MemoryCategory::Knowledge,
None,
None,
)
.await
.unwrap();
mm.store(
"dup_key",
"updated",
@ -245,12 +247,7 @@ mod tests {
// Recall scoped to session A — should get only tl_a
let scoped = mm
.recall(
"summary",
10,
Some(MemoryCategory::Timeline),
Some("chan:chat:dialog_a"),
)
.recall("summary", 10, Some(MemoryCategory::Timeline), Some("chan:chat:dialog_a"))
.await
.unwrap();
assert_eq!(scoped.len(), 1);

View File

@ -20,7 +20,10 @@ pub enum ObserverEvent {
success: bool,
},
/// Emitted when the agent starts processing.
AgentStart { provider: String, model: String },
AgentStart {
provider: String,
model: String,
},
/// Emitted when the agent finishes processing.
AgentEnd {
provider: String,
@ -91,11 +94,7 @@ impl ToolExecutionOutcome {
}
/// Create a failed outcome with duration.
pub fn failure_with_duration(
output: String,
error_reason: Option<String>,
duration: Duration,
) -> Self {
pub fn failure_with_duration(output: String, error_reason: Option<String>, duration: Duration) -> Self {
Self {
output,
success: false,

View File

@ -4,24 +4,23 @@ use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
use super::traits::Usage;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
use crate::bus::message::ContentBlock;
use crate::storage::Storage;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
use super::traits::Usage;
use std::sync::Arc;
use crate::storage::Storage;
const LLM_REQUEST_TIMEOUT_SECS: u64 = 300;
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> {
blocks
.iter()
.map(|b| match b {
ContentBlock::Text { text } => {
serde_json::json!({ "type": "text", "text": text })
}
ContentBlock::ImageUrl { image_url } => convert_image_url_to_anthropic(&image_url.url),
})
.collect()
blocks.iter().map(|b| match b {
ContentBlock::Text { text } => {
serde_json::json!({ "type": "text", "text": text })
}
ContentBlock::ImageUrl { image_url } => {
convert_image_url_to_anthropic(&image_url.url)
}
}).collect()
}
fn convert_image_url_to_anthropic(url: &str) -> serde_json::Value {
@ -198,13 +197,8 @@ impl LLMProvider for AnthropicProvider {
};
let content = if let Some(ref tc_id) = m.tool_call_id {
// Tool result: wrap as tool_result content block
let output = m
.content
.iter()
.filter_map(|b| match b {
ContentBlock::Text { text } => Some(text.as_str()),
_ => None,
})
let output = m.content.iter()
.filter_map(|b| match b { ContentBlock::Text { text } => Some(text.as_str()), _ => None })
.collect::<Vec<_>>()
.join("");
vec![serde_json::json!({
@ -250,18 +244,19 @@ impl LLMProvider for AnthropicProvider {
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
tracing::debug!(req_body = %req_body_str, "LLM request");
let resp = req_builder.json(&body).send().await.inspect_err(|e| {
let is_timeout = e.is_timeout();
tracing::error!(
provider = %self.name,
model = %self.model_id,
url = %url,
timeout = is_timeout,
error = %e,
elapsed_ms = %start.elapsed().as_millis(),
"LLM API request failed"
);
})?;
let resp = req_builder.json(&body).send().await
.inspect_err(|e| {
let is_timeout = e.is_timeout();
tracing::error!(
provider = %self.name,
model = %self.model_id,
url = %url,
timeout = is_timeout,
error = %e,
elapsed_ms = %start.elapsed().as_millis(),
"LLM API request failed"
);
})?;
let status = resp.status();
let body_text = resp.text().await?;
@ -286,38 +281,32 @@ impl LLMProvider for AnthropicProvider {
"LLM API returned error"
);
if let Some(ref storage) = self.storage {
let _ = storage
.append_llm_call(
&self.name,
&self.model_id,
&req_body_str,
Some(&body_text),
Some(&error_msg),
start.elapsed().as_millis() as u64,
)
.await;
let _ = storage.append_llm_call(
&self.name, &self.model_id, &req_body_str,
Some(&body_text), Some(&error_msg),
start.elapsed().as_millis() as u64,
).await;
}
return Err(format!("API error ({}): {}", status.as_u16(), error_msg).into());
}
let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text).map_err(|e| {
let err_msg = format!("decode error: {} | body: {}", e, &body_text);
if let Some(ref storage) = self.storage {
let name = self.name.clone();
let model = self.model_id.clone();
let req = req_body_str.clone();
let resp_body = body_text.clone();
let dur = start.elapsed().as_millis() as u64;
let err = err_msg.clone();
let s = storage.clone();
tokio::spawn(async move {
let _ = s
.append_llm_call(&name, &model, &req, Some(&resp_body), Some(&err), dur)
.await;
});
}
err_msg
})?;
let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text)
.map_err(|e| {
let err_msg = format!("decode error: {} | body: {}", e, &body_text);
if let Some(ref storage) = self.storage {
let name = self.name.clone();
let model = self.model_id.clone();
let req = req_body_str.clone();
let resp_body = body_text.clone();
let dur = start.elapsed().as_millis() as u64;
let err = err_msg.clone();
let s = storage.clone();
tokio::spawn(async move {
let _ = s.append_llm_call(&name, &model, &req, Some(&resp_body), Some(&err), dur).await;
});
}
err_msg
})?;
let mut content = String::new();
let mut reasoning = None;
@ -354,35 +343,21 @@ impl LLMProvider for AnthropicProvider {
reasoning_content: reasoning,
tool_calls,
usage: Usage {
prompt_tokens: anthropic_resp
.usage
.as_ref()
.map(|u| u.input_tokens)
.unwrap_or(0),
completion_tokens: anthropic_resp
.usage
.as_ref()
.map(|u| u.output_tokens)
.unwrap_or(0),
total_tokens: anthropic_resp
.usage
.as_ref()
.map(|u| u.input_tokens + u.output_tokens)
.unwrap_or(0),
prompt_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens).unwrap_or(0),
completion_tokens: anthropic_resp.usage.as_ref().map(|u| u.output_tokens).unwrap_or(0),
total_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens + u.output_tokens).unwrap_or(0),
},
};
if let Some(ref storage) = self.storage {
let _ = storage
.append_llm_call(
&self.name,
&self.model_id,
&req_body_str,
Some(&body_text),
None,
start.elapsed().as_millis() as u64,
)
.await;
let _ = storage.append_llm_call(
&self.name,
&self.model_id,
&req_body_str,
Some(&body_text),
None,
start.elapsed().as_millis() as u64,
).await;
}
Ok(response)

View File

@ -1,15 +1,12 @@
pub mod anthropic;
pub mod openai;
pub mod traits;
pub mod openai;
pub mod anthropic;
pub use self::anthropic::AnthropicProvider;
pub use self::openai::OpenAIProvider;
pub use self::anthropic::AnthropicProvider;
use crate::config::LLMProviderConfig;
pub use traits::{
ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall,
ToolFunction, Usage,
};
pub use traits::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall, ToolFunction, Usage};
pub fn create_provider(config: LLMProviderConfig) -> Result<Box<dyn LLMProvider>, ProviderError> {
match config.provider_type.as_str() {

View File

@ -1,35 +1,29 @@
use async_trait::async_trait;
use reqwest::Client;
use serde::Deserialize;
use serde_json::{Value, json};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::time::Duration;
use super::traits::Usage;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
use crate::bus::message::ContentBlock;
use crate::storage::Storage;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
use super::traits::Usage;
use std::sync::Arc;
use crate::storage::Storage;
const LLM_REQUEST_TIMEOUT_SECS: u64 = 300;
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
if blocks.len() == 1
&& let ContentBlock::Text { text } = &blocks[0]
{
return Value::String(text.clone());
}
Value::Array(
blocks
.iter()
.map(|b| match b {
ContentBlock::Text { text } => json!({ "type": "text", "text": text }),
ContentBlock::ImageUrl { image_url } => {
json!({ "type": "image_url", "image_url": { "url": image_url.url } })
}
})
.collect(),
)
&& let ContentBlock::Text { text } = &blocks[0] {
return Value::String(text.clone());
}
Value::Array(blocks.iter().map(|b| match b {
ContentBlock::Text { text } => json!({ "type": "text", "text": text }),
ContentBlock::ImageUrl { image_url } => {
json!({ "type": "image_url", "image_url": { "url": image_url.url } })
}
}).collect())
}
pub struct OpenAIProvider {
@ -207,14 +201,10 @@ impl LLMProvider for OpenAIProvider {
if let Some(content) = msg.get("content").and_then(|c| c.as_array()) {
for (j, item) in content.iter().enumerate() {
if item.get("type").and_then(|t| t.as_str()) == Some("image_url")
&& let Some(url_str) = item
.get("image_url")
.and_then(|u| u.get("url"))
.and_then(|v| v.as_str())
{
let prefix: String = url_str.chars().take(20).collect();
tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)");
}
&& let Some(url_str) = item.get("image_url").and_then(|u| u.get("url")).and_then(|v| v.as_str()) {
let prefix: String = url_str.chars().take(20).collect();
tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)");
}
}
}
}
@ -234,18 +224,19 @@ impl LLMProvider for OpenAIProvider {
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
tracing::debug!(req_body = %req_body_str, "LLM request");
let resp = req_builder.json(&body).send().await.inspect_err(|e| {
let is_timeout = e.is_timeout();
tracing::error!(
provider = %self.name,
model = %self.model_id,
url = %url,
timeout = is_timeout,
error = %e,
elapsed_ms = %start.elapsed().as_millis(),
"LLM API request failed"
);
})?;
let resp = req_builder.json(&body).send().await
.inspect_err(|e| {
let is_timeout = e.is_timeout();
tracing::error!(
provider = %self.name,
model = %self.model_id,
url = %url,
timeout = is_timeout,
error = %e,
elapsed_ms = %start.elapsed().as_millis(),
"LLM API request failed"
);
})?;
let status = resp.status();
let text = resp.text().await?;
@ -262,48 +253,37 @@ impl LLMProvider for OpenAIProvider {
"LLM API returned error"
);
if let Some(ref storage) = self.storage
&& let Err(e) = storage
.append_llm_call(
&self.name,
&self.model_id,
&req_body_str,
Some(&text),
Some(&error),
start.elapsed().as_millis() as u64,
)
.await
{
tracing::warn!("failed to persist LLM call: {}", e);
}
&& let Err(e) = storage.append_llm_call(
&self.name, &self.model_id, &req_body_str,
Some(&text), Some(&error),
start.elapsed().as_millis() as u64,
).await {
tracing::warn!("failed to persist LLM call: {}", e);
}
return Err(error.into());
}
let openai_resp: OpenAIResponse = serde_json::from_str(&text).map_err(|e| {
let err_msg = format!("decode error: {} | body: {}", e, &text);
if let Some(ref storage) = self.storage {
let name = self.name.clone();
let model = self.model_id.clone();
let req = req_body_str.clone();
let resp = text.clone();
let dur = start.elapsed().as_millis() as u64;
let err = err_msg.clone();
let s = storage.clone();
tokio::spawn(async move {
if let Err(e) = s
.append_llm_call(&name, &model, &req, Some(&resp), Some(&err), dur)
.await
{
tracing::warn!("failed to persist LLM call (decode error): {}", e);
}
});
}
err_msg
})?;
let openai_resp: OpenAIResponse = serde_json::from_str(&text)
.map_err(|e| {
let err_msg = format!("decode error: {} | body: {}", e, &text);
if let Some(ref storage) = self.storage {
let name = self.name.clone();
let model = self.model_id.clone();
let req = req_body_str.clone();
let resp = text.clone();
let dur = start.elapsed().as_millis() as u64;
let err = err_msg.clone();
let s = storage.clone();
tokio::spawn(async move {
if let Err(e) = s.append_llm_call(&name, &model, &req, Some(&resp), Some(&err), dur).await {
tracing::warn!("failed to persist LLM call (decode error): {}", e);
}
});
}
err_msg
})?;
let first_choice = openai_resp
.choices
.into_iter()
.next()
let first_choice = openai_resp.choices.into_iter().next()
.ok_or("no choices in response")?;
let content = first_choice
@ -320,8 +300,7 @@ impl LLMProvider for OpenAIProvider {
.map(|tc| ToolCall {
id: tc.id.clone(),
name: tc.function.name.clone(),
arguments: serde_json::from_str(&tc.function.arguments)
.unwrap_or(serde_json::Value::Null),
arguments: serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null),
})
.collect();
@ -339,19 +318,13 @@ impl LLMProvider for OpenAIProvider {
};
if let Some(ref storage) = self.storage
&& let Err(e) = storage
.append_llm_call(
&self.name,
&self.model_id,
&req_body_str,
Some(&text),
None,
start.elapsed().as_millis() as u64,
)
.await
{
tracing::warn!("failed to persist LLM call: {}", e);
}
&& let Err(e) = storage.append_llm_call(
&self.name, &self.model_id, &req_body_str,
Some(&text), None,
start.elapsed().as_millis() as u64,
).await {
tracing::warn!("failed to persist LLM call: {}", e);
}
Ok(response)
}
@ -413,9 +386,6 @@ mod tests {
assert_eq!(tool_calls[0]["id"], "call_1");
assert_eq!(tool_calls[0]["type"], "function");
assert_eq!(tool_calls[0]["function"]["name"], "calculator");
assert_eq!(
tool_calls[0]["function"]["arguments"],
"{\"expression\":\"1+1\"}"
);
assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}");
}
}

View File

@ -1,6 +1,6 @@
use crate::bus::message::ContentBlock;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::bus::message::ContentBlock;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
@ -61,11 +61,7 @@ impl Message {
}
}
pub fn tool(
tool_call_id: impl Into<String>,
tool_name: impl Into<String>,
content: impl Into<String>,
) -> Self {
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self {
Self {
role: "tool".to_string(),
content: vec![ContentBlock::text(content)],

View File

@ -5,11 +5,11 @@ use std::time::Instant;
use tokio::time;
use crate::config::SchedulerConfig;
use crate::session::SessionManager;
use crate::session::session::HandleResult;
use crate::storage::JobRun;
use crate::session::SessionManager;
use crate::storage::ScheduledJob;
use crate::storage::Storage;
use crate::storage::JobRun;
pub use types::Schedule;
@ -89,11 +89,7 @@ impl Scheduler {
let now = now_ms();
let due = match self
.storage
.due_scheduled_jobs(now, self.config.max_concurrent)
.await
{
let due = match self.storage.due_scheduled_jobs(now, self.config.max_concurrent).await {
Ok(jobs) => jobs,
Err(e) => {
tracing::error!("scheduler: failed to query due jobs: {}", e);
@ -111,11 +107,7 @@ impl Scheduler {
let start = Instant::now();
let started_at = now_ms();
if let Err(e) = self
.storage
.touch_scheduled_job_last_run(&job.id, started_at)
.await
{
if let Err(e) = self.storage.touch_scheduled_job_last_run(&job.id, started_at).await {
tracing::error!(job_id = %job.id, "scheduler: failed to touch last_run_at: {}", e);
continue;
}
@ -143,10 +135,7 @@ impl Scheduler {
match result {
Ok(HandleResult::AgentResponse(output)) => {
let output_truncated = if output.len() > 8000 {
format!(
"{}...[truncated]",
&output[..output.ceil_char_boundary(8000)]
)
format!("{}...[truncated]", &output[..output.ceil_char_boundary(8000)])
} else {
output.clone()
};
@ -166,11 +155,7 @@ impl Scheduler {
tracing::error!(job_id = %job.id, "scheduler: failed to record run: {}", e);
}
if let Err(e) = self
.storage
.set_scheduled_job_last_status(&job.id, "ok", None)
.await
{
if let Err(e) = self.storage.set_scheduled_job_last_status(&job.id, "ok", None).await {
tracing::error!(job_id = %job.id, "scheduler: failed to set last_status: {}", e);
}
@ -214,11 +199,9 @@ impl Scheduler {
tracing::error!(job_id = %job.id, "scheduler: failed to record error run: {}", e2);
}
if let Err(e2) = self
.storage
.set_scheduled_job_last_status(&job.id, "error", Some(&error_str))
.await
{
if let Err(e2) = self.storage.set_scheduled_job_last_status(
&job.id, "error", Some(&error_str),
).await {
tracing::error!(job_id = %job.id, "scheduler: failed to set error status: {}", e2);
}
@ -248,23 +231,17 @@ impl Scheduler {
self.storage.remove_scheduled_job(&job.id).await?;
tracing::info!(job_id = %job.id, "scheduler: one-shot job deleted after run");
} else {
self.storage
.set_scheduled_job_enabled(&job.id, false)
.await?;
self.storage.set_scheduled_job_enabled(&job.id, false).await?;
tracing::info!(job_id = %job.id, "scheduler: one-shot job disabled after run");
}
}
Schedule::Every { .. } | Schedule::Cron { .. } => {
if let Some(next) = next_run_for_schedule(&job.schedule, now) {
self.storage
.set_scheduled_job_next_run(&job.id, next)
.await?;
self.storage.set_scheduled_job_next_run(&job.id, next).await?;
tracing::info!(job_id = %job.id, next_run_at = %next, "scheduler: job rescheduled");
} else {
tracing::error!(job_id = %job.id, "scheduler: could not compute next run -- disabling job");
self.storage
.set_scheduled_job_enabled(&job.id, false)
.await?;
self.storage.set_scheduled_job_enabled(&job.id, false).await?;
}
}
}

View File

@ -22,20 +22,32 @@ pub enum SessionCommand {
dialog_id: String,
},
/// Get the current dialog for a chat
GetCurrentDialog { channel: String, chat_id: String },
GetCurrentDialog {
channel: String,
chat_id: String,
},
/// Rename a dialog
RenameDialog {
session_id: UnifiedSessionId,
title: String,
},
/// Archive a dialog
ArchiveDialog { session_id: UnifiedSessionId },
ArchiveDialog {
session_id: UnifiedSessionId,
},
/// Delete a dialog
DeleteDialog { session_id: UnifiedSessionId },
DeleteDialog {
session_id: UnifiedSessionId,
},
/// Clear dialog history
ClearHistory { session_id: UnifiedSessionId },
ClearHistory {
session_id: UnifiedSessionId,
},
/// Get list of available slash commands
GetSlashCommands { channel: String, chat_id: String },
GetSlashCommands {
channel: String,
chat_id: String,
},
/// Execute a slash command
ExecuteSlashCommand {
command: String,
@ -48,11 +60,7 @@ pub enum SessionCommand {
impl SessionCommand {
/// Create a CreateDialog command
pub fn create_dialog(
channel: impl Into<String>,
chat_id: impl Into<String>,
title: Option<String>,
) -> Self {
pub fn create_dialog(channel: impl Into<String>, chat_id: impl Into<String>, title: Option<String>) -> Self {
Self::CreateDialog {
channel: channel.into(),
chat_id: chat_id.into(),
@ -61,11 +69,7 @@ impl SessionCommand {
}
/// Create a ListDialogs command
pub fn list_dialogs(
channel: impl Into<String>,
chat_id: impl Into<String>,
include_archived: bool,
) -> Self {
pub fn list_dialogs(channel: impl Into<String>, chat_id: impl Into<String>, include_archived: bool) -> Self {
Self::ListDialogs {
channel: channel.into(),
chat_id: chat_id.into(),

View File

@ -1,5 +1,5 @@
use super::session::SlashCommand;
use super::session_id::UnifiedSessionId;
use super::session::SlashCommand;
/// Dialog information returned by SessionManager
#[derive(Debug, Clone)]
@ -30,20 +30,30 @@ pub enum SessionEvent {
session_id: Option<UnifiedSessionId>,
},
/// Dialog switched successfully
DialogSwitched { session_id: UnifiedSessionId },
DialogSwitched {
session_id: UnifiedSessionId,
},
/// Dialog renamed
DialogRenamed {
session_id: UnifiedSessionId,
title: String,
},
/// Dialog archived
DialogArchived { session_id: UnifiedSessionId },
DialogArchived {
session_id: UnifiedSessionId,
},
/// Dialog deleted
DialogDeleted { session_id: UnifiedSessionId },
DialogDeleted {
session_id: UnifiedSessionId,
},
/// Dialog history cleared
HistoryCleared { session_id: UnifiedSessionId },
HistoryCleared {
session_id: UnifiedSessionId,
},
/// List of available slash commands
SlashCommandsList { commands: Vec<SlashCommand> },
SlashCommandsList {
commands: Vec<SlashCommand>,
},
/// Slash command executed successfully
SlashCommandExecuted {
new_session_id: Option<UnifiedSessionId>,
@ -60,5 +70,8 @@ pub enum SessionEvent {
message_count: usize,
},
/// Error occurred
Error { code: String, message: String },
Error {
code: String,
message: String,
},
}

View File

@ -1,11 +1,11 @@
pub mod commands;
pub mod error;
pub mod commands;
pub mod events;
pub mod session;
pub mod session_id;
pub use commands::SessionCommand;
pub use error::SessionError;
pub use events::{DialogInfo, SessionEvent};
pub use session::{SLASH_COMMANDS, Session, SessionManager, SlashCommand};
pub use commands::SessionCommand;
pub use events::{SessionEvent, DialogInfo};
pub use session::{Session, SessionManager, SlashCommand, SLASH_COMMANDS};
pub use session_id::UnifiedSessionId;

File diff suppressed because it is too large Load Diff

View File

@ -8,6 +8,7 @@
///
/// For simple cases where only one dialog exists per chat:
/// - `dialog_id` defaults to `"default"`
use serde::{Deserialize, Serialize};
pub const DEFAULT_DIALOG_ID: &str = "default";
@ -21,11 +22,7 @@ pub struct UnifiedSessionId {
impl UnifiedSessionId {
/// Create a new UnifiedSessionId
pub fn new(
channel: impl Into<String>,
chat_id: impl Into<String>,
dialog_id: impl Into<String>,
) -> Self {
pub fn new(channel: impl Into<String>, chat_id: impl Into<String>, dialog_id: impl Into<String>) -> Self {
Self {
channel: channel.into(),
chat_id: chat_id.into(),

View File

@ -1,6 +1,6 @@
use std::path::Path;
use super::embedded::{EMBEDDED_SKILLS, EmbeddedSkill};
use super::embedded::{EmbeddedSkill, EMBEDDED_SKILLS};
pub fn install_builtin_skills(target_dir: &Path) {
for skill in EMBEDDED_SKILLS {
@ -22,7 +22,8 @@ pub fn install_builtin_skills(target_dir: &Path) {
}
fn install_one(skill: &EmbeddedSkill, target_dir: &Path) -> Result<(), String> {
let decompressed = zstd::decode_all(skill.data).map_err(|e| format!("zstd decode: {}", e))?;
let decompressed = zstd::decode_all(skill.data)
.map_err(|e| format!("zstd decode: {}", e))?;
let mut archive = tar::Archive::new(decompressed.as_slice());
archive

View File

@ -120,11 +120,7 @@ impl SkillsLoader {
let count = loaded.len();
let mut replaced = 0usize;
for skill in loaded {
if let Some(existing) = state
.loaded_skills
.iter_mut()
.find(|s| s.name == skill.name)
{
if let Some(existing) = state.loaded_skills.iter_mut().find(|s| s.name == skill.name) {
*existing = skill;
replaced += 1;
} else {
@ -142,42 +138,33 @@ impl SkillsLoader {
// Load from workspace skills dir (highest priority) — replace same-name skills
if let Some(ref ws_dir) = self.workspace_skills_dir
&& ws_dir.exists()
{
let loaded = self.load_skills_from_dir(ws_dir);
let count = loaded.len();
let mut replaced = 0usize;
for skill in loaded {
if let Some(existing) = state
.loaded_skills
.iter_mut()
.find(|s| s.name == skill.name)
{
*existing = skill;
replaced += 1;
} else {
state.loaded_skills.push(skill);
&& ws_dir.exists() {
let loaded = self.load_skills_from_dir(ws_dir);
let count = loaded.len();
let mut replaced = 0usize;
for skill in loaded {
if let Some(existing) = state.loaded_skills.iter_mut().find(|s| s.name == skill.name) {
*existing = skill;
replaced += 1;
} else {
state.loaded_skills.push(skill);
}
}
tracing::debug!(
dir = %ws_dir.display(),
count = count,
replaced = replaced,
"Loaded skills from workspace directory"
);
state.last_workspace_mtime = Self::get_dir_mtime(ws_dir);
}
tracing::debug!(
dir = %ws_dir.display(),
count = count,
replaced = replaced,
"Loaded skills from workspace directory"
);
state.last_workspace_mtime = Self::get_dir_mtime(ws_dir);
}
state.last_load_time = SystemTime::now();
if state.loaded_skills.is_empty() {
tracing::debug!("No skills found in any skills directory");
} else {
tracing::info!(
count = state.loaded_skills.len(),
"Loaded {} skills total",
state.loaded_skills.len()
);
tracing::info!(count = state.loaded_skills.len(), "Loaded {} skills total", state.loaded_skills.len());
}
}
@ -228,20 +215,18 @@ impl SkillsLoader {
let mut max_mtime = None;
if let Ok(metadata) = std::fs::metadata(dir)
&& let Ok(mtime) = metadata.modified()
{
max_mtime = Some(mtime);
}
&& let Ok(mtime) = metadata.modified() {
max_mtime = Some(mtime);
}
if let Ok(entries) = std::fs::read_dir(dir) {
for entry in entries.flatten() {
let path = entry.path();
if let Ok(metadata) = std::fs::metadata(&path)
&& let Ok(mtime) = metadata.modified()
&& max_mtime.is_none_or(|current| mtime > current)
{
max_mtime = Some(mtime);
}
&& max_mtime.is_none_or(|current| mtime > current) {
max_mtime = Some(mtime);
}
}
}
@ -259,12 +244,7 @@ impl SkillsLoader {
pub fn get_always_skills(&self) -> Vec<Skill> {
self.reload_if_changed();
let state = self.state.lock().unwrap();
state
.loaded_skills
.iter()
.filter(|s| s.always)
.cloned()
.collect()
state.loaded_skills.iter().filter(|s| s.always).cloned().collect()
}
/// Get a specific skill by name (checks for changes first)
@ -278,8 +258,7 @@ impl SkillsLoader {
pub fn list_skills(&self) -> Vec<(String, String)> {
self.reload_if_changed();
let state = self.state.lock().unwrap();
state
.loaded_skills
state.loaded_skills
.iter()
.map(|s| (s.name.clone(), s.description.clone()))
.collect()
@ -300,21 +279,15 @@ impl SkillsLoader {
prompt.push_str("### 目录说明\n\n");
prompt.push_str("- `~/.agents/skills/` — 外部共享 skill 目录(第三方、系统级 skill\n");
prompt.push_str("- `~/.picobot/skills/` — 安装 skill 的默认目录\n");
prompt.push_str(
"- `{workspace}/skills/` — 工作目录下的 skillpicobot 自行创建的 skill 存放于此\n\n",
);
prompt.push_str(
"安装或创建 skill 时请按上述目录规范存放创建skill时不要和已有skill同名。\n\n",
);
prompt.push_str("- `{workspace}/skills/` — 工作目录下的 skillpicobot 自行创建的 skill 存放于此\n\n");
prompt.push_str("安装或创建 skill 时请按上述目录规范存放创建skill时不要和已有skill同名。\n\n");
// Always skills summary
let always_skills: Vec<_> = state.loaded_skills.iter().filter(|s| s.always).collect();
if !always_skills.is_empty() {
prompt.push_str("### 常用技能\n\n");
for skill in &always_skills {
let path_str = skill
.path
.as_ref()
let path_str = skill.path.as_ref()
.map(|p| p.to_string_lossy().to_string())
.unwrap_or_else(|| "".to_string());
prompt.push_str(&format!(
@ -327,12 +300,8 @@ impl SkillsLoader {
// Usage instructions
prompt.push_str("### 使用方法\n\n");
prompt.push_str(
"- 使用 `get_skill` 工具 action=\"list\" 列出所有可用 skill 及其名称、简介、路径\n",
);
prompt.push_str(
"- 使用 `get_skill` 工具 action=\"get\" 并提供 `skill_name` 获取指定 skill 完整内容\n",
);
prompt.push_str("- 使用 `get_skill` 工具 action=\"list\" 列出所有可用 skill 及其名称、简介、路径\n");
prompt.push_str("- 使用 `get_skill` 工具 action=\"get\" 并提供 `skill_name` 获取指定 skill 完整内容\n");
prompt.push_str("- 当用户询问已安装的 skill 相关信息时,需重新调用 get_skill 工具查询最新内容,避免 skill 已变更导致信息过时\n");
// Always skills full content
@ -369,23 +338,25 @@ impl SkillsLoader {
}
match std::fs::read_to_string(&skill_file) {
Ok(content) => match self.parse_skill(&path, &content) {
Some(skill) => {
tracing::debug!(
skill = %skill.name,
path = %skill_file.display(),
always = skill.always,
"Loaded skill"
);
skills.push(skill);
Ok(content) => {
match self.parse_skill(&path, &content) {
Some(skill) => {
tracing::debug!(
skill = %skill.name,
path = %skill_file.display(),
always = skill.always,
"Loaded skill"
);
skills.push(skill);
}
None => {
tracing::warn!(
path = %skill_file.display(),
"Failed to parse skill"
);
}
}
None => {
tracing::warn!(
path = %skill_file.display(),
"Failed to parse skill"
);
}
},
}
Err(e) => {
tracing::warn!(
path = %skill_file.display(),
@ -476,6 +447,7 @@ impl Default for SkillsLoader {
}
}
/// Extract first non-empty, non-heading line as description
fn extract_description(content: &str) -> String {
content

View File

@ -1,19 +0,0 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BackgroundTask {
pub id: String,
pub session_id: String,
pub channel: String,
pub chat_id: String,
pub prompt: String,
pub allowed_tools: Option<String>,
pub status: String,
pub result: Option<String>,
pub error: Option<String>,
pub tool_calls_count: i64,
pub iterations: i64,
pub started_at: Option<i64>,
pub finished_at: Option<i64>,
pub created_at: i64,
}

View File

@ -241,11 +241,12 @@ impl super::Storage {
let cutoff = chrono::Utc::now() - chrono::Duration::days(retention_days as i64);
let cutoff_str = cutoff.to_rfc3339();
let result =
sqlx::query("DELETE FROM memories WHERE category = 'timeline' AND created_at < ?")
.bind(&cutoff_str)
.execute(self.pool())
.await?;
let result = sqlx::query(
"DELETE FROM memories WHERE category = 'timeline' AND created_at < ?",
)
.bind(&cutoff_str)
.execute(self.pool())
.await?;
Ok(result.rows_affected())
}
@ -275,7 +276,9 @@ impl super::Storage {
}
}
fn parse_memory_rows(rows: &[sqlx::sqlite::SqliteRow]) -> Result<Vec<MemoryEntry>, StorageError> {
fn parse_memory_rows(
rows: &[sqlx::sqlite::SqliteRow],
) -> Result<Vec<MemoryEntry>, StorageError> {
rows.iter()
.map(|row| {
Ok(MemoryEntry {

View File

@ -1,17 +1,15 @@
pub mod background_task;
pub mod error;
pub mod memory;
pub mod message;
pub mod scheduler;
pub mod session;
pub use background_task::BackgroundTask;
pub use error::StorageError;
pub use scheduler::{JobRun, ScheduledJob};
use sqlx::{Pool, Row, Sqlite, SqlitePool};
use tokio::time::{sleep, Duration};
use std::path::Path;
use tokio::time::{Duration, sleep};
pub struct Storage {
pub(crate) pool: Pool<Sqlite>,
@ -42,7 +40,6 @@ impl Storage {
last_active_at INTEGER NOT NULL,
message_count INTEGER DEFAULT 0,
routing_info TEXT,
archived_at INTEGER,
deleted_at INTEGER,
last_consolidated_at INTEGER,
last_compressed_message_at INTEGER,
@ -93,58 +90,20 @@ impl Storage {
.await?;
// Migration: add source column if upgrading from older schema
sqlx::query(r#"ALTER TABLE messages ADD COLUMN source TEXT"#)
.execute(&self.pool)
.await
.ok();
sqlx::query(
r#"ALTER TABLE messages ADD COLUMN source TEXT"#,
)
.execute(&self.pool)
.await
.ok();
// Migration: add reasoning_content column if upgrading from older schema
sqlx::query(r#"ALTER TABLE messages ADD COLUMN reasoning_content TEXT"#)
.execute(&self.pool)
.await
.ok();
// Background tasks table — for async sub-agent tasks.
// Note: No FOREIGN KEY on session_id because sessions use soft delete (deleted_at IS NULL).
// Session and task association is maintained at the application level.
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS background_tasks (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
channel TEXT NOT NULL,
chat_id TEXT NOT NULL,
prompt TEXT NOT NULL,
allowed_tools TEXT,
status TEXT NOT NULL DEFAULT 'pending',
result TEXT,
error TEXT,
tool_calls_count INTEGER DEFAULT 0,
iterations INTEGER DEFAULT 0,
started_at INTEGER,
finished_at INTEGER,
created_at INTEGER NOT NULL
)
"#,
r#"ALTER TABLE messages ADD COLUMN reasoning_content TEXT"#,
)
.execute(&self.pool)
.await?;
sqlx::query(
r#"
CREATE INDEX IF NOT EXISTS idx_bg_tasks_session ON background_tasks(session_id)
"#,
)
.execute(&self.pool)
.await?;
sqlx::query(
r#"
CREATE INDEX IF NOT EXISTS idx_bg_tasks_status ON background_tasks(status)
"#,
)
.execute(&self.pool)
.await?;
.await
.ok();
sqlx::query(
r#"
@ -213,19 +172,11 @@ impl Storage {
.await?;
// Rebuild FTS5 index for any existing records
sqlx::query("INSERT INTO memory_fts(memory_fts) VALUES ('rebuild')")
.execute(&self.pool)
.await?;
// Migration: add last_consolidated_at column if not exists
sqlx::query(
r#"
ALTER TABLE sessions ADD COLUMN archived_at INTEGER
"#,
"INSERT INTO memory_fts(memory_fts) VALUES ('rebuild')",
)
.execute(&self.pool)
.await
.ok();
.await?;
// Migration: add last_consolidated_at column if not exists
sqlx::query(
@ -265,10 +216,7 @@ impl Storage {
.await?;
if let Err(e) = Self::init_scheduler_schema(&self.pool).await {
tracing::warn!(
"Failed to init scheduler schema (tables may already exist): {}",
e
);
tracing::warn!("Failed to init scheduler schema (tables may already exist): {}", e);
}
Ok(())
@ -382,20 +330,16 @@ impl Storage {
&self.pool
}
pub async fn upsert_session(
&self,
meta: &crate::storage::session::SessionMeta,
) -> Result<(), StorageError> {
pub async fn upsert_session(&self, meta: &crate::storage::session::SessionMeta) -> Result<(), StorageError> {
sqlx::query(
r#"
INSERT INTO sessions (id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
INSERT INTO sessions (id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
title = excluded.title,
last_active_at = excluded.last_active_at,
message_count = excluded.message_count,
routing_info = excluded.routing_info,
archived_at = excluded.archived_at,
deleted_at = excluded.deleted_at,
last_consolidated_at = excluded.last_consolidated_at,
last_compressed_message_at = excluded.last_compressed_message_at
@ -410,7 +354,6 @@ impl Storage {
.bind(meta.last_active_at)
.bind(meta.message_count)
.bind(&meta.routing_info)
.bind(meta.archived_at)
.bind(meta.deleted_at)
.bind(meta.last_consolidated_at)
.bind(meta.last_compressed_message_at)
@ -420,13 +363,10 @@ impl Storage {
Ok(())
}
pub async fn get_session(
&self,
id: &str,
) -> Result<crate::storage::session::SessionMeta, StorageError> {
pub async fn get_session(&self, id: &str) -> Result<crate::storage::session::SessionMeta, StorageError> {
let row = sqlx::query(
r#"
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at
FROM sessions WHERE id = ? AND deleted_at IS NULL
"#,
)
@ -445,7 +385,6 @@ impl Storage {
last_active_at: row.get("last_active_at"),
message_count: row.get("message_count"),
routing_info: row.get("routing_info"),
archived_at: row.get("archived_at"),
deleted_at: row.get("deleted_at"),
last_consolidated_at: row.get("last_consolidated_at"),
last_compressed_message_at: row.get("last_compressed_message_at"),
@ -457,21 +396,18 @@ impl Storage {
channel: &str,
chat_id: &str,
limit: i64,
include_archived: bool,
) -> Result<Vec<crate::storage::session::SessionMeta>, StorageError> {
let rows = sqlx::query(
r#"
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at
FROM sessions
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL
AND (? OR archived_at IS NULL)
ORDER BY last_active_at DESC
LIMIT ?
"#,
)
.bind(channel)
.bind(chat_id)
.bind(include_archived)
.bind(limit)
.fetch_all(self.pool())
.await?;
@ -488,7 +424,6 @@ impl Storage {
last_active_at: row.get("last_active_at"),
message_count: row.get("message_count"),
routing_info: row.get("routing_info"),
archived_at: row.get("archived_at"),
deleted_at: row.get("deleted_at"),
last_consolidated_at: row.get("last_consolidated_at"),
last_compressed_message_at: row.get("last_compressed_message_at"),
@ -519,22 +454,13 @@ impl Storage {
pub async fn soft_delete_session(&self, id: &str) -> Result<(), StorageError> {
let now = chrono::Utc::now().timestamp_millis();
sqlx::query(r#"UPDATE sessions SET deleted_at = ? WHERE id = ?"#)
.bind(now)
.bind(id)
.execute(self.pool())
.await?;
Ok(())
}
pub async fn archive_session(&self, id: &str) -> Result<(), StorageError> {
let now = chrono::Utc::now().timestamp_millis();
sqlx::query(r#"UPDATE sessions SET archived_at = ? WHERE id = ? AND deleted_at IS NULL"#)
.bind(now)
.bind(id)
.execute(self.pool())
.await?;
sqlx::query(
r#"UPDATE sessions SET deleted_at = ? WHERE id = ?"#,
)
.bind(now)
.bind(id)
.execute(self.pool())
.await?;
Ok(())
}
@ -546,9 +472,9 @@ impl Storage {
) -> Result<Option<crate::storage::session::SessionMeta>, StorageError> {
let row = sqlx::query(
r#"
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at
FROM sessions
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL AND archived_at IS NULL
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL
ORDER BY last_active_at DESC
LIMIT 1
"#,
@ -569,7 +495,6 @@ impl Storage {
last_active_at: row.get("last_active_at"),
message_count: row.get("message_count"),
routing_info: row.get("routing_info"),
archived_at: row.get("archived_at"),
deleted_at: row.get("deleted_at"),
last_consolidated_at: row.get("last_consolidated_at"),
last_compressed_message_at: row.get("last_compressed_message_at"),
@ -578,11 +503,7 @@ impl Storage {
}
}
pub async fn append_message(
&self,
session_id: &str,
msg: &crate::storage::message::MessageMeta,
) -> Result<i64, StorageError> {
pub async fn append_message(&self, session_id: &str, msg: &crate::storage::message::MessageMeta) -> Result<i64, StorageError> {
sqlx::query(
r#"
INSERT INTO messages (id, session_id, seq, role, content, reasoning_content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at)
@ -709,15 +630,16 @@ impl Storage {
offset: i64,
limit: i64,
) -> Result<(Vec<crate::storage::session::SessionMeta>, i64), StorageError> {
let count_row =
sqlx::query("SELECT COUNT(*) as total FROM sessions WHERE deleted_at IS NULL")
.fetch_one(self.pool())
.await?;
let count_row = sqlx::query(
"SELECT COUNT(*) as total FROM sessions WHERE deleted_at IS NULL",
)
.fetch_one(self.pool())
.await?;
let total: i64 = count_row.get("total");
let rows = sqlx::query(
r#"
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at
FROM sessions
WHERE deleted_at IS NULL
ORDER BY last_active_at DESC
@ -741,7 +663,6 @@ impl Storage {
last_active_at: row.get("last_active_at"),
message_count: row.get("message_count"),
routing_info: row.get("routing_info"),
archived_at: row.get("archived_at"),
deleted_at: row.get("deleted_at"),
last_consolidated_at: row.get("last_consolidated_at"),
last_compressed_message_at: row.get("last_compressed_message_at"),
@ -807,10 +728,7 @@ impl Storage {
where_extra.push_str(" AND created_at > ?");
}
let count_sql = format!(
"SELECT COUNT(*) as total FROM messages WHERE session_id = ?{}",
where_extra
);
let count_sql = format!("SELECT COUNT(*) as total FROM messages WHERE session_id = ?{}", where_extra);
let select_sql = format!(
r#"
SELECT id, session_id, seq, role, content, reasoning_content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at
@ -898,148 +816,6 @@ impl Storage {
}
unreachable!()
}
// ── Background Task CRUD ──
pub async fn create_background_task(
&self,
task: &crate::storage::background_task::BackgroundTask,
) -> Result<(), StorageError> {
sqlx::query(
r#"
INSERT INTO background_tasks (id, session_id, channel, chat_id, prompt, allowed_tools, status, result, error, tool_calls_count, iterations, started_at, finished_at, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"#,
)
.bind(&task.id)
.bind(&task.session_id)
.bind(&task.channel)
.bind(&task.chat_id)
.bind(&task.prompt)
.bind(&task.allowed_tools)
.bind(&task.status)
.bind(&task.result)
.bind(&task.error)
.bind(task.tool_calls_count)
.bind(task.iterations)
.bind(task.started_at)
.bind(task.finished_at)
.bind(task.created_at)
.execute(self.pool())
.await?;
Ok(())
}
pub async fn update_background_task_status(
&self,
id: &str,
status: &str,
result: Option<&str>,
error: Option<&str>,
started_at: Option<i64>,
finished_at: Option<i64>,
) -> Result<(), StorageError> {
sqlx::query(
r#"
UPDATE background_tasks
SET status = ?, result = COALESCE(?, result), error = COALESCE(?, error),
started_at = COALESCE(?, started_at), finished_at = COALESCE(?, finished_at)
WHERE id = ?
"#,
)
.bind(status)
.bind(result)
.bind(error)
.bind(started_at)
.bind(finished_at)
.bind(id)
.execute(self.pool())
.await?;
Ok(())
}
pub async fn get_background_task(
&self,
id: &str,
) -> Result<crate::storage::background_task::BackgroundTask, StorageError> {
let row = sqlx::query(
r#"
SELECT id, session_id, channel, chat_id, prompt, allowed_tools, status, result, error,
tool_calls_count, iterations, started_at, finished_at, created_at
FROM background_tasks WHERE id = ?
"#,
)
.bind(id)
.fetch_optional(self.pool())
.await?
.ok_or_else(|| StorageError::NotFound(id.to_string()))?;
Ok(crate::storage::background_task::BackgroundTask {
id: row.get("id"),
session_id: row.get("session_id"),
channel: row.get("channel"),
chat_id: row.get("chat_id"),
prompt: row.get("prompt"),
allowed_tools: row.get("allowed_tools"),
status: row.get("status"),
result: row.get("result"),
error: row.get("error"),
tool_calls_count: row.get("tool_calls_count"),
iterations: row.get("iterations"),
started_at: row.get("started_at"),
finished_at: row.get("finished_at"),
created_at: row.get("created_at"),
})
}
pub async fn list_background_tasks(
&self,
session_id: &str,
) -> Result<Vec<crate::storage::background_task::BackgroundTask>, StorageError> {
let rows = sqlx::query(
r#"
SELECT id, session_id, channel, chat_id, prompt, allowed_tools, status, result, error,
tool_calls_count, iterations, started_at, finished_at, created_at
FROM background_tasks
WHERE session_id = ?
ORDER BY created_at DESC
"#,
)
.bind(session_id)
.fetch_all(self.pool())
.await?;
Ok(rows
.into_iter()
.map(|row| crate::storage::background_task::BackgroundTask {
id: row.get("id"),
session_id: row.get("session_id"),
channel: row.get("channel"),
chat_id: row.get("chat_id"),
prompt: row.get("prompt"),
allowed_tools: row.get("allowed_tools"),
status: row.get("status"),
result: row.get("result"),
error: row.get("error"),
tool_calls_count: row.get("tool_calls_count"),
iterations: row.get("iterations"),
started_at: row.get("started_at"),
finished_at: row.get("finished_at"),
created_at: row.get("created_at"),
})
.collect())
}
pub async fn cleanup_old_tasks(&self, ttl_ms: i64) -> Result<usize, StorageError> {
let cutoff = chrono::Utc::now().timestamp_millis() - ttl_ms;
let result = sqlx::query(
"DELETE FROM background_tasks WHERE status IN ('completed', 'failed', 'cancelled') AND finished_at IS NOT NULL AND finished_at < ?",
)
.bind(cutoff)
.execute(self.pool())
.await?;
Ok(result.rows_affected() as usize)
}
}
#[cfg(test)]
@ -1068,7 +844,6 @@ mod tests {
last_active_at: 1000,
message_count: 0,
routing_info: Some(r#"{"type":"cli"}"#.to_string()),
archived_at: None,
deleted_at: None,
last_consolidated_at: None,
last_compressed_message_at: None,
@ -1105,18 +880,14 @@ mod tests {
last_active_at: i as i64 * 1000,
message_count: i,
routing_info: None,
archived_at: None,
deleted_at: None,
last_consolidated_at: None,
last_compressed_message_at: None,
last_consolidated_at: None,
last_compressed_message_at: None,
};
storage.upsert_session(&meta).await.unwrap();
}
let sessions = storage
.list_sessions("cli_chat", "sid123", 10, false)
.await
.unwrap();
let sessions = storage.list_sessions("cli_chat", "sid123", 10).await.unwrap();
assert_eq!(sessions.len(), 5);
// 按 last_active_at DESC 排序
assert_eq!(sessions[0].dialog_id, "dialog4");
@ -1136,7 +907,6 @@ mod tests {
last_active_at: 1000,
message_count: 0,
routing_info: None,
archived_at: None,
deleted_at: None,
last_consolidated_at: None,
last_compressed_message_at: None,
@ -1164,7 +934,6 @@ mod tests {
last_active_at: 1000,
message_count: 0,
routing_info: None,
archived_at: None,
deleted_at: None,
last_consolidated_at: None,
last_compressed_message_at: None,
@ -1186,10 +955,7 @@ mod tests {
created_at: 1000,
};
let seq = storage
.append_message(&session_meta.id, &msg)
.await
.unwrap();
let seq = storage.append_message(&session_meta.id, &msg).await.unwrap();
assert_eq!(seq, 1);
let loaded = storage.load_messages(&session_meta.id, 0).await.unwrap();
@ -1211,7 +977,6 @@ mod tests {
last_active_at: 1000,
message_count: 0,
routing_info: None,
archived_at: None,
deleted_at: None,
last_consolidated_at: None,
last_compressed_message_at: None,

View File

@ -165,11 +165,7 @@ impl crate::storage::Storage {
}
/// Update next_run_at and last_run_at for a job.
pub async fn set_scheduled_job_next_run(
&self,
id: &str,
next_run_at: i64,
) -> anyhow::Result<()> {
pub async fn set_scheduled_job_next_run(&self, id: &str, next_run_at: i64) -> anyhow::Result<()> {
let now = now_ms();
sqlx::query(
"UPDATE scheduled_jobs SET next_run_at = ?, last_run_at = ?, updated_at = ? WHERE id = ?",
@ -335,9 +331,7 @@ mod tests {
async fn setup_storage() -> Storage {
let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
let storage = Storage { pool };
Storage::init_scheduler_schema(storage.pool())
.await
.unwrap();
Storage::init_scheduler_schema(storage.pool()).await.unwrap();
storage
}
@ -456,10 +450,7 @@ mod tests {
updated_at: t,
};
storage.add_scheduled_job(&job).await.unwrap();
storage
.set_scheduled_job_enabled("job-toggle", false)
.await
.unwrap();
storage.set_scheduled_job_enabled("job-toggle", false).await.unwrap();
let got = storage.get_scheduled_job("job-toggle").await.unwrap();
assert!(!got.enabled);
}
@ -470,55 +461,31 @@ mod tests {
let t = now();
let jobs = vec![
ScheduledJob {
id: "due".into(),
name: "due".into(),
schedule: Schedule::At { at: t },
prompt: "1".into(),
channel: "cli_chat".into(),
chat_id: "c".into(),
model: None,
enabled: true,
delete_after_run: false,
next_run_at: t - 1000,
last_run_at: None,
last_status: None,
last_error: None,
created_at: t,
updated_at: t,
id: "due".into(), name: "due".into(),
schedule: Schedule::At { at: t }, prompt: "1".into(),
channel: "cli_chat".into(), chat_id: "c".into(),
model: None, enabled: true, delete_after_run: false,
next_run_at: t - 1000, last_run_at: None,
last_status: None, last_error: None,
created_at: t, updated_at: t,
},
ScheduledJob {
id: "future".into(),
name: "future".into(),
schedule: Schedule::At { at: t + 99999999 },
prompt: "2".into(),
channel: "cli_chat".into(),
chat_id: "c".into(),
model: None,
enabled: true,
delete_after_run: false,
next_run_at: t + 99999999,
last_run_at: None,
last_status: None,
last_error: None,
created_at: t,
updated_at: t,
id: "future".into(), name: "future".into(),
schedule: Schedule::At { at: t + 99999999 }, prompt: "2".into(),
channel: "cli_chat".into(), chat_id: "c".into(),
model: None, enabled: true, delete_after_run: false,
next_run_at: t + 99999999, last_run_at: None,
last_status: None, last_error: None,
created_at: t, updated_at: t,
},
ScheduledJob {
id: "disabled-due".into(),
name: "disabled due".into(),
schedule: Schedule::At { at: t },
prompt: "3".into(),
channel: "cli_chat".into(),
chat_id: "c".into(),
model: None,
enabled: false,
delete_after_run: false,
next_run_at: t - 1000,
last_run_at: None,
last_status: None,
last_error: None,
created_at: t,
updated_at: t,
id: "disabled-due".into(), name: "disabled due".into(),
schedule: Schedule::At { at: t }, prompt: "3".into(),
channel: "cli_chat".into(), chat_id: "c".into(),
model: None, enabled: false, delete_after_run: false,
next_run_at: t - 1000, last_run_at: None,
last_status: None, last_error: None,
created_at: t, updated_at: t,
},
];
for j in &jobs {
@ -534,39 +501,24 @@ mod tests {
let storage = setup_storage().await;
let t = now();
let job = ScheduledJob {
id: "job-run".into(),
name: "run test".into(),
id: "job-run".into(), name: "run test".into(),
schedule: Schedule::Every { every_ms: 1000 },
prompt: "hi".into(),
channel: "cli_chat".into(),
chat_id: "c".into(),
model: None,
enabled: true,
delete_after_run: false,
next_run_at: t,
last_run_at: None,
last_status: None,
last_error: None,
created_at: t,
updated_at: t,
prompt: "hi".into(), channel: "cli_chat".into(), chat_id: "c".into(),
model: None, enabled: true, delete_after_run: false,
next_run_at: t, last_run_at: None,
last_status: None, last_error: None,
created_at: t, updated_at: t,
};
storage.add_scheduled_job(&job).await.unwrap();
let run = super::JobRun {
id: 0,
job_id: "job-run".into(),
started_at: t,
finished_at: t + 500,
status: "ok".into(),
output: Some("hello".into()),
error: None,
duration_ms: 500,
id: 0, job_id: "job-run".into(),
started_at: t, finished_at: t + 500,
status: "ok".into(), output: Some("hello".into()),
error: None, duration_ms: 500,
};
storage.record_scheduled_job_run(&run).await.unwrap();
let runs = storage
.list_scheduled_job_runs("job-run", 10)
.await
.unwrap();
let runs = storage.list_scheduled_job_runs("job-run", 10).await.unwrap();
assert_eq!(runs.len(), 1);
assert_eq!(runs[0].status, "ok");
assert_eq!(runs[0].output.as_deref(), Some("hello"));
@ -577,34 +529,22 @@ mod tests {
let storage = setup_storage().await;
let t = now();
let job = ScheduledJob {
id: "job-update".into(),
name: "old name".into(),
id: "job-update".into(), name: "old name".into(),
schedule: Schedule::Every { every_ms: 1000 },
prompt: "old prompt".into(),
channel: "feishu".into(),
chat_id: "oc_1".into(),
model: None,
enabled: true,
delete_after_run: false,
next_run_at: t,
last_run_at: None,
last_status: None,
last_error: None,
created_at: t,
updated_at: t,
prompt: "old prompt".into(), channel: "feishu".into(),
chat_id: "oc_1".into(), model: None,
enabled: true, delete_after_run: false,
next_run_at: t, last_run_at: None,
last_status: None, last_error: None,
created_at: t, updated_at: t,
};
storage.add_scheduled_job(&job).await.unwrap();
storage
.update_scheduled_job(
"job-update",
Some("new prompt".into()),
Some(Schedule::Every { every_ms: 60000 }),
None,
None,
None,
)
.await
.unwrap();
storage.update_scheduled_job(
"job-update",
Some("new prompt".into()),
Some(Schedule::Every { every_ms: 60000 }),
None, None, None,
).await.unwrap();
let got = storage.get_scheduled_job("job-update").await.unwrap();
assert_eq!(got.prompt, "new prompt");
}

View File

@ -11,7 +11,6 @@ pub struct SessionMeta {
pub last_active_at: i64,
pub message_count: i64,
pub routing_info: Option<String>,
pub archived_at: Option<i64>,
pub deleted_at: Option<i64>,
pub last_consolidated_at: Option<i64>,
pub last_compressed_message_at: Option<i64>,

View File

@ -167,7 +167,10 @@ impl Tool for BashTool {
Err(_) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Command timed out after {} seconds", timeout_secs)),
error: Some(format!(
"Command timed out after {} seconds",
timeout_secs
)),
}),
}
}
@ -246,7 +249,10 @@ mod tests {
#[tokio::test]
async fn test_pwd_command() {
let tool = BashTool::new();
let result = tool.execute(json!({ "command": "pwd" })).await.unwrap();
let result = tool
.execute(json!({ "command": "pwd" }))
.await
.unwrap();
assert!(result.success);
}
@ -254,10 +260,7 @@ mod tests {
#[tokio::test]
async fn test_ls_command() {
let tool = BashTool::new();
let result = tool
.execute(json!({ "command": "ls -la /tmp" }))
.await
.unwrap();
let result = tool.execute(json!({ "command": "ls -la /tmp" })).await.unwrap();
assert!(result.success);
}

View File

@ -5,7 +5,7 @@ use std::time::Duration;
use anyhow::Context;
use async_trait::async_trait;
use base64::Engine;
use fantoccini::actions::{InputSource, MOUSE_BUTTON_LEFT, MouseActions, PointerAction};
use fantoccini::actions::{InputSource, MouseActions, PointerAction, MOUSE_BUTTON_LEFT};
use fantoccini::key::Key;
use fantoccini::{Client, ClientBuilder, Locator};
use serde::{Deserialize, Serialize};
@ -63,9 +63,7 @@ impl BrowserTool {
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum BrowserAction {
Open {
url: String,
},
Open { url: String },
Snapshot {
#[serde(default)]
interactive_only: bool,
@ -74,20 +72,10 @@ pub enum BrowserAction {
#[serde(default)]
depth: Option<i64>,
},
Click {
selector: String,
},
Fill {
selector: String,
value: String,
},
Type {
selector: Option<String>,
text: String,
},
GetText {
selector: String,
},
Click { selector: String },
Fill { selector: String, value: String },
Type { selector: Option<String>, text: String },
GetText { selector: String },
GetTitle,
GetUrl,
Screenshot {
@ -96,9 +84,7 @@ pub enum BrowserAction {
#[serde(default)]
return_base64: bool,
},
Focus {
selector: String,
},
Focus { selector: String },
Wait {
#[serde(default)]
selector: Option<String>,
@ -107,16 +93,9 @@ pub enum BrowserAction {
#[serde(default)]
text: Option<String>,
},
Press {
key: String,
},
Hover {
selector: String,
},
ClickAt {
x: u32,
y: u32,
},
Press { key: String },
Hover { selector: String },
ClickAt { x: u32, y: u32 },
Scroll {
direction: String,
#[serde(default)]
@ -141,8 +120,13 @@ fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result<Browse
.get("interactive_only")
.and_then(Value::as_bool)
.unwrap_or(true),
compact: args.get("compact").and_then(Value::as_bool).unwrap_or(true),
depth: args.get("depth").and_then(|v| v.as_i64()),
compact: args
.get("compact")
.and_then(Value::as_bool)
.unwrap_or(true),
depth: args
.get("depth")
.and_then(|v| v.as_i64()),
}),
"click" => {
let selector = args
@ -214,7 +198,10 @@ fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result<Browse
.and_then(|v| v.as_str())
.map(String::from),
ms: args.get("ms").and_then(|v| v.as_u64()),
text: args.get("text").and_then(|v| v.as_str()).map(String::from),
text: args
.get("text")
.and_then(|v| v.as_str())
.map(String::from),
}),
"press" => {
let key = args
@ -252,13 +239,11 @@ fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result<Browse
let x = args
.get("x")
.and_then(|v| v.as_u64())
.ok_or_else(|| anyhow::anyhow!("Missing 'x' for click_at"))?
as u32;
.ok_or_else(|| anyhow::anyhow!("Missing 'x' for click_at"))? as u32;
let y = args
.get("y")
.and_then(|v| v.as_u64())
.ok_or_else(|| anyhow::anyhow!("Missing 'y' for click_at"))?
as u32;
.ok_or_else(|| anyhow::anyhow!("Missing 'y' for click_at"))? as u32;
Ok(BrowserAction::ClickAt { x, y })
}
other => anyhow::bail!("Unsupported browser action: {}", other),
@ -503,11 +488,7 @@ impl BrowserState {
}
Err(e) => return Err(e.into()),
}
tracing::debug!(
action = "fill",
output_len = value.len(),
"Browser action completed"
);
tracing::debug!(action = "fill", output_len = value.len(), "Browser action completed");
Ok(ToolResult {
success: true,
output: format!("Filled {} with {}", selector, value),
@ -592,10 +573,7 @@ impl BrowserState {
error: None,
})
}
BrowserAction::Screenshot {
path,
return_base64,
} => {
BrowserAction::Screenshot { path, return_base64 } => {
let client = self.active_client()?;
let png = client.screenshot().await?;
let save_path = path.unwrap_or_else(|| {
@ -610,25 +588,14 @@ impl BrowserState {
tokio::fs::write(&save_path, &png).await?;
if return_base64 {
let b64 = base64::engine::general_purpose::STANDARD.encode(&png);
tracing::debug!(
action = "screenshot",
output_len = b64.len(),
"Browser action completed"
);
tracing::debug!(action = "screenshot", output_len = b64.len(), "Browser action completed");
return Ok(ToolResult {
success: true,
output: format!(
"Screenshot saved to {}. Base64: data:image/png;base64,{}",
save_path, b64
),
output: format!("Screenshot saved to {}. Base64: data:image/png;base64,{}", save_path, b64),
error: None,
});
}
tracing::debug!(
action = "screenshot",
output_len = save_path.len(),
"Browser action completed"
);
tracing::debug!(action = "screenshot", output_len = save_path.len(), "Browser action completed");
Ok(ToolResult {
success: true,
output: format!("Screenshot saved to {}", save_path),
@ -644,18 +611,18 @@ impl BrowserState {
vec![serde_json::to_value(el)?],
)
.await?;
tracing::debug!(
action = "focus",
output_len = selector.len(),
"Browser action completed"
);
tracing::debug!(action = "focus", output_len = selector.len(), "Browser action completed");
Ok(ToolResult {
success: true,
output: format!("Focused {}", selector),
error: None,
})
}
BrowserAction::Wait { selector, ms, text } => {
BrowserAction::Wait {
selector,
ms,
text,
} => {
if let Some(sel) = selector {
let client = self.active_client()?;
wait_for_selector(client, &sel).await?;
@ -752,21 +719,9 @@ impl BrowserState {
let id = info.get("id").and_then(|v| v.as_str()).unwrap_or("");
let el_type = info.get("type").and_then(|v| v.as_str()).unwrap_or("");
let text = info.get("text").and_then(|v| v.as_str()).unwrap_or("");
let id_str = if id.is_empty() {
String::new()
} else {
format!("#{id}")
};
let type_str = if el_type.is_empty() {
String::new()
} else {
format!("[type={el_type}]")
};
let text_str = if text.is_empty() {
String::new()
} else {
format!(" ({text})")
};
let id_str = if id.is_empty() { String::new() } else { format!("#{id}") };
let type_str = if el_type.is_empty() { String::new() } else { format!("[type={el_type}]") };
let text_str = if text.is_empty() { String::new() } else { format!(" ({text})") };
format!("Clicked at ({x},{y}) on <{tag}{id_str}{type_str}>{text_str}")
}
None => format!("Clicked at ({}, {})", x, y),
@ -1135,7 +1090,10 @@ fn css_attr_escape(input: &str) -> String {
}
fn xpath_contains_text(text: &str) -> String {
format!("//*[contains(normalize-space(.), {})]", xpath_literal(text))
format!(
"//*[contains(normalize-space(.), {})]",
xpath_literal(text)
)
}
fn xpath_literal(input: &str) -> String {
@ -1182,10 +1140,7 @@ fn webdriver_key(key: &str) -> String {
"pagedown" => Key::PageDown.to_string(),
"space" => " ".to_string(),
other => {
tracing::warn!(
"Unrecognized key '{}', this will have no effect (press only supports single named keys)",
other
);
tracing::warn!("Unrecognized key '{}', this will have no effect (press only supports single named keys)", other);
other.to_string()
}
}

View File

@ -659,7 +659,10 @@ mod tests {
#[tokio::test]
async fn test_evaluate_missing_expression() {
let tool = CalculatorTool::new();
let result = tool.execute(json!({"function": "evaluate"})).await.unwrap();
let result = tool
.execute(json!({"function": "evaluate"}))
.await
.unwrap();
assert!(!result.success);
}

View File

@ -126,10 +126,7 @@ impl ChatManagerTool {
let start_num = offset + 1;
let end_num = offset + sessions.len() as i64;
let mut output = format!(
"全部会话 (共 {} 个,第 {}-{} 个):\n",
total, start_num, end_num
);
let mut output = format!("全部会话 (共 {} 个,第 {}-{} 个):\n", total, start_num, end_num);
for s in &sessions {
let ago = format_duration_ago(now_ms - s.last_active_at);
@ -303,10 +300,9 @@ mod tests {
last_active_at: now - i * 3600_000,
message_count: i * 5,
routing_info: None,
archived_at: None,
deleted_at: None,
last_consolidated_at: None,
last_compressed_message_at: None,
last_consolidated_at: None,
last_compressed_message_at: None,
};
storage.upsert_session(&meta).await.unwrap();
}
@ -339,7 +335,6 @@ mod tests {
last_active_at: now,
message_count: 3,
routing_info: None,
archived_at: None,
deleted_at: None,
last_consolidated_at: None,
last_compressed_message_at: None,
@ -351,11 +346,7 @@ mod tests {
id: format!("msg{}", i),
session_id: session_id.to_string(),
seq: i as i64 + 1,
role: if i == 0 {
"user".to_string()
} else {
"assistant".to_string()
},
role: if i == 0 { "user".to_string() } else { "assistant".to_string() },
content: format!("消息内容 {}", i),
reasoning_content: None,
media_refs: None,
@ -401,7 +392,6 @@ mod tests {
last_active_at: now,
message_count: 5,
routing_info: None,
archived_at: None,
deleted_at: None,
last_consolidated_at: None,
last_compressed_message_at: None,
@ -413,11 +403,7 @@ mod tests {
id: format!("msg{}", i),
session_id: session_id.to_string(),
seq: i as i64 + 1,
role: if i % 2 == 0 {
"user".to_string()
} else {
"assistant".to_string()
},
role: if i % 2 == 0 { "user".to_string() } else { "assistant".to_string() },
content: format!("消息内容 {}", i),
reasoning_content: None,
media_refs: None,
@ -461,7 +447,6 @@ mod tests {
last_active_at: now,
message_count: 5,
routing_info: None,
archived_at: None,
deleted_at: None,
last_consolidated_at: None,
last_compressed_message_at: None,
@ -507,7 +492,10 @@ mod tests {
let (storage, _dir) = create_test_storage().await;
let tool = ChatManagerTool::new(storage, vec![]);
let result = tool.execute(json!({ "action": "unknown" })).await.unwrap();
let result = tool
.execute(json!({ "action": "unknown" }))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("Unknown action"));
}

View File

@ -31,7 +31,10 @@ impl ContentSearchTool {
for (i, line) in lines.iter().enumerate() {
if output.len() + line.len() + 1 > MAX_OUTPUT_CHARS {
let omitted = lines.len() - i;
output.push_str(&format!("\n... ({} matches omitted) ...", omitted));
output.push_str(&format!(
"\n... ({} matches omitted) ...",
omitted
));
break;
}
if !output.is_empty() {
@ -110,40 +113,18 @@ impl Tool for ContentSearchTool {
let dir = self.resolve_dir(args.get("dir").and_then(|v| v.as_str()));
let file_pattern = args.get("file_pattern").and_then(|v| v.as_str());
let case_sensitive = args
.get("case_sensitive")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let context_lines = args
.get("context_lines")
.and_then(|v| v.as_u64())
.unwrap_or(0) as usize;
let max_results = args
.get("max_results")
.and_then(|v| v.as_u64())
.unwrap_or(MAX_RESULTS as u64) as usize;
let case_sensitive = args.get("case_sensitive").and_then(|v| v.as_bool()).unwrap_or(false);
let context_lines = args.get("context_lines").and_then(|v| v.as_u64()).unwrap_or(0) as usize;
let max_results = args.get("max_results").and_then(|v| v.as_u64()).unwrap_or(MAX_RESULTS as u64) as usize;
let result = self
.run_search(
pattern,
&dir,
file_pattern,
case_sensitive,
context_lines,
max_results,
)
.await;
let result = self.run_search(pattern, &dir, file_pattern, case_sensitive, context_lines, max_results).await;
match result {
Ok(lines) => {
let count = lines.len();
let mut output = self.truncate_output(&lines);
output.push_str(&format!("\n\n---\n{} 条匹配", count));
Ok(ToolResult {
success: true,
output,
error: None,
})
Ok(ToolResult { success: true, output, error: None })
}
Err(e) => Ok(ToolResult {
success: false,
@ -165,52 +146,22 @@ impl ContentSearchTool {
max_results: usize,
) -> anyhow::Result<Vec<String>> {
if which::which("rg").is_ok() {
match self
.search_with_rg(
pattern,
dir,
file_pattern,
case_sensitive,
context_lines,
max_results,
)
.await
{
match self.search_with_rg(pattern, dir, file_pattern, case_sensitive, context_lines, max_results).await {
Ok(lines) => return Ok(lines),
Err(e) => tracing::warn!("rg failed: {}, falling back", e),
}
}
if which::which("grep").is_ok() {
match self
.search_with_grep(
pattern,
dir,
file_pattern,
case_sensitive,
context_lines,
max_results,
)
.await
{
match self.search_with_grep(pattern, dir, file_pattern, case_sensitive, context_lines, max_results).await {
Ok(lines) if !lines.is_empty() => return Ok(lines),
Ok(_) => {}
Ok(_) => {},
Err(e) => tracing::warn!("grep failed: {}, falling back", e),
}
}
tracing::warn!(
"No rg/grep available, using built-in content search (much slower). Install ripgrep for better performance."
);
self.search_with_rust(
pattern,
dir,
file_pattern,
case_sensitive,
context_lines,
max_results,
)
.await
tracing::warn!("No rg/grep available, using built-in content search (much slower). Install ripgrep for better performance.");
self.search_with_rust(pattern, dir, file_pattern, case_sensitive, context_lines, max_results).await
}
async fn search_with_rg(
@ -225,10 +176,8 @@ impl ContentSearchTool {
let mut cmd = Command::new("rg");
cmd.arg("-n")
.arg("--no-heading")
.arg("--color")
.arg("never")
.arg("--max-count")
.arg(max_results.to_string())
.arg("--color").arg("never")
.arg("--max-count").arg(max_results.to_string())
.arg(pattern)
.arg(dir)
.stdout(Stdio::piped())
@ -244,9 +193,12 @@ impl ContentSearchTool {
cmd.arg("--glob").arg(fp);
}
let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
.await
.map_err(|_| anyhow::anyhow!("rg timed out after {}s", TIMEOUT_SECS))??;
let output = timeout(
std::time::Duration::from_secs(TIMEOUT_SECS),
cmd.output(),
)
.await
.map_err(|_| anyhow::anyhow!("rg timed out after {}s", TIMEOUT_SECS))??;
if !output.status.success() && output.status.code() != Some(1) {
let stderr = String::from_utf8_lossy(&output.stderr);
@ -254,8 +206,7 @@ impl ContentSearchTool {
}
let text = String::from_utf8_lossy(&output.stdout);
let lines: Vec<String> = text
.lines()
let lines: Vec<String> = text.lines()
.take(max_results)
.map(|l| l.to_string())
.collect();
@ -291,13 +242,15 @@ impl ContentSearchTool {
cmd.arg("--include").arg(fp);
}
let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
.await
.map_err(|_| anyhow::anyhow!("grep timed out after {}s", TIMEOUT_SECS))??;
let output = timeout(
std::time::Duration::from_secs(TIMEOUT_SECS),
cmd.output(),
)
.await
.map_err(|_| anyhow::anyhow!("grep timed out after {}s", TIMEOUT_SECS))??;
let text = String::from_utf8_lossy(&output.stdout);
let lines: Vec<String> = text
.lines()
let lines: Vec<String> = text.lines()
.take(max_results)
.map(|l| l.to_string())
.collect();
@ -327,9 +280,7 @@ impl ContentSearchTool {
if case_sensitive {
regex::Regex::new(&re_str)
} else {
regex::RegexBuilder::new(&re_str)
.case_insensitive(true)
.build()
regex::RegexBuilder::new(&re_str).case_insensitive(true).build()
}
});
@ -340,14 +291,7 @@ impl ContentSearchTool {
};
let mut results = Vec::new();
grep_dir(
Path::new(dir),
Path::new(dir),
&re,
file_re.as_ref(),
&mut results,
max_results,
)?;
grep_dir(Path::new(dir), Path::new(dir), &re, file_re.as_ref(), &mut results, max_results)?;
Ok(results)
}
@ -406,19 +350,16 @@ fn grep_dir(
if path.is_dir() {
if let Some(name) = rel.file_name().and_then(|n| n.to_str())
&& name.starts_with('.')
&& name.len() > 1
{
continue;
}
&& name.starts_with('.') && name.len() > 1 {
continue;
}
grep_dir(base, &path, re, file_re, results, max)?;
} else if path.is_file() {
if let Some(file_re) = file_re
&& let Some(name) = rel.file_name().and_then(|n| n.to_str())
&& !file_re.is_match(name)
{
continue;
}
&& !file_re.is_match(name) {
continue;
}
if let Ok(content) = std::fs::read_to_string(&path) {
for (line_num, line) in content.lines().enumerate() {
@ -450,16 +391,8 @@ mod tests {
#[tokio::test]
async fn test_content_search_rust_fallback() {
let dir = TempDir::new().unwrap();
fs::write(
dir.path().join("main.rs"),
"fn main() {\n let x = 42;\n println!(\"hello\");\n}",
)
.unwrap();
fs::write(
dir.path().join("lib.rs"),
"pub fn foo() -> u32 {\n let y = 42;\n y\n}",
)
.unwrap();
fs::write(dir.path().join("main.rs"), "fn main() {\n let x = 42;\n println!(\"hello\");\n}").unwrap();
fs::write(dir.path().join("lib.rs"), "pub fn foo() -> u32 {\n let y = 42;\n y\n}").unwrap();
fs::write(dir.path().join("README.md"), "# Project\nHello world").unwrap();
let tool = ContentSearchTool::new();

View File

@ -1,10 +1,10 @@
use std::sync::Arc;
use async_trait::async_trait;
use serde_json::{Value, json};
use serde_json::{json, Value};
use uuid::Uuid;
use crate::scheduler::{Schedule, next_run_for_schedule};
use crate::scheduler::{next_run_for_schedule, Schedule};
use crate::storage::{ScheduledJob, Storage};
use crate::tools::traits::{Tool, ToolResult};
@ -229,7 +229,10 @@ impl Tool for CronListTool {
}
async fn execute(&self, args: Value) -> anyhow::Result<ToolResult> {
let filter = args.get("status").and_then(|v| v.as_str()).unwrap_or("all");
let filter = args
.get("status")
.and_then(|v| v.as_str())
.unwrap_or("all");
let jobs = self.storage.list_scheduled_jobs().await?;
let filtered: Vec<&ScheduledJob> = match filter {
@ -394,9 +397,7 @@ impl Tool for CronEnableTool {
.map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?;
let next = next_run_for_schedule(&job.schedule, now_ms());
self.storage
.set_scheduled_job_enabled(&job_id, true)
.await?;
self.storage.set_scheduled_job_enabled(&job_id, true).await?;
if let Some(n) = next {
self.storage.set_scheduled_job_next_run(&job_id, n).await?;
}
@ -463,9 +464,7 @@ impl Tool for CronDisableTool {
.get_scheduled_job(&job_id)
.await
.map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?;
self.storage
.set_scheduled_job_enabled(&job_id, false)
.await?;
self.storage.set_scheduled_job_enabled(&job_id, false).await?;
Ok(ToolResult {
success: true,
@ -581,9 +580,7 @@ impl Tool for CronUpdateTool {
if args.get("schedule").is_some() {
let job = self.storage.get_scheduled_job(&job_id).await?;
if let Some(next) = next_run_for_schedule(&job.schedule, now_ms()) {
self.storage
.set_scheduled_job_next_run(&job_id, next)
.await?;
self.storage.set_scheduled_job_next_run(&job_id, next).await?;
}
}
@ -768,7 +765,9 @@ mod tests {
let job = ScheduledJob {
id: "job-update-tool".into(),
name: "old".into(),
schedule: Schedule::Every { every_ms: 3600000 },
schedule: Schedule::Every {
every_ms: 3600000,
},
prompt: "old prompt".into(),
channel: "feishu".into(),
chat_id: "oc_1".into(),

View File

@ -1,390 +0,0 @@
use std::sync::Arc;
use async_trait::async_trait;
use serde_json::json;
use crate::agent::{ExecutionMode, SubAgentConfig, SubAgentManager, TaskStatus};
use crate::tools::traits::{Tool, ToolResult};
pub struct DelegateTool {
sub_agent_manager: Arc<SubAgentManager>,
}
impl DelegateTool {
pub fn new(sub_agent_manager: Arc<SubAgentManager>) -> Self {
Self { sub_agent_manager }
}
}
#[async_trait]
impl Tool for DelegateTool {
fn name(&self) -> &str {
"delegate"
}
fn description(&self) -> &str {
"子任务委托工具。创建子 Agent 处理独立任务,支持三种模式:\
inline ()background ()\
parallel ( Agent )\
"
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": ["run", "check_task", "cancel_task", "list_tasks"],
"description": "操作类型: run 创建子Agent执行任务, check_task 查询后台任务, cancel_task 取消后台任务, list_tasks 列出后台任务"
},
"prompt": {
"type": "string",
"description": "子任务描述(可内含额外约束,如:跳过 .tmp 文件。action=run 时必填"
},
"mode": {
"type": "string",
"enum": ["inline", "background", "parallel"],
"description": "执行模式: inline=阻塞返回结果, background=异步执行+通知, parallel=多子Agent并发。默认 inline"
},
"allowed_tools": {
"type": "array",
"items": { "type": "string" },
"description": "允许子Agent使用的工具列表。不填使用默认只读集: file_read,file_search,content_search,web_fetch,http_request,calculator"
},
"max_iterations": {
"type": "integer",
"description": "最大迭代次数,默认 99"
},
"timeout_secs": {
"type": "integer",
"description": "超时秒数,默认 36001小时"
},
"tasks": {
"type": "array",
"description": "并行模式下的多个子任务(仅 mode=parallel 时使用)",
"items": {
"type": "object",
"properties": {
"prompt": { "type": "string", "description": "子任务描述" },
"allowed_tools": {
"type": "array",
"items": { "type": "string" },
"description": "该子任务的工具列表"
}
},
"required": ["prompt"]
}
},
"task_id": {
"type": "string",
"description": "后台任务IDaction=check_task/cancel_task 时必填)"
}
},
"required": ["action"]
})
}
fn read_only(&self) -> bool {
false
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let action = args["action"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("missing required parameter: action"))?;
match action {
"run" => self.handle_run(&args).await,
"check_task" => self.handle_check_task(&args).await,
"cancel_task" => self.handle_cancel_task(&args).await,
"list_tasks" => self.handle_list_tasks(&args).await,
_ => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Unknown action: {}. Supported: run, check_task, cancel_task, list_tasks",
action
)),
}),
}
}
}
impl DelegateTool {
fn parse_config_from_args(&self, args: &serde_json::Value) -> anyhow::Result<SubAgentConfig> {
let prompt = args["prompt"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("missing required parameter: prompt"))?
.to_string();
let allowed_tools: Option<Vec<String>> = args["allowed_tools"].as_array().map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
});
let max_iterations = args["max_iterations"].as_u64().map(|v| v as usize);
let timeout_secs = args["timeout_secs"].as_u64();
Ok(SubAgentConfig {
prompt,
mode: ExecutionMode::Inline,
allowed_tools,
max_iterations,
timeout_secs,
})
}
async fn handle_run(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
let mode_str = args["mode"].as_str().unwrap_or("inline");
let mode = match mode_str {
"inline" => ExecutionMode::Inline,
"background" => ExecutionMode::Background,
"parallel" => ExecutionMode::Parallel,
_ => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"unknown mode: {}. Supported: inline, background, parallel",
mode_str
)),
});
}
};
match mode {
ExecutionMode::Inline => {
let config = self.parse_config_from_args(args)?;
let result = self
.sub_agent_manager
.run_inline(config)
.await
.map_err(|e| anyhow::anyhow!("{}", e))?;
match result.status {
TaskStatus::Completed => Ok(ToolResult {
success: true,
output: result.content,
error: None,
}),
TaskStatus::Failed(err) => Ok(ToolResult {
success: false,
output: result.content,
error: Some(err),
}),
TaskStatus::TimedOut => Ok(ToolResult {
success: false,
output: result.content,
error: Some("sub-agent timed out".into()),
}),
TaskStatus::Cancelled => Ok(ToolResult {
success: false,
output: result.content,
error: Some("sub-agent cancelled".into()),
}),
}
}
ExecutionMode::Background => {
let config = self.parse_config_from_args(args)?;
let ctx = crate::agent::sub_agent::get_delegate_context().map_err(|_| {
anyhow::anyhow!("delegate context not available: not in an agent worker")
})?;
let task_id = self
.sub_agent_manager
.run_background(config, ctx)
.await
.map_err(|e| anyhow::anyhow!("{}", e))?;
Ok(ToolResult {
success: true,
output: format!("后台任务已启动。\ntask_id: {}", task_id),
error: None,
})
}
ExecutionMode::Parallel => {
let tasks = args["tasks"]
.as_array()
.ok_or_else(|| anyhow::anyhow!("parallel mode requires 'tasks' array"))?;
let mut configs = Vec::new();
for task in tasks {
let prompt = task["prompt"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("each parallel task requires 'prompt'"))?
.to_string();
let allowed_tools: Option<Vec<String>> =
task["allowed_tools"].as_array().map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
});
configs.push(SubAgentConfig {
prompt,
mode: ExecutionMode::Inline,
allowed_tools,
max_iterations: args["max_iterations"].as_u64().map(|v| v as usize),
timeout_secs: args["timeout_secs"].as_u64(),
});
}
let has_args_allowed = args["allowed_tools"].as_array().is_some();
for c in &mut configs {
if c.allowed_tools.is_none() && has_args_allowed {
c.allowed_tools = args["allowed_tools"].as_array().map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
});
}
}
let results = self
.sub_agent_manager
.run_parallel(configs)
.await
.map_err(|e| anyhow::anyhow!("{}", e))?;
let mut output = String::new();
for (i, r) in results.iter().enumerate() {
let status_icon = match r.status {
TaskStatus::Completed => "",
TaskStatus::Failed(_) => "",
TaskStatus::TimedOut => "⏱️ 超时",
TaskStatus::Cancelled => "🚫 已取消",
};
output.push_str(&format!("[task_{}] {}\n", i + 1, status_icon));
if !r.content.is_empty() {
output.push_str(&r.content);
output.push_str("\n\n");
}
if let TaskStatus::Failed(ref err) = r.status {
output.push_str(&format!("错误: {}\n\n", err));
}
}
let all_success = results
.iter()
.all(|r| matches!(r.status, TaskStatus::Completed));
Ok(ToolResult {
success: all_success,
output: output.trim().to_string(),
error: None,
})
}
}
}
async fn handle_check_task(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
let task_id = args["task_id"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("missing required parameter: task_id"))?;
match self.sub_agent_manager.check_task(task_id).await {
Some(task) => {
let status_icon = match task.status.as_str() {
"completed" => "✅ 已完成",
"failed" => "❌ 失败",
"cancelled" => "🚫 已取消",
"running" => "🔄 运行中",
"pending" => "⏳ 等待中",
_ => task.status.as_str(),
};
let mut output = format!(
"任务 ID: {}\n状态: {}\n任务: {}",
task.id, status_icon, task.prompt
);
if let Some(ref result) = task.result {
output.push_str(&format!("\n\n结果:\n{}", result));
}
if let Some(ref error) = task.error {
output.push_str(&format!("\n错误: {}", error));
}
if let Some(started) = task.started_at {
if let Some(finished) = task.finished_at {
let duration = (finished - started) as f64 / 1000.0;
output.push_str(&format!("\n耗时: {:.1}s", duration));
}
}
Ok(ToolResult {
success: true,
output,
error: None,
})
}
None => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("task not found: {}", task_id)),
}),
}
}
async fn handle_cancel_task(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
let task_id = args["task_id"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("missing required parameter: task_id"))?;
match self.sub_agent_manager.cancel_task(task_id).await {
Ok(true) => Ok(ToolResult {
success: true,
output: format!("后台任务 {} 已取消", task_id),
error: None,
}),
Ok(false) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("无法取消任务 {}(可能已完成或不存在)", task_id)),
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("取消失败: {}", e)),
}),
}
}
async fn handle_list_tasks(&self, _args: &serde_json::Value) -> anyhow::Result<ToolResult> {
let ctx = crate::agent::sub_agent::get_delegate_context()
.map_err(|_| anyhow::anyhow!("delegate context not available"))?;
let tasks = self.sub_agent_manager.list_tasks(&ctx.session_id).await;
if tasks.is_empty() {
return Ok(ToolResult {
success: true,
output: "没有后台任务".to_string(),
error: None,
});
}
let mut output = String::from("后台任务列表:\n\n");
for task in &tasks {
let status_icon = match task.status.as_str() {
"completed" => "",
"failed" => "",
"cancelled" => "🚫",
"running" => "🔄",
"pending" => "",
_ => "",
};
output.push_str(&format!(
"{} {} - {} - {} (created: {})\n",
status_icon,
&task.id[..std::cmp::min(8, task.id.len())],
task.prompt.chars().take(60).collect::<String>(),
task.status,
task.created_at,
));
}
output.push_str(&format!("\n{} 个任务", tasks.len()));
Ok(ToolResult {
success: true,
output,
error: None,
})
}
}

View File

@ -243,8 +243,8 @@ impl Tool for FileEditTool {
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
use std::io::Write;
#[tokio::test]
async fn test_edit_simple() {

View File

@ -181,7 +181,10 @@ impl Tool for FileReadTool {
}
result = lines[..end_idx].join("\n");
let truncated = original_len - result.len();
result.push_str(&format!("\n\n... ({} chars truncated) ...", truncated));
result.push_str(&format!(
"\n\n... ({} chars truncated) ...",
truncated
));
}
if end < total {
@ -193,7 +196,10 @@ impl Tool for FileReadTool {
end + 1
));
} else {
result.push_str(&format!("\n\n(End of file — {} lines total)", total));
result.push_str(&format!(
"\n\n(End of file — {} lines total)",
total
));
}
if let Some(label) = encoding_label {
@ -208,7 +214,7 @@ impl Tool for FileReadTool {
}
None => {
// Truly binary file — base64 encode
use base64::{Engine, engine::general_purpose::STANDARD};
use base64::{engine::general_purpose::STANDARD, Engine};
let encoded = STANDARD.encode(&bytes);
let mime = mime_guess::from_path(&resolved)
.first_or_octet_stream()
@ -272,8 +278,8 @@ fn decode_text(bytes: &[u8]) -> (Option<String>, Option<&'static str>) {
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
use std::io::Write;
#[tokio::test]
async fn test_read_simple_file() {
@ -332,7 +338,10 @@ mod tests {
#[tokio::test]
async fn test_is_directory() {
let tool = FileReadTool::new();
let result = tool.execute(json!({ "path": "." })).await.unwrap();
let result = tool
.execute(json!({ "path": "." }))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("Not a file"));

View File

@ -101,29 +101,17 @@ impl Tool for FileSearchTool {
};
let dir = self.resolve_dir(args.get("dir").and_then(|v| v.as_str()));
let case_sensitive = args
.get("case_sensitive")
.and_then(|v| v.as_bool())
.unwrap_or(true);
let max_results = args
.get("max_results")
.and_then(|v| v.as_u64())
.unwrap_or(MAX_RESULTS as u64) as usize;
let case_sensitive = args.get("case_sensitive").and_then(|v| v.as_bool()).unwrap_or(true);
let max_results = args.get("max_results").and_then(|v| v.as_u64()).unwrap_or(MAX_RESULTS as u64) as usize;
let result = self
.run_search(pattern, &dir, case_sensitive, max_results)
.await;
let result = self.run_search(pattern, &dir, case_sensitive, max_results).await;
match result {
Ok(lines) => {
let count = lines.len();
let mut output = self.truncate_output(&lines);
output.push_str(&format!("\n\n---\n{} 个文件", count));
Ok(ToolResult {
success: true,
output,
error: None,
})
Ok(ToolResult { success: true, output, error: None })
}
Err(e) => Ok(ToolResult {
success: false,
@ -151,12 +139,9 @@ impl FileSearchTool {
};
if !fd_cmd.is_empty() {
match self
.search_with_fd(pattern, dir, case_sensitive, max_results, fd_cmd)
.await
{
match self.search_with_fd(pattern, dir, case_sensitive, max_results, fd_cmd).await {
Ok(lines) if !lines.is_empty() => return Ok(lines),
Ok(_) => {}
Ok(_) => {},
Err(e) => tracing::warn!("{} failed: {}, falling back", fd_cmd, e),
}
}
@ -164,14 +149,13 @@ impl FileSearchTool {
if which::which("find").is_ok() {
match self.search_with_find(pattern, dir, max_results).await {
Ok(lines) if !lines.is_empty() => return Ok(lines),
Ok(_) => {}
Ok(_) => {},
Err(e) => tracing::warn!("find failed: {}, falling back", e),
}
}
tracing::warn!("No fd/find available, using built-in file search (slower)");
self.search_with_rust(pattern, dir, case_sensitive, max_results)
.await
self.search_with_rust(pattern, dir, case_sensitive, max_results).await
}
async fn search_with_fd(
@ -183,15 +167,11 @@ impl FileSearchTool {
fd_cmd: &str,
) -> anyhow::Result<Vec<String>> {
let mut cmd = Command::new(fd_cmd);
cmd.arg("--search-path")
.arg(dir)
.arg("--glob")
.arg(pattern)
.arg("--color")
.arg("never")
cmd.arg("--search-path").arg(dir)
.arg("--glob").arg(pattern)
.arg("--color").arg("never")
.arg("--strip-cwd-prefix")
.arg("--max-results")
.arg(max_results.to_string())
.arg("--max-results").arg(max_results.to_string())
.stdout(Stdio::piped())
.stderr(Stdio::piped());
@ -199,9 +179,12 @@ impl FileSearchTool {
cmd.arg("--ignore-case");
}
let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
.await
.map_err(|_| anyhow::anyhow!("fd timed out after {}s", TIMEOUT_SECS))??;
let output = timeout(
std::time::Duration::from_secs(TIMEOUT_SECS),
cmd.output(),
)
.await
.map_err(|_| anyhow::anyhow!("fd timed out after {}s", TIMEOUT_SECS))??;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
@ -209,8 +192,7 @@ impl FileSearchTool {
}
let text = String::from_utf8_lossy(&output.stdout);
let lines: Vec<String> = text
.lines()
let lines: Vec<String> = text.lines()
.filter(|l| !l.is_empty())
.map(|l| l.to_string())
.collect();
@ -233,13 +215,15 @@ impl FileSearchTool {
.stdout(Stdio::piped())
.stderr(Stdio::null());
let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
.await
.map_err(|_| anyhow::anyhow!("find timed out after {}s", TIMEOUT_SECS))??;
let output = timeout(
std::time::Duration::from_secs(TIMEOUT_SECS),
cmd.output(),
)
.await
.map_err(|_| anyhow::anyhow!("find timed out after {}s", TIMEOUT_SECS))??;
let text = String::from_utf8_lossy(&output.stdout);
let mut lines: Vec<String> = text
.lines()
let mut lines: Vec<String> = text.lines()
.filter(|l| !l.is_empty())
.map(|l| {
let p = Path::new(l);
@ -270,13 +254,7 @@ impl FileSearchTool {
.map_err(|e| anyhow::anyhow!("Invalid glob pattern '{}': {}", pattern, e))?;
let mut results = Vec::new();
walk_dir(
Path::new(dir),
Path::new(dir),
&re,
&mut results,
max_results,
)?;
walk_dir(Path::new(dir), Path::new(dir), &re, &mut results, max_results)?;
Ok(results)
}
}
@ -333,18 +311,15 @@ fn walk_dir(
if path.is_dir() {
if let Some(name) = rel.file_name().and_then(|n| n.to_str())
&& name.starts_with('.')
&& name.len() > 1
{
continue;
}
&& name.starts_with('.') && name.len() > 1 {
continue;
}
walk_dir(base, &path, re, results, max)?;
} else if path.is_file() {
if let Some(name) = rel.file_name().and_then(|n| n.to_str())
&& re.is_match(name)
{
results.push(rel.to_string_lossy().to_string());
}
&& re.is_match(name) {
results.push(rel.to_string_lossy().to_string());
}
if results.len() >= max {
return Ok(());
}

View File

@ -90,14 +90,13 @@ impl Tool for FileWriteTool {
// Create parent directories if needed
if let Some(parent) = resolved.parent()
&& !parent.exists()
&& let Err(e) = std::fs::create_dir_all(parent)
{
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Failed to create parent directory: {}", e)),
});
}
&& let Err(e) = std::fs::create_dir_all(parent) {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Failed to create parent directory: {}", e)),
});
}
match std::fs::write(&resolved, content) {
Ok(_) => Ok(ToolResult {
@ -169,7 +168,10 @@ mod tests {
#[tokio::test]
async fn test_write_missing_path() {
let tool = FileWriteTool::new();
let result = tool.execute(json!({ "content": "Hello" })).await.unwrap();
let result = tool
.execute(json!({ "content": "Hello" }))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("path"));

View File

@ -129,9 +129,7 @@ impl GetSkillTool {
let mut output = format!("可用 skill (共 {} 个):\n", skills.len());
for s in &skills {
let always_mark = if s.always { " [常驻]" } else { "" };
let path_str = s
.path
.as_ref()
let path_str = s.path.as_ref()
.map(|p| p.to_string_lossy().to_string())
.unwrap_or_else(|| "".to_string());
output.push_str(&format!(
@ -150,10 +148,10 @@ impl GetSkillTool {
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
use std::fs::File;
use std::io::Write;
use std::path::PathBuf;
use tempfile::tempdir;
#[tokio::test]
async fn test_get_existing_skill() {

View File

@ -50,7 +50,10 @@ impl HttpRequestTool {
}
if !host_matches_allowlist(&host, &self.allowed_domains) {
return Err(format!("Host '{}' is not in allowed_domains", host));
return Err(format!(
"Host '{}' is not in allowed_domains",
host
));
}
Ok(url.to_string())
@ -77,10 +80,11 @@ impl HttpRequestTool {
for (key, value) in obj {
if let Some(str_val) = value.as_str()
&& let Ok(name) = reqwest::header::HeaderName::from_bytes(key.as_bytes())
&& let Ok(val) = reqwest::header::HeaderValue::from_str(str_val)
{
header_map.insert(name, val);
}
&& let Ok(val) =
reqwest::header::HeaderValue::from_str(str_val)
{
header_map.insert(name, val);
}
}
}
@ -187,9 +191,7 @@ fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool {
allowed_domains.iter().any(|domain| {
host == domain
|| host
.strip_suffix(domain)
.is_some_and(|prefix| prefix.ends_with('.'))
|| host.strip_suffix(domain).is_some_and(|prefix| prefix.ends_with('.'))
})
}
@ -200,11 +202,7 @@ fn is_private_host(host: &str) -> bool {
}
// Check .local TLD
if host
.rsplit('.')
.next()
.is_some_and(|label| label == "local")
{
if host.rsplit('.').next().is_some_and(|label| label == "local") {
return true;
}
@ -226,7 +224,9 @@ fn is_private_ip(ip: &std::net::IpAddr) -> bool {
|| v4.is_broadcast()
|| v4.is_multicast()
}
std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(),
std::net::IpAddr::V6(v6) => {
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
}
}
}
@ -278,7 +278,10 @@ impl Tool for HttpRequestTool {
}
};
let method_str = args.get("method").and_then(|v| v.as_str()).unwrap_or("GET");
let method_str = args
.get("method")
.and_then(|v| v.as_str())
.unwrap_or("GET");
let headers_val = args.get("headers").cloned().unwrap_or(json!({}));
let body = args.get("body").and_then(|v| v.as_str());

View File

@ -151,19 +151,10 @@ impl Tool for MemoryRecallTool {
.and_then(|v| v.as_i64())
.unwrap_or(chrono::Utc::now().timestamp_millis());
self.memory
.recall_by_time(
since,
until,
Some(query),
limit,
Some(MemoryCategory::Knowledge),
None,
)
.recall_by_time(since, until, Some(query), limit, Some(MemoryCategory::Knowledge), None)
.await?
} else {
self.memory
.recall(query, limit, Some(MemoryCategory::Knowledge), None)
.await?
self.memory.recall(query, limit, Some(MemoryCategory::Knowledge), None).await?
};
if entries.is_empty() {
@ -177,11 +168,7 @@ impl Tool for MemoryRecallTool {
let formatted = entries
.iter()
.map(|e| {
let session = e
.session_id
.as_deref()
.map(|s| format!(" [session: {}]", s))
.unwrap_or_default();
let session = e.session_id.as_deref().map(|s| format!(" [session: {}]", s)).unwrap_or_default();
format!(
"- {} [{}]{} [importance: {:.1}]: {}",
e.key,
@ -277,19 +264,10 @@ impl Tool for TimelineRecallTool {
.and_then(|v| v.as_i64())
.unwrap_or(chrono::Utc::now().timestamp_millis());
self.memory
.recall_by_time(
since,
until,
Some(query),
limit,
Some(MemoryCategory::Timeline),
session_id,
)
.recall_by_time(since, until, Some(query), limit, Some(MemoryCategory::Timeline), session_id)
.await?
} else {
self.memory
.recall(query, limit, Some(MemoryCategory::Timeline), session_id)
.await?
self.memory.recall(query, limit, Some(MemoryCategory::Timeline), session_id).await?
};
if entries.is_empty() {
@ -303,11 +281,7 @@ impl Tool for TimelineRecallTool {
let formatted = entries
.iter()
.map(|e| {
let session = e
.session_id
.as_deref()
.map(|s| format!(" [session: {}]", s))
.unwrap_or_default();
let session = e.session_id.as_deref().map(|s| format!(" [session: {}]", s)).unwrap_or_default();
format!(
"- {} [{}]{} [importance: {:.1}]: {}",
e.key,

View File

@ -4,7 +4,6 @@ pub mod calculator;
pub mod chat_manager;
pub mod content_search;
pub mod cron;
pub mod delegate;
pub mod file_edit;
pub mod file_read;
pub mod file_search;
@ -24,7 +23,6 @@ pub use browser::BrowserTool;
pub use calculator::CalculatorTool;
pub use chat_manager::ChatManagerTool;
pub use content_search::ContentSearchTool;
pub use delegate::DelegateTool;
pub use file_edit::FileEditTool;
pub use file_read::FileReadTool;
pub use file_search::FileSearchTool;
@ -37,11 +35,10 @@ pub use send_message::SendMessageTool;
pub use traits::{OutboundMessenger, Tool, ToolResult};
pub use web_fetch::WebFetchTool;
use crate::agent::SubAgentManager;
use std::sync::Arc;
use crate::config::BrowserConfig;
use crate::memory::MemoryManager;
use crate::skills::SkillsLoader;
use std::sync::Arc;
/// Create the base tool registry (without send_message).
/// `send_message` tool is registered later via `SessionManager::register_outbound_tool()`
@ -49,7 +46,6 @@ use std::sync::Arc;
pub fn create_default_tools(
skills_loader: Arc<SkillsLoader>,
memory: Arc<MemoryManager>,
sub_agent_manager: Option<Arc<SubAgentManager>>,
browser_config: Option<&BrowserConfig>,
) -> ToolRegistry {
let registry = ToolRegistry::new();
@ -80,9 +76,5 @@ pub fn create_default_tools(
}
}
if let Some(mgr) = sub_agent_manager {
registry.register(DelegateTool::new(mgr));
}
registry
}

View File

@ -1,608 +0,0 @@
use std::collections::{HashMap, VecDeque};
use std::io::Write;
use std::sync::{Arc, Mutex};
use std::time::Instant;
use async_trait::async_trait;
use serde_json::json;
use crate::tools::traits::{Tool, ToolResult};
const MAX_OUTPUT_LINES: usize = 50_000;
const MAX_CHARS_PER_LINE: usize = 2_000;
const MAX_SESSIONS: usize = 10;
fn guard_command(command: &str) -> Option<String> {
let deny_patterns: &[&str] = &[
r"\brm\s+-[rf]{1,2}\b",
r"\bdel\s+/[fq]\b",
r"\brmdir\s+/s\b",
r":\(\)\s*\{.*\};\s*:",
];
let lower = command.to_lowercase();
for pattern in deny_patterns {
if regex::Regex::new(pattern)
.ok()
.map(|re| re.is_match(&lower))
.unwrap_or(false)
{
return Some(format!(
"Command blocked by safety guard (dangerous pattern: {})",
pattern
));
}
}
None
}
fn unescape_control_chars(s: &str) -> String {
let mut result = String::with_capacity(s.len());
let mut chars = s.chars().peekable();
while let Some(c) = chars.next() {
if c == '\\' {
match chars.next() {
Some('n') => result.push('\n'),
Some('r') => result.push('\r'),
Some('t') => result.push('\t'),
Some('x') => {
let hex: String = chars.by_ref().take(2).collect();
if hex.len() == 2 {
if let Ok(byte) = u8::from_str_radix(&hex, 16) {
result.push(byte as char);
} else {
result.push_str(&format!("\\x{}", hex));
}
} else {
result.push_str(&format!("\\x{}", hex));
}
}
Some('\\') => result.push('\\'),
Some(other) => {
result.push('\\');
result.push(other);
}
None => result.push('\\'),
}
} else {
result.push(c);
}
}
result
}
fn truncate_command(cmd: &str, max_len: usize) -> String {
let cmd = cmd.trim();
let first_arg = cmd.split_whitespace().next().unwrap_or(cmd);
if first_arg.len() > max_len {
format!("{}...", &first_arg[..max_len])
} else {
first_arg.to_string()
}
}
fn status_str(status: &SessionStatus) -> &str {
match status {
SessionStatus::Running => "running",
SessionStatus::Exited(_) => "exited",
SessionStatus::Killed => "killed",
}
}
#[derive(Debug, Clone, PartialEq)]
enum SessionStatus {
Running,
Exited(i32),
Killed,
}
struct PtySession {
#[allow(dead_code)]
id: String,
#[allow(dead_code)]
command: String,
#[allow(dead_code)]
started_at: Instant,
status: SessionStatus,
child: Arc<Mutex<Option<Box<dyn portable_pty::Child + Send + Sync>>>>,
writer: Arc<Mutex<Option<Box<dyn Write + Send>>>>,
output_buffer: VecDeque<String>,
output_total_lines: usize,
}
impl PtySession {
fn new(
id: String,
command: String,
child: Box<dyn portable_pty::Child + Send + Sync>,
writer: Box<dyn Write + Send>,
) -> Self {
Self {
id,
command,
started_at: Instant::now(),
status: SessionStatus::Running,
child: Arc::new(Mutex::new(Some(child))),
writer: Arc::new(Mutex::new(Some(writer))),
output_buffer: VecDeque::new(),
output_total_lines: 0,
}
}
fn push_line(&mut self, line: String) {
let line = if line.len() > MAX_CHARS_PER_LINE {
format!("{}...<truncated>", &line[..MAX_CHARS_PER_LINE])
} else {
line
};
self.output_total_lines += 1;
if self.output_buffer.len() >= MAX_OUTPUT_LINES {
self.output_buffer.pop_front();
}
self.output_buffer.push_back(line);
}
}
pub struct PtyManager {
sessions: Mutex<HashMap<String, Arc<Mutex<PtySession>>>>,
}
impl PtyManager {
pub fn new() -> Self {
Self {
sessions: Mutex::new(HashMap::new()),
}
}
pub fn cleanup_all(&self) {
let sessions: Vec<Arc<Mutex<PtySession>>> = {
let mut guard = self.sessions.lock().unwrap();
guard.drain().map(|(_, s)| s).collect()
};
for session in sessions {
let mut guard = session.lock().unwrap();
let mut child_guard = guard.child.lock().unwrap();
if let Some(ref mut child) = *child_guard {
let _ = child.kill();
}
*child_guard = None;
guard.status = SessionStatus::Killed;
}
}
fn spawn(&self, command: &str) -> Result<String, String> {
if let Some(reason) = guard_command(command) {
return Err(reason);
}
let mut sessions = self.sessions.lock().unwrap();
if sessions.len() >= MAX_SESSIONS {
return Err(format!(
"Max sessions ({}) reached, kill some sessions first",
MAX_SESSIONS
));
}
let session_id = format!("pty_{}", crate::util::short_id());
let cwd = std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from("."));
let pty_system = portable_pty::native_pty_system();
let pty_pair = pty_system
.openpty(portable_pty::PtySize {
rows: 24,
cols: 80,
pixel_width: 0,
pixel_height: 0,
})
.map_err(|e| format!("Failed to open PTY: {}", e))?;
let mut cmd = portable_pty::CommandBuilder::new("bash");
cmd.args(&["-c", command]);
cmd.cwd(cwd);
let child = pty_pair
.slave
.spawn_command(cmd)
.map_err(|e| format!("Failed to spawn: {}", e))?;
let pid = child.process_id().unwrap_or(0);
let writer = pty_pair
.master
.take_writer()
.map_err(|e| format!("Failed to take writer: {}", e))?;
let reader = pty_pair
.master
.try_clone_reader()
.map_err(|e| format!("Failed to clone reader: {}", e))?;
let session_id_clone = session_id.clone();
let session = PtySession::new(session_id_clone, command.to_string(), child, writer);
let session = Arc::new(Mutex::new(session));
sessions.insert(session_id.clone(), session.clone());
let session_for_reader = session.clone();
let child_for_reader = session_for_reader.lock().unwrap().child.clone();
tokio::task::spawn_blocking(move || {
let mut reader = reader;
let mut buf = [0u8; 4096];
let mut partial = String::new();
loop {
use std::io::Read;
match reader.read(&mut buf) {
Ok(0) => break,
Ok(n) => {
let s = String::from_utf8_lossy(&buf[..n]);
partial.push_str(&s);
let lines: Vec<&str> = partial.split('\n').collect();
if lines.len() <= 1 {
continue;
}
let complete = lines.len() - 1;
let mut guard = session_for_reader.lock().unwrap();
for line in lines[..complete].iter() {
guard.push_line(line.to_string());
}
partial = lines[complete].to_string();
}
Err(_) => break,
}
}
let mut guard = session_for_reader.lock().unwrap();
if !partial.is_empty() {
guard.push_line(partial);
}
let exit_code = {
let mut cg = child_for_reader.lock().unwrap();
if let Some(ref mut c) = *cg {
c.wait().map(|s| s.exit_code() as i32).ok()
} else {
None
}
};
guard.status = SessionStatus::Exited(exit_code.unwrap_or(-1));
});
Ok(format!("session_id: {}, pid: {}", session_id, pid))
}
fn write(&self, session_id: &str, data: &str) -> Result<String, String> {
let unescaped = unescape_control_chars(data);
let byte_count = unescaped.len();
let sessions = self.sessions.lock().unwrap();
let session = sessions
.get(session_id)
.ok_or_else(|| format!("Session not found: {}", session_id))?;
let mut guard = session.lock().unwrap();
if guard.status != SessionStatus::Running {
return Err("Session is not running".to_string());
}
let writer = guard.writer.clone();
drop(guard);
drop(sessions);
let mut writer_guard = writer.lock().unwrap();
match *writer_guard {
Some(ref mut w) => {
w.write_all(unescaped.as_bytes())
.map_err(|e| format!("Write error: {}", e))?;
w.flush().map_err(|e| format!("Flush error: {}", e))?;
}
None => return Err("Writer not available".to_string()),
}
Ok(format!("OK, wrote {} bytes", byte_count))
}
fn read(
&self,
session_id: &str,
offset: usize,
limit: usize,
) -> Result<String, String> {
let sessions = self.sessions.lock().unwrap();
let session = sessions
.get(session_id)
.ok_or_else(|| format!("Session not found: {}", session_id))?;
let guard = session.lock().unwrap();
let total = guard.output_total_lines;
let buffer_len = guard.output_buffer.len();
let start = 0_usize.max(offset);
let skip_old = total.saturating_sub(buffer_len);
let view_start = start.saturating_sub(skip_old);
let lines: Vec<String> = guard
.output_buffer
.iter()
.skip(view_start)
.take(limit)
.enumerate()
.map(|(i, line)| format!("{}: {}", start + i, line))
.collect();
let displayed = start + lines.len();
let has_more = displayed < total;
let mut output = format!(
"# Lines {}-{} (共 {} 行{})\n",
start,
displayed.saturating_sub(1),
total,
if has_more { ",还有更多" } else { "" }
);
output.push_str(&lines.join("\n"));
if has_more {
output.push_str(&format!(
"\n[还有 {} 行未显示,用 offset={} 继续读取]",
total.saturating_sub(displayed),
displayed
));
}
Ok(output)
}
fn kill(&self, session_id: &str) -> Result<String, String> {
let mut sessions = self.sessions.lock().unwrap();
let session = sessions
.get(session_id)
.ok_or_else(|| format!("Session not found: {}", session_id))?;
let mut guard = session.lock().unwrap();
let mut child_guard = guard.child.lock().unwrap();
if let Some(ref mut child) = *child_guard {
let _ = child.kill();
let _ = child.wait();
}
*child_guard = None;
guard.status = SessionStatus::Killed;
drop(child_guard);
drop(guard);
sessions.remove(session_id);
Ok(format!("Session {} killed", session_id))
}
fn list(&self) -> String {
let sessions = self.sessions.lock().unwrap();
if sessions.is_empty() {
return "No active PTY sessions".to_string();
}
let mut lines: Vec<String> = sessions
.iter()
.map(|(id, session)| {
let guard = session.lock().unwrap();
let age = Instant::now().duration_since(guard.started_at);
let age_str = if age.as_secs() < 60 {
format!("{}s ago", age.as_secs())
} else {
format!("{}m ago", age.as_secs() / 60)
};
format!(
"{:<14} {:<12} {:<10} {:<6} lines {}",
id,
truncate_command(&guard.command, 10),
status_str(&guard.status),
guard.output_total_lines,
age_str,
)
})
.collect();
lines.sort();
lines.join("\n")
}
}
impl Drop for PtyManager {
fn drop(&mut self) {
self.cleanup_all();
}
}
pub struct PtyTool {
pty_manager: Arc<PtyManager>,
}
impl PtyTool {
pub fn new(pty_manager: Arc<PtyManager>) -> Self {
Self { pty_manager }
}
}
#[async_trait]
impl Tool for PtyTool {
fn name(&self) -> &str {
"pty"
}
fn description(&self) -> &str {
"管理持久伪终端(PTY)会话。用于交互式程序、长运行服务、多步骤命令等需要保持终端状态的场景。支持操作: spawn(创建)/write(写入输入)/read(读取输出)/kill(终止)/list(列出所有会话)。"
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": ["spawn", "write", "read", "kill", "list"],
"description": "操作类型: spawn=创建会话, write=写入输入, read=读取输出, kill=终止会话, list=列出所有会话"
},
"session_id": {
"type": "string",
"description": "会话ID (write/read/kill 需要)"
},
"command": {
"type": "string",
"description": "要执行的命令 (spawn 需要)"
},
"data": {
"type": "string",
"description": "写入终端的数据,支持转义序列: \\n(换行) \\x03(Ctrl+C) \\x04(Ctrl+D) \\x1a(Ctrl+Z) (write 需要)"
},
"offset": {
"type": "integer",
"description": "输出读取起始行号,从 0 开始 (read 可选,默认 0)",
"minimum": 0
},
"limit": {
"type": "integer",
"description": "读取的最大行数 (read 可选,默认 500)",
"minimum": 1
}
},
"required": ["action"]
})
}
fn exclusive(&self) -> bool {
true
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let action = match args.get("action").and_then(|v| v.as_str()) {
Some(a) => a,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: action".to_string()),
});
}
};
match action {
"spawn" => {
let command = match args.get("command").and_then(|v| v.as_str()) {
Some(c) => c,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: command".to_string()),
});
}
};
match self.pty_manager.spawn(command) {
Ok(output) => Ok(ToolResult {
success: true,
output,
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
}),
}
}
"write" => {
let session_id = match args.get("session_id").and_then(|v| v.as_str()) {
Some(id) => id,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: session_id".to_string()),
});
}
};
let data = match args.get("data").and_then(|v| v.as_str()) {
Some(d) => d,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: data".to_string()),
});
}
};
match self.pty_manager.write(session_id, data) {
Ok(output) => Ok(ToolResult {
success: true,
output,
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
}),
}
}
"read" => {
let session_id = match args.get("session_id").and_then(|v| v.as_str()) {
Some(id) => id,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: session_id".to_string()),
});
}
};
let offset = args
.get("offset")
.and_then(|v| v.as_u64())
.unwrap_or(0) as usize;
let limit = args
.get("limit")
.and_then(|v| v.as_u64())
.unwrap_or(500) as usize;
match self.pty_manager.read(session_id, offset, limit) {
Ok(output) => Ok(ToolResult {
success: true,
output,
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
}),
}
}
"kill" => {
let session_id = match args.get("session_id").and_then(|v| v.as_str()) {
Some(id) => id,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: session_id".to_string()),
});
}
};
match self.pty_manager.kill(session_id) {
Ok(output) => Ok(ToolResult {
success: true,
output,
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
}),
}
}
"list" => {
let output = self.pty_manager.list();
Ok(ToolResult {
success: true,
output,
error: None,
})
}
_ => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Unknown action: {}", action)),
}),
}
}
}

View File

@ -17,15 +17,7 @@ impl ToolRegistry {
}
pub fn register<T: ToolTrait + 'static>(&self, tool: T) {
self.tools
.lock()
.unwrap()
.insert(tool.name().to_string(), Arc::new(tool));
}
/// Register an existing Arc-wrapped tool by name
pub fn register_raw(&self, name: String, tool: Arc<dyn ToolTrait>) {
self.tools.lock().unwrap().insert(name, tool);
self.tools.lock().unwrap().insert(tool.name().to_string(), Arc::new(tool));
}
pub fn get(&self, name: &str) -> Option<Arc<dyn ToolTrait>> {
@ -70,17 +62,6 @@ impl ToolRegistry {
.map(|(k, v)| (k.clone(), v.clone()))
.collect()
}
/// 生成工具列表描述,用于子 Agent 系统提示词
pub fn describe_for_prompt(&self) -> String {
let mut entries: Vec<String> = self
.iter()
.into_iter()
.map(|(name, tool)| format!("- {}: {}", name, tool.description()))
.collect();
entries.sort();
entries.join("\n")
}
}
impl Default for ToolRegistry {

View File

@ -115,11 +115,9 @@ impl SchemaCleanr {
}
if let Some(Value::String(t)) = obj.get("type")
&& t == "object"
&& !obj.contains_key("properties")
{
tracing::warn!("Object schema without 'properties' field may cause issues");
}
&& t == "object" && !obj.contains_key("properties") {
tracing::warn!("Object schema without 'properties' field may cause issues");
}
Ok(())
}
@ -175,10 +173,9 @@ impl SchemaCleanr {
// Handle anyOf/oneOf simplification
if (obj.contains_key("anyOf") || obj.contains_key("oneOf"))
&& let Some(simplified) = Self::try_simplify_union(&obj, defs, strategy, ref_stack)
{
return simplified;
}
&& let Some(simplified) = Self::try_simplify_union(&obj, defs, strategy, ref_stack) {
return simplified;
}
// Build cleaned object
let mut cleaned = Map::new();
@ -246,13 +243,12 @@ impl SchemaCleanr {
}
if let Some(def_name) = Self::parse_local_ref(ref_value)
&& let Some(definition) = defs.get(def_name.as_str())
{
ref_stack.insert(ref_value.to_string());
let cleaned = Self::clean_with_defs(definition.clone(), defs, strategy, ref_stack);
ref_stack.remove(ref_value);
return Self::preserve_meta(obj, cleaned);
}
&& let Some(definition) = defs.get(def_name.as_str()) {
ref_stack.insert(ref_value.to_string());
let cleaned = Self::clean_with_defs(definition.clone(), defs, strategy, ref_stack);
ref_stack.remove(ref_value);
return Self::preserve_meta(obj, cleaned);
}
tracing::warn!("Cannot resolve $ref: {}", ref_value);
Self::preserve_meta(obj, Value::Object(Map::new()))
@ -344,16 +340,13 @@ impl SchemaCleanr {
return true;
}
if let Some(Value::Array(arr)) = obj.get("enum")
&& arr.len() == 1
&& matches!(arr[0], Value::Null)
{
return true;
}
&& arr.len() == 1 && matches!(arr[0], Value::Null) {
return true;
}
if let Some(Value::String(t)) = obj.get("type")
&& t == "null"
{
return true;
}
&& t == "null" {
return true;
}
}
false
}
@ -410,10 +403,7 @@ impl SchemaCleanr {
match non_null.len() {
0 => Value::String("null".to_string()),
1 => non_null
.into_iter()
.next()
.unwrap_or(Value::String("null".to_string())),
1 => non_null.into_iter().next().unwrap_or(Value::String("null".to_string())),
_ => Value::Array(non_null),
}
} else {

View File

@ -1,5 +1,5 @@
use std::collections::HashSet;
use std::sync::Arc;
use std::collections::HashSet;
use async_trait::async_trait;
use mime_guess::mime;
@ -31,20 +31,14 @@ fn parse_target_chat_id(raw: &str) -> Result<(&str, &str, Option<&str>), String>
match parts.len() {
2 => {
if parts[0].is_empty() || parts[1].is_empty() {
Err(format!(
"Invalid target_chat_id format '{}': channel and chat_id must not be empty",
raw
))
Err(format!("Invalid target_chat_id format '{}': channel and chat_id must not be empty", raw))
} else {
Ok((parts[0], parts[1], None))
}
}
3 => {
if parts[0].is_empty() || parts[1].is_empty() || parts[2].is_empty() {
Err(format!(
"Invalid target_chat_id format '{}': all three parts must not be empty",
raw
))
Err(format!("Invalid target_chat_id format '{}': all three parts must not be empty", raw))
} else {
Ok((parts[0], parts[1], Some(parts[2])))
}
@ -104,8 +98,8 @@ target_chat_id 支持两种格式:<channel>:<chat_id>(发送到该聊天下
.ok_or_else(|| anyhow::anyhow!("missing content"))?;
// 1. Parse target_chat_id
let (channel, chat_id, dialog_id) =
parse_target_chat_id(raw_id).map_err(|e| anyhow::anyhow!(e))?;
let (channel, chat_id, dialog_id) = parse_target_chat_id(raw_id)
.map_err(|e| anyhow::anyhow!(e))?;
// 2. Validate channel
if !self.available_channels.contains(channel) {
@ -115,11 +109,7 @@ target_chat_id 支持两种格式:<channel>:<chat_id>(发送到该聊天下
error: Some(format!(
"Channel '{}' is not available. Available channels: {}",
channel,
self.available_channels
.iter()
.cloned()
.collect::<Vec<_>>()
.join(", ")
self.available_channels.iter().cloned().collect::<Vec<_>>().join(", ")
)),
});
}
@ -139,8 +129,7 @@ target_chat_id 支持两种格式:<channel>:<chat_id>(发送到该聊天下
let media = parse_files_arg(&args);
// 4. Send via messenger
match self
.messenger
match self.messenger
.send_message(channel, chat_id, dialog_id, content, source, media)
.await
{

View File

@ -1,5 +1,5 @@
use crate::bus::{MediaItem, MessageSource};
use async_trait::async_trait;
use crate::bus::{MediaItem, MessageSource};
#[derive(Debug, Clone)]
pub struct ToolResult {

View File

@ -239,11 +239,7 @@ fn is_private_host(host: &str) -> bool {
return true;
}
if host
.rsplit('.')
.next()
.is_some_and(|label| label == "local")
{
if host.rsplit('.').next().is_some_and(|label| label == "local") {
return true;
}
@ -252,9 +248,7 @@ fn is_private_host(host: &str) -> bool {
std::net::IpAddr::V4(v4) => {
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified()
}
std::net::IpAddr::V6(v6) => {
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
}
std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(),
};
}

View File

@ -1,6 +1,6 @@
use picobot::config::{Config, LLMProviderConfig};
use picobot::providers::{ChatCompletionRequest, Message, create_provider};
use std::collections::HashMap;
use picobot::providers::{create_provider, ChatCompletionRequest, Message};
use picobot::config::{Config, LLMProviderConfig};
fn load_config() -> Option<LLMProviderConfig> {
dotenv::from_filename("tests/test.env").ok()?;
@ -42,7 +42,8 @@ fn create_request(content: &str) -> ChatCompletionRequest {
#[tokio::test]
#[ignore]
async fn test_openai_simple_completion() {
let config = load_config().expect("Please configure tests/test.env with valid API keys");
let config = load_config()
.expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider");
let response = provider.chat(create_request("Say 'ok'")).await.unwrap();
@ -56,7 +57,8 @@ async fn test_openai_simple_completion() {
#[tokio::test]
#[ignore]
async fn test_openai_conversation() {
let config = load_config().expect("Please configure tests/test.env with valid API keys");
let config = load_config()
.expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider");
@ -80,9 +82,7 @@ async fn test_openai_conversation() {
async fn test_config_load() {
// Test that config.json can be loaded and provider config created
let config = Config::load("config.json").expect("Failed to load config.json");
let provider_config = config
.get_provider_config("default")
.expect("Failed to get provider config");
let provider_config = config.get_provider_config("default").expect("Failed to get provider config");
assert_eq!(provider_config.provider_type, "openai");
assert_eq!(provider_config.name, "aliyun");

View File

@ -1,5 +1,5 @@
use picobot::protocol::{SessionSummary, WsInbound, WsOutbound};
use picobot::providers::{ChatCompletionRequest, Message};
use picobot::protocol::{SessionSummary, WsInbound, WsOutbound};
/// Test that message with special characters is properly escaped
#[test]
@ -19,9 +19,7 @@ fn test_message_special_characters() {
#[test]
fn test_multiline_system_prompt() {
let messages = vec![
Message::system(
"You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate",
),
Message::system("You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate"),
Message::user("Hi"),
];
@ -35,7 +33,10 @@ fn test_multiline_system_prompt() {
#[test]
fn test_chat_request_serialization() {
let request = ChatCompletionRequest {
messages: vec![Message::system("You are helpful"), Message::user("Hello")],
messages: vec![
Message::system("You are helpful"),
Message::user("Hello"),
],
temperature: Some(0.7),
max_tokens: Some(100),
tools: None,

View File

@ -41,7 +41,7 @@ async fn test_scheduler_types_roundtrip() {
/// Verify that next_run_for_schedule produces valid future timestamps.
#[test]
fn test_next_run_always_future() {
use picobot::scheduler::{Schedule, next_run_for_schedule};
use picobot::scheduler::{next_run_for_schedule, Schedule};
let now = 1700000000000_i64;
@ -56,10 +56,6 @@ fn test_next_run_always_future() {
for s in &schedules {
let next = next_run_for_schedule(s, now);
assert!(next.is_some(), "expected next run for {:?}", s);
assert!(
next.unwrap() > now,
"next run should be after now for {:?}",
s
);
assert!(next.unwrap() > now, "next run should be after now for {:?}", s);
}
}

View File

@ -1,6 +1,6 @@
use picobot::config::LLMProviderConfig;
use picobot::providers::{ChatCompletionRequest, Message, Tool, ToolFunction, create_provider};
use std::collections::HashMap;
use picobot::providers::{create_provider, ChatCompletionRequest, Message, Tool, ToolFunction};
use picobot::config::LLMProviderConfig;
fn load_openai_config() -> Option<LLMProviderConfig> {
dotenv::from_filename("tests/test.env").ok()?;
@ -53,7 +53,8 @@ fn make_weather_tool() -> Tool {
#[tokio::test]
#[ignore]
async fn test_openai_tool_call() {
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
let config = load_openai_config()
.expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider");
@ -67,11 +68,7 @@ async fn test_openai_tool_call() {
let response = provider.chat(request).await.unwrap();
// Should have tool calls
assert!(
!response.tool_calls.is_empty(),
"Expected tool call, got: {}",
response.content
);
assert!(!response.tool_calls.is_empty(), "Expected tool call, got: {}", response.content);
let tool_call = &response.tool_calls[0];
assert_eq!(tool_call.name, "get_weather");
@ -81,7 +78,8 @@ async fn test_openai_tool_call() {
#[tokio::test]
#[ignore]
async fn test_openai_tool_call_with_manual_execution() {
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
let config = load_openai_config()
.expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider");
@ -94,7 +92,8 @@ async fn test_openai_tool_call_with_manual_execution() {
};
let response1 = provider.chat(request1).await.unwrap();
let tool_call = response1.tool_calls.first().expect("Expected tool call");
let tool_call = response1.tool_calls.first()
.expect("Expected tool call");
assert_eq!(tool_call.name, "get_weather");
// Second request with tool result
@ -117,7 +116,8 @@ async fn test_openai_tool_call_with_manual_execution() {
#[tokio::test]
#[ignore]
async fn test_openai_no_tool_when_not_provided() {
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
let config = load_openai_config()
.expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider");