Compare commits
10 Commits
a77c026826
...
8f4ee79d8d
| Author | SHA1 | Date | |
|---|---|---|---|
| 8f4ee79d8d | |||
| c6f4392e63 | |||
| 0d66536e90 | |||
| 0c3e740d15 | |||
| 014538eedc | |||
| 2a69021e27 | |||
| 3df628bd28 | |||
| ea1338c94f | |||
| b89dce013c | |||
| 41b4895ff0 |
28
.dockerignore
Normal file
28
.dockerignore
Normal file
@ -0,0 +1,28 @@
|
||||
# Git
|
||||
.git
|
||||
.gitignore
|
||||
|
||||
# Build artifacts
|
||||
target/
|
||||
!target/release/picobot
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# Docs and references
|
||||
docs/
|
||||
reference/
|
||||
|
||||
# Test files
|
||||
tests/
|
||||
|
||||
# Misc
|
||||
*.md
|
||||
*.txt
|
||||
.opencode/
|
||||
CLAUDE.md
|
||||
AGENTS.md
|
||||
ARCHITECTURE_REVIEW.md
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@ -1,6 +1,8 @@
|
||||
/target
|
||||
docker_build/
|
||||
reference/**
|
||||
.env
|
||||
*.env
|
||||
Cargo.lock
|
||||
.worktrees/
|
||||
design
|
||||
|
||||
@ -1,128 +0,0 @@
|
||||
# 架构审查报告
|
||||
|
||||
> 生成时间: 2026-04-26
|
||||
> 更新时间: 2026-04-26
|
||||
|
||||
## 审查摘要
|
||||
|
||||
本报告识别了当前代码库中的架构不合理、冗余和无效代码的问题。
|
||||
|
||||
---
|
||||
|
||||
## 问题清单
|
||||
|
||||
### 已修复
|
||||
|
||||
#### ✅ #1 OutboundDispatcher 重复维护 Channel 注册表
|
||||
|
||||
**修复方案**: `OutboundDispatcher` 现在从 `ChannelManager` 获取 channels,而不是自己维护一份注册表。
|
||||
|
||||
**修改文件**:
|
||||
- `src/bus/dispatcher.rs` - 移除 `channels` 字段,改用 `ChannelManager`
|
||||
- `src/channels/manager.rs` - 添加 `register_channel` 方法
|
||||
- `src/gateway/mod.rs` - 简化 dispatcher 初始化
|
||||
|
||||
---
|
||||
|
||||
#### ✅ #2 CliChatChannel 持有独立的 SessionStore
|
||||
|
||||
**修复方案**: `CliChatChannel` 的 `SessionStore` 通过依赖注入从 `ChannelManager` 获取,而不是独立持有。
|
||||
|
||||
**修改文件**:
|
||||
- `src/channels/cli_chat.rs` - 添加 `set_store()` 方法
|
||||
- `src/channels/manager.rs` - 添加 `cli_chat_channel` 字段
|
||||
- `src/gateway/mod.rs` - 重构 channel 初始化流程
|
||||
|
||||
---
|
||||
|
||||
#### ✅ #3 MessageBus 被创建两次引用
|
||||
|
||||
**修复方案**: 移除 `GatewayState.bus` 字段,直接使用 `channel_manager.bus()`。
|
||||
|
||||
**修改文件**:
|
||||
- `src/gateway/mod.rs` - 移除冗余的 `bus` 字段
|
||||
|
||||
---
|
||||
|
||||
#### ✅ #4 GatewayState 同时持有 channel_manager 和 cli_chat_channel
|
||||
|
||||
**修复方案**: `cli_chat_channel` 只通过 `ChannelManager` 管理,`GatewayState` 不再单独持有。
|
||||
|
||||
**修改文件**:
|
||||
- `src/gateway/mod.rs` - 移除 `cli_chat_channel` 字段,添加 `cli_chat_channel()` getter 方法
|
||||
|
||||
---
|
||||
|
||||
### 高优先级(待修复)
|
||||
|
||||
#### ❌ Session 每次重建都创建新的 LLM Provider
|
||||
|
||||
**文件**: `src/gateway/session.rs:349-361`
|
||||
|
||||
**问题**: 每当 session TTL 过期(默认4小时),就会销毁并重建 session,同时创建新的 LLM provider 连接。
|
||||
|
||||
**建议**: Provider 应该池化复用,不随 session 销毁而重建。
|
||||
|
||||
---
|
||||
|
||||
#### ❌ CliChatChannel::send 广播给所有客户端
|
||||
|
||||
**文件**: `src/channels/cli_chat.rs:279-289`
|
||||
|
||||
**问题**: `OutboundMessage` 有 `chat_id` 字段用于路由,但实现广播给所有客户端,而不是只发给对应 chat_id 的客户端。
|
||||
|
||||
**建议**: 根据 `chat_id` 过滤客户端,只发送给对应的客户端。
|
||||
|
||||
---
|
||||
|
||||
### 中优先级(待修复)
|
||||
|
||||
#### ❌ default_tools() 每次调用创建新 ToolRegistry
|
||||
|
||||
**文件**: `src/gateway/session.rs:212-227`
|
||||
|
||||
**建议**: 如果工具列表是只读的,直接 clone Arc;如果需要修改,需要澄清设计意图。
|
||||
|
||||
---
|
||||
|
||||
### 低优先级(待修复)
|
||||
|
||||
#### ❌ FeishuChannel::new 接收未使用的 provider_config
|
||||
|
||||
**文件**: `src/channels/feishu.rs:175-178`
|
||||
|
||||
---
|
||||
|
||||
#### ❌ OutboundDispatcher::send_with_retry 永不执行的 unreachable
|
||||
|
||||
**文件**: `src/bus/dispatcher.rs:81`
|
||||
|
||||
---
|
||||
|
||||
#### ❌ Channel trait 的 `is_running` 使用 std::sync::Mutex
|
||||
|
||||
**文件**: `src/channels/base.rs:38` vs `src/channels/cli_chat.rs:265-267`
|
||||
|
||||
---
|
||||
|
||||
#### ❌ LoopDetector 硬编码在 AgentLoop 中
|
||||
|
||||
**文件**: `src/agent/agent_loop.rs:88-172`
|
||||
|
||||
---
|
||||
|
||||
#### ❌ InboundMessage 和 OutboundMessage 结构重复
|
||||
|
||||
**文件**: `src/bus/message.rs`
|
||||
|
||||
---
|
||||
|
||||
## 问题统计
|
||||
|
||||
| 状态 | 优先级 | 数量 |
|
||||
|------|--------|------|
|
||||
| ✅ 已修复 | - | 4 |
|
||||
| ❌ 待修复 | 高 | 2 |
|
||||
| ❌ 待修复 | 中 | 1 |
|
||||
| ❌ 待修复 | 低 | 5 |
|
||||
| **总计** | - | **12** |
|
||||
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "picobot"
|
||||
version = "0.1.0"
|
||||
version = "1.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
@ -12,6 +12,8 @@ serde_json = "1.0"
|
||||
async-trait = "0.1"
|
||||
thiserror = "2.0.18"
|
||||
tokio = { version = "1.52", features = ["full"] }
|
||||
tokio-util = { version = "0.7", features = ["rt"] }
|
||||
dashmap = "6.1"
|
||||
uuid = { version = "1.23", features = ["v4"] }
|
||||
axum = { version = "0.8", features = ["ws"] }
|
||||
tokio-tungstenite = { version = "0.29.0", features = ["rustls-tls-webpki-roots", "rustls"] }
|
||||
@ -49,6 +51,7 @@ encoding_rs = "0.8"
|
||||
zstd = "0.13"
|
||||
tar = "0.4"
|
||||
fantoccini = { version = "0.22", default-features = false, features = ["rustls-tls"] }
|
||||
portable-pty = "0.9"
|
||||
|
||||
[build-dependencies]
|
||||
zstd = "0.13"
|
||||
|
||||
110
Dockerfile
Normal file
110
Dockerfile
Normal file
@ -0,0 +1,110 @@
|
||||
# =============================================================================
|
||||
# PicoBot Docker Image
|
||||
# =============================================================================
|
||||
# Build binary on host:
|
||||
# cargo build --release
|
||||
#
|
||||
# Build image:
|
||||
# docker build -t picobot .
|
||||
#
|
||||
# Run gateway: docker run -d -v ~/.picobot:/app/.picobot -p 19876:19876 picobot gateway
|
||||
# Run chat: docker run -it -v ~/.picobot:/app/.picobot picobot chat
|
||||
# =============================================================================
|
||||
|
||||
FROM debian:trixie-slim
|
||||
|
||||
LABEL org.opencontainers.image.title="PicoBot"
|
||||
LABEL org.opencontainers.image.description="AI agent gateway and chat client"
|
||||
LABEL org.opencontainers.image.source="https://github.com/your-repo/picobot"
|
||||
|
||||
# Avoid interactive prompts
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Configure domestic mirrors for pip, uv, npm (China)
|
||||
ENV PIP_INDEX_URL=https://mirrors.aliyun.com/pypi/simple/
|
||||
ENV UV_INDEX_URL=https://mirrors.aliyun.com/pypi/simple/
|
||||
|
||||
# Install base tools, Python, and uv in one layer to reduce duplication
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
tini \
|
||||
curl \
|
||||
gnupg \
|
||||
git \
|
||||
jq \
|
||||
tree \
|
||||
zip \
|
||||
unzip \
|
||||
sqlite3 \
|
||||
openssh-client \
|
||||
sshpass \
|
||||
dnsutils \
|
||||
poppler-utils \
|
||||
fonts-wqy-zenhei \
|
||||
fonts-wqy-microhei \
|
||||
python3 \
|
||||
python3-pip \
|
||||
python3-venv \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& pip3 install --no-cache-dir --break-system-packages uv
|
||||
|
||||
# Install Node.js and npx
|
||||
RUN curl -fsSL https://deb.nodesource.com/setup_22.x | bash - \
|
||||
&& apt-get install -y --no-install-recommends nodejs \
|
||||
&& npm config set registry https://registry.npmmirror.com \
|
||||
&& npm cache clean --force \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install himalaya (CLI email client) from local file
|
||||
COPY docker_build/himalaya.x86_64-linux.tgz /tmp/himalaya.tgz
|
||||
RUN tar -xzf /tmp/himalaya.tgz -C /usr/local/bin \
|
||||
&& chmod +x /usr/local/bin/himalaya \
|
||||
&& rm -f /tmp/himalaya.tgz
|
||||
|
||||
# Install fd (alternative to find)
|
||||
RUN curl -fsSL https://github.com/sharkdp/fd/releases/download/v9.0.0/fd-v9.0.0-x86_64-unknown-linux-gnu.tar.gz | \
|
||||
tar -xz --strip-components=1 -C /usr/local/bin \
|
||||
&& chmod +x /usr/local/bin/fd
|
||||
|
||||
# Install ripgrep (rg)
|
||||
RUN curl -fsSL https://github.com/BurntSushi/ripgrep/releases/download/14.1.0/ripgrep-14.1.0-x86_64-unknown-linux-musl.tar.gz | \
|
||||
tar -xz --strip-components=1 -C /usr/local/bin \
|
||||
&& chmod +x /usr/local/bin/rg
|
||||
|
||||
# Install Chromium and chromedriver for browser automation
|
||||
# Debian's chromium package is real (not a snap shim like Ubuntu 24.04)
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
chromium \
|
||||
chromium-driver \
|
||||
&& ln -sf /usr/bin/chromium /usr/local/bin/chrome \
|
||||
&& ln -sf /usr/bin/chromedriver /usr/local/bin/chromedriver \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create non-root user
|
||||
RUN useradd -m -s /bin/bash app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy pre-built binary from host
|
||||
COPY target/release/picobot /app/picobot
|
||||
|
||||
# Copy config template
|
||||
COPY resources/templates/config.example.json /app/config.json.example
|
||||
|
||||
# Create required directories
|
||||
RUN mkdir -p /app/.picobot/workspace /app/.picobot/media /app/.picobot/tmp && \
|
||||
chown -R app:app /app
|
||||
|
||||
USER app
|
||||
ENV HOME=/app
|
||||
|
||||
# Environment variables for Chromium in containers
|
||||
ENV CHROME_BIN=/usr/bin/chromium
|
||||
ENV TMPDIR=/app/.picobot/tmp
|
||||
|
||||
ENTRYPOINT ["/app/picobot"]
|
||||
CMD ["gateway"]
|
||||
|
||||
EXPOSE 19876
|
||||
|
||||
ENV RUST_LOG=info
|
||||
537
README.md
537
README.md
@ -1,143 +1,102 @@
|
||||
# PicoBot
|
||||
|
||||
A multi-channel AI agent framework with a WebSocket gateway and TUI client, supporting OpenAI-compatible and Anthropic LLM providers, tool calling, session persistence, and cron-based scheduling.
|
||||
PicoBot is a Rust-based personal AI assistant runtime. It runs a local gateway, connects chat channels such as the terminal TUI and Feishu/Lark, persists sessions in SQLite, and gives the agent a tool system for files, shell commands, web access, memory, scheduling, skills, MCP tools, and delegated sub-agents.
|
||||
|
||||
## System Architecture
|
||||
## What It Does
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph Clients
|
||||
TUI["🖥️ CLI Chat (TUI)"]
|
||||
FS["📱 Feishu/Lark"]
|
||||
end
|
||||
- Runs as a gateway server on `127.0.0.1:19876` by default.
|
||||
- Provides a Ratatui terminal client over WebSocket.
|
||||
- Supports Feishu/Lark messages, reactions, file upload/download, and media references.
|
||||
- Calls OpenAI-compatible providers and Anthropic Messages API providers.
|
||||
- Persists conversations, messages, memories, scheduled jobs, LLM call metadata, and background sub-agent tasks in SQLite.
|
||||
- Loads skills from workspace, user, and shared skill directories, with built-in skills installed on first use.
|
||||
- Compresses long contexts and stores timeline summaries for later recall.
|
||||
- Can register tools discovered from configured MCP servers.
|
||||
|
||||
subgraph Gateway["Gateway Server (127.0.0.1:19876)"]
|
||||
HTTP["HTTP Endpoints<br/>GET /health<br/>GET /ws (WebSocket upgrade)"]
|
||||
WS["WebSocket Handler"]
|
||||
CD["ChannelManager"]
|
||||
SP["SessionManager"]
|
||||
AL["AgentLoop"]
|
||||
end
|
||||
## Architecture
|
||||
|
||||
subgraph Bus["MessageBus"]
|
||||
IB["Inbound Channel"]
|
||||
OB["Outbound Channel"]
|
||||
CC["Control Channel"]
|
||||
end
|
||||
```text
|
||||
Channel -> MessageBus -> SessionManager -> AgentLoop -> LLM Provider
|
||||
| |
|
||||
| v
|
||||
| Tools
|
||||
v
|
||||
SQLite
|
||||
|
||||
subgraph Storage
|
||||
SQLite[("SQLite<br/>picobot.db")]
|
||||
end
|
||||
|
||||
subgraph AI["AI Providers"]
|
||||
OpenAI["OpenAI / DashScope"]
|
||||
Anthropic["Anthropic Claude"]
|
||||
end
|
||||
|
||||
TUI <-->|WebSocket| WS
|
||||
FS <-->|Webhook| HTTP
|
||||
|
||||
CD -->|InboundMessage| IB
|
||||
IB -->|DialogEvent| SP
|
||||
CC -->|ControlMessage| SP
|
||||
SP <--> AL
|
||||
AL -->|API Call| OpenAI
|
||||
AL -->|API Call| Anthropic
|
||||
AL -->|Tool Call| Tools
|
||||
SP -->|OutboundMessage| OB
|
||||
OB --> CD
|
||||
SP --> SQLite
|
||||
Tools --> SQLite
|
||||
|
||||
subgraph Tools
|
||||
Bash["Bash"]
|
||||
FileIO["File Read/Write/Edit"]
|
||||
Web["HTTP Request / Web Fetch"]
|
||||
Calc["Calculator"]
|
||||
Skill["Get Skill"]
|
||||
Msg["Send Message"]
|
||||
Cron["Cron Jobs"]
|
||||
end
|
||||
Control messages -> SessionManager -> MessageBus -> OutboundDispatcher -> Channel
|
||||
```
|
||||
|
||||
### Core Data Flow
|
||||
The main runtime boundary is:
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Channel as Channel<br/>(CLI/Feishu)
|
||||
participant Bus as MessageBus
|
||||
participant SM as SessionManager
|
||||
participant AL as AgentLoop
|
||||
participant LLM as LLM Provider
|
||||
participant Tool as Tools
|
||||
|
||||
Channel->>Bus: InboundMessage (user input)
|
||||
Bus->>SM: DialogEvent
|
||||
SM->>SM: Load/Resolve Session
|
||||
SM->>AL: Process (session state)
|
||||
AL->>LLM: ChatCompletionRequest
|
||||
LLM-->>AL: response / tool_calls
|
||||
alt Tool Calls
|
||||
AL->>Tool: execute tool
|
||||
Tool-->>AL: result
|
||||
AL->>LLM: continue with tool result
|
||||
end
|
||||
AL-->>SM: AgentProcessResult (text + token count)
|
||||
SM->>SM: Persist to SQLite
|
||||
SM->>Bus: OutboundMessage
|
||||
Bus->>Channel: response to user
|
||||
```
|
||||
- `channels` only receive and send external messages.
|
||||
- `bus` is an async queue, not a router.
|
||||
- `session` owns dialog lifecycle, persistence, memory recall, prompt assembly, compression, and task cancellation.
|
||||
- `agent` runs the stateless LLM/tool loop.
|
||||
- `providers` are HTTP clients for model APIs.
|
||||
- `tools` execute agent actions and return string results.
|
||||
- `storage` owns SQLite schema and CRUD.
|
||||
- `scheduler` polls due jobs and feeds prompts back into sessions.
|
||||
|
||||
## Features
|
||||
|
||||
### Multi-Channel Support
|
||||
- **CLI Chat Client** — Full TUI with session management, Markdown rendering, slash commands
|
||||
- **Feishu (Lark)** — Webhook-based integration with typing indicators and media support
|
||||
### Channels
|
||||
|
||||
### Multi-Provider LLM
|
||||
- OpenAI-compatible API (GPT-4, DashScope, Volcengine, etc.)
|
||||
- Anthropic Messages API (Claude)
|
||||
- Cross-provider JSON Schema normalization for tool calling compatibility
|
||||
- `cli_chat`: terminal TUI client connected through `/ws`.
|
||||
- `feishu`: Feishu/Lark channel with configurable allow list, media directory, and reaction emoji.
|
||||
|
||||
### Session Management
|
||||
- Multi-session conversations per channel/chat
|
||||
- Create, switch, rename, archive, delete dialogs via slash commands or WebSocket
|
||||
- SQLite-persisted session history with automatic TTL-based cleanup
|
||||
- Context compression for long conversations approaching token limits
|
||||
### LLM Providers
|
||||
|
||||
### Tool System
|
||||
| Tool | Description |
|
||||
|------|-------------|
|
||||
| `bash` | Execute shell commands in workspace |
|
||||
| `file_read` | Read file contents |
|
||||
| `file_write` | Create/overwrite files |
|
||||
| `file_edit` | Precise string substitution in files |
|
||||
| `http_request` | Make HTTP API requests |
|
||||
| `web_fetch` | Fetch and parse web pages |
|
||||
| `calculator` | Evaluate mathematical expressions |
|
||||
| `get_skill` | Load agent skills from local skill files |
|
||||
| `send_message` | Send messages to other channels |
|
||||
| `cron_add/list/remove/enable/disable/update` | Manage scheduled jobs |
|
||||
- OpenAI-compatible chat completions, including DashScope, Volcengine, and similar APIs.
|
||||
- Anthropic Messages API.
|
||||
- Model-specific `input_type` metadata for text/image capability checks.
|
||||
- JSON Schema cleanup for cross-provider tool compatibility.
|
||||
|
||||
### Scheduling
|
||||
- Cron-based recurring jobs with optional timezone support
|
||||
- One-shot (`at`) and interval (`every`) schedules
|
||||
- Jobs trigger agent processing via specified channel/chat
|
||||
### Sessions And Memory
|
||||
|
||||
### Skills System
|
||||
- Load Markdown skill files from `~/.picobot/skills` and `~/.agents/skills`
|
||||
- Skills inject specialized system prompts for specific tasks
|
||||
- Automatic hot-reload on file changes
|
||||
- Session IDs use `<channel>:<chat_id>:<dialog_id>`.
|
||||
- Each channel/chat can have multiple dialogs.
|
||||
- Dialog operations include create, list, switch, rename, delete, compact, dump, info, and stop.
|
||||
- Session history is persisted to SQLite and can be incrementally restored after compression.
|
||||
- Knowledge memories are recalled into the system prompt each turn.
|
||||
- Timeline memories are produced by context compression and can be searched later.
|
||||
|
||||
### Observability
|
||||
- Observer pattern for agent and tool telemetry
|
||||
- Events: `AgentStart`, `AgentEnd`, `ToolCallStart`, `ToolCall`
|
||||
- Structured JSON logging with file rotation
|
||||
### Tools
|
||||
|
||||
Base tools registered for the agent:
|
||||
|
||||
| Tool | Purpose |
|
||||
|------|---------|
|
||||
| `calculator` | Math expressions and statistics |
|
||||
| `file_read` / `file_write` / `file_edit` | Workspace file operations |
|
||||
| `file_search` / `content_search` | File and content search |
|
||||
| `bash` | Run shell commands in the workspace |
|
||||
| `http_request` | HTTP API requests |
|
||||
| `web_fetch` | Fetch and extract web page text |
|
||||
| `get_skill` | List or load local skills |
|
||||
| `memory_store` / `memory_recall` / `timeline_recall` / `memory_forget` | Long-term memory operations |
|
||||
| `delegate` | Run inline, background, or parallel sub-agents |
|
||||
| `send_message` | Send outbound messages to configured channels |
|
||||
| `chat_manager` | Inspect sessions, channels, and stored messages |
|
||||
| `cron_add/list/remove/enable/disable/update` | Manage scheduled jobs when scheduler is enabled |
|
||||
| `browser` | Optional WebDriver browser automation when enabled |
|
||||
| MCP tools | Dynamically registered from configured MCP servers |
|
||||
|
||||
### Skills
|
||||
|
||||
Skills are directories containing `SKILL.md`. Load priority is:
|
||||
|
||||
1. `{workspace}/skills`
|
||||
2. `~/.picobot/skills`
|
||||
3. `~/.agents/skills`
|
||||
|
||||
Same-name skills in higher-priority locations override lower-priority ones. Built-in skills from `resources/skills` are embedded into the binary and installed into `~/.picobot/skills` if missing.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Prerequisites
|
||||
- Rust nightly (edition 2024) — use `rustup` to install
|
||||
|
||||
- Rust toolchain with edition 2024 support.
|
||||
- A configured LLM provider API key.
|
||||
|
||||
### Build
|
||||
|
||||
@ -147,276 +106,186 @@ cargo build
|
||||
|
||||
### Configure
|
||||
|
||||
1. Create `config.json` (or `~/.picobot/config.json`):
|
||||
PicoBot loads `~/.picobot/config.json` first, then falls back to `./config.json`. On gateway startup, a template is released to `~/.picobot/config.example.json` if it does not exist. The source template is [resources/templates/config.example.json](/home/xiaoxixi/code/PicoBot/resources/templates/config.example.json).
|
||||
|
||||
Minimal example:
|
||||
|
||||
```json
|
||||
{
|
||||
"providers": {
|
||||
"openai": {
|
||||
"type": "openai",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "<OPENAI_API_KEY>"
|
||||
}
|
||||
},
|
||||
"models": {
|
||||
"gpt-4o": {
|
||||
"model_id": "gpt-4o",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 4096
|
||||
}
|
||||
},
|
||||
"agents": {
|
||||
"default": {
|
||||
"provider": "openai",
|
||||
"model": "gpt-4o",
|
||||
"max_tool_iterations": 99,
|
||||
"token_limit": 128000
|
||||
}
|
||||
"providers": {
|
||||
"openai": {
|
||||
"type": "openai",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "<OPENAI_API_KEY>",
|
||||
"extra_headers": {}
|
||||
}
|
||||
},
|
||||
"models": {
|
||||
"gpt-4o": {
|
||||
"model_id": "gpt-4o",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 4096,
|
||||
"input_type": ["text", "image"]
|
||||
}
|
||||
},
|
||||
"agents": {
|
||||
"default": {
|
||||
"provider": "openai",
|
||||
"model": "gpt-4o",
|
||||
"max_tool_iterations": 99,
|
||||
"token_limit": 128000
|
||||
}
|
||||
},
|
||||
"workspace_dir": "~/.picobot/workspace"
|
||||
}
|
||||
```
|
||||
|
||||
2. Set API keys via `.env` file (one `KEY=VALUE` per line):
|
||||
|
||||
```env
|
||||
OPENAI_API_KEY=sk-xxxxx
|
||||
```
|
||||
The `.env` file in the current directory is parsed by PicoBot itself. Values like `<OPENAI_API_KEY>` in JSON are replaced from the process environment after `.env` is loaded.
|
||||
|
||||
### Run
|
||||
|
||||
**Start gateway server:**
|
||||
|
||||
```bash
|
||||
cargo run -- gateway
|
||||
```
|
||||
|
||||
Binds `127.0.0.1:19876` by default. Override with `--host` and `--port`.
|
||||
The gateway switches the process working directory to `workspace_dir` and stores `picobot.db` there by default.
|
||||
|
||||
**Connect CLI client:**
|
||||
In another terminal:
|
||||
|
||||
```bash
|
||||
cargo run -- chat
|
||||
```
|
||||
|
||||
Connects to `ws://127.0.0.1:19876/ws`. Override with `--gateway-url`.
|
||||
The client connects to `ws://127.0.0.1:19876/ws` by default. Override with `--gateway-url`.
|
||||
|
||||
## Configuration Reference
|
||||
## Configuration
|
||||
|
||||
Config load order: `~/.picobot/config.json` → `./config.json` (fallback).
|
||||
Top-level config fields:
|
||||
|
||||
### Full Config Structure
|
||||
| Field | Purpose |
|
||||
|-------|---------|
|
||||
| `providers` | Named LLM provider configs |
|
||||
| `models` | Named model configs |
|
||||
| `agents` | Agent-to-provider/model binding |
|
||||
| `gateway` | Bind address, session DB path, cleanup, scheduler, background task limits |
|
||||
| `client` | Default WebSocket URL for the TUI client |
|
||||
| `channels` | Channel configs, currently Feishu/Lark |
|
||||
| `memory` | Recall and consolidation settings |
|
||||
| `mcp` | MCP server configs |
|
||||
| `browser` | Optional WebDriver browser tool config |
|
||||
| `workspace_dir` | Workspace used for file tools, shell commands, DB default, and workspace skills |
|
||||
|
||||
```mermaid
|
||||
graph LR
|
||||
Config["config.json"]
|
||||
Config --> Providers["providers<br/>ProviderConfig{}"]
|
||||
Config --> Models["models<br/>ModelConfig{}"]
|
||||
Config --> Agents["agents<br/>AgentConfig{}"]
|
||||
Config --> Gateway["gateway<br/>GatewayConfig"]
|
||||
Config --> Client["client<br/>ClientConfig"]
|
||||
Config --> Channels["channels<br/>ChannelConfig{}"]
|
||||
Config --> Workspace["workspace_dir"]
|
||||
Important defaults:
|
||||
|
||||
Providers --> PT["type (openai / anthropic)<br/>base_url<br/>api_key<br/>extra_headers"]
|
||||
Models --> MT["model_id<br/>temperature<br/>max_tokens"]
|
||||
Agents --> AT["provider (ref)<br/>model (ref)<br/>max_tool_iterations<br/>token_limit"]
|
||||
Gateway --> GT["host / port<br/>session_db_path<br/>scheduler"]
|
||||
Channels --> CT["feishu: app_id, app_secret<br/>allow_from, agent, media_dir"]
|
||||
```
|
||||
| Key | Default |
|
||||
|-----|---------|
|
||||
| `gateway.host` | `127.0.0.1` |
|
||||
| `gateway.port` | `19876` |
|
||||
| `gateway.max_concurrent_background_tasks` | `10` |
|
||||
| `gateway.scheduler.enabled` | `true` if `scheduler` is omitted and defaulted |
|
||||
| `client.gateway_url` | `ws://127.0.0.1:19876/ws` |
|
||||
| `memory.recall_limit` | `5` |
|
||||
| `memory.timeline_retention_days` | `90` |
|
||||
| `mcp.tool_timeout_secs` | `180` |
|
||||
| `browser.enabled` | `false` |
|
||||
|
||||
### Environment Variables
|
||||
|
||||
The `.env` file in the working directory is loaded manually (not via dotenv crate). Placeholders in `config.json` written as `<VAR_NAME>` are substituted at load time.
|
||||
|
||||
### Gateway Config
|
||||
|
||||
| Key | Type | Default | Description |
|
||||
|-----|------|---------|-------------|
|
||||
| `host` | string | `127.0.0.1` | Bind address |
|
||||
| `port` | u16 | `19876` | Listen port |
|
||||
| `session_db_path` | string | workspace `picobot.db` | SQLite database path |
|
||||
| `scheduler.enabled` | bool | `false` | Enable cron scheduler |
|
||||
|
||||
### Agent Config
|
||||
|
||||
| Key | Type | Default | Description |
|
||||
|-----|------|---------|-------------|
|
||||
| `provider` | string | — | Provider name (key in `providers`) |
|
||||
| `model` | string | — | Model name (key in `models`) |
|
||||
| `max_tool_iterations` | number | `99` | Max tool call iterations per turn |
|
||||
| `token_limit` | number | `128000` | Context window token limit |
|
||||
MCP servers support `stdio`, `sse`, and `streamable-http` transports. Browser automation requires a compatible Chrome/Chromium and chromedriver/WebDriver endpoint.
|
||||
|
||||
## Slash Commands
|
||||
|
||||
Available in CLI chat and Feishu:
|
||||
Available from CLI chat and channel text messages:
|
||||
|
||||
| Command | Alias | Description |
|
||||
|---------|-------|-------------|
|
||||
| `/new` | `/刷新` | Create a new dialog |
|
||||
| `/list` | `/对话列表` | List all dialogs |
|
||||
| `/switch <id>` | — | Switch to a dialog |
|
||||
| `/rename <title>` | — | Rename current dialog |
|
||||
| `/archive` | — | Archive current dialog |
|
||||
| `/delete` | — | Delete current dialog |
|
||||
| `/clear` | `/清空` | Clear current dialog history |
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/new` | Create a new dialog |
|
||||
| `/sessions` | List recent dialogs |
|
||||
| `/switch <dialog_id>` | Switch dialog |
|
||||
| `/rename <title>` | Rename current dialog |
|
||||
| `/delete` | Delete current dialog |
|
||||
| `/compact` | Manually trigger context compression |
|
||||
| `/info` | Show current dialog information |
|
||||
| `/dump` | Save current dialog as Markdown |
|
||||
| `/?`, `/help` | Show help |
|
||||
| `/mcp` | Show MCP server and tool status |
|
||||
| `/stop` | Stop active tasks and clear queued messages |
|
||||
|
||||
## WebSocket Protocol
|
||||
## WebSocket API
|
||||
|
||||
The gateway exposes a WebSocket endpoint at `/ws`. Messages use typed JSON with a `type` discriminator field.
|
||||
|
||||
### Client → Server (WsInbound)
|
||||
|
||||
| Type | Fields |
|
||||
|------|--------|
|
||||
| `user_input` | `content`, `channel?`, `chat_id?`, `sender_id?` |
|
||||
| `create_session` | `title?` |
|
||||
| `list_sessions` | `include_archived` |
|
||||
| `load_session` | `session_id` |
|
||||
| `rename_session` | `session_id?`, `title` |
|
||||
| `archive_session` | `session_id?` |
|
||||
| `delete_session` | `session_id?` |
|
||||
| `clear_history` | `chat_id?`, `session_id?` |
|
||||
| `get_slash_commands` | — |
|
||||
| `ping` | — |
|
||||
|
||||
### Server → Client (WsOutbound)
|
||||
|
||||
| Type | Fields |
|
||||
|------|--------|
|
||||
| `assistant_response` | `session_id`, `response`, `tokens_used?`, `tool_calls?` |
|
||||
| `session_list` | `sessions[]` |
|
||||
| `session_loaded` | `session_id`, `messages[]` |
|
||||
| `session_created` | `session_id`, `title` |
|
||||
| `session_renamed` | `session_id`, `title` |
|
||||
| `session_archived` | `session_id` |
|
||||
| `session_deleted` | `session_id` |
|
||||
| `slash_commands` | `commands[]` |
|
||||
| `error` | `message` |
|
||||
| `pong` | — |
|
||||
|
||||
## HTTP Endpoints
|
||||
The gateway exposes:
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| `GET` | `/health` | Health check — returns `{"status":"ok","version":"x.y.z"}` |
|
||||
| `GET` | `/health` | Returns service health and version |
|
||||
| `GET` | `/ws` | WebSocket upgrade for chat clients |
|
||||
|
||||
Inbound WebSocket message types:
|
||||
|
||||
| Type | Main fields |
|
||||
|------|-------------|
|
||||
| `user_input` | `content`, optional `channel`, `chat_id`, `sender_id` |
|
||||
| `clear_history` | optional `chat_id`, `session_id` |
|
||||
| `create_session` | optional `title` |
|
||||
| `list_sessions` | `include_archived` |
|
||||
| `load_session` | `session_id` |
|
||||
| `rename_session` | optional `session_id`, `title` |
|
||||
| `archive_session` | optional `session_id` |
|
||||
| `delete_session` | optional `session_id` |
|
||||
| `get_slash_commands` | none |
|
||||
| `ping` | none |
|
||||
|
||||
Outbound WebSocket message types include `assistant_response`, `error`, `session_established`, `session_created`, `session_list`, `session_loaded`, `session_renamed`, `session_archived`, `session_deleted`, `history_cleared`, `slash_commands_list`, `pong`, `command_executed`, and `system_notification`.
|
||||
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
# Unit tests (no external dependencies)
|
||||
# Unit tests
|
||||
cargo test --lib
|
||||
|
||||
# Integration tests (require API keys)
|
||||
# Integration tests require real API keys in tests/test.env
|
||||
cp tests/test.env.example tests/test.env
|
||||
# Fill in your API keys in tests/test.env
|
||||
cargo test --test test_integration -- --ignored
|
||||
cargo test --test test_tool_calling -- --ignored
|
||||
cargo test --test test_request_format -- --ignored
|
||||
|
||||
# Run all tests
|
||||
cargo test -- --ignored
|
||||
```
|
||||
|
||||
Integration tests are `#[ignore]` by default because they make real API calls.
|
||||
Integration tests are ignored by default because they make real provider calls.
|
||||
|
||||
## Project Structure
|
||||
## Project Layout
|
||||
|
||||
```
|
||||
├── src/
|
||||
│ ├── main.rs # CLI entrypoint (clap-based subcommands)
|
||||
│ ├── lib.rs # Module declarations
|
||||
│ ├── gateway/ # HTTP/WS server, GatewayState initialization
|
||||
│ │ ├── mod.rs
|
||||
│ │ ├── http.rs # Health endpoint
|
||||
│ │ └── ws.rs # WebSocket handler
|
||||
│ ├── client/ # TUI chat client
|
||||
│ │ ├── mod.rs
|
||||
│ │ └── tui/ # Ratatui-based terminal UI
|
||||
│ ├── channels/ # Channel integrations
|
||||
│ │ ├── base.rs # Channel trait
|
||||
│ │ ├── cli_chat.rs # CLI WebSocket channel
|
||||
│ │ ├── feishu.rs # Feishu/Lark webhook channel
|
||||
│ │ ├── manager.rs # ChannelManager
|
||||
│ │ └── slash_command.rs # Slash command parser
|
||||
│ ├── bus/ # Async message bus
|
||||
│ │ ├── mod.rs # MessageBus (tokio mpsc channels)
|
||||
│ │ ├── message.rs # Message types
|
||||
│ │ └── dispatcher.rs # OutboundDispatcher
|
||||
│ ├── session/ # Session & dialog management
|
||||
│ │ ├── mod.rs
|
||||
│ │ ├── session.rs # Session, SessionManager
|
||||
│ │ ├── session_id.rs # UnifiedSessionId
|
||||
│ │ ├── commands.rs # SessionCommand enum
|
||||
│ │ └── events.rs # SessionEvent, DialogInfo
|
||||
│ ├── agent/ # LLM interaction loop
|
||||
│ │ ├── mod.rs
|
||||
│ │ ├── agent_loop.rs # AgentLoop (stateless)
|
||||
│ │ ├── context_compressor.rs # Token estimation & summarization
|
||||
│ │ └── system_prompt.rs # System prompt builder
|
||||
│ ├── providers/ # LLM API clients
|
||||
│ │ ├── mod.rs # Factory: create_provider()
|
||||
│ │ ├── traits.rs # LLMProvider trait
|
||||
│ │ ├── openai.rs # OpenAI-compatible client
|
||||
│ │ └── anthropic.rs # Anthropic Messages API client
|
||||
│ ├── tools/ # Agent tools
|
||||
│ │ ├── mod.rs # create_default_tools()
|
||||
│ │ ├── registry.rs # ToolRegistry
|
||||
│ │ ├── traits.rs # Tool trait, ToolResult
|
||||
│ │ ├── schema.rs # Cross-provider JSON Schema cleaner
|
||||
│ │ ├── bash.rs # Shell command execution
|
||||
│ │ ├── calculator.rs # Math expression evaluator
|
||||
│ │ ├── chat_manager.rs # Session management tool
|
||||
│ │ ├── cron.rs # Cron job management tools
|
||||
│ │ ├── file_read.rs # File reader
|
||||
│ │ ├── file_write.rs # File writer
|
||||
│ │ ├── file_edit.rs # File editor (string substitution)
|
||||
│ │ ├── get_skill.rs # Skill loader tool
|
||||
│ │ ├── http_request.rs # HTTP request tool
|
||||
│ │ ├── send_message.rs # Cross-channel messaging
|
||||
│ │ └── web_fetch.rs # Web page fetcher
|
||||
│ ├── skills/ # Skills loading from markdown files
|
||||
│ │ └── mod.rs # SkillsLoader, Skill
|
||||
│ ├── storage/ # SQLite persistence
|
||||
│ │ ├── mod.rs # Storage, schema init
|
||||
│ │ ├── session.rs # Session CRUD operations
|
||||
│ │ ├── message.rs # Message persistence
|
||||
│ │ ├── scheduler.rs # ScheduledJob, JobRun storage
|
||||
│ │ └── error.rs # StorageError
|
||||
│ ├── scheduler/ # Cron scheduler runtime
|
||||
│ │ ├── mod.rs # Scheduler, next_run_for_schedule()
|
||||
│ │ └── types.rs # Schedule enum (At/Every/Cron)
|
||||
│ ├── observability/ # Telemetry observer pattern
|
||||
│ │ └── mod.rs # Observer trait, ObserverEvent, MultiObserver
|
||||
│ ├── protocol.rs # WebSocket message types (WsInbound/WsOutbound)
|
||||
│ ├── config/ # Config loading & env substitution
|
||||
│ │ └── mod.rs # Config, LLMProviderConfig, load_env_file()
|
||||
│ └── logging.rs # Tracing subscriber init with file rotation
|
||||
├── tests/
|
||||
│ ├── test_integration.rs # LLM provider integration tests
|
||||
│ ├── test_tool_calling.rs # Tool calling integration tests
|
||||
│ ├── test_request_format.rs # Request format tests
|
||||
│ ├── test_scheduler.rs # Scheduler unit tests
|
||||
│ ├── test.env.example # Test environment template
|
||||
│ └── test.env # Actual test keys (gitignored)
|
||||
├── reference/ # Third-party reference code (do not modify)
|
||||
├── resources/ # Assets embedded in binary
|
||||
│ └── templates/ # Templates released to ~/.picobot/ on first run
|
||||
├── config.example.json # Full config example
|
||||
└── Cargo.toml
|
||||
```text
|
||||
src/
|
||||
agent/ LLM loop, context compression, system prompts, media handling, sub-agents
|
||||
bus/ Inbound, outbound, and control message queues
|
||||
channels/ CLI chat and Feishu/Lark integrations
|
||||
client/ Ratatui terminal UI
|
||||
config/ Config loading, env substitution, path expansion
|
||||
gateway/ Axum HTTP/WebSocket server and GatewayState wiring
|
||||
mcp/ MCP client connections and tool wrappers
|
||||
memory/ Memory manager and memory types
|
||||
observability/ Agent/tool telemetry observer interfaces
|
||||
providers/ OpenAI-compatible and Anthropic clients
|
||||
scheduler/ Scheduled job runtime
|
||||
session/ Session lifecycle, dialog commands, persistence integration
|
||||
skills/ Skill loading and embedded built-in skill installation
|
||||
storage/ SQLite schema and CRUD
|
||||
tools/ Agent tool implementations
|
||||
resources/
|
||||
skills/ Built-in skills embedded at build time
|
||||
templates/ Config, AGENTS.md, and USER.md templates released on first run
|
||||
tests/ Unit and ignored integration tests
|
||||
reference/ Third-party reference code; do not modify as project source
|
||||
```
|
||||
|
||||
## Key Dependencies
|
||||
|
||||
| Crate | Purpose |
|
||||
|-------|---------|
|
||||
| `axum` + `tokio-tungstenite` | HTTP server & WebSocket |
|
||||
| `sqlx` (SQLite) | Session/Message/Job persistence |
|
||||
| `reqwest` (rustls) | LLM API & external HTTP calls |
|
||||
| `ratatui` + `crossterm` | Terminal UI |
|
||||
| `clap` | CLI argument parsing |
|
||||
| `tracing` + `tracing-subscriber` | Structured logging |
|
||||
| `cron` + `chrono-tz` | Cron schedule parsing |
|
||||
| `meval` | Mathematical expression evaluation |
|
||||
| `uuid` | Session/Dialog ID generation |
|
||||
| `dirs` | Platform config directory resolution |
|
||||
| `axum`, `tokio`, `tokio-tungstenite` | Gateway and WebSocket runtime |
|
||||
| `sqlx` | SQLite persistence |
|
||||
| `reqwest` | LLM and HTTP clients |
|
||||
| `ratatui`, `crossterm`, `termimad` | Terminal UI |
|
||||
| `rmcp` | MCP client support |
|
||||
| `fantoccini` | Optional browser automation |
|
||||
| `cron`, `chrono-tz` | Scheduling |
|
||||
| `jieba-rs` | Chinese tokenization for memory search |
|
||||
| `zstd`, `tar` | Embedded built-in skill packaging |
|
||||
|
||||
16
docker-compose.yml
Normal file
16
docker-compose.yml
Normal file
@ -0,0 +1,16 @@
|
||||
services:
|
||||
picobot:
|
||||
image: picobot:latest
|
||||
container_name: picobot
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "19876:19876"
|
||||
volumes:
|
||||
- ~/.picobot/config.json:/app/.picobot/config.json:ro
|
||||
- picobot_data:/app/.picobot
|
||||
environment:
|
||||
- RUST_LOG=info
|
||||
command: gateway
|
||||
|
||||
volumes:
|
||||
picobot_data:
|
||||
359
docs/CODE_QUALITY_ANALYSIS.md
Normal file
359
docs/CODE_QUALITY_ANALYSIS.md
Normal file
@ -0,0 +1,359 @@
|
||||
# PicoBot 代码质量分析报告
|
||||
|
||||
审查日期:2026-06-15
|
||||
|
||||
## 结论摘要
|
||||
|
||||
PicoBot 的总体架构方向是清晰的:Gateway 负责装配,Channel 只做收发,MessageBus 解耦输入输出,SessionManager 管理会话,AgentLoop 保持无状态并执行工具,Storage 统一持久化。这条主线是成立的,也已经具备较完整的 AI 助手运行时能力。
|
||||
|
||||
当前主要质量风险集中在三类:
|
||||
|
||||
1. 会话/CLI 路由语义不一致,导致多客户端隔离、加载会话、当前会话追踪不可靠。
|
||||
2. 若干公开控制接口是空实现或弱实现,协议层暴露的能力和后端实际行为不匹配。
|
||||
3. 工具和后台任务的资源边界偏弱,文件、shell、HTTP、长期任务在异常情况下容易突破预期的安全或稳定性边界。
|
||||
|
||||
如果只安排一轮修复,优先处理会话路由和控制接口。这些问题会直接影响用户看到的行为;工具安全和大模块拆分可以作为第二阶段。
|
||||
|
||||
## 修复状态
|
||||
|
||||
- 已修复:CLI 会话路由现在按每个 WebSocket client 的稳定 `chat_id` 隔离,普通输入、创建、列表、加载和 outbound 投递不再混用完整 `session_id` 与 `chat_id`。
|
||||
- 已修复:Dialog 控制接口已补齐当前会话查询、列表 current 标记、归档、清空历史和 `/delete` 删除当前会话后新建的行为;`include_archived` 现在由 Storage 查询生效。
|
||||
- 待处理:工具文件边界、Session 锁粒度、Bash 超时进程清理等仍是后续质量风险。
|
||||
|
||||
## 主要发现
|
||||
|
||||
### 已修复:CLI 会话路由会破坏会话连续性和多客户端隔离
|
||||
|
||||
位置:
|
||||
|
||||
- `src/channels/cli_chat.rs:113-126`
|
||||
- `src/channels/cli_chat.rs:160-164`
|
||||
- `src/channels/cli_chat.rs:225-249`
|
||||
- `src/channels/cli_chat.rs:479-494`
|
||||
- `src/session/session.rs:1305-1310`
|
||||
|
||||
问题:
|
||||
|
||||
`Client.current_session_id` 存的是完整 session id,但 CLI channel 在多个地方把它当作 `chat_id` 使用。普通用户输入如果没有显式传 `chat_id`,会在 `src/channels/cli_chat.rs:119` 生成新的短 ID,而不是复用当前 client 的 chat scope。`CreateSession` 又把当前完整 session id 当成新会话的 chat_id。`LoadSession` 解析了传入 session id,但随后调用 `GetCurrentDialog`,而后端 `get_current_dialog()` 固定返回 `None`。
|
||||
|
||||
同时,`send()` 会把所有 `OutboundMessage` 广播给所有 CLI WebSocket client,没有按 `msg.chat_id` 或 client 当前会话过滤。这意味着一个客户端的回复可能出现在另一个客户端里。
|
||||
|
||||
影响:
|
||||
|
||||
- CLI 多轮对话可能落入不同 chat scope。
|
||||
- 创建/列出/加载会话得到的结果可能不符合 UI 预期。
|
||||
- 多个 CLI 客户端同时连接时存在串话。
|
||||
|
||||
建议:
|
||||
|
||||
- 将 client 状态拆成 `chat_id` 和 `current_session_id`,不要混用。
|
||||
- 注册 client 时生成稳定 `chat_id`,后续 `UserInput` 默认复用它。
|
||||
- `send()` 按 `OutboundMessage.chat_id` 精确投递;必要时维护 `chat_id -> clients` 映射。
|
||||
- `LoadSession` 应直接切换到指定 session,或通过 `SwitchDialog` 使用其中的 `dialog_id`。
|
||||
- 为 CLI WebSocket 增加多客户端路由测试。
|
||||
|
||||
### 已修复:Dialog 控制接口与协议承诺不一致
|
||||
|
||||
位置:
|
||||
|
||||
- `src/session/session.rs:996-997`
|
||||
- `src/session/session.rs:1305-1310`
|
||||
- `src/session/session.rs:1329-1349`
|
||||
- `src/session/session.rs:1378-1384`
|
||||
- `src/channels/cli_chat.rs:128-158`
|
||||
|
||||
问题:
|
||||
|
||||
后端暴露了 create/list/load/rename/archive/delete/clear 等 dialog 操作,但部分行为是空实现或语义错位:
|
||||
|
||||
- `/delete` 只创建新 session,并没有删除当前 session。
|
||||
- `get_current_dialog()` 固定返回 `Ok(None)`。
|
||||
- `list_dialogs()` 忽略 `include_archived`,且总是返回 `current_dialog_id = None`。
|
||||
- `archive_dialog()` 是空操作。
|
||||
- `clear_dialog_history()` 直接返回不可用,但 WebSocket 协议仍暴露 `clear_history`。
|
||||
|
||||
影响:
|
||||
|
||||
用户通过 slash command 和 WebSocket 调用同一类能力时,会得到不一致结果。前端难以基于协议实现可靠状态同步。
|
||||
|
||||
建议:
|
||||
|
||||
- 明确“archive/clear 是否支持”。不支持就从协议和命令列表移除;支持就实现到底。
|
||||
- `/delete` 应调用 `delete_dialog(current_session_id)`,再创建一个新的 current session。
|
||||
- `get_current_dialog()` 应读取 `current_sessions[channel:chat_id]` 并解析为 `UnifiedSessionId`。
|
||||
- `list_dialogs()` 返回真实 current dialog,并补上 archived 模型或移除 archived 参数。
|
||||
|
||||
### 高优先级:工具文件边界不符合“工作目录内工具”的架构约束
|
||||
|
||||
位置:
|
||||
|
||||
- `src/tools/mod.rs:56-62`
|
||||
- `src/tools/path_utils.rs:3-23`
|
||||
- `src/tools/bash.rs:146-185`
|
||||
|
||||
问题:
|
||||
|
||||
文件工具默认通过 `FileReadTool::new()`、`FileWriteTool::new()` 等注册,没有传入 workspace allowlist。`resolve_path()` 对绝对路径直接放行;即使传入 allowlist,也只是做 `Path::starts_with()` 的词法判断,没有 canonicalize,不能防御 `..`、符号链接等路径逃逸。
|
||||
|
||||
`bash` 默认工作目录是 `"."`,Gateway 启动时切到 workspace,这对相对路径有效,但 shell 命令仍然可以访问绝对路径。当前 denylist 只挡少数危险模式,不构成权限边界。
|
||||
|
||||
影响:
|
||||
|
||||
Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“工作目录内操作”不一致。对于个人助手这可能是有意设计,但如果未来接入外部渠道、多用户或 MCP,风险会放大。
|
||||
|
||||
建议:
|
||||
|
||||
- 工具注册时传入 `workspace_dir`,默认所有文件工具限制在 workspace。
|
||||
- `resolve_path()` 使用 `std::fs::canonicalize` 或 `path_absolutize` 风格逻辑,并处理目标文件不存在时的父目录 canonicalize。
|
||||
- 写工具禁止跟随危险符号链接,或至少在文档中明确该能力是全文件系统权限。
|
||||
- shell 工具如果保留,应在配置中显式开关,并区分本地可信模式和渠道暴露模式。
|
||||
|
||||
### 中高优先级:Session 锁内执行过多异步操作
|
||||
|
||||
位置:
|
||||
|
||||
- `src/session/session.rs:1001-1018`
|
||||
- `src/session/session.rs:1604-1711`
|
||||
|
||||
问题:
|
||||
|
||||
`/compact` 在持有 session mutex 时执行压缩和持久化。agent worker 的 Phase 1 也在持有 session mutex 时执行用户消息落库、memory recall、上下文压缩、session meta 持久化和 agent 创建。其中 `compress_if_needed()` 可能触发 LLM 摘要,属于慢操作。
|
||||
|
||||
影响:
|
||||
|
||||
- 同一 session 的 slash command、stop、消息排队、状态查询会被慢操作阻塞。
|
||||
- 当压缩或存储出现抖动时,用户感觉像“卡死”。
|
||||
- 后续如果在这些慢操作里间接需要 session 状态,容易形成锁顺序问题。
|
||||
|
||||
建议:
|
||||
|
||||
- 锁内只做内存状态快照和必要的状态标记。
|
||||
- 将 memory recall、压缩、LLM 摘要放到锁外执行。
|
||||
- 锁外完成后重新加锁提交结果,并用 generation/version 检测期间是否被 `/stop` 或新任务替换。
|
||||
|
||||
### 中优先级:Bash 超时不会显式终止子进程
|
||||
|
||||
位置:
|
||||
|
||||
- `src/tools/bash.rs:150-174`
|
||||
- `src/tools/bash.rs:180-207`
|
||||
|
||||
问题:
|
||||
|
||||
`timeout()` 包裹的是 `run_command()` future。超时后 future 被取消,但代码没有持有 child 句柄并显式 `kill()` / `wait()`。对于已经启动的长运行命令或子进程树,可能留下后台进程。
|
||||
|
||||
影响:
|
||||
|
||||
长任务、服务进程或卡住的 shell 命令会泄漏进程和资源,后续工具调用的行为也会变得不可预测。
|
||||
|
||||
建议:
|
||||
|
||||
- 使用 `tokio::process::Child` 的 `kill_on_drop(true)`。
|
||||
- 超时分支显式 kill child 并 wait。
|
||||
- 对 shell 子进程树使用进程组隔离,必要时杀整个进程组。
|
||||
- 对需要持久进程的场景使用 PTY 工具,不混用 bash 的一次性语义。
|
||||
|
||||
### 中优先级:文件读取对大二进制文件没有输出上限
|
||||
|
||||
位置:
|
||||
|
||||
- `src/tools/file_read.rs:121-131`
|
||||
- `src/tools/file_read.rs:214-229`
|
||||
|
||||
问题:
|
||||
|
||||
`file_read` 先 `std::fs::read()` 读取整个文件。文本路径有 `MAX_CHARS` 截断,但二进制路径会完整 base64 编码后返回,没有大小限制。
|
||||
|
||||
影响:
|
||||
|
||||
读取大文件会造成内存膨胀、响应膨胀、上下文污染,甚至拖垮进程。
|
||||
|
||||
建议:
|
||||
|
||||
- 先检查 metadata size,超过阈值直接返回提示。
|
||||
- 二进制文件默认只返回 mime、大小和建议操作;需要内容时提供显式 `max_bytes` 参数。
|
||||
- 对文本读取也改成流式按行读取,而不是整文件读入。
|
||||
|
||||
### 中优先级:HTTP 私网防护只检查字面 host,未做 DNS 解析校验
|
||||
|
||||
位置:
|
||||
|
||||
- `src/tools/http_request.rs:31-59`
|
||||
|
||||
问题:
|
||||
|
||||
`http_request` 阻止 localhost、私网 IP 字面量和 `.local`,但普通域名不会解析后检查最终 IP。DNS rebinding 或内网域名解析到私网地址时,当前校验拦不住。
|
||||
|
||||
影响:
|
||||
|
||||
如果该工具暴露给非完全可信输入,存在 SSRF 风险。
|
||||
|
||||
建议:
|
||||
|
||||
- 请求前解析域名,拒绝私网、loopback、link-local、multicast、unspecified 地址。
|
||||
- 禁止或限制重定向,重定向后的每个 URL 重新校验。
|
||||
- 对 `http_request` 和 `web_fetch` 复用同一套 URL 安全策略。
|
||||
|
||||
### 中优先级:后台任务和主循环缺少监督与优雅关闭
|
||||
|
||||
位置:
|
||||
|
||||
- `src/bus/mod.rs:51-99`
|
||||
- `src/gateway/mod.rs:187-244`
|
||||
- `src/gateway/mod.rs:247-266`
|
||||
|
||||
问题:
|
||||
|
||||
Gateway 中多个长期任务通过 `tokio::spawn` 启动后没有保存 JoinHandle,也没有统一 cancellation token。MessageBus 的 `consume_*()` 在 channel 关闭时使用 `expect()` panic。
|
||||
|
||||
影响:
|
||||
|
||||
- 某个后台 loop 异常退出后,Gateway 不一定能发现。
|
||||
- 关闭流程只能 stop channel,无法系统性停止 scheduler、dispatcher、agent workers、notification publishers。
|
||||
- bus channel 关闭时更像崩溃,而不是可恢复状态。
|
||||
|
||||
建议:
|
||||
|
||||
- 引入 runtime supervisor,保存 JoinHandle 并集中处理退出原因。
|
||||
- 用 `CancellationToken` 贯穿 Gateway 子任务。
|
||||
- `consume_*()` 返回 `Result<Option<T>>`,由调用方决定退出或重启。
|
||||
|
||||
### 中低优先级:Cron 计算函数没有按入参 `from` 计算 cron 下一次时间
|
||||
|
||||
位置:
|
||||
|
||||
- `src/scheduler/mod.rs:18-40`
|
||||
|
||||
问题:
|
||||
|
||||
`next_run_for_schedule(schedule, from)` 的注释说基于 `from` 计算,但 cron 分支创建了 `from_dt` 后没有传给 `cron_schedule`,实际使用的是 `upcoming(Utc)` 或 `upcoming(tz)` 的当前时间。
|
||||
|
||||
影响:
|
||||
|
||||
单元测试或补偿调度传入历史/未来时间时,结果不符合函数契约。线上 reschedule 当前使用 now,影响较小,但函数语义是错的。
|
||||
|
||||
建议:
|
||||
|
||||
- 使用 `cron_schedule.after(&from_dt).next()` 或等价 API。
|
||||
- timezone 分支用 `from_dt.with_timezone(&tz)` 作为 after 起点。
|
||||
- 增加固定时间输入的单元测试,避免受系统时间影响。
|
||||
|
||||
### 中低优先级:存在未接入或半接入代码,增加维护噪音
|
||||
|
||||
位置:
|
||||
|
||||
- `src/tools/pty.rs`
|
||||
- `src/tools/mod.rs:1-20`
|
||||
- `src/tools/mod.rs:49-88`
|
||||
|
||||
问题:
|
||||
|
||||
仓库里有完整 `pty.rs`,但 `tools/mod.rs` 没有声明 `pub mod pty`,`create_default_tools()` 也没有注册 PTY 工具。类似情况会让文档、计划和实现状态难以判断。
|
||||
|
||||
影响:
|
||||
|
||||
维护者会误以为功能已上线。未来改动容易遗漏测试和注册路径。
|
||||
|
||||
建议:
|
||||
|
||||
- 若 PTY 是要发布的功能:接入模块导出、注册、配置开关、测试和文档。
|
||||
- 若暂不发布:移动到设计文档或 feature branch,避免主干保留死代码。
|
||||
|
||||
## 架构评价
|
||||
|
||||
### 做得好的地方
|
||||
|
||||
- 模块分层方向清楚:Channel、Bus、Session、Agent、Provider、Tool、Storage 边界基本可理解。
|
||||
- AgentLoop 设计为无状态,历史由 SessionManager 管理,这一点利于恢复、压缩和测试。
|
||||
- Provider 抽象简单直接,OpenAI-compatible 与 Anthropic 的差异被限制在 provider 层。
|
||||
- Storage 集中初始化 schema,便于部署单二进制应用。
|
||||
- Skill、memory、MCP、delegate 这几条扩展线已经形成统一的 ToolRegistry 接入点。
|
||||
|
||||
### 主要架构债务
|
||||
|
||||
- SessionManager 承担过多职责:会话生命周期、命令解析、memory recall、压缩、agent worker、任务取消、send_message 目标解析都在一个 2000 行文件内。
|
||||
- Channel 和 Session 对 chat_id/session_id/dialog_id 的边界没有类型保护,导致 CLI 层混用字符串。
|
||||
- Tool 权限模型不够显式:工具是否能访问全文件系统、是否能联网、是否能修改状态主要靠工具自身约定。
|
||||
- 后台任务生命周期分散:gateway loop、agent worker、notification publisher、scheduler、sub-agent task 各自 spawn,缺少统一管理。
|
||||
|
||||
## 模块级分析
|
||||
|
||||
### gateway
|
||||
|
||||
`GatewayState::new()` 是清晰的装配中心:配置、workspace、storage、memory、bus、session manager、channels、MCP、scheduler 都在这里接线。问题是启动后任务监督不足,且 scheduler 默认 `unwrap_or_default()` 会在省略 `gateway.scheduler` 时启用调度器,这和“省略配置是否代表开启”需要产品层确认。
|
||||
|
||||
### channels
|
||||
|
||||
Feishu channel 功能较厚,单文件接近 2000 行,建议后续按 API client、message parsing、media handling、outbound rendering 拆分。CLI channel 目前是质量风险最高的 channel,核心问题是会话身份混用和广播投递。
|
||||
|
||||
### bus
|
||||
|
||||
MessageBus 简洁,但当前消费者 API 通过 mutex 包住 receiver 并 `expect()`,更像“单消费者内部队列”。这没问题,但应该把“只能有一个 consumer”写进类型/文档,并把关闭作为正常状态处理。
|
||||
|
||||
### session
|
||||
|
||||
这是系统核心,也是债务最集中的模块。建议把 `session.rs` 拆成:
|
||||
|
||||
- `manager.rs`:SessionManager 状态和 dialog 生命周期
|
||||
- `worker.rs`:per-session agent worker 和 cancellation
|
||||
- `commands.rs`:slash command 执行
|
||||
- `outbound.rs`:OutboundMessenger 实现
|
||||
- `restore.rs`:storage 恢复与 tool call chain repair
|
||||
|
||||
拆分之前,先补行为测试,尤其是 CLI/WS session lifecycle。
|
||||
|
||||
### agent
|
||||
|
||||
AgentLoop 的职责相对聚焦:请求模型、执行工具、回填 tool result、循环直到 final response。需要关注的是工具并发的语义:`read_only()` 目前是工具自己声明,副作用工具不能错标。LoopDetector 有帮助,但属于 runtime guard,不应替代工具层的资源限制。
|
||||
|
||||
### providers
|
||||
|
||||
Provider 层整体可维护。OpenAI/Anthropic 的请求构造逻辑可以继续保留在 provider 内。建议补充请求脱敏策略:当前 debug log 和 `llm_calls` 会持久化完整 request/response,可能包含用户隐私、API 返回内容和文件内容。
|
||||
|
||||
### tools
|
||||
|
||||
工具体系覆盖面很强,但需要明确权限模型。建议新增统一的 `ToolExecutionContext`,包含 workspace、channel、session_id、权限策略、网络策略、输出预算。现在很多策略散落在各工具构造函数里,默认值容易失控。
|
||||
|
||||
### storage
|
||||
|
||||
Storage schema 初始化实用,但迁移方式是“CREATE IF NOT EXISTS + ALTER IGNORE”,适合早期迭代,不适合长期演进。建议引入 schema version 表或 sqlx migrations,至少把每次迁移记录下来。
|
||||
|
||||
### skills
|
||||
|
||||
Skill 加载优先级清晰,内置 skill 打包也实用。需要注意 `SkillsLoader` 使用同步文件系统扫描和 `std::sync::Mutex`,在请求路径频繁 `reload_if_changed()` 时可能造成阻塞。短期可以接受,长期建议缓存刷新放到后台 watcher。
|
||||
|
||||
## 建议修复路线
|
||||
|
||||
### P0:先修会话正确性
|
||||
|
||||
1. 修正 CLI `chat_id/current_session_id` 数据模型。
|
||||
2. 修正 CLI 出站按 client/chat_id 投递。
|
||||
3. 实现 `get_current_dialog()`、`list_dialogs()` current 返回。
|
||||
4. 修正 `/delete`、`clear_history`、`archive` 的真实行为或从协议移除。
|
||||
5. 增加 WebSocket session lifecycle 测试。
|
||||
|
||||
### P1:收紧工具和资源边界
|
||||
|
||||
1. 文件工具默认限制 workspace,路径 canonicalize。
|
||||
2. bash 超时杀进程,必要时引入进程组。
|
||||
3. file_read 增加文件大小上限和二进制输出上限。
|
||||
4. HTTP/web 工具增加 DNS 解析后的私网校验和重定向校验。
|
||||
5. 明确高危工具的配置开关。
|
||||
|
||||
### P2:降低架构复杂度
|
||||
|
||||
1. 拆分 `session.rs`、`feishu.rs`、`storage/mod.rs`、`browser.rs`。
|
||||
2. 引入任务 supervisor 和统一 shutdown token。
|
||||
3. 引入正式数据库迁移。
|
||||
4. 增加工具注册快照测试,避免死代码和文档漂移。
|
||||
|
||||
## 建议测试补充
|
||||
|
||||
- CLI 多客户端并发:两个 WebSocket client 同时发消息,互不串话。
|
||||
- CLI 不传 chat_id 的连续对话:所有消息应进入同一 session。
|
||||
- Load/switch/list/delete/clear 的完整 WebSocket 流程。
|
||||
- `/delete` 后旧 session 软删除、新 session 成为 current。
|
||||
- 文件路径逃逸:`../`、绝对路径、符号链接、workspace 前缀欺骗。
|
||||
- bash timeout 后检查子进程不存在。
|
||||
- cron `next_run_for_schedule()` 使用固定 `from` 的 deterministic 测试。
|
||||
- HTTP 工具对 DNS 解析到 `127.0.0.1` / `10.0.0.0/8` 的域名拒绝测试。
|
||||
@ -1,40 +0,0 @@
|
||||
# 客户端代码整合设计
|
||||
|
||||
## 目标
|
||||
|
||||
将分散在 `src/cli/` 和 `src/client/` 的客户端代码整合到 `src/client/` 目录。
|
||||
|
||||
## 变更
|
||||
|
||||
### 目录结构
|
||||
|
||||
```
|
||||
src/
|
||||
├── client/ # 整合后的客户端模块
|
||||
│ ├── mod.rs # 主程序入口 (run 函数)
|
||||
│ ├── input.rs # InputHandler + InputCommand (从 cli/input.rs 合并)
|
||||
│ └── channel.rs # CliChannel (从 cli/channel.rs 合并)
|
||||
├── cli/ # 删除
|
||||
└── protocol.rs # 保留
|
||||
```
|
||||
|
||||
### 关键变更
|
||||
|
||||
| 变更 | 说明 |
|
||||
|------|------|
|
||||
| `InputEvent::Message(String)` | 简化为只携带文本内容,不再使用 `ChatMessage` |
|
||||
| `cli` 模块删除 | 代码合并到 `client` |
|
||||
| 解耦 | `client` 不再依赖 `bus::ChatMessage` |
|
||||
|
||||
## 实施步骤
|
||||
|
||||
1. 创建 `src/client/input.rs` - 从 `cli/input.rs` 合并,修改 `InputEvent::Message` 为 `String`
|
||||
2. 创建 `src/client/channel.rs` - 从 `cli/channel.rs` 直接复制
|
||||
3. 更新 `src/client/mod.rs` - 更新 import
|
||||
4. 更新 `src/lib.rs` - 删除 `pub mod cli;`
|
||||
5. 删除 `src/cli/` 目录
|
||||
|
||||
## 验证
|
||||
|
||||
- `cargo build` 通过
|
||||
- 功能保持不变
|
||||
@ -1,877 +0,0 @@
|
||||
# Phase 1: Storage 基础 实现计划
|
||||
|
||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||
|
||||
**Goal:** 创建 `src/storage/` 模块,实现 SQLite 持久化层,为后续 Session 扩展提供 Storage 基础设施。
|
||||
|
||||
**Architecture:** 使用 `sqlx` + `sqlite`,通过 `SqlitePool` 实现异步连接池,所有 Storage 操作均为 async,在 `Storage` 内部管理连接池的生命周期。
|
||||
|
||||
**Tech Stack:** `sqlx` (sqlite, tokio), `serde`, `chrono` (时间戳), `tokio::time::sleep` (重试退避)
|
||||
|
||||
---
|
||||
|
||||
## Task 1: 添加依赖
|
||||
|
||||
**Files:**
|
||||
- Modify: `Cargo.toml:36` (在 `[dependencies]` 末尾添加)
|
||||
|
||||
**Step 1: 添加 sqlx + sqlite 依赖**
|
||||
|
||||
在 `Cargo.toml` 末尾添加:
|
||||
|
||||
```toml
|
||||
sqlx = { version = "0.8", features = ["sqlite", "tokio", "macros", "chrono"] }
|
||||
```
|
||||
|
||||
**Step 2: 运行 cargo check 验证依赖**
|
||||
|
||||
Run: `cargo check 2>&1`
|
||||
Expected: 无报错,依赖解析成功
|
||||
|
||||
**Step 3: Commit**
|
||||
|
||||
```bash
|
||||
git add Cargo.toml
|
||||
git commit -m "deps: 添加 sqlx + sqlite 依赖"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 2: 创建 Storage Error 类型
|
||||
|
||||
**Files:**
|
||||
- Create: `src/storage/error.rs`
|
||||
|
||||
**Step 1: 编写 StorageError 枚举和测试**
|
||||
|
||||
```rust
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum StorageError {
|
||||
#[error("session not found: {0}")]
|
||||
NotFound(String),
|
||||
|
||||
#[error("session already exists: {0}")]
|
||||
AlreadyExists(String),
|
||||
|
||||
#[error("database error: {0}")]
|
||||
Database(#[from] sqlx::Error),
|
||||
|
||||
#[error("serialization error: {0}")]
|
||||
Serialization(String),
|
||||
}
|
||||
```
|
||||
|
||||
**Step 2: 验证编译**
|
||||
|
||||
Run: `cargo build --lib 2>&1 | head -30`
|
||||
Expected: 报错 "cannot find module `storage`"(因为模块未创建),这是预期的
|
||||
|
||||
**Step 3: Commit**
|
||||
|
||||
```bash
|
||||
git add src/storage/error.rs
|
||||
git commit -m "feat(storage): 添加 StorageError 类型"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 3: 创建 Storage 模块骨架
|
||||
|
||||
**Files:**
|
||||
- Create: `src/storage/mod.rs`
|
||||
- Create: `src/storage/session.rs`
|
||||
- Create: `src/storage/message.rs`
|
||||
|
||||
**Step 1: 创建 `src/storage/mod.rs`**
|
||||
|
||||
```rust
|
||||
pub mod error;
|
||||
pub mod session;
|
||||
pub mod message;
|
||||
|
||||
pub use error::StorageError;
|
||||
```
|
||||
|
||||
**Step 2: 创建 `src/storage/session.rs`(空壳)**
|
||||
|
||||
```rust
|
||||
// Session CRUD 操作占位符
|
||||
```
|
||||
|
||||
**Step 3: 创建 `src/storage/message.rs`(空壳)**
|
||||
|
||||
```rust
|
||||
// Message CRUD 操作占位符
|
||||
```
|
||||
|
||||
**Step 4: 在 `src/lib.rs` 中添加 storage 模块**
|
||||
|
||||
在 `src/lib.rs` 末尾添加:
|
||||
|
||||
```rust
|
||||
pub mod storage;
|
||||
```
|
||||
|
||||
**Step 5: 验证编译**
|
||||
|
||||
Run: `cargo build --lib 2>&1`
|
||||
Expected: 编译成功(空壳模块)
|
||||
|
||||
**Step 6: Commit**
|
||||
|
||||
```bash
|
||||
git add src/storage/ src/lib.rs
|
||||
git commit -m "feat(storage): 创建 storage 模块骨架"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 4: 实现 Storage 主结构
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/storage/mod.rs`
|
||||
|
||||
**Step 1: 编写 Storage 结构和初始化逻辑**
|
||||
|
||||
```rust
|
||||
use sqlx::{Pool, Sqlite, SqlitePool};
|
||||
use std::path::Path;
|
||||
|
||||
pub struct Storage {
|
||||
pool: Pool<Sqlite>,
|
||||
}
|
||||
|
||||
impl Storage {
|
||||
/// 打开或创建数据库
|
||||
pub async fn new(db_path: &Path) -> Result<Self, StorageError> {
|
||||
let database_url = format!("sqlite:{}?mode=rwc", db_path.display());
|
||||
let pool = SqlitePool::connect(&database_url).await?;
|
||||
|
||||
let storage = Self { pool };
|
||||
storage.init_schema().await?;
|
||||
Ok(storage)
|
||||
}
|
||||
|
||||
/// 初始化数据库 schema
|
||||
async fn init_schema(&self) -> Result<(), StorageError> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
channel TEXT NOT NULL,
|
||||
chat_id TEXT NOT NULL,
|
||||
dialog_id TEXT NOT NULL,
|
||||
title TEXT NOT NULL DEFAULT '新对话',
|
||||
created_at INTEGER NOT NULL,
|
||||
last_active_at INTEGER NOT NULL,
|
||||
message_count INTEGER DEFAULT 0,
|
||||
routing_info TEXT,
|
||||
deleted_at INTEGER,
|
||||
UNIQUE(channel, chat_id, dialog_id)
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_chat
|
||||
ON sessions(channel, chat_id, deleted_at)
|
||||
"#,
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id TEXT PRIMARY KEY,
|
||||
session_id TEXT NOT NULL,
|
||||
seq INTEGER NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
media_refs TEXT,
|
||||
tool_call_id TEXT,
|
||||
tool_name TEXT,
|
||||
tool_calls TEXT,
|
||||
created_at INTEGER NOT NULL,
|
||||
FOREIGN KEY (session_id) REFERENCES sessions(id) ON DELETE CASCADE
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_session_seq
|
||||
ON messages(session_id, seq)
|
||||
"#,
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 获取连接池引用(供内部 CRUD 使用)
|
||||
pub(crate) fn pool(&self) -> &Pool<Sqlite> {
|
||||
&self.pool
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Step 2: 验证编译**
|
||||
|
||||
Run: `cargo build --lib 2>&1`
|
||||
Expected: 编译成功
|
||||
|
||||
**Step 3: Commit**
|
||||
|
||||
```bash
|
||||
git add src/storage/mod.rs
|
||||
git commit -m "feat(storage): 实现 Storage 主结构和初始化"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 5: 定义 SessionMeta 和 MessageMeta 数据结构
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/storage/session.rs`
|
||||
- Modify: `src/storage/message.rs`
|
||||
|
||||
**Step 1: 在 `session.rs` 中定义 SessionMeta**
|
||||
|
||||
```rust
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SessionMeta {
|
||||
pub id: String,
|
||||
pub channel: String,
|
||||
pub chat_id: String,
|
||||
pub dialog_id: String,
|
||||
pub title: String,
|
||||
pub created_at: i64,
|
||||
pub last_active_at: i64,
|
||||
pub message_count: i64,
|
||||
pub routing_info: Option<String>,
|
||||
pub deleted_at: Option<i64>,
|
||||
}
|
||||
```
|
||||
|
||||
**Step 2: 在 `message.rs` 中定义 MessageMeta**
|
||||
|
||||
```rust
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MessageMeta {
|
||||
pub id: String,
|
||||
pub session_id: String,
|
||||
pub seq: i64,
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
pub media_refs: Option<String>,
|
||||
pub tool_call_id: Option<String>,
|
||||
pub tool_name: Option<String>,
|
||||
pub tool_calls: Option<String>,
|
||||
pub created_at: i64,
|
||||
}
|
||||
```
|
||||
|
||||
**Step 3: 验证编译**
|
||||
|
||||
Run: `cargo build --lib 2>&1`
|
||||
Expected: 编译成功
|
||||
|
||||
**Step 4: Commit**
|
||||
|
||||
```bash
|
||||
git add src/storage/session.rs src/storage/message.rs
|
||||
git commit -m "feat(storage): 定义 SessionMeta 和 MessageMeta 数据结构"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 6: 实现 Session CRUD 操作
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/storage/session.rs`
|
||||
|
||||
**Step 1: 编写 upsert_session**
|
||||
|
||||
```rust
|
||||
use sqlx::Row;
|
||||
use super::SessionMeta;
|
||||
|
||||
impl Storage {
|
||||
pub async fn upsert_session(&self, meta: &SessionMeta) -> Result<(), StorageError> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO sessions (id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
title = excluded.title,
|
||||
last_active_at = excluded.last_active_at,
|
||||
message_count = excluded.message_count,
|
||||
routing_info = excluded.routing_info,
|
||||
deleted_at = excluded.deleted_at
|
||||
"#,
|
||||
)
|
||||
.bind(&meta.id)
|
||||
.bind(&meta.channel)
|
||||
.bind(&meta.chat_id)
|
||||
.bind(&meta.dialog_id)
|
||||
.bind(&meta.title)
|
||||
.bind(meta.created_at)
|
||||
.bind(meta.last_active_at)
|
||||
.bind(meta.message_count)
|
||||
.bind(&meta.routing_info)
|
||||
.bind(meta.deleted_at)
|
||||
.execute(self.pool())
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_session(&self, id: &str) -> Result<SessionMeta, StorageError> {
|
||||
let row = sqlx::query(
|
||||
r#"
|
||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at
|
||||
FROM sessions WHERE id = ? AND deleted_at IS NULL
|
||||
"#,
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_optional(self.pool())
|
||||
.await?
|
||||
.ok_or_else(|| StorageError::NotFound(id.to_string()))?;
|
||||
|
||||
Ok(SessionMeta {
|
||||
id: row.get("id"),
|
||||
channel: row.get("channel"),
|
||||
chat_id: row.get("chat_id"),
|
||||
dialog_id: row.get("dialog_id"),
|
||||
title: row.get("title"),
|
||||
created_at: row.get("created_at"),
|
||||
last_active_at: row.get("last_active_at"),
|
||||
message_count: row.get("message_count"),
|
||||
routing_info: row.get("routing_info"),
|
||||
deleted_at: row.get("deleted_at"),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn list_sessions(
|
||||
&self,
|
||||
channel: &str,
|
||||
chat_id: &str,
|
||||
limit: i64,
|
||||
) -> Result<Vec<SessionMeta>, StorageError> {
|
||||
let rows = sqlx::query(
|
||||
r#"
|
||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at
|
||||
FROM sessions
|
||||
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL
|
||||
ORDER BY last_active_at DESC
|
||||
LIMIT ?
|
||||
"#,
|
||||
)
|
||||
.bind(channel)
|
||||
.bind(chat_id)
|
||||
.bind(limit)
|
||||
.fetch_all(self.pool())
|
||||
.await?;
|
||||
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.map(|row| SessionMeta {
|
||||
id: row.get("id"),
|
||||
channel: row.get("channel"),
|
||||
chat_id: row.get("chat_id"),
|
||||
dialog_id: row.get("dialog_id"),
|
||||
title: row.get("title"),
|
||||
created_at: row.get("created_at"),
|
||||
last_active_at: row.get("last_active_at"),
|
||||
message_count: row.get("message_count"),
|
||||
routing_info: row.get("routing_info"),
|
||||
deleted_at: row.get("deleted_at"),
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub async fn touch_session(
|
||||
&self,
|
||||
id: &str,
|
||||
message_count: i64,
|
||||
last_active_at: i64,
|
||||
) -> Result<(), StorageError> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE sessions SET message_count = ?, last_active_at = ?
|
||||
WHERE id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(message_count)
|
||||
.bind(last_active_at)
|
||||
.bind(id)
|
||||
.execute(self.pool())
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn soft_delete_session(&self, id: &str) -> Result<(), StorageError> {
|
||||
let now = chrono::Utc::now().timestamp_millis();
|
||||
sqlx::query(
|
||||
r#"UPDATE sessions SET deleted_at = ? WHERE id = ?"#,
|
||||
)
|
||||
.bind(now)
|
||||
.bind(id)
|
||||
.execute(self.pool())
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 查找 channel:chat_id 下最近活跃且未过期的 session
|
||||
pub async fn find_active_session(
|
||||
&self,
|
||||
channel: &str,
|
||||
chat_id: &str,
|
||||
ttl_millis: i64,
|
||||
) -> Result<Option<SessionMeta>, StorageError> {
|
||||
let cutoff = chrono::Utc::now().timestamp_millis() - ttl_millis;
|
||||
let row = sqlx::query(
|
||||
r#"
|
||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at
|
||||
FROM sessions
|
||||
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL AND last_active_at > ?
|
||||
ORDER BY last_active_at DESC
|
||||
LIMIT 1
|
||||
"#,
|
||||
)
|
||||
.bind(channel)
|
||||
.bind(chat_id)
|
||||
.bind(cutoff)
|
||||
.fetch_optional(self.pool())
|
||||
.await?;
|
||||
|
||||
match row {
|
||||
Some(row) => Ok(Some(SessionMeta {
|
||||
id: row.get("id"),
|
||||
channel: row.get("channel"),
|
||||
chat_id: row.get("chat_id"),
|
||||
dialog_id: row.get("dialog_id"),
|
||||
title: row.get("title"),
|
||||
created_at: row.get("created_at"),
|
||||
last_active_at: row.get("last_active_at"),
|
||||
message_count: row.get("message_count"),
|
||||
routing_info: row.get("routing_info"),
|
||||
deleted_at: row.get("deleted_at"),
|
||||
})),
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> 注意:`Storage` 的 CRUD 方法需要能访问 `pool()`,但目前 `pool()` 是 `pub(crate)`。在 `mod.rs` 中为 `session.rs` 实现 `Storage` 的 CRUD,所以同模块内可访问。
|
||||
|
||||
**Step 2: 验证编译**
|
||||
|
||||
Run: `cargo build --lib 2>&1`
|
||||
Expected: 编译成功
|
||||
|
||||
**Step 3: Commit**
|
||||
|
||||
```bash
|
||||
git add src/storage/session.rs
|
||||
git commit -m "feat(storage): 实现 Session CRUD 操作"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 7: 实现 Message CRUD 操作
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/storage/message.rs`
|
||||
|
||||
**Step 1: 编写 Message CRUD**
|
||||
|
||||
```rust
|
||||
use sqlx::Row;
|
||||
use super::MessageMeta;
|
||||
|
||||
impl Storage {
|
||||
pub async fn append_message(&self, session_id: &str, msg: &MessageMeta) -> Result<i64, StorageError> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO messages (id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"#,
|
||||
)
|
||||
.bind(&msg.id)
|
||||
.bind(session_id)
|
||||
.bind(msg.seq)
|
||||
.bind(&msg.role)
|
||||
.bind(&msg.content)
|
||||
.bind(&msg.media_refs)
|
||||
.bind(&msg.tool_call_id)
|
||||
.bind(&msg.tool_name)
|
||||
.bind(&msg.tool_calls)
|
||||
.bind(msg.created_at)
|
||||
.execute(self.pool())
|
||||
.await?;
|
||||
|
||||
Ok(msg.seq)
|
||||
}
|
||||
|
||||
pub async fn append_messages(
|
||||
&self,
|
||||
session_id: &str,
|
||||
msgs: &[MessageMeta],
|
||||
) -> Result<Vec<i64>, StorageError> {
|
||||
let mut seqs = Vec::with_capacity(msgs.len());
|
||||
for msg in msgs {
|
||||
let seq = self.append_message(session_id, msg).await?;
|
||||
seqs.push(seq);
|
||||
}
|
||||
Ok(seqs)
|
||||
}
|
||||
|
||||
pub async fn load_messages(
|
||||
&self,
|
||||
session_id: &str,
|
||||
from_seq: i64,
|
||||
) -> Result<Vec<MessageMeta>, StorageError> {
|
||||
let rows = sqlx::query(
|
||||
r#"
|
||||
SELECT id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, created_at
|
||||
FROM messages
|
||||
WHERE session_id = ? AND seq >= ?
|
||||
ORDER BY seq ASC
|
||||
"#,
|
||||
)
|
||||
.bind(session_id)
|
||||
.bind(from_seq)
|
||||
.fetch_all(self.pool())
|
||||
.await?;
|
||||
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.map(|row| MessageMeta {
|
||||
id: row.get("id"),
|
||||
session_id: row.get("session_id"),
|
||||
seq: row.get("seq"),
|
||||
role: row.get("role"),
|
||||
content: row.get("content"),
|
||||
media_refs: row.get("media_refs"),
|
||||
tool_call_id: row.get("tool_call_id"),
|
||||
tool_name: row.get("tool_name"),
|
||||
tool_calls: row.get("tool_calls"),
|
||||
created_at: row.get("created_at"),
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub async fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> {
|
||||
sqlx::query(r#"DELETE FROM messages WHERE session_id = ?"#)
|
||||
.bind(session_id)
|
||||
.execute(self.pool())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> 注意:同样在 `mod.rs` 中实现,这样 `Storage` 的方法对 `message.rs` 中的 impl 可见。
|
||||
|
||||
**Step 2: 验证编译**
|
||||
|
||||
Run: `cargo build --lib 2>&1`
|
||||
Expected: 编译成功
|
||||
|
||||
**Step 3: Commit**
|
||||
|
||||
```bash
|
||||
git add src/storage/message.rs
|
||||
git commit -m "feat(storage): 实现 Message CRUD 操作"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 8: 实现写入重试逻辑
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/storage/mod.rs`
|
||||
|
||||
**Step 1: 在 Storage 中添加带重试的 append_message**
|
||||
|
||||
在 `mod.rs` 的 `Storage` impl 块中添加:
|
||||
|
||||
```rust
|
||||
use tokio::time::{sleep, Duration};
|
||||
|
||||
impl Storage {
|
||||
/// 追加消息,带重试逻辑
|
||||
/// 重试 3 次(100/200/300ms 退避),仍失败返回错误
|
||||
pub async fn append_message_with_retry(
|
||||
&self,
|
||||
session_id: &str,
|
||||
msg: &MessageMeta,
|
||||
) -> Result<i64, StorageError> {
|
||||
let delays = [100, 200, 300];
|
||||
|
||||
for (i, delay) in delays.iter().enumerate() {
|
||||
match self.append_message(session_id, msg).await {
|
||||
Ok(seq) => return Ok(seq),
|
||||
Err(e) if i < delays.len() - 1 => {
|
||||
sleep(Duration::from_millis(*delay)).await;
|
||||
tracing::warn!("Storage write failed, retrying: {}", e);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Storage write failed after retries: {}", e);
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> 注意:需要 `use sqlx::Row;` 在 `mod.rs` 中。
|
||||
|
||||
**Step 2: 验证编译**
|
||||
|
||||
Run: `cargo build --lib 2>&1`
|
||||
Expected: 编译成功
|
||||
|
||||
**Step 3: Commit**
|
||||
|
||||
```bash
|
||||
git add src/storage/mod.rs
|
||||
git commit -m "feat(storage): 添加写入重试逻辑"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 9: 编写 Storage 单元测试
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/storage/mod.rs`(添加测试模块)
|
||||
|
||||
**Step 1: 编写 Storage 集成测试**
|
||||
|
||||
在 `src/storage/mod.rs` 末尾添加:
|
||||
|
||||
```rust
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::tempdir;
|
||||
use std::path::Path;
|
||||
|
||||
async fn create_test_storage() -> (Storage, impl Fn()) {
|
||||
let dir = tempdir().unwrap();
|
||||
let db_path = dir.path().join("test.db");
|
||||
let storage = Storage::new(&db_path).await.unwrap();
|
||||
|
||||
let cleanup = || {
|
||||
drop(dir);
|
||||
};
|
||||
|
||||
(storage, cleanup)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_upsert_and_get_session() {
|
||||
let (storage, cleanup) = create_test_storage().await;
|
||||
defer { cleanup(); }
|
||||
|
||||
let meta = SessionMeta {
|
||||
id: "cli_chat:sid123:dialog1".to_string(),
|
||||
channel: "cli_chat".to_string(),
|
||||
chat_id: "sid123".to_string(),
|
||||
dialog_id: "dialog1".to_string(),
|
||||
title: "测试会话".to_string(),
|
||||
created_at: 1000,
|
||||
last_active_at: 1000,
|
||||
message_count: 0,
|
||||
routing_info: Some(r#"{"type":"cli"}"#.to_string()),
|
||||
deleted_at: None,
|
||||
};
|
||||
|
||||
storage.upsert_session(&meta).await.unwrap();
|
||||
|
||||
let loaded = storage.get_session(&meta.id).await.unwrap();
|
||||
assert_eq!(loaded.title, "测试会话");
|
||||
assert_eq!(loaded.channel, "cli_chat");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_nonexistent_session() {
|
||||
let (storage, cleanup) = create_test_storage().await;
|
||||
defer { cleanup(); }
|
||||
|
||||
let result = storage.get_session("nonexistent").await;
|
||||
assert!(result.is_err());
|
||||
matches!(result.unwrap_err(), StorageError::NotFound(_));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_sessions() {
|
||||
let (storage, cleanup) = create_test_storage().await;
|
||||
defer { cleanup(); }
|
||||
|
||||
for i in 0..5 {
|
||||
let meta = SessionMeta {
|
||||
id: format!("cli_chat:sid123:dialog{}", i),
|
||||
channel: "cli_chat".to_string(),
|
||||
chat_id: "sid123".to_string(),
|
||||
dialog_id: format!("dialog{}", i),
|
||||
title: format!("会话{}", i),
|
||||
created_at: i as i64 * 1000,
|
||||
last_active_at: i as i64 * 1000,
|
||||
message_count: i,
|
||||
routing_info: None,
|
||||
deleted_at: None,
|
||||
};
|
||||
storage.upsert_session(&meta).await.unwrap();
|
||||
}
|
||||
|
||||
let sessions = storage.list_sessions("cli_chat", "sid123", 10).await.unwrap();
|
||||
assert_eq!(sessions.len(), 5);
|
||||
// 按 last_active_at DESC 排序
|
||||
assert_eq!(sessions[0].dialog_id, "dialog4");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_soft_delete() {
|
||||
let (storage, cleanup) = create_test_storage().await;
|
||||
defer { cleanup(); }
|
||||
|
||||
let meta = SessionMeta {
|
||||
id: "cli_chat:sid123:dialog1".to_string(),
|
||||
channel: "cli_chat".to_string(),
|
||||
chat_id: "sid123".to_string(),
|
||||
dialog_id: "dialog1".to_string(),
|
||||
title: "测试".to_string(),
|
||||
created_at: 1000,
|
||||
last_active_at: 1000,
|
||||
message_count: 0,
|
||||
routing_info: None,
|
||||
deleted_at: None,
|
||||
};
|
||||
|
||||
storage.upsert_session(&meta).await.unwrap();
|
||||
storage.soft_delete_session(&meta.id).await.unwrap();
|
||||
|
||||
let result = storage.get_session(&meta.id).await;
|
||||
assert!(result.is_err());
|
||||
matches!(result.unwrap_err(), StorageError::NotFound(_));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_append_and_load_messages() {
|
||||
let (storage, cleanup) = create_test_storage().await;
|
||||
defer { cleanup(); }
|
||||
|
||||
let session_meta = SessionMeta {
|
||||
id: "cli_chat:sid123:dialog1".to_string(),
|
||||
channel: "cli_chat".to_string(),
|
||||
chat_id: "sid123".to_string(),
|
||||
dialog_id: "dialog1".to_string(),
|
||||
title: "测试".to_string(),
|
||||
created_at: 1000,
|
||||
last_active_at: 1000,
|
||||
message_count: 0,
|
||||
routing_info: None,
|
||||
deleted_at: None,
|
||||
};
|
||||
storage.upsert_session(&session_meta).await.unwrap();
|
||||
|
||||
let msg = MessageMeta {
|
||||
id: "msg1".to_string(),
|
||||
session_id: session_meta.id.clone(),
|
||||
seq: 1,
|
||||
role: "user".to_string(),
|
||||
content: "你好".to_string(),
|
||||
media_refs: None,
|
||||
tool_call_id: None,
|
||||
tool_name: None,
|
||||
tool_calls: None,
|
||||
created_at: 1000,
|
||||
};
|
||||
|
||||
let seq = storage.append_message(&session_meta.id, &msg).await.unwrap();
|
||||
assert_eq!(seq, 1);
|
||||
|
||||
let loaded = storage.load_messages(&session_meta.id, 0).await.unwrap();
|
||||
assert_eq!(loaded.len(), 1);
|
||||
assert_eq!(loaded[0].content, "你好");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_touch_session() {
|
||||
let (storage, cleanup) = create_test_storage().await;
|
||||
defer { cleanup(); }
|
||||
|
||||
let meta = SessionMeta {
|
||||
id: "cli_chat:sid123:dialog1".to_string(),
|
||||
channel: "cli_chat".to_string(),
|
||||
chat_id: "sid123".to_string(),
|
||||
dialog_id: "dialog1".to_string(),
|
||||
title: "测试".to_string(),
|
||||
created_at: 1000,
|
||||
last_active_at: 1000,
|
||||
message_count: 0,
|
||||
routing_info: None,
|
||||
deleted_at: None,
|
||||
};
|
||||
storage.upsert_session(&meta).await.unwrap();
|
||||
|
||||
storage.touch_session(&meta.id, 5, 2000).await.unwrap();
|
||||
|
||||
let loaded = storage.get_session(&meta.id).await.unwrap();
|
||||
assert_eq!(loaded.message_count, 5);
|
||||
assert_eq!(loaded.last_active_at, 2000);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> 需要在 `Cargo.toml` 中添加 `tempfile` 依赖(已存在)。`defer` 宏可自己实现一个简单的:`fn defer<F: FnOnce()>(f: F) { f() }`
|
||||
|
||||
**Step 2: 运行测试**
|
||||
|
||||
Run: `cargo test storage::tests --lib 2>&1`
|
||||
Expected: 所有 7 个测试 PASS
|
||||
|
||||
**Step 3: Commit**
|
||||
|
||||
```bash
|
||||
git add src/storage/mod.rs
|
||||
git commit -m "test(storage): 编写 Storage 单元测试"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 汇总
|
||||
|
||||
| Task | 改动文件 | 关键交付物 |
|
||||
|------|----------|-----------|
|
||||
| 1 | `Cargo.toml` | sqlx 依赖 |
|
||||
| 2 | `src/storage/error.rs` | StorageError |
|
||||
| 3 | `src/storage/{mod.rs,session.rs,message.rs}`, `src/lib.rs` | 模块骨架 |
|
||||
| 4 | `src/storage/mod.rs` | Storage 结构 + init_schema |
|
||||
| 5 | `src/storage/session.rs`, `message.rs` | SessionMeta, MessageMeta |
|
||||
| 6 | `src/storage/session.rs` | Session CRUD |
|
||||
| 7 | `src/storage/message.rs` | Message CRUD |
|
||||
| 8 | `src/storage/mod.rs` | append_message_with_retry |
|
||||
| 9 | `src/storage/mod.rs` | 7 个单元测试 |
|
||||
|
||||
**Phase 1 完成后:** Storage 模块可独立使用,具备完整的持久化能力,可安全地集成到 Session 和 SessionManager 中。
|
||||
@ -1,278 +0,0 @@
|
||||
# Session 持久化设计方案
|
||||
|
||||
## 概述
|
||||
|
||||
为 PicoBot 添加 SQLite 持久化层,实现 Session 数据的持久化、完整 Dialog 生命周期管理、消息实时落盘、以及基于 TTL 的自动内存清理。
|
||||
|
||||
## 核心概念
|
||||
|
||||
```
|
||||
UnifiedSessionId = {channel}:{chat_id}:{dialog_id}
|
||||
Session = Dialog(两者等价,不再分层)
|
||||
```
|
||||
|
||||
每个 Session 独立管理自己的消息历史、LLM 配置和路由信息。
|
||||
|
||||
## 数据库 Schema
|
||||
|
||||
### sessions 表
|
||||
|
||||
```sql
|
||||
CREATE TABLE sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
channel TEXT NOT NULL,
|
||||
chat_id TEXT NOT NULL,
|
||||
dialog_id TEXT NOT NULL,
|
||||
title TEXT NOT NULL DEFAULT '新对话',
|
||||
created_at INTEGER NOT NULL,
|
||||
last_active_at INTEGER NOT NULL,
|
||||
message_count INTEGER DEFAULT 0,
|
||||
routing_info TEXT,
|
||||
deleted_at INTEGER,
|
||||
UNIQUE(channel, chat_id, dialog_id)
|
||||
);
|
||||
CREATE INDEX idx_sessions_chat ON sessions(channel, chat_id, deleted_at);
|
||||
```
|
||||
|
||||
### messages 表
|
||||
|
||||
```sql
|
||||
CREATE TABLE messages (
|
||||
id TEXT PRIMARY KEY,
|
||||
session_id TEXT NOT NULL,
|
||||
seq INTEGER NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
media_refs TEXT,
|
||||
tool_call_id TEXT,
|
||||
tool_name TEXT,
|
||||
tool_calls TEXT,
|
||||
created_at INTEGER NOT NULL,
|
||||
FOREIGN KEY (session_id) REFERENCES sessions(id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX idx_messages_session_seq ON messages(session_id, seq);
|
||||
```
|
||||
|
||||
## Storage API
|
||||
|
||||
### Session 操作
|
||||
|
||||
| 方法 | 说明 |
|
||||
|------|------|
|
||||
| `new(db_path) -> Storage` | 打开/创建数据库 |
|
||||
| `upsert_session(meta) -> Result<(), StorageError>` | 插入或更新 session 元数据 |
|
||||
| `get_session(id) -> Result<SessionMeta, StorageError>` | 获取单个 session |
|
||||
| `list_sessions(channel, chat_id, limit) -> Result<Vec<SessionMeta>>` | 最近 N 条 |
|
||||
| `touch_session(id, message_count, last_active_at)` | 更新计数和最后活跃时间 |
|
||||
| `soft_delete_session(id) -> Result<(), StorageError>` | 软删除 |
|
||||
|
||||
### Message 操作
|
||||
|
||||
| 方法 | 说明 |
|
||||
|------|------|
|
||||
| `append_message(session_id, msg) -> Result<i64, StorageError>` | 追加单条消息,返回 seq |
|
||||
| `append_messages(session_id, msgs) -> Result<Vec<i64>, StorageError>` | 批量追加 |
|
||||
| `load_messages(session_id, from_seq) -> Result<Vec<MessageMeta>>` | 从指定 seq 加载 |
|
||||
| `clear_messages(session_id) -> Result<(), StorageError>` | 清除消息(保留 session) |
|
||||
|
||||
### 写入失败处理
|
||||
|
||||
重试 3 次(100/200/300ms 退避),仍失败则发送系统通知告警。
|
||||
|
||||
## Session 结构
|
||||
|
||||
```rust
|
||||
pub struct Session {
|
||||
pub id: UnifiedSessionId,
|
||||
pub title: String,
|
||||
pub created_at: i64,
|
||||
pub last_active_at: i64,
|
||||
pub message_count: i64, // 用户消息计数
|
||||
pub total_message_count: i64, // 含系统消息
|
||||
|
||||
messages: Vec<ChatMessage>, // 内存消息历史
|
||||
seq_counter: i64, // 下一个消息的 seq
|
||||
|
||||
provider_config: LLMProviderConfig,
|
||||
provider: Arc<dyn LLMProvider>,
|
||||
tools: Arc<ToolRegistry>,
|
||||
compressor: ContextCompressor,
|
||||
user_tx: mpsc::Sender<WsOutbound>,
|
||||
storage: Arc<Storage>, // 持久化 sink
|
||||
routing_info: String, // JSON 路由信息
|
||||
}
|
||||
```
|
||||
|
||||
### 初始化流程
|
||||
|
||||
```
|
||||
new() 或 from_storage()
|
||||
↓
|
||||
注入 storage 引用
|
||||
↓
|
||||
创建 provider, tools, compressor
|
||||
↓
|
||||
从 Storage 加载 messages(from_seq = 0)
|
||||
↓
|
||||
设置 seq_counter = messages.len() + 1
|
||||
↓
|
||||
返回 Session 实例
|
||||
```
|
||||
|
||||
## handle_message 流程
|
||||
|
||||
```
|
||||
handle_message(channel, chat_id, sender_id, content, media)
|
||||
│
|
||||
├── 1. 确定 dialog_id
|
||||
│ │
|
||||
│ ├── 显式传入 dialog_id → 使用
|
||||
│ └── 无 dialog_id
|
||||
│ ├── 查找 channel:chat_id 下最近活跃且未过期的 session
|
||||
│ ├── 找到 → 使用该 session
|
||||
│ └── 未找到 → 创建新 session(dialog_id = 新随机 ID)
|
||||
│
|
||||
├── 2. 获取或创建 Session
|
||||
│ 有 → 更新 session_timestamps
|
||||
│ 无 → 从 Storage 恢复 或 创建新 Session
|
||||
│
|
||||
├── 3. 追加用户消息并持久化
|
||||
│ seq = seq_counter; seq_counter += 1
|
||||
│ Storage.append_message()(失败重试 → 告警)
|
||||
│ messages.push(user_msg)
|
||||
│ message_count += 1
|
||||
│
|
||||
├── 4. 检查 title 自动生成
|
||||
│ message_count == 10 且 title == 默认值 → LLM 生成 → 更新 title → 写回 Storage
|
||||
│
|
||||
├── 5. 注入 skills_prompt
|
||||
│
|
||||
├── 6. 新 session 注入欢迎消息(系统消息,不计入 message_count)
|
||||
│
|
||||
├── 7. 上下文压缩(如需要)
|
||||
│
|
||||
├── 8. 调用 AgentLoop
|
||||
│
|
||||
├── 9. 持久化 Agent 响应
|
||||
│
|
||||
└── 10. 返回响应
|
||||
```
|
||||
|
||||
## Dialog 生命周期命令
|
||||
|
||||
| 命令 | 行为 |
|
||||
|------|------|
|
||||
| `/new [标题]` | 创建新 dialog(新随机 dialog_id),新建 Session |
|
||||
| `/sessions` | 列出 channel:chat_id 下最近 10 条 session(按 last_active_at 倒序) |
|
||||
| `/switch <dialog_id>` | 切换到指定 session(从 Storage 恢复或内存命中) |
|
||||
| `/rename <新标题>` | 重命名当前 session |
|
||||
| `/delete` | 软删除当前 session(内存移除 + Storage 标记 deleted_at) |
|
||||
| `/info` | 显示当前 session 信息 |
|
||||
| `/compact` | 手动触发上下文压缩 |
|
||||
|
||||
## 路由信息
|
||||
|
||||
每种 Channel 在创建 Session 时注入路由信息:
|
||||
|
||||
```rust
|
||||
// CLI
|
||||
routing_info = json!({"type": "cli", "ws_sender_id": "xxx"})
|
||||
|
||||
// Feishu
|
||||
routing_info = json!({"type": "feishu", "open_conversation_id": "oc_xxx", "tenant_key": "xxx"})
|
||||
```
|
||||
|
||||
## Title 自动生成
|
||||
|
||||
调用时机:
|
||||
1. Session 首次创建时(初始 title = "新对话")
|
||||
2. `message_count` 达到 10 且 title 仍为默认值时,自动更新
|
||||
|
||||
生成 Prompt:
|
||||
```
|
||||
给定以下对话历史,生成一个简短的会话标题(5-15 个中文字符),
|
||||
概括这个对话的核心内容或用户的主要需求。只返回一个标题,不要解释。
|
||||
|
||||
历史:
|
||||
{messages}
|
||||
```
|
||||
|
||||
## TTL 清理
|
||||
|
||||
- 内存 session 超时 → 释放内存,Storage 记录保留
|
||||
- 用户切换回该 session → 从 Storage 重新加载到内存
|
||||
- Storage 中的 session 记录通过 `deleted_at` 软删除,不会物理删除
|
||||
|
||||
## 文件结构
|
||||
|
||||
```
|
||||
src/
|
||||
├── storage/
|
||||
│ ├── mod.rs # Storage 主模块
|
||||
│ ├── session.rs # Session CRUD
|
||||
│ ├── message.rs # Message CRUD
|
||||
│ └── error.rs # StorageError
|
||||
│
|
||||
└── session/
|
||||
├── mod.rs # 导出 Session, SessionManager
|
||||
├── session.rs # Session, SessionManager 实现
|
||||
├── session_id.rs # UnifiedSessionId
|
||||
├── commands.rs # SessionCommand
|
||||
├── events.rs # SessionEvent, DialogInfo
|
||||
└── error.rs # SessionError
|
||||
```
|
||||
|
||||
## 实现顺序
|
||||
|
||||
### Phase 1: Storage 基础
|
||||
1. 添加 `sqlx` + `sqlite` 依赖
|
||||
2. 实现 `Storage` 结构(连接池、初始化)
|
||||
3. Session CRUD + Message CRUD
|
||||
4. 写入重试逻辑
|
||||
5. 单元测试
|
||||
|
||||
### Phase 2: Session 扩展
|
||||
1. 扩展 `Session` 结构(添加 storage、routing_info、计数字段、seq_counter)
|
||||
2. `from_storage()` 恢复逻辑
|
||||
3. `add_message` 持久化集成
|
||||
4. `send_system_notification` 接口
|
||||
5. Title 自动生成
|
||||
|
||||
### Phase 3: SessionManager 完善
|
||||
1. 注入 `Arc<Storage>`
|
||||
2. 实现 `list_dialogs()`
|
||||
3. 实现 `switch_dialog()`
|
||||
4. 实现 `delete_dialog()` / `rename_dialog()`
|
||||
5. 后台 TTL 清理任务
|
||||
6. 集成测试
|
||||
|
||||
### Phase 4: 斜杠命令
|
||||
1. 实现 `/sessions`
|
||||
2. 实现 `/switch`
|
||||
3. 实现 `/rename`
|
||||
4. 实现 `/delete`
|
||||
5. 端到端测试
|
||||
|
||||
## 配置项
|
||||
|
||||
```json
|
||||
{
|
||||
"session": {
|
||||
"ttl_hours": 24,
|
||||
"cleanup_interval_minutes": 60,
|
||||
"auto_title_after_n_messages": 10,
|
||||
"storage_retry_delays_ms": [100, 200, 300]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 与现有代码的冲突点
|
||||
|
||||
| 冲突 | 处理方式 |
|
||||
|------|----------|
|
||||
| `DialogInfo` 有 `archived_at` | 删除该字段,改用 `deleted_at` |
|
||||
| `SessionCommand::ArchiveDialog` | 删除 |
|
||||
| `/new` 现有行为 | 改为创建新 session(新 dialog_id) |
|
||||
| 现有 `Session` 无 storage/routing_info | 扩展结构,新增 `from_storage()` |
|
||||
| `SessionManager` 需注入 `Arc<Storage>` | 扩展构造方法 |
|
||||
| stub 方法 | 实现 |
|
||||
@ -1,226 +0,0 @@
|
||||
# PicoBot Memory System Design
|
||||
|
||||
Date: 2026-05-07
|
||||
|
||||
## 1. Overview
|
||||
|
||||
Introduce a memory system that allows PicoBot agents to remember user preferences, project context, facts, and conversation history across sessions. The memory system is **unified with the existing context compression pipeline**: compression automatically produces `timeline` memory entries and advances a `last_consolidated_at` pointer to avoid redundant reprocessing.
|
||||
|
||||
### Design Principles
|
||||
|
||||
- **Compression is memory** (inspired by nanobot): when old messages are compressed, the summary is persisted — not discarded
|
||||
- **FTS5 only** (no vector embeddings): keyword search via SQLite FTS5, sufficient for current scale
|
||||
- **Extend existing infrastructure**: reuse `Storage` connection pool, `ContextCompressor`, `SystemPromptBuilder`
|
||||
- **YAGNI**: no knowledge graph, no response cache, no namespace isolation, no audit trail
|
||||
|
||||
## 2. Core Architecture
|
||||
|
||||
```
|
||||
ContextCompressor (existing) MemoryManager (new)
|
||||
│ │
|
||||
│ compress_if_needed() │ store / recall / forget
|
||||
│ ├─ LLM summary → inject │
|
||||
│ └─ store(timeline entry) ──────┘
|
||||
│ └─ advance last_consolidated_at
|
||||
│
|
||||
SystemPromptBuilder ── recall(knowledge, limit=5) ──→ inject into system prompt
|
||||
AgentLoop ── after_turn ──→ memory_store / memory_recall / memory_forget tools
|
||||
```
|
||||
|
||||
## 3. Memory Categories
|
||||
|
||||
| Category | Purpose | Written By | Retrieved By |
|
||||
|----------|---------|-----------|--------------|
|
||||
| `knowledge` | Long-term facts, preferences, patterns, insights | Agent via `memory_store` tool | FTS5 → injected into system prompt every turn |
|
||||
| `timeline` | Compressed conversation summaries | ContextCompressor automatically | FTS5 + time-range queries |
|
||||
|
||||
## 4. Storage Schema
|
||||
|
||||
### New table: `memories`
|
||||
|
||||
Added to the existing `Storage` initialization in `src/storage/mod.rs`:
|
||||
|
||||
```sql
|
||||
CREATE TABLE IF NOT EXISTS memories (
|
||||
id TEXT PRIMARY KEY,
|
||||
key TEXT NOT NULL UNIQUE,
|
||||
content TEXT NOT NULL,
|
||||
category TEXT NOT NULL DEFAULT 'knowledge',
|
||||
importance REAL NOT NULL DEFAULT 0.5,
|
||||
session_id TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS memory_fts USING fts5(
|
||||
key,
|
||||
content,
|
||||
content=memories,
|
||||
content_rowid=rowid
|
||||
);
|
||||
```
|
||||
|
||||
### Modified table: `sessions`
|
||||
|
||||
```sql
|
||||
ALTER TABLE sessions ADD COLUMN last_consolidated_at INTEGER;
|
||||
```
|
||||
|
||||
## 5. Unified Compression-Memory Pipeline
|
||||
|
||||
### Trigger Conditions
|
||||
|
||||
Compression/consolidation fires when **any** of these conditions is met:
|
||||
|
||||
| Condition | Value | Rationale |
|
||||
|-----------|-------|-----------|
|
||||
| Token budget exceeds 50% threshold | `context_window / 2` | Primary trigger — context is getting full |
|
||||
| Accumulated N turns without consolidation | 3 (configurable) | Catch-up for short messages that don't hit token threshold |
|
||||
| Session idle | 10 minutes (configurable) | Important for async channels like Feishu |
|
||||
|
||||
### Flow
|
||||
|
||||
```
|
||||
compress_if_needed(history, session_id):
|
||||
1. Read last_consolidated_at from session
|
||||
→ Only compress messages after that timestamp
|
||||
2. If no messages to compress → return history unchanged
|
||||
3. FTS5 recall(user_input, limit=recall_limit, category=knowledge)
|
||||
→ Inject relevant facts into system prompt
|
||||
4. LLM summarization of old messages → [Context Summary]
|
||||
→ Inject into current conversation
|
||||
5. Store summary as timeline entry:
|
||||
key: "ctx_{session_id}_{uuid}"
|
||||
content: "[YYYY-MM-DD HH:MM] summary text..."
|
||||
category: timeline
|
||||
6. UPDATE sessions.last_consolidated_at = now()
|
||||
7. Return compressed history
|
||||
```
|
||||
|
||||
### timeline Entry Format
|
||||
|
||||
Each timeline entry follows nanobot's convention:
|
||||
```
|
||||
[2026-05-07 14:30] User asked about Rust async patterns. Discussed tokio::select!,
|
||||
semaphore-based rate limiting, and backpressure strategies. No code was written.
|
||||
```
|
||||
|
||||
This format is grep-friendly and human-readable.
|
||||
|
||||
## 6. Retrieval Strategy
|
||||
|
||||
### Automatic Retrieval (every turn)
|
||||
|
||||
`SystemPromptBuilder.build_system_prompt()` calls:
|
||||
```rust
|
||||
memory.recall(query=user_message, limit=recall_limit, category=knowledge)
|
||||
```
|
||||
|
||||
Results sorted by FTS5 BM25 score, injected as:
|
||||
```
|
||||
## Memory Context
|
||||
|
||||
- user_prefers_rust: User prefers Rust for all backend projects
|
||||
- project_picobot_stack: PicoBot uses Rust, axum, sqlx, ratatui, tokio
|
||||
- user_workflow: User prefers TDD workflow with cargo test --lib
|
||||
```
|
||||
|
||||
### Agent-Initiated Retrieval
|
||||
|
||||
Agent uses `memory_recall` tool with optional `category`, `since`, `until` parameters.
|
||||
|
||||
### Fallback
|
||||
|
||||
If FTS5 returns empty results, fallback to `LIKE '%keyword%'` on `key` and `content` columns.
|
||||
|
||||
## 7. Agent Tools
|
||||
|
||||
| Tool | Parameters | Description |
|
||||
|------|-----------|-------------|
|
||||
| `memory_store` | `key: str`, `content: str`, `category: str`, `importance?: f64` | Write or update a memory entry. Key is semantic identifier (e.g., "user_language_pref") |
|
||||
| `memory_recall` | `query: str`, `category?: str`, `since?: i64`, `until?: i64`, `limit?: usize` | Search memories by keyword and optional filters |
|
||||
| `memory_forget` | `key: str` | Delete a memory entry by key |
|
||||
|
||||
## 8. Error Handling & Degradation
|
||||
|
||||
| Scenario | Strategy |
|
||||
|----------|----------|
|
||||
| Consolidation LLM call fails | Log warning, increment failure counter, do NOT block main flow |
|
||||
| Consecutive failures >= 3 | Degrade: append raw message dump to timeline with `[RAW]` prefix, reset counter |
|
||||
| FTS5 recall returns empty | Fallback to `LIKE '%keyword%'` query |
|
||||
| `memory.enabled = false` | ContextCompressor works normally, no memory writes |
|
||||
| MemoryManager uninitialized | ContextCompressor works with feature-gated memory write path |
|
||||
|
||||
## 9. Configuration
|
||||
|
||||
```json
|
||||
{
|
||||
"memory": {
|
||||
"enabled": true,
|
||||
"consolidation_provider": "openai",
|
||||
"consolidation_model": "gpt-4o-mini",
|
||||
"recall_limit": 5,
|
||||
"consolidation_turn_threshold": 3,
|
||||
"idle_consolidation_minutes": 10,
|
||||
"timeline_retention_days": 90,
|
||||
"max_failures_before_degrade": 3
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| Key | Type | Default | Description |
|
||||
|-----|------|---------|-------------|
|
||||
| `enabled` | bool | `false` | Master switch for memory system |
|
||||
| `consolidation_provider` | string | — | Provider name for consolidation LLM calls |
|
||||
| `consolidation_model` | string | — | Model name for consolidation |
|
||||
| `recall_limit` | usize | `5` | Max knowledge entries injected into system prompt |
|
||||
| `consolidation_turn_threshold` | usize | `3` | Turns before forced consolidation |
|
||||
| `idle_consolidation_minutes` | u64 | `10` | Idle time before consolidation trigger |
|
||||
| `timeline_retention_days` | u64 | `90` | Auto-cleanup age for timeline entries |
|
||||
| `max_failures_before_degrade` | usize | `3` | Consecutive failures before raw archive fallback |
|
||||
|
||||
## 10. New Module Structure
|
||||
|
||||
```
|
||||
src/
|
||||
├── memory/
|
||||
│ ├── mod.rs # MemoryManager, MemoryConfig
|
||||
│ ├── types.rs # MemoryEntry, MemoryCategory, ConsolidationResult
|
||||
│ └── consolidation.rs # Consolidation prompt + LLM call logic
|
||||
├── storage/
|
||||
│ └── memory.rs # SQLite CRUD for memories table + FTS5
|
||||
├── tools/
|
||||
│ ├── memory_store.rs # memory_store tool
|
||||
│ ├── memory_recall.rs # memory_recall tool
|
||||
│ └── memory_forget.rs # memory_forget tool
|
||||
```
|
||||
|
||||
## 11. Integration Points (Existing Files Modified)
|
||||
|
||||
| File | Change |
|
||||
|------|--------|
|
||||
| `src/lib.rs` | Add `pub mod memory;` |
|
||||
| `src/config/mod.rs` | Add `MemoryConfig` struct and deserialization |
|
||||
| `src/storage/mod.rs` | Add `pub mod memory;`, init `memories` table and FTS5 in `init_schema()` |
|
||||
| `src/storage/session.rs` | Add `last_consolidated_at` column read/write |
|
||||
| `src/session/session.rs` | Add `last_consolidated_at: Option<i64>` field to Session |
|
||||
| `src/agent/context_compressor.rs` | Add `memory: Option<Arc<MemoryManager>>` field, write timeline on compress |
|
||||
| `src/agent/system_prompt.rs` | Add `memory_context` section via `MemoryManager::recall()` |
|
||||
| `src/agent/agent_loop.rs` | No changes (tools registered via ToolRegistry) |
|
||||
| `src/tools/mod.rs` | Register `memory_store`, `memory_recall`, `memory_forget` in `create_default_tools()` |
|
||||
| `src/gateway/mod.rs` | Initialize `MemoryManager` in `GatewayState::new()`, pass to ContextCompressor |
|
||||
|
||||
## 12. Implementation Order
|
||||
|
||||
| # | Task | Dependencies |
|
||||
|---|------|-------------|
|
||||
| 1 | Types: `MemoryEntry`, `MemoryCategory`, `ConsolidationResult` | — |
|
||||
| 2 | Config: `MemoryConfig` + deserialization | — |
|
||||
| 3 | Storage: `memories` table + FTS5 + CRUD + search | #1 |
|
||||
| 4 | `MemoryManager` API | #1, #2, #3 |
|
||||
| 5 | Session: `last_consolidated_at` field | — |
|
||||
| 6 | `ContextCompressor` memory integration | #4, #5 |
|
||||
| 7 | `SystemPromptBuilder` memory context injection | #4 |
|
||||
| 8 | Agent tools: `memory_store`, `memory_recall`, `memory_forget` | #4 |
|
||||
| 9 | `GatewayState` initialization wiring | #4, #5, #6 |
|
||||
| 10 | Unit tests | #1-#9 |
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,90 +0,0 @@
|
||||
# 启动增量恢复设计
|
||||
|
||||
## 问题
|
||||
|
||||
PicoBot 重启后,`Session::from_storage()` 全量加载 `messages` 表,恢复的 history 可能直接超出上下文窗口,首次 LLM 调用即触发压缩,浪费 token。
|
||||
|
||||
## 设计
|
||||
|
||||
### 核心思路
|
||||
|
||||
用 `last_compressed_message_at` 标记最后压缩时刻。恢复时:
|
||||
- 加载该标记之后的原始消息
|
||||
- 用该 session 的 Timeline 条目替代已压缩部分
|
||||
- `seq_counter` 统一从 SQLite 查 `MAX(seq) + 1`
|
||||
|
||||
```
|
||||
messages 表 memories(timeline)
|
||||
┌──────────────────────────┐ ┌───────────────────────────┐
|
||||
│ created_at = T1..T5 │ ← 跳过 │ session = feishu:oc:dialog │
|
||||
│ (压缩已覆盖,用Timeline替代)│ │ created_at 降序 │
|
||||
├──────────────────────────┤ ├───────────────────────────┤
|
||||
│ created_at > T6 │ ← 加载 │ 只取最近 3 条 │
|
||||
└──────────────────────────┘ └───────────────────────────┘
|
||||
```
|
||||
|
||||
### 数据变更
|
||||
|
||||
**`sessions` 表加列:**
|
||||
```sql
|
||||
last_compressed_message_at INTEGER
|
||||
```
|
||||
|
||||
**`SessionMeta` / `Session` 加字段:** `last_compressed_message_at: Option<i64>`
|
||||
|
||||
### Storage 层新增方法
|
||||
|
||||
| 方法 | SQL |
|
||||
|------|-----|
|
||||
| `get_max_message_seq(session_id)` | `SELECT MAX(seq) FROM messages WHERE session_id = ?` |
|
||||
| `load_messages_after_timestamp(session_id, after_ts)` | `WHERE created_at > ?` |
|
||||
| `load_session_timelines(session_id, limit)` | `WHERE session_id = ? AND category = 'timeline' ORDER BY created_at DESC LIMIT ?` |
|
||||
|
||||
### 压缩跟踪
|
||||
|
||||
`compress_if_needed()` 返回值改为 `CompressionResult { history, created_timelines: bool }`。
|
||||
`compress_once()` 中 LLM 摘要路径才置 `true`(Tier 2),Tier 1/3 不产生 Timeline。
|
||||
|
||||
**记录时机**(`handle_message` 正常流、溢出重试流、`/compact` 统一):
|
||||
```rust
|
||||
if result.created_timelines {
|
||||
session.last_compressed_message_at = Some(now());
|
||||
session.persist_session_meta().await;
|
||||
}
|
||||
```
|
||||
|
||||
### Session::from_storage() 恢复逻辑
|
||||
|
||||
有压缩标记时:
|
||||
1. `load_session_timelines(limit=4)` → 取 3 条给 LLM,第 4 条判"有更多"
|
||||
2. 有更多 → 插入提示 user 消息
|
||||
3. 逐条插入 Timeline 为 `[Previous Context]` user 消息
|
||||
4. `load_messages_after_timestamp(after_ts)` → 原始尾消息
|
||||
5. `repair_tool_call_chains`
|
||||
|
||||
无压缩标记 → 全量加载(现有行为)。
|
||||
|
||||
统一:`seq_counter = MAX(seq) + 1`
|
||||
|
||||
### 系统提示词
|
||||
|
||||
`Session.last_compressed_message_at` 非空时追加:
|
||||
```
|
||||
## 历史会话
|
||||
之前的对话摘要已归档。如需回顾历史上下文,使用 `timeline_recall` 工具搜索。
|
||||
```
|
||||
|
||||
## 改动清单
|
||||
|
||||
| # | 文件 | 改动 |
|
||||
|---|------|------|
|
||||
| 1 | `storage/session.rs` | `SessionMeta` 加 `last_compressed_message_at` |
|
||||
| 2 | `storage/mod.rs` | DDL migration + upsert/get_session 加列 |
|
||||
| 3 | `storage/mod.rs` | 新增 `get_max_message_seq`, `load_messages_after_timestamp` |
|
||||
| 4 | `storage/memory.rs` | 新增 `load_session_timelines` |
|
||||
| 5 | `agent/context_compressor.rs` | 返回值改为 `CompressionResult` 含 `created_timelines` |
|
||||
| 6 | `session/session.rs` | `Session` 加字段,`persist_session_meta` 加字段 |
|
||||
| 7 | `session/session.rs` | `from_storage()` 重写恢复逻辑 |
|
||||
| 8 | `session/session.rs` | `handle_message()` 压缩后记录标记 |
|
||||
| 9 | `session/session.rs` | `/compact` 命令压缩后记录标记 |
|
||||
| 10 | `session/session.rs` | `build_system_prompt()` 注入 `last_compressed_message_at` |
|
||||
@ -1,674 +0,0 @@
|
||||
# 启动增量恢复 Implementation Plan
|
||||
|
||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||
|
||||
**Goal:** PicoBot 重启后不再全量加载 messages 表,而是基于 `last_compressed_message_at` 标记增量恢复,用 Timeline 替代已压缩部分。
|
||||
|
||||
**Architecture:** 在 `sessions` 表加 `last_compressed_message_at` 列,`compress_if_needed` 返回值增加 `created_timelines` 标志,恢复时按时间戳增量加载消息并用近 3 条 Timeline 前置,`seq_counter` 统一从 SQLite 查 MAX(seq)。
|
||||
|
||||
**Tech Stack:** Rust, sqlx (SQLite), tokio
|
||||
|
||||
---
|
||||
|
||||
### Task 1: SessionMeta 和数据库 DDL 加列
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/storage/session.rs:15`
|
||||
- Modify: `src/storage/mod.rs:44-45` (DDL), `:172-180` (migration)
|
||||
- Modify: `src/storage/mod.rs:317-326` (upsert_session SQL + ON CONFLICT)
|
||||
- Modify: `src/storage/mod.rs:345-369` (get_session SELECT + struct)
|
||||
- Modify: `src/storage/mod.rs:380-406`, `:454-479`, `:564-588`, `:728`, `:754` (list_sessions 及测试 mock)
|
||||
|
||||
**Step 1: 在 `src/storage/session.rs` SessionMeta 加字段**
|
||||
|
||||
在 `last_consolidated_at: Option<i64>` 后加一行:
|
||||
```rust
|
||||
pub last_compressed_message_at: Option<i64>,
|
||||
```
|
||||
|
||||
**Step 2: DDL schema 加列**
|
||||
|
||||
在 `src/storage/mod.rs` 的 CREATE TABLE sessions 中 (line 44),`last_consolidated_at INTEGER` 后加逗号和:
|
||||
```sql
|
||||
last_compressed_message_at INTEGER
|
||||
```
|
||||
|
||||
**Step 3: migration 加列**
|
||||
|
||||
在 `src/storage/mod.rs` line 182 之后(现有 migration 的 `); .ok();` 之后),添加新 migration:
|
||||
```rust
|
||||
// Migration: add last_compressed_message_at column if not exists
|
||||
sqlx::query(
|
||||
r#"ALTER TABLE sessions ADD COLUMN last_compressed_message_at INTEGER"#,
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.ok();
|
||||
```
|
||||
|
||||
**Step 4: upsert_session SQL 加列**
|
||||
|
||||
`src/storage/mod.rs` line 317: INSERT 列列表加 `last_compressed_message_at`,VALUES 加 `?`,ON CONFLICT DO UPDATE SET 加 `last_compressed_message_at = excluded.last_compressed_message_at`。line 338 后加 `.bind(meta.last_compressed_message_at)`。
|
||||
|
||||
**Step 5: get_session SELECT 加列**
|
||||
|
||||
`src/storage/mod.rs` line 348: SELECT 列加 `last_compressed_message_at`。line 368 后加:
|
||||
```rust
|
||||
last_compressed_message_at: row.get("last_compressed_message_at"),
|
||||
```
|
||||
|
||||
**Step 6: 其他 SELECT sessions 的地方(list_sessions 多个变体)**
|
||||
|
||||
同样补 `last_compressed_message_at` 到 SELECT 列和 struct 构造。以及测试中的 mock SessionMeta 构造(line 728, 754 等)。
|
||||
|
||||
**Step 7: 编译检查**
|
||||
|
||||
```bash
|
||||
cargo check 2>&1
|
||||
```
|
||||
|
||||
**Step 8: Commit**
|
||||
|
||||
```bash
|
||||
git add src/storage/session.rs src/storage/mod.rs
|
||||
git commit -m "feat(storage): add last_compressed_message_at column to sessions"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 2: Storage 新增加载方法
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/storage/mod.rs` (在 load_messages 之后)
|
||||
- Modify: `src/storage/memory.rs` (在 cleanup_old_timelines 之后)
|
||||
|
||||
**Step 1: `get_max_message_seq`**
|
||||
|
||||
在 `src/storage/mod.rs` 中 `load_messages` 函数后面添加:
|
||||
```rust
|
||||
pub async fn get_max_message_seq(&self, session_id: &str) -> Result<i64, StorageError> {
|
||||
let row = sqlx::query(
|
||||
"SELECT COALESCE(MAX(seq), 0) as max_seq FROM messages WHERE session_id = ?",
|
||||
)
|
||||
.bind(session_id)
|
||||
.fetch_one(self.pool())
|
||||
.await?;
|
||||
Ok(row.get::<i64, _>("max_seq"))
|
||||
}
|
||||
```
|
||||
|
||||
**Step 2: `load_messages_after_timestamp`**
|
||||
|
||||
```rust
|
||||
pub async fn load_messages_after_timestamp(
|
||||
&self,
|
||||
session_id: &str,
|
||||
after_ts: i64,
|
||||
) -> Result<Vec<crate::storage::message::MessageMeta>, StorageError> {
|
||||
let rows = sqlx::query(
|
||||
r#"
|
||||
SELECT id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at
|
||||
FROM messages
|
||||
WHERE session_id = ? AND created_at > ?
|
||||
ORDER BY seq ASC
|
||||
"#,
|
||||
)
|
||||
.bind(session_id)
|
||||
.bind(after_ts)
|
||||
.fetch_all(self.pool())
|
||||
.await?;
|
||||
|
||||
Ok(rows.into_iter().map(|row| crate::storage::message::MessageMeta {
|
||||
id: row.get("id"),
|
||||
session_id: row.get("session_id"),
|
||||
seq: row.get("seq"),
|
||||
role: row.get("role"),
|
||||
content: row.get("content"),
|
||||
media_refs: row.get("media_refs"),
|
||||
tool_call_id: row.get("tool_call_id"),
|
||||
tool_name: row.get("tool_name"),
|
||||
tool_calls: row.get("tool_calls"),
|
||||
source: row.get("source"),
|
||||
created_at: row.get("created_at"),
|
||||
}).collect())
|
||||
}
|
||||
```
|
||||
|
||||
**Step 3: `load_session_timelines`**
|
||||
|
||||
在 `src/storage/memory.rs` 的 `cleanup_old_timelines` 之后(line 252 的 `}` 之前)添加:
|
||||
```rust
|
||||
pub async fn load_session_timelines(
|
||||
&self,
|
||||
session_id: &str,
|
||||
limit: usize,
|
||||
) -> Result<Vec<MemoryEntry>, StorageError> {
|
||||
let rows = sqlx::query(
|
||||
r#"
|
||||
SELECT id, key, content, category, importance,
|
||||
session_id, created_at, updated_at
|
||||
FROM memories
|
||||
WHERE session_id = ? AND category = 'timeline'
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ?
|
||||
"#,
|
||||
)
|
||||
.bind(session_id)
|
||||
.bind(limit as i64)
|
||||
.fetch_all(self.pool())
|
||||
.await?;
|
||||
|
||||
parse_memory_rows(&rows)
|
||||
}
|
||||
```
|
||||
|
||||
**Step 4: 编译检查**
|
||||
|
||||
```bash
|
||||
cargo check 2>&1
|
||||
```
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add src/storage/mod.rs src/storage/memory.rs
|
||||
git commit -m "feat(storage): add load_messages_after_timestamp, load_session_timelines, get_max_message_seq"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 3: ContextCompressor 引入 CompressionResult
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/agent/context_compressor.rs:172-274` (compress_if_needed)
|
||||
- Modify: `src/agent/context_compressor.rs:320-402` (compress_once)
|
||||
|
||||
**Step 1: 定义 CompressionResult**
|
||||
|
||||
在 context_compressor.rs 中 `ContextCompressor` struct 定义之后添加:
|
||||
```rust
|
||||
pub struct CompressionResult {
|
||||
pub history: Vec<ChatMessage>,
|
||||
pub created_timelines: bool,
|
||||
}
|
||||
```
|
||||
|
||||
**Step 2: 修改 compress_if_needed 签名和返回**
|
||||
|
||||
将 `pub async fn compress_if_needed(&self, mut history: Vec<ChatMessage>) -> Result<Vec<ChatMessage>, AgentError>` 改为:
|
||||
```rust
|
||||
pub async fn compress_if_needed(
|
||||
&self,
|
||||
mut history: Vec<ChatMessage>,
|
||||
) -> Result<CompressionResult, AgentError> {
|
||||
```
|
||||
|
||||
内部的 `return Ok(history)` 改为 `return Ok(CompressionResult { history, created_timelines: false })`(Tier 1 fast trim 和不需要压缩时)。
|
||||
|
||||
**Step 3: 修改 LLM summarization pass 部分**
|
||||
|
||||
在压缩循环中维护一个 `created_timelines` 标志:
|
||||
```rust
|
||||
let mut created_timelines = false;
|
||||
for pass in 0..self.config.max_passes {
|
||||
// ...
|
||||
match self.compress_once(...).await {
|
||||
Ok(Some(compressed)) => {
|
||||
current_history = compressed;
|
||||
created_timelines = true;
|
||||
}
|
||||
// ...
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
最后返回:
|
||||
```rust
|
||||
Ok(CompressionResult { history: current_history, created_timelines })
|
||||
```
|
||||
|
||||
**Step 4: 更新所有 compress_if_needed 调用方**
|
||||
|
||||
所有 `compress_if_needed(history)` 改为 `compress_if_needed(history).await?.history`。在 `handle_message` 和 `/compact` 中还需要用到 `created_timelines`。
|
||||
|
||||
**Step 5: 编译检查**
|
||||
|
||||
```bash
|
||||
cargo check 2>&1
|
||||
```
|
||||
|
||||
**Step 6: Commit**
|
||||
|
||||
```bash
|
||||
git add src/agent/context_compressor.rs src/session/session.rs
|
||||
git commit -m "feat(compressor): return CompressionResult with created_timelines flag"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 4: Session 结构体和持久化
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/session/session.rs:52-74` (Session struct)
|
||||
- Modify: `src/session/session.rs:76-121` (Session::new)
|
||||
- Modify: `src/session/session.rs:298-320` (persist_session_meta)
|
||||
|
||||
**Step 1: Session struct 加字段**
|
||||
|
||||
在 `pub last_consolidated_at: Option<i64>` 后加:
|
||||
```rust
|
||||
pub last_compressed_message_at: Option<i64>,
|
||||
```
|
||||
|
||||
**Step 2: Session::new 初始化**
|
||||
|
||||
在 `last_consolidated_at: None` 后加:
|
||||
```rust
|
||||
last_compressed_message_at: None,
|
||||
```
|
||||
|
||||
**Step 3: persist_session_meta 加字段**
|
||||
|
||||
在 `last_consolidated_at: self.last_consolidated_at` 后加:
|
||||
```rust
|
||||
last_compressed_message_at: self.last_compressed_message_at,
|
||||
```
|
||||
|
||||
**Step 4: 编译检查**
|
||||
|
||||
```bash
|
||||
cargo check 2>&1
|
||||
```
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add src/session/session.rs
|
||||
git commit -m "feat(session): add last_compressed_message_at field to Session and persist"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 5: Session::from_storage() 增量恢复
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/session/session.rs:124-189` (from_storage)
|
||||
|
||||
**Step 1: 重写 from_storage**
|
||||
|
||||
替换现有实现为:
|
||||
|
||||
```rust
|
||||
pub async fn from_storage(
|
||||
id: UnifiedSessionId,
|
||||
provider_config: LLMProviderConfig,
|
||||
tools: Arc<ToolRegistry>,
|
||||
storage: StdArc<Storage>,
|
||||
memory_manager: Arc<crate::memory::MemoryManager>,
|
||||
) -> Result<Self, AgentError> {
|
||||
let session_meta = storage.get_session(&id.to_string()).await
|
||||
.map_err(|e| AgentError::Other(format!("failed to load session from storage: {}", e)))?;
|
||||
|
||||
let mut provider_box = create_provider(provider_config.clone())
|
||||
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
|
||||
provider_box.set_storage(storage.clone());
|
||||
let provider: Arc<dyn LLMProvider> = Arc::from(provider_box);
|
||||
|
||||
let compressor_config = ContextCompressionConfig {
|
||||
protect_first_n: 2,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut compressor = ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config, memory_manager.clone());
|
||||
compressor.set_session_id(Some(id.to_string()));
|
||||
|
||||
// Determine recovery strategy
|
||||
let mut chat_messages: Vec<ChatMessage> = Vec::new();
|
||||
|
||||
if let Some(after_ts) = session_meta.last_compressed_message_at {
|
||||
// Load last 4 timelines to determine if there are > 3
|
||||
let timelines = storage
|
||||
.load_session_timelines(&id.to_string(), 4)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
let has_more_timelines = timelines.len() > 3;
|
||||
|
||||
// Insert hint if more timelines exist
|
||||
if has_more_timelines {
|
||||
chat_messages.push(ChatMessage::user(
|
||||
"[Earlier conversation summaries exist. \
|
||||
Use `timeline_recall` to search if needed.]"
|
||||
));
|
||||
}
|
||||
|
||||
// Insert latest 3 timelines as context (reversed: oldest first)
|
||||
for tl in timelines.iter().take(3).rev() {
|
||||
chat_messages.push(ChatMessage::user(format!(
|
||||
"[Previous Context]\n{}", tl.content
|
||||
)));
|
||||
}
|
||||
|
||||
// Load raw messages after compressed timestamp
|
||||
let tail = storage
|
||||
.load_messages_after_timestamp(&id.to_string(), after_ts)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
let mut tail_msgs: Vec<ChatMessage> = tail.into_iter().map(|m| {
|
||||
ChatMessage {
|
||||
id: m.id,
|
||||
role: m.role,
|
||||
content: m.content,
|
||||
media_refs: m.media_refs.map(|refs| serde_json::from_str(&refs).unwrap_or_default()).unwrap_or_default(),
|
||||
timestamp: m.created_at,
|
||||
tool_call_id: m.tool_call_id,
|
||||
tool_name: m.tool_name,
|
||||
tool_calls: m.tool_calls
|
||||
.and_then(|tc| serde_json::from_str::<Vec<crate::providers::ToolCall>>(&tc).ok())
|
||||
.filter(|v| !v.is_empty()),
|
||||
source: m.source.and_then(|s| serde_json::from_str(&s).ok()),
|
||||
}
|
||||
}).collect();
|
||||
|
||||
repair_tool_call_chains(&mut tail_msgs);
|
||||
chat_messages.extend(tail_msgs);
|
||||
} else {
|
||||
// No prior compression — load all messages (existing behavior)
|
||||
let messages = storage.load_messages(&id.to_string(), 0).await
|
||||
.map_err(|e| AgentError::Other(format!("failed to load messages from storage: {}", e)))?;
|
||||
|
||||
chat_messages = messages.into_iter().map(|m| {
|
||||
ChatMessage {
|
||||
id: m.id,
|
||||
role: m.role,
|
||||
content: m.content,
|
||||
media_refs: m.media_refs.map(|refs| serde_json::from_str(&refs).unwrap_or_default()).unwrap_or_default(),
|
||||
timestamp: m.created_at,
|
||||
tool_call_id: m.tool_call_id,
|
||||
tool_name: m.tool_name,
|
||||
tool_calls: m.tool_calls
|
||||
.and_then(|tc| serde_json::from_str::<Vec<crate::providers::ToolCall>>(&tc).ok())
|
||||
.filter(|v| !v.is_empty()),
|
||||
source: m.source.and_then(|s| serde_json::from_str(&s).ok()),
|
||||
}
|
||||
}).collect();
|
||||
|
||||
repair_tool_call_chains(&mut chat_messages);
|
||||
}
|
||||
|
||||
// seq_counter from actual DB max
|
||||
let max_seq = storage
|
||||
.get_max_message_seq(&id.to_string())
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
let seq_counter = max_seq + 1;
|
||||
let total_message_count = session_meta.message_count;
|
||||
|
||||
Ok(Self {
|
||||
id: id.clone(),
|
||||
title: session_meta.title,
|
||||
created_at: session_meta.created_at,
|
||||
last_active_at: session_meta.last_active_at,
|
||||
message_count: session_meta.message_count,
|
||||
total_message_count,
|
||||
messages: chat_messages,
|
||||
seq_counter,
|
||||
provider_config: provider_config.clone(),
|
||||
provider: provider.clone(),
|
||||
tools,
|
||||
compressor,
|
||||
storage: Some(storage),
|
||||
routing_info: session_meta.routing_info.unwrap_or_default(),
|
||||
last_consolidated_at: session_meta.last_consolidated_at,
|
||||
last_compressed_message_at: session_meta.last_compressed_message_at,
|
||||
memory_manager,
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
**Step 2: 编译检查**
|
||||
|
||||
```bash
|
||||
cargo check 2>&1
|
||||
```
|
||||
|
||||
**Step 3: Commit**
|
||||
|
||||
```bash
|
||||
git add src/session/session.rs
|
||||
git commit -m "feat(session): incremental recovery from storage using compressed timeline"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 6: 系统提示词加历史会话提示
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/agent/system_prompt.rs:289-304` (MemorySection)
|
||||
- Modify: `src/agent/system_prompt.rs:16-23` (PromptContext)
|
||||
- Modify: `src/agent/system_prompt.rs:343-358` (build_system_prompt free function)
|
||||
- Modify: `src/session/session.rs:411-426` (build_system_prompt)
|
||||
|
||||
**Step 1: PromptContext 加 has_compressed_history 字段**
|
||||
|
||||
```rust
|
||||
pub struct PromptContext<'a> {
|
||||
pub workspace_dir: &'a Path,
|
||||
pub model_name: &'a str,
|
||||
pub tools: &'a ToolRegistry,
|
||||
pub session_id: Option<&'a str>,
|
||||
pub memory_context: Option<&'a str>,
|
||||
pub has_compressed_history: bool,
|
||||
}
|
||||
```
|
||||
|
||||
**Step 2: 加 HistorySection**
|
||||
|
||||
在 MemorySection 后面添加:
|
||||
```rust
|
||||
pub struct HistorySection;
|
||||
|
||||
impl PromptSection for HistorySection {
|
||||
fn name(&self) -> &str {
|
||||
"history"
|
||||
}
|
||||
|
||||
fn build(&self, ctx: &PromptContext<'_>) -> String {
|
||||
if ctx.has_compressed_history {
|
||||
"## 历史会话\n之前的对话摘要已归档。如需回顾历史上下文,使用 `timeline_recall` 工具搜索。".to_string()
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Step 3: 注册到 SystemPromptBuilder::with_defaults**
|
||||
|
||||
在 `with_defaults()` 的 sections vec 中 `Box::new(MemorySection)` 后加 `Box::new(HistorySection)`。
|
||||
|
||||
**Step 4: 更新 build_system_prompt 签名和调用**
|
||||
|
||||
```rust
|
||||
pub fn build_system_prompt(
|
||||
workspace_dir: &Path,
|
||||
model_name: &str,
|
||||
tools: &ToolRegistry,
|
||||
session_id: Option<&str>,
|
||||
memory_context: Option<&str>,
|
||||
has_compressed_history: bool,
|
||||
) -> String {
|
||||
let ctx = PromptContext {
|
||||
workspace_dir,
|
||||
model_name,
|
||||
tools,
|
||||
session_id,
|
||||
memory_context,
|
||||
has_compressed_history,
|
||||
};
|
||||
SystemPromptBuilder::with_defaults().build(&ctx)
|
||||
}
|
||||
```
|
||||
|
||||
**Step 5: 更新 Session::build_system_prompt**
|
||||
|
||||
```rust
|
||||
pub fn build_system_prompt(&self, skills_prompt: &str, memory_context: Option<&str>) -> String {
|
||||
let base_prompt = build_system_prompt(
|
||||
&self.provider_config.workspace_dir,
|
||||
&self.provider_config.model_id,
|
||||
&self.tools,
|
||||
Some(&self.id.to_string()),
|
||||
memory_context,
|
||||
self.last_compressed_message_at.is_some(),
|
||||
);
|
||||
|
||||
if skills_prompt.trim().is_empty() {
|
||||
base_prompt
|
||||
} else {
|
||||
format!("{}\n\n## Skills\n\n{}\n\nUse the `get_skill` tool to load a skill's full content when needed.", base_prompt, skills_prompt)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Step 6: 更新所有其他 build_system_prompt 调用方**
|
||||
|
||||
搜索 `build_system_prompt(` 的所有调用位置,每个都要加 `false` 参数。主要有 `agent/agent_loop.rs` 中的两个调用。
|
||||
|
||||
**Step 7: 编译检查**
|
||||
|
||||
```bash
|
||||
cargo check 2>&1
|
||||
```
|
||||
|
||||
**Step 8: Commit**
|
||||
|
||||
```bash
|
||||
git add src/agent/system_prompt.rs src/session/session.rs src/agent/agent_loop.rs
|
||||
git commit -m "feat(system-prompt): add history section for archived conversation context"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 7: handle_message 和 /compact 记录压缩标记
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/session/session.rs:1339-1355` (handle_message 压缩后)
|
||||
- Modify: `src/session/session.rs:1372-1376` (handle_message 溢出重试)
|
||||
- Modify: `src/session/session.rs:851-872` (/compact 命令)
|
||||
|
||||
**Step 1: handle_message 正常流**
|
||||
|
||||
在 `compress_if_needed(history).await?` 之后(line 1346),改为:
|
||||
```rust
|
||||
let result = session_guard.compressor
|
||||
.compress_if_needed(history)
|
||||
.await?;
|
||||
if result.created_timelines {
|
||||
session_guard.last_compressed_message_at = Some(chrono::Utc::now().timestamp_millis());
|
||||
if let Err(e) = session_guard.persist_session_meta().await {
|
||||
tracing::warn!(error = %e, "Failed to persist compressed message marker");
|
||||
}
|
||||
}
|
||||
let mut history = result.history;
|
||||
```
|
||||
|
||||
同时删除后面(line 1350-1355)单独的 `persist_session_meta` 调用(现在已合入上面的逻辑)。
|
||||
|
||||
**Step 2: handle_message 溢出重试流**
|
||||
|
||||
```rust
|
||||
let raw = session_guard.get_history().to_vec();
|
||||
let result = session_guard.compressor.compress_if_needed(raw).await?;
|
||||
if result.created_timelines {
|
||||
session_guard.last_compressed_message_at = Some(chrono::Utc::now().timestamp_millis());
|
||||
let _ = session_guard.persist_session_meta().await;
|
||||
}
|
||||
let mut retry = result.history;
|
||||
retry.insert(0, ChatMessage::system(system_prompt));
|
||||
agent.process(retry).await?
|
||||
```
|
||||
|
||||
**Step 3: /compact 命令**
|
||||
|
||||
```rust
|
||||
let result = session_guard.compressor
|
||||
.compress_if_needed(history)
|
||||
.await?;
|
||||
let compressed_count = result.history.len();
|
||||
if result.created_timelines {
|
||||
session_guard.last_compressed_message_at = Some(chrono::Utc::now().timestamp_millis());
|
||||
let _ = session_guard.persist_session_meta().await;
|
||||
}
|
||||
session_guard.clear_history();
|
||||
for msg in result.history {
|
||||
session_guard.add_message(msg, false).await
|
||||
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
|
||||
}
|
||||
```
|
||||
|
||||
同时确认 `compress_if_needed` 的 import 正常(已在 scope 中)。
|
||||
|
||||
**Step 4: 编译检查**
|
||||
|
||||
```bash
|
||||
cargo check 2>&1
|
||||
```
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add src/session/session.rs
|
||||
git commit -m "feat(session): record last_compressed_message_at after compression"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 8: 全局编译和测试
|
||||
|
||||
**Step 1: 全局编译**
|
||||
|
||||
```bash
|
||||
cargo check 2>&1
|
||||
```
|
||||
|
||||
修复所有编译错误,确保全部文件一致。
|
||||
|
||||
**Step 2: 运行单元测试**
|
||||
|
||||
```bash
|
||||
cargo test --lib 2>&1
|
||||
```
|
||||
|
||||
**Step 3: 测试通过后 commit**
|
||||
|
||||
```bash
|
||||
git add -A
|
||||
git commit -m "chore: fix remaining compilation and test issues for incremental recovery"
|
||||
```
|
||||
|
||||
**Step 4: 运行 lint**
|
||||
|
||||
```bash
|
||||
cargo clippy --lib 2>&1 | head -50
|
||||
```
|
||||
|
||||
修复任何 warning。
|
||||
|
||||
---
|
||||
|
||||
### Task 9: 验证 & 提交设计文档
|
||||
|
||||
**Step 1: 最终验证**
|
||||
|
||||
```bash
|
||||
cargo test --lib 2>&1
|
||||
```
|
||||
|
||||
**Step 2: Commit 设计文档**
|
||||
|
||||
```bash
|
||||
git add docs/plans/2026-05-10-incremental-session-recovery-design.md
|
||||
git commit -m "docs: add incremental session recovery design doc"
|
||||
```
|
||||
File diff suppressed because it is too large
Load Diff
@ -5,7 +5,7 @@ always: true
|
||||
---
|
||||
# About PicoBot
|
||||
|
||||
PicoBot 是一个基于 Rust 的个人 AI 助手,支持多渠道(飞书、CLI)、长记忆、定时任务、Skill 系统等。
|
||||
PicoBot 是一个基于 Rust 的个人 AI 助手运行时,包含本地 Gateway、CLI TUI 客户端、飞书渠道、SQLite 会话持久化、长期记忆、定时任务、Skill 系统、MCP 工具接入和子 Agent 委托能力。
|
||||
|
||||
## 目录索引
|
||||
|
||||
@ -13,10 +13,10 @@ PicoBot 是一个基于 Rust 的个人 AI 助手,支持多渠道(飞书、CL
|
||||
|
||||
| 文件 | 内容 |
|
||||
|------|------|
|
||||
| `references/config.md` | 配置字段详解:providers、models、agents、gateway、memory、channels、mcp |
|
||||
| `references/db-schema.md` | 数据库表结构:sessions、messages、memories、scheduled_jobs、llm_calls |
|
||||
| `references/architecture.md` | 核心架构:数据流、会话系统、上下文压缩、记忆系统、Skill 优先级机制 |
|
||||
| `references/faq.md` | 常见问题:模型切换、渠道添加、Skill 安装、历史查询、定时任务等 |
|
||||
| `references/config.md` | 配置字段详解:providers、models、agents、gateway、client、channels、memory、mcp、browser |
|
||||
| `references/db-schema.md` | 数据库表结构:sessions、messages、memories、scheduled_jobs、llm_calls、background_tasks |
|
||||
| `references/architecture.md` | 核心架构:数据流、会话系统、上下文压缩、记忆系统、Skill 优先级、MCP、子 Agent |
|
||||
| `references/faq.md` | 常见问题:模型切换、渠道添加、Skill 安装、历史查询、定时任务、MCP 等 |
|
||||
| `references/commands.md` | 常用命令:编译、启动网关、启动客户端、运行测试 |
|
||||
| `assets/config.example.json` | config.json 完整示例 |
|
||||
|
||||
|
||||
@ -72,5 +72,15 @@
|
||||
"timeline_retention_days": 90,
|
||||
"max_failures_before_degrade": 3
|
||||
},
|
||||
"mcp": {
|
||||
"servers": [],
|
||||
"tool_timeout_secs": 180
|
||||
},
|
||||
"browser": {
|
||||
"enabled": false,
|
||||
"webdriver_url": "http://127.0.0.1:9515",
|
||||
"headless": true,
|
||||
"chrome_path": null
|
||||
},
|
||||
"workspace_dir": "~/.picobot/workspace"
|
||||
}
|
||||
|
||||
@ -17,9 +17,9 @@ Channel → MessageBus → SessionManager → AgentLoop → (tools) → SessionM
|
||||
| `channels` | 外部集成(飞书、CLI),仅收发消息 |
|
||||
| `bus` | 异步消息队列,纯队列不路由 |
|
||||
| `session` | 会话生命周期管理、dialog 操作 |
|
||||
| `agent` | LLM 调用循环、工具执行、上下文压缩 |
|
||||
| `agent` | LLM 调用循环、工具执行、上下文压缩、媒体处理、子 Agent |
|
||||
| `providers` | LLM API 客户端(OpenAI 兼容、Anthropic) |
|
||||
| `tools` | Agent 工具(bash、文件操作、HTTP、web、get_skill 等) |
|
||||
| `tools` | Agent 工具(bash、文件操作、搜索、HTTP、web、browser、memory、delegate 等) |
|
||||
| `skills` | Skill 加载、管理和 prompt 构建 |
|
||||
| `storage` | SQLite 持久化 |
|
||||
| `scheduler` | Cron 作业调度 |
|
||||
@ -37,6 +37,8 @@ Channel → MessageBus → SessionManager → AgentLoop → (tools) → SessionM
|
||||
- AgentLoop 无状态,接收 dialog 事件调用 LLM、执行工具
|
||||
- Providers 是纯 HTTP 客户端,无 bus/session/channel 感知
|
||||
- Tools 接收原始参数,返回字符串结果
|
||||
- MCP 工具在 Gateway 初始化时连接服务器、发现工具,并包装成普通 Tool 注册到 ToolRegistry
|
||||
- 子 Agent 由 `delegate` 工具创建,复用 provider 配置和按需过滤后的工具集;后台任务结果通过 MessageBus 发回原会话
|
||||
|
||||
## 关键约束
|
||||
|
||||
@ -45,6 +47,7 @@ Channel → MessageBus → SessionManager → AgentLoop → (tools) → SessionM
|
||||
- ChannelManager 持有 MessageBus 和所有 channel
|
||||
- OutboundDispatcher 通过 ChannelManager 路由出站消息
|
||||
- Config `.env` 加载使用 `unsafe { env::set_var(...) }`
|
||||
- `browser` 工具只有在 `browser.enabled=true` 时注册,依赖 Chrome/Chromium 与 WebDriver
|
||||
|
||||
## 上下文压缩
|
||||
|
||||
@ -192,3 +195,48 @@ LLM 对话上下文接近 token 限制 (默认 128K × 70%) 时自动触发压
|
||||
| 有压缩历史时 | `HistorySection` 提示 LLM 使用 `timeline_recall` |
|
||||
| 压缩完成后 | 摘要自动存储为 Timeline 记忆 |
|
||||
| 空闲时 | 可配置自动 consolidation(`idle_consolidation_minutes`) |
|
||||
|
||||
---
|
||||
|
||||
## MCP 工具集成
|
||||
|
||||
Gateway 初始化时读取 `config.mcp.servers`:
|
||||
|
||||
1. 按服务器配置连接 `stdio`、`sse` 或 `streamable-http` 传输
|
||||
2. 调用 MCP `list_tools`
|
||||
3. 将每个 MCP tool 包装为 `McpToolWrapper`
|
||||
4. 注册到当前 session 的 `ToolRegistry`
|
||||
|
||||
`/mcp` 斜杠命令会显示 MCP 服务器连接状态和工具列表。
|
||||
|
||||
---
|
||||
|
||||
## 子 Agent / delegate
|
||||
|
||||
`delegate` 工具用于把独立任务交给子 Agent:
|
||||
|
||||
| 模式 | 行为 |
|
||||
|------|------|
|
||||
| `inline` | 当前轮阻塞等待子 Agent 返回 |
|
||||
| `background` | 后台运行,完成后通过原 channel/chat 通知 |
|
||||
| `parallel` | 多个子 Agent 并发执行并聚合结果 |
|
||||
|
||||
默认工具集是只读工具:`file_read`、`file_search`、`content_search`、`web_fetch`、`http_request`、`calculator`。调用时可通过 `allowed_tools` 显式放开其他工具。后台任务会写入 `background_tasks` 表,默认 24 小时后清理。
|
||||
|
||||
---
|
||||
|
||||
## 当前斜杠命令
|
||||
|
||||
| 命令 | 说明 |
|
||||
|------|------|
|
||||
| `/new` | 创建新对话 |
|
||||
| `/sessions` | 列出最近对话 |
|
||||
| `/switch <dialog_id>` | 切换到指定对话 |
|
||||
| `/rename <title>` | 重命名当前对话 |
|
||||
| `/delete` | 删除当前对话 |
|
||||
| `/compact` | 手动触发上下文压缩 |
|
||||
| `/info` | 显示当前对话信息 |
|
||||
| `/dump` | 保存当前对话为 markdown |
|
||||
| `/?`, `/help` | 显示帮助 |
|
||||
| `/mcp` | 显示 MCP 状态 |
|
||||
| `/stop` | 停止当前任务并清空消息队列 |
|
||||
|
||||
@ -14,8 +14,9 @@
|
||||
"client": {}, // 客户端配置
|
||||
"channels": {}, // 渠道配置
|
||||
"memory": {}, // 记忆系统配置
|
||||
"workspace_dir": // 工作目录,默认 ~/.picobot/workspace
|
||||
"mcp": {} // MCP 服务器配置
|
||||
"workspace_dir": "", // 工作目录,默认 ~/.picobot/workspace
|
||||
"mcp": {}, // MCP 服务器配置
|
||||
"browser": {} // 可选浏览器自动化配置
|
||||
}
|
||||
```
|
||||
|
||||
@ -57,8 +58,17 @@
|
||||
| `session_ttl_hours` | int | - | 会话过期小时数 |
|
||||
| `session_db_path` | string | - | SQLite 数据库路径,默认在 workspace 下 |
|
||||
| `cleanup_interval_minutes` | int | - | 清理间隔 |
|
||||
| `max_concurrent_background_tasks` | int | 10 | delegate 后台子任务最大并发数 |
|
||||
| `scheduler` | object | - | 调度器配置 |
|
||||
|
||||
### gateway.scheduler 字段
|
||||
|
||||
| 字段 | 类型 | 默认 | 说明 |
|
||||
|------|------|------|------|
|
||||
| `enabled` | bool | true | 是否启动调度器并注册 cron 工具 |
|
||||
| `poll_interval_secs` | int | 60 | 检查到期任务的轮询间隔 |
|
||||
| `max_concurrent` | int | 1 | 最大并发任务数,当前实现预留 |
|
||||
|
||||
## memory 字段
|
||||
|
||||
| 字段 | 类型 | 默认 | 说明 |
|
||||
@ -94,8 +104,21 @@ MCP 服务器单条配置:
|
||||
| 字段 | 说明 |
|
||||
|------|------|
|
||||
| `name` | 服务器名称 |
|
||||
| `transport` | 传输方式: `Stdio`、`Sse`、`streamable-http` |
|
||||
| `command` | 启动命令(Stdio 模式) |
|
||||
| `transport` | 传输方式: `stdio`、`sse`、`streamable-http` |
|
||||
| `command` | 启动命令(stdio 模式) |
|
||||
| `args` | 命令参数 |
|
||||
| `url` | URL(Sse / streamable-http 模式) |
|
||||
| `env` | 子进程环境变量 |
|
||||
| `url` | URL(sse / streamable-http 模式) |
|
||||
| `headers` | HTTP 传输额外请求头 |
|
||||
| `tool_timeout_secs` | 单独的超时设置 |
|
||||
|
||||
## browser 字段
|
||||
|
||||
浏览器工具默认关闭,开启后注册 `browser` 工具。依赖 Chrome/Chromium 与 chromedriver/WebDriver。
|
||||
|
||||
| 字段 | 类型 | 默认 | 说明 |
|
||||
|------|------|------|------|
|
||||
| `enabled` | bool | false | 是否启用浏览器工具 |
|
||||
| `webdriver_url` | string | http://127.0.0.1:9515 | WebDriver 服务地址 |
|
||||
| `headless` | bool | true | 是否无头运行 |
|
||||
| `chrome_path` | string | - | 自定义 Chrome/Chromium 路径 |
|
||||
|
||||
@ -36,6 +36,28 @@
|
||||
| `tool_calls` | TEXT | 工具调用参数 JSON |
|
||||
| `source` | TEXT | 消息来源(跨会话消息时标记来源 session_id) |
|
||||
| `created_at` | INTEGER | 创建时间(unix 秒) |
|
||||
| `reasoning_content` | TEXT | provider 返回的推理内容(如有) |
|
||||
|
||||
## background_tasks 表
|
||||
|
||||
delegate 后台子任务表。`session_id` 不使用数据库外键,因为 session 使用软删除,关联关系由应用层维护。
|
||||
|
||||
| 字段 | 类型 | 说明 |
|
||||
|------|------|------|
|
||||
| `id` | TEXT PK | 后台任务 ID |
|
||||
| `session_id` | TEXT | 所属会话 |
|
||||
| `channel` | TEXT | 回传渠道 |
|
||||
| `chat_id` | TEXT | 回传目标对话 |
|
||||
| `prompt` | TEXT | 子任务提示 |
|
||||
| `allowed_tools` | TEXT | 允许工具 JSON |
|
||||
| `status` | TEXT | pending / running / completed / failed / cancelled |
|
||||
| `result` | TEXT | 执行结果 |
|
||||
| `error` | TEXT | 错误信息 |
|
||||
| `tool_calls_count` | INTEGER | 工具调用次数 |
|
||||
| `iterations` | INTEGER | Agent 迭代次数 |
|
||||
| `started_at` | INTEGER | 开始时间 |
|
||||
| `finished_at` | INTEGER | 结束时间 |
|
||||
| `created_at` | INTEGER | 创建时间 |
|
||||
|
||||
## memories 表
|
||||
|
||||
|
||||
@ -124,9 +124,51 @@
|
||||
|
||||
---
|
||||
|
||||
## file_read / file_write / file_edit / file_search — 文件操作
|
||||
## delegate — 子 Agent 委托
|
||||
|
||||
工作目录内的文件读写编辑和搜索。详细的参数定义见各工具的 parameters_schema。
|
||||
创建子 Agent 处理独立任务。
|
||||
|
||||
| 参数 | 必填 | 说明 |
|
||||
|------|------|------|
|
||||
| `action` | 是 | `run`, `check_task`, `cancel_task`, `list_tasks` |
|
||||
| `prompt` | run 必填 | 子任务描述 |
|
||||
| `mode` | 否 | `inline`, `background`, `parallel`,默认 `inline` |
|
||||
| `allowed_tools` | 否 | 子 Agent 可用工具列表;默认只读工具集 |
|
||||
| `max_iterations` | 否 | 最大迭代次数,默认 99 |
|
||||
| `timeout_secs` | 否 | 超时秒数,默认 3600 |
|
||||
| `tasks` | parallel 必填 | 并行子任务数组 |
|
||||
| `task_id` | 查询/取消必填 | 后台任务 ID |
|
||||
|
||||
默认只读工具集:`file_read`、`file_search`、`content_search`、`web_fetch`、`http_request`、`calculator`。
|
||||
|
||||
---
|
||||
|
||||
## browser — 浏览器自动化
|
||||
|
||||
仅在 `browser.enabled=true` 时注册。底层使用 WebDriver/Chrome。
|
||||
|
||||
| action | 说明 |
|
||||
|--------|------|
|
||||
| `open` | 打开 URL |
|
||||
| `snapshot` | 获取页面结构快照 |
|
||||
| `click`, `click_at` | 点击元素或坐标 |
|
||||
| `fill`, `type`, `press` | 输入文本或按键 |
|
||||
| `get_text`, `get_title`, `get_url` | 读取页面信息 |
|
||||
| `screenshot` | 截图,可写入文件或返回 base64 |
|
||||
| `focus`, `hover`, `scroll`, `wait` | 常见交互和等待 |
|
||||
| `close` | 关闭浏览器会话 |
|
||||
|
||||
---
|
||||
|
||||
## MCP 工具
|
||||
|
||||
如果 `config.mcp.servers` 配置了 MCP 服务器,Gateway 启动时会连接服务器、发现工具,并把 MCP 工具包装后注册到 ToolRegistry。使用 `/mcp` 查看当前连接状态和工具列表。
|
||||
|
||||
---
|
||||
|
||||
## file_read / file_write / file_edit / file_search / content_search — 文件操作和搜索
|
||||
|
||||
工作目录内的文件读写编辑、文件名搜索和内容搜索。详细的参数定义见各工具的 parameters_schema。
|
||||
|
||||
## bash — 执行命令
|
||||
|
||||
|
||||
@ -72,5 +72,15 @@
|
||||
"timeline_retention_days": 90,
|
||||
"max_failures_before_degrade": 3
|
||||
},
|
||||
"mcp": {
|
||||
"servers": [],
|
||||
"tool_timeout_secs": 180
|
||||
},
|
||||
"browser": {
|
||||
"enabled": false,
|
||||
"webdriver_url": "http://127.0.0.1:9515",
|
||||
"headless": true,
|
||||
"chrome_path": null
|
||||
},
|
||||
"workspace_dir": "~/.picobot/workspace"
|
||||
}
|
||||
|
||||
@ -4,10 +4,8 @@ use crate::agent::system_prompt::build_system_prompt;
|
||||
use crate::bus::message::ContentBlock;
|
||||
use crate::bus::{ChatMessage, MediaRef};
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::observability::{
|
||||
truncate_args, Observer, ObserverEvent, ToolExecutionOutcome,
|
||||
};
|
||||
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall};
|
||||
use crate::observability::{Observer, ObserverEvent, ToolExecutionOutcome, truncate_args};
|
||||
use crate::providers::{ChatCompletionRequest, LLMProvider, Message, ToolCall, create_provider};
|
||||
use crate::tools::ToolRegistry;
|
||||
use std::collections::VecDeque;
|
||||
use std::hash::{Hash, Hasher};
|
||||
@ -228,6 +226,7 @@ pub struct AgentLoop {
|
||||
pub struct AgentProcessResult {
|
||||
pub final_response: ChatMessage,
|
||||
pub emitted_messages: Vec<ChatMessage>,
|
||||
pub total_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
impl AgentLoop {
|
||||
@ -255,7 +254,10 @@ impl AgentLoop {
|
||||
}
|
||||
|
||||
/// Create a new AgentLoop with provider created from config and given tools.
|
||||
pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>) -> Result<Self, AgentError> {
|
||||
pub fn with_tools(
|
||||
provider_config: LLMProviderConfig,
|
||||
tools: Arc<ToolRegistry>,
|
||||
) -> Result<Self, AgentError> {
|
||||
let max_iterations = provider_config.max_tool_iterations;
|
||||
let model_name = provider_config.model_id.clone();
|
||||
let workspace_dir = provider_config.workspace_dir.clone();
|
||||
@ -278,7 +280,13 @@ impl AgentLoop {
|
||||
}
|
||||
|
||||
/// Create a new AgentLoop with an existing shared provider.
|
||||
pub fn with_provider(provider: Arc<dyn LLMProvider>, max_iterations: usize, model_name: String, workspace_dir: PathBuf, input_types: Vec<String>) -> Self {
|
||||
pub fn with_provider(
|
||||
provider: Arc<dyn LLMProvider>,
|
||||
max_iterations: usize,
|
||||
model_name: String,
|
||||
workspace_dir: PathBuf,
|
||||
input_types: Vec<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
provider,
|
||||
tools: Arc::new(ToolRegistry::new()),
|
||||
@ -340,8 +348,9 @@ impl AgentLoop {
|
||||
}
|
||||
|
||||
/// Preemptive trim: truncate old tool results in-place when history is
|
||||
/// approaching the context window limit. Only trims tool messages with
|
||||
/// content > TRIM_CHARS, preserving the most recent KEEP messages.
|
||||
/// approaching the context window limit. Old results (outside of `keep_recent`
|
||||
/// zone) are replaced with a short placeholder; recent results are truncated
|
||||
/// to `max_chars`.
|
||||
fn preemptive_trim_old_tool_results(
|
||||
&self,
|
||||
messages: &mut [ChatMessage],
|
||||
@ -358,11 +367,11 @@ impl AgentLoop {
|
||||
if messages[i].content.len() <= max_chars {
|
||||
continue;
|
||||
}
|
||||
let removed = messages[i].content.len() - max_chars;
|
||||
let tool_name = messages[i].tool_name.as_deref().unwrap_or("unknown");
|
||||
let chars = messages[i].content.len();
|
||||
messages[i].content = format!(
|
||||
"{}...\n\n[Output truncated - {} characters removed]",
|
||||
&messages[i].content[..messages[i].content.ceil_char_boundary(max_chars)],
|
||||
removed
|
||||
"[Tool output ({}) — {} chars, omitted from context]",
|
||||
tool_name, chars
|
||||
);
|
||||
modified += 1;
|
||||
}
|
||||
@ -377,7 +386,12 @@ impl AgentLoop {
|
||||
let content = if m.media_refs.is_empty() {
|
||||
vec![ContentBlock::text(&m.content)]
|
||||
} else {
|
||||
build_content_blocks(&m.content, &m.media_refs, &self.input_types, &self.media_registry)
|
||||
build_content_blocks(
|
||||
&m.content,
|
||||
&m.media_refs,
|
||||
&self.input_types,
|
||||
&self.media_registry,
|
||||
)
|
||||
};
|
||||
|
||||
Message {
|
||||
@ -397,14 +411,28 @@ impl AgentLoop {
|
||||
/// it loops back to the LLM with the tool results until either:
|
||||
/// - The LLM returns no more tool calls (final response)
|
||||
/// - Maximum iterations are reached
|
||||
pub async fn process(&self, mut messages: Vec<ChatMessage>) -> Result<AgentProcessResult, AgentError> {
|
||||
pub async fn process(
|
||||
&self,
|
||||
mut messages: Vec<ChatMessage>,
|
||||
) -> Result<AgentProcessResult, AgentError> {
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process");
|
||||
tracing::debug!(
|
||||
history_len = messages.len(),
|
||||
max_iterations = self.max_iterations,
|
||||
"Starting agent process"
|
||||
);
|
||||
|
||||
// Build and inject system prompt if not present
|
||||
let has_system = messages.first().is_some_and(|m| m.role == "system");
|
||||
if !has_system {
|
||||
let system_prompt = build_system_prompt(&self.workspace_dir, &self.model_name, &self.tools, None, None, false);
|
||||
let system_prompt = build_system_prompt(
|
||||
&self.workspace_dir,
|
||||
&self.model_name,
|
||||
&self.tools,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
);
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!("System prompt injected:\n{}", system_prompt);
|
||||
messages.insert(0, ChatMessage::system(system_prompt));
|
||||
@ -413,6 +441,7 @@ impl AgentLoop {
|
||||
// Track tool calls for loop detection
|
||||
let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default());
|
||||
let mut emitted_messages = Vec::new();
|
||||
let mut accumulated_tokens: u32 = 0;
|
||||
|
||||
for iteration in 0..self.max_iterations {
|
||||
#[cfg(debug_assertions)]
|
||||
@ -424,9 +453,7 @@ impl AgentLoop {
|
||||
let estimated = estimate_tokens(&messages);
|
||||
let danger = (self.context_window as f64 * 0.8) as usize;
|
||||
if estimated > danger {
|
||||
let trimmed = self.preemptive_trim_old_tool_results(
|
||||
&mut messages, 2000, 4,
|
||||
);
|
||||
let trimmed = self.preemptive_trim_old_tool_results(&mut messages, 2000, 4);
|
||||
if trimmed > 0 {
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(
|
||||
@ -460,11 +487,12 @@ impl AgentLoop {
|
||||
};
|
||||
|
||||
// Call LLM
|
||||
let response = (*self.provider).chat(request).await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "LLM request failed");
|
||||
AgentError::LlmError(e.to_string())
|
||||
})?;
|
||||
let response = (*self.provider).chat(request).await.map_err(|e| {
|
||||
tracing::error!(error = %e, "LLM request failed");
|
||||
AgentError::LlmError(e.to_string())
|
||||
})?;
|
||||
|
||||
accumulated_tokens += response.usage.total_tokens;
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(
|
||||
@ -482,12 +510,15 @@ impl AgentLoop {
|
||||
return Ok(AgentProcessResult {
|
||||
final_response: assistant_message,
|
||||
emitted_messages,
|
||||
total_tokens: Some(accumulated_tokens),
|
||||
});
|
||||
}
|
||||
|
||||
// Execute tool calls — log and notify immediately
|
||||
{
|
||||
let tools_info: Vec<String> = response.tool_calls.iter()
|
||||
let tools_info: Vec<String> = response
|
||||
.tool_calls
|
||||
.iter()
|
||||
.map(|tc| {
|
||||
let args = serde_json::to_string(&tc.arguments).unwrap_or_default();
|
||||
let s = format!("{}:{}", tc.name, args);
|
||||
@ -516,7 +547,9 @@ impl AgentLoop {
|
||||
// Log function call with name and arguments
|
||||
let args_str = match &tool_call.arguments {
|
||||
serde_json::Value::Object(obj) if obj.is_empty() => "{}".to_string(),
|
||||
other => serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()),
|
||||
other => {
|
||||
serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string())
|
||||
}
|
||||
};
|
||||
tracing::info!(tool = %tool_call.name, args = %args_str, "Calling tool");
|
||||
|
||||
@ -556,7 +589,11 @@ impl AgentLoop {
|
||||
|
||||
// Loop continues to next iteration with updated messages
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(iteration, message_count = messages.len(), "Tool execution complete, continuing to next iteration");
|
||||
tracing::debug!(
|
||||
iteration,
|
||||
message_count = messages.len(),
|
||||
"Tool execution complete, continuing to next iteration"
|
||||
);
|
||||
}
|
||||
|
||||
// Max iterations reached - ask LLM for a summary based on completed work
|
||||
@ -565,7 +602,7 @@ impl AgentLoop {
|
||||
// Add a message asking for summary
|
||||
let summary_request = ChatMessage::user(
|
||||
"You have reached the maximum number of tool call iterations. \
|
||||
Please provide your best answer based on the work completed so far."
|
||||
Please provide your best answer based on the work completed so far.",
|
||||
);
|
||||
messages.push(summary_request);
|
||||
|
||||
@ -584,24 +621,32 @@ impl AgentLoop {
|
||||
|
||||
match (*self.provider).chat(request).await {
|
||||
Ok(response) => {
|
||||
accumulated_tokens += response.usage.total_tokens;
|
||||
let mut assistant_message = ChatMessage::assistant(response.content);
|
||||
assistant_message.reasoning_content = response.reasoning_content;
|
||||
emitted_messages.push(assistant_message.clone());
|
||||
Ok(AgentProcessResult {
|
||||
final_response: assistant_message,
|
||||
emitted_messages,
|
||||
total_tokens: Some(accumulated_tokens),
|
||||
})
|
||||
}
|
||||
Err(e) => {
|
||||
// Fallback if summary call fails
|
||||
tracing::error!(error = %e, "Failed to get summary from LLM");
|
||||
let final_message = ChatMessage::assistant(
|
||||
format!("I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.", self.max_iterations)
|
||||
);
|
||||
let final_message = ChatMessage::assistant(format!(
|
||||
"I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.",
|
||||
self.max_iterations
|
||||
));
|
||||
emitted_messages.push(final_message.clone());
|
||||
Ok(AgentProcessResult {
|
||||
final_response: final_message,
|
||||
emitted_messages,
|
||||
total_tokens: if accumulated_tokens > 0 {
|
||||
Some(accumulated_tokens)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -689,10 +734,7 @@ impl AgentLoop {
|
||||
}
|
||||
|
||||
// Apply duration
|
||||
ToolExecutionOutcome {
|
||||
duration,
|
||||
..result
|
||||
}
|
||||
ToolExecutionOutcome { duration, ..result }
|
||||
}
|
||||
|
||||
/// Internal tool execution without event tracking.
|
||||
@ -714,18 +756,12 @@ impl AgentLoop {
|
||||
ToolExecutionOutcome::success(result.output)
|
||||
} else {
|
||||
let error = result.error.unwrap_or_default();
|
||||
ToolExecutionOutcome::failure(
|
||||
format!("Error: {}", error),
|
||||
Some(error),
|
||||
)
|
||||
ToolExecutionOutcome::failure(format!("Error: {}", error), Some(error))
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(tool = %tool_call.name, error = %e, "Tool execution failed");
|
||||
ToolExecutionOutcome::failure(
|
||||
format!("Error: {}", e),
|
||||
Some(e.to_string()),
|
||||
)
|
||||
ToolExecutionOutcome::failure(format!("Error: {}", e), Some(e.to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -813,8 +849,14 @@ mod tests {
|
||||
|
||||
assert_eq!(provider_message.role, "assistant");
|
||||
assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1);
|
||||
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].id, "call_1");
|
||||
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].name, "calculator");
|
||||
assert_eq!(
|
||||
provider_message.tool_calls.as_ref().unwrap()[0].id,
|
||||
"call_1"
|
||||
);
|
||||
assert_eq!(
|
||||
provider_message.tool_calls.as_ref().unwrap()[0].name,
|
||||
"calculator"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -68,6 +68,10 @@ pub struct ContextCompressor {
|
||||
memory: Arc<MemoryManager>,
|
||||
/// Current session ID for timeline memory writes.
|
||||
session_id: Option<String>,
|
||||
/// Message count sent in the last LLM call (used to split known/new history).
|
||||
last_sent_message_count: Option<usize>,
|
||||
/// Real total_tokens from the last API response.
|
||||
last_api_total_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
/// Result of context compression.
|
||||
@ -76,6 +80,15 @@ pub struct CompressionResult {
|
||||
pub created_timelines: bool,
|
||||
}
|
||||
|
||||
/// Token budget state snapshot for diagnostics.
|
||||
pub struct TokenInfo {
|
||||
pub context_window: usize,
|
||||
pub threshold: usize,
|
||||
pub estimated_tokens: usize,
|
||||
pub last_api_tokens: Option<u32>,
|
||||
pub cache_active: bool,
|
||||
}
|
||||
|
||||
impl ContextCompressor {
|
||||
/// Create a new compressor with the given provider, context window size, and memory manager.
|
||||
pub fn new(
|
||||
@ -90,6 +103,8 @@ impl ContextCompressor {
|
||||
provider,
|
||||
memory,
|
||||
session_id: None,
|
||||
last_sent_message_count: None,
|
||||
last_api_total_tokens: None,
|
||||
}
|
||||
}
|
||||
|
||||
@ -107,6 +122,8 @@ impl ContextCompressor {
|
||||
provider,
|
||||
memory,
|
||||
session_id: None,
|
||||
last_sent_message_count: None,
|
||||
last_api_total_tokens: None,
|
||||
}
|
||||
}
|
||||
|
||||
@ -120,39 +137,91 @@ impl ContextCompressor {
|
||||
self.context_window = window;
|
||||
}
|
||||
|
||||
/// Record the API's reported token usage from the last completed turn.
|
||||
/// `msg_count`: number of messages sent to LLM in that call.
|
||||
/// `tokens`: `total_tokens` from the API response.
|
||||
pub fn set_last_api_info(&mut self, msg_count: usize, tokens: Option<u32>) {
|
||||
self.last_sent_message_count = Some(msg_count);
|
||||
self.last_api_total_tokens = tokens;
|
||||
}
|
||||
|
||||
/// Invalidate the cached API token info — called after compression modifies messages.
|
||||
fn invalidate_token_cache(&mut self) {
|
||||
self.last_sent_message_count = None;
|
||||
self.last_api_total_tokens = None;
|
||||
}
|
||||
|
||||
/// Hybrid token estimation: API-reported tokens for known history +
|
||||
/// char/4 estimate for new messages since last API call.
|
||||
fn token_estimate_with_history(&self, messages: &[ChatMessage]) -> usize {
|
||||
match (self.last_api_total_tokens, self.last_sent_message_count) {
|
||||
(Some(known), Some(known_count)) if messages.len() > known_count => {
|
||||
let delta = &messages[known_count..];
|
||||
known as usize + estimate_tokens(delta)
|
||||
}
|
||||
(Some(known), _) => known as usize,
|
||||
_ => estimate_tokens(messages),
|
||||
}
|
||||
}
|
||||
|
||||
/// Always true — memory is always available (memory system is always on).
|
||||
pub fn has_memory(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// Get a snapshot of the current token budget state for diagnostics.
|
||||
pub fn token_info(&self, messages: &[ChatMessage]) -> TokenInfo {
|
||||
TokenInfo {
|
||||
context_window: self.context_window,
|
||||
threshold: self.threshold(),
|
||||
estimated_tokens: self.token_estimate_with_history(messages),
|
||||
last_api_tokens: self.last_api_total_tokens,
|
||||
cache_active: self.last_api_total_tokens.is_some(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the compression threshold in tokens.
|
||||
pub fn threshold(&self) -> usize {
|
||||
(self.context_window as f64 * self.threshold_ratio) as usize
|
||||
}
|
||||
|
||||
/// Fast-path: trim oversized tool results without LLM call.
|
||||
/// Old tool results (outside of `protect_tail` zone) are replaced with a
|
||||
/// concise placeholder; recent results are truncated to `tool_result_trim_chars`.
|
||||
/// Returns the number of messages modified.
|
||||
fn fast_trim_tool_results(&self, messages: &mut [ChatMessage]) -> usize {
|
||||
fn fast_trim_tool_results(&self, messages: &mut [ChatMessage], protect_tail: usize) -> usize {
|
||||
let limit = self.config.tool_result_trim_chars;
|
||||
let tail_start = messages.len().saturating_sub(protect_tail);
|
||||
let mut modified = 0;
|
||||
|
||||
for msg in messages.iter_mut() {
|
||||
if msg.role == "tool" && msg.content.len() > limit {
|
||||
for (i, msg) in messages.iter_mut().enumerate() {
|
||||
if msg.role != "tool" || msg.content.len() <= limit {
|
||||
continue;
|
||||
}
|
||||
if i < tail_start {
|
||||
let tool_name = msg.tool_name.as_deref().unwrap_or("unknown");
|
||||
let chars = msg.content.len();
|
||||
msg.content = format!(
|
||||
"[Tool output ({}) — {} chars, omitted from context]",
|
||||
tool_name, chars
|
||||
);
|
||||
} else {
|
||||
let removed = msg.content.len() - limit;
|
||||
msg.content = format!(
|
||||
"{}...\n\n[Output truncated - {} characters removed]",
|
||||
&msg.content[..msg.content.ceil_char_boundary(limit)],
|
||||
removed
|
||||
);
|
||||
modified += 1;
|
||||
}
|
||||
modified += 1;
|
||||
}
|
||||
|
||||
modified
|
||||
}
|
||||
|
||||
/// Remove orphan tool results whose declaring tool_calls have been compressed away.
|
||||
/// Scans for tool messages with no preceding assistant tool_call, and removes them.
|
||||
/// Repair tool call chains after compression.
|
||||
/// Phase 1: remove orphan tool results whose declaring tool_calls are missing.
|
||||
/// Phase 2: strip tool_calls from assistants whose results are missing.
|
||||
pub fn repair_tool_pairs(messages: &mut Vec<ChatMessage>) {
|
||||
let mut declared: std::collections::HashSet<String> = std::collections::HashSet::new();
|
||||
let mut i = 0;
|
||||
@ -165,23 +234,58 @@ impl ContextCompressor {
|
||||
}
|
||||
} else if messages[i].role == "tool"
|
||||
&& let Some(ref tid) = messages[i].tool_call_id
|
||||
&& !declared.contains(tid.as_str()) {
|
||||
messages.remove(i);
|
||||
continue;
|
||||
}
|
||||
&& !declared.contains(tid.as_str())
|
||||
{
|
||||
messages.remove(i);
|
||||
continue;
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
|
||||
let broken: Vec<usize> = messages
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(idx, msg)| {
|
||||
if msg.role == "assistant"
|
||||
&& let Some(ref tcs) = msg.tool_calls
|
||||
&& !tcs.is_empty()
|
||||
{
|
||||
let all_present = tcs.iter().all(|tc| {
|
||||
messages.iter().any(|m| {
|
||||
m.role == "tool" && m.tool_call_id.as_deref() == Some(tc.id.as_str())
|
||||
})
|
||||
});
|
||||
if !all_present { Some(idx) } else { None }
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
for idx in broken {
|
||||
let msg = &mut messages[idx];
|
||||
let tcs = msg.tool_calls.take().unwrap_or_default();
|
||||
let names: Vec<&str> = tcs.iter().map(|tc| tc.name.as_str()).collect();
|
||||
msg.content = format!(
|
||||
"{}\n\n[Tool calls ({}) — results are no longer available]",
|
||||
msg.content,
|
||||
names.join(", ")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Main entry point - compresses history if over threshold.
|
||||
pub async fn compress_if_needed(
|
||||
&self,
|
||||
&mut self,
|
||||
mut history: Vec<ChatMessage>,
|
||||
) -> Result<CompressionResult, AgentError> {
|
||||
// Check if compression is needed
|
||||
let tokens = estimate_tokens(&history);
|
||||
let tokens = self.token_estimate_with_history(&history);
|
||||
if tokens <= self.threshold() {
|
||||
return Ok(CompressionResult { history, created_timelines: false });
|
||||
return Ok(CompressionResult {
|
||||
history,
|
||||
created_timelines: false,
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
@ -193,8 +297,8 @@ impl ContextCompressor {
|
||||
);
|
||||
|
||||
// Fast trim pass first — modify history in place
|
||||
let trimmed = self.fast_trim_tool_results(&mut history);
|
||||
let tokens_after = estimate_tokens(&history);
|
||||
let trimmed = self.fast_trim_tool_results(&mut history, self.config.protect_last_n);
|
||||
let tokens_after = self.token_estimate_with_history(&history);
|
||||
if trimmed > 0 {
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(
|
||||
@ -204,24 +308,24 @@ impl ContextCompressor {
|
||||
);
|
||||
}
|
||||
if tokens_after <= self.threshold() {
|
||||
return Ok(CompressionResult { history, created_timelines: false });
|
||||
self.invalidate_token_cache();
|
||||
return Ok(CompressionResult {
|
||||
history,
|
||||
created_timelines: false,
|
||||
});
|
||||
}
|
||||
|
||||
// LLM summarization pass
|
||||
let mut current_history = history;
|
||||
let mut created_timelines = false;
|
||||
for pass in 0..self.config.max_passes {
|
||||
let tokens = estimate_tokens(¤t_history);
|
||||
let tokens = self.token_estimate_with_history(¤t_history);
|
||||
if tokens <= self.threshold() {
|
||||
break;
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(
|
||||
pass = pass + 1,
|
||||
tokens = tokens,
|
||||
"Compression pass"
|
||||
);
|
||||
tracing::debug!(pass = pass + 1, tokens = tokens, "Compression pass");
|
||||
|
||||
match self.compress_once(¤t_history).await {
|
||||
Ok(Some(compressed)) => {
|
||||
@ -241,15 +345,52 @@ impl ContextCompressor {
|
||||
|
||||
// Hard safety net: if still dangerously high after all passes,
|
||||
// fall back to head+tail truncation so the LLM call doesn't overflow.
|
||||
let final_tokens = estimate_tokens(¤t_history);
|
||||
let final_tokens = self.token_estimate_with_history(¤t_history);
|
||||
let danger_threshold = (self.context_window as f64 * 0.9) as usize;
|
||||
if final_tokens > danger_threshold
|
||||
&& current_history.len() > self.config.protect_first_n + self.config.protect_last_n
|
||||
{
|
||||
let mut tail_start = current_history.len() - self.config.protect_last_n;
|
||||
|
||||
// Align tail_start backwards to preserve tool chain boundaries:
|
||||
// if an assistant with tool_calls has results spanning the cut,
|
||||
// include the assistant in the tail.
|
||||
if tail_start > 0 && tail_start < current_history.len() {
|
||||
let mut scan = tail_start.saturating_sub(1);
|
||||
loop {
|
||||
let m = ¤t_history[scan];
|
||||
if m.role == "assistant" {
|
||||
if let Some(tcs) = &m.tool_calls
|
||||
&& !tcs.is_empty()
|
||||
{
|
||||
let has_post = current_history[scan + 1..]
|
||||
.iter()
|
||||
.filter(|r| r.role == "tool")
|
||||
.any(|r| {
|
||||
tcs.iter()
|
||||
.any(|tc| r.tool_call_id.as_deref() == Some(tc.id.as_str()))
|
||||
});
|
||||
if has_post {
|
||||
tail_start = scan;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
if scan == 0 {
|
||||
break;
|
||||
}
|
||||
scan -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Skip orphan tool messages at the new head-tail boundary
|
||||
while tail_start < current_history.len() && current_history[tail_start].role == "tool" {
|
||||
tail_start += 1;
|
||||
}
|
||||
|
||||
let head: Vec<_> = current_history[..self.config.protect_first_n].to_vec();
|
||||
let tail_start = current_history.len() - self.config.protect_last_n;
|
||||
let tail: Vec<_> = current_history[tail_start..].to_vec();
|
||||
let dropped = current_history.len() - self.config.protect_first_n - self.config.protect_last_n;
|
||||
let dropped = current_history.len() - self.config.protect_first_n - tail.len();
|
||||
|
||||
let mut truncated = head;
|
||||
truncated.push(ChatMessage::user(format!(
|
||||
@ -259,6 +400,26 @@ impl ContextCompressor {
|
||||
)));
|
||||
truncated.extend(tail);
|
||||
|
||||
// Strip tool_calls from any assistant in the head whose results
|
||||
// were dropped (previously in the middle section).
|
||||
for msg in &mut truncated[..self.config.protect_first_n] {
|
||||
if msg.role == "assistant" {
|
||||
if let Some(ref tcs) = msg.tool_calls
|
||||
&& !tcs.is_empty()
|
||||
{
|
||||
let names: Vec<&str> = tcs.iter().map(|tc| tc.name.as_str()).collect();
|
||||
msg.content = format!(
|
||||
"{}\n\n[Tool calls ({}) — results dropped during truncation]",
|
||||
msg.content,
|
||||
names.join(", ")
|
||||
);
|
||||
msg.tool_calls = None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Self::repair_tool_pairs(&mut truncated);
|
||||
|
||||
tracing::warn!(
|
||||
final_tokens = final_tokens,
|
||||
danger = danger_threshold,
|
||||
@ -269,14 +430,21 @@ impl ContextCompressor {
|
||||
current_history = truncated;
|
||||
}
|
||||
|
||||
if created_timelines {
|
||||
self.invalidate_token_cache();
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(
|
||||
final_tokens = estimate_tokens(¤t_history),
|
||||
final_tokens = self.token_estimate_with_history(¤t_history),
|
||||
final_msg_count = current_history.len(),
|
||||
"Context compression completed"
|
||||
);
|
||||
|
||||
Ok(CompressionResult { history: current_history, created_timelines })
|
||||
Ok(CompressionResult {
|
||||
history: current_history,
|
||||
created_timelines,
|
||||
})
|
||||
}
|
||||
|
||||
/// Try to extract the actual context token limit from an LLM error message.
|
||||
@ -299,20 +467,21 @@ impl ContextCompressor {
|
||||
// Look for a number in the vicinity (up to 10 chars after marker)
|
||||
if let Some(num_str) = find_number_nearby(after, 50)
|
||||
&& let Ok(n) = num_str.parse::<usize>()
|
||||
&& (1024..=10_000_000).contains(&n) {
|
||||
return Some(n);
|
||||
}
|
||||
&& (1024..=10_000_000).contains(&n)
|
||||
{
|
||||
return Some(n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also try: "XXXX token context" or "XXXX limit"
|
||||
if let Some(num_str) = find_number_nearby(&lower, lower.len())
|
||||
&& let Ok(n) = num_str.parse::<usize>()
|
||||
&& (1024..=10_000_000).contains(&n)
|
||||
&& (lower.contains("token") || lower.contains("context") || lower.contains("limit"))
|
||||
{
|
||||
return Some(n);
|
||||
}
|
||||
&& (1024..=10_000_000).contains(&n)
|
||||
&& (lower.contains("token") || lower.contains("context") || lower.contains("limit"))
|
||||
{
|
||||
return Some(n);
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
@ -361,19 +530,26 @@ impl ContextCompressor {
|
||||
|
||||
// Persist compressed summary as timeline memory entry
|
||||
let ts = chrono::Utc::now().format("%Y-%m-%d %H:%M").to_string();
|
||||
let timeline_content = format!("[{}] Compressed {} conversation segments:\n{}",
|
||||
ts, between.len(), summary);
|
||||
let timeline_content = format!(
|
||||
"[{}] Compressed {} conversation segments:\n{}",
|
||||
ts,
|
||||
between.len(),
|
||||
summary
|
||||
);
|
||||
let key = format!("ctx_compressed_{}", uuid::Uuid::new_v4());
|
||||
let mm = self.memory.clone();
|
||||
let sid = self.session_id.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = mm.store(
|
||||
&key,
|
||||
&timeline_content,
|
||||
crate::memory::MemoryCategory::Timeline,
|
||||
sid.as_deref(),
|
||||
Some(0.3),
|
||||
).await {
|
||||
if let Err(e) = mm
|
||||
.store(
|
||||
&key,
|
||||
&timeline_content,
|
||||
crate::memory::MemoryCategory::Timeline,
|
||||
sid.as_deref(),
|
||||
Some(0.3),
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::warn!(error = %e, "Failed to store compressed context as timeline");
|
||||
}
|
||||
});
|
||||
@ -404,10 +580,7 @@ impl ContextCompressor {
|
||||
}
|
||||
|
||||
/// Summarize a segment of messages using LLM.
|
||||
async fn summarize_segment(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
) -> Result<String, AgentError> {
|
||||
async fn summarize_segment(&self, messages: &[ChatMessage]) -> Result<String, AgentError> {
|
||||
if messages.is_empty() {
|
||||
return Ok(String::new());
|
||||
}
|
||||
@ -421,7 +594,8 @@ impl ContextCompressor {
|
||||
"tool" => "Tool",
|
||||
_ => m.role.as_str(),
|
||||
};
|
||||
let name = m.tool_name
|
||||
let name = m
|
||||
.tool_name
|
||||
.as_ref()
|
||||
.map(|n| format!(" ({})", n))
|
||||
.unwrap_or_default();
|
||||
@ -466,7 +640,10 @@ Be concise, aim for {} characters or less.
|
||||
);
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
messages: vec![Message::system("You are a helpful assistant."), Message::user(&prompt)],
|
||||
messages: vec![
|
||||
Message::system("You are a helpful assistant."),
|
||||
Message::user(&prompt),
|
||||
],
|
||||
temperature: Some(0.3),
|
||||
max_tokens: Some(1000),
|
||||
tools: None,
|
||||
@ -538,13 +715,23 @@ mod tests {
|
||||
content: "[summarized]".into(),
|
||||
reasoning_content: None,
|
||||
tool_calls: vec![],
|
||||
usage: Usage { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 },
|
||||
usage: Usage {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn ptype(&self) -> &str { "mock" }
|
||||
fn name(&self) -> &str { "mock" }
|
||||
fn model_id(&self) -> &str { "mock" }
|
||||
fn ptype(&self) -> &str {
|
||||
"mock"
|
||||
}
|
||||
fn name(&self) -> &str {
|
||||
"mock"
|
||||
}
|
||||
fn model_id(&self) -> &str {
|
||||
"mock"
|
||||
}
|
||||
}
|
||||
|
||||
fn mock_summarizer() -> Arc<dyn LLMProvider> {
|
||||
@ -556,11 +743,13 @@ mod tests {
|
||||
MM.get_or_init(|| {
|
||||
let rt = tokio::runtime::Runtime::new().unwrap();
|
||||
rt.block_on(async {
|
||||
let tmp = std::env::temp_dir().join(format!("picobot_ctx_test_{}.db", std::process::id()));
|
||||
let tmp = std::env::temp_dir()
|
||||
.join(format!("picobot_ctx_test_{}.db", std::process::id()));
|
||||
let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap());
|
||||
Arc::new(MemoryManager::new(storage, "test".into(), "test".into()))
|
||||
})
|
||||
}).clone()
|
||||
})
|
||||
.clone()
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -576,7 +765,11 @@ mod tests {
|
||||
// "Hi there!" (8) -> ceil(8/4)+4 = 2+4 = 6
|
||||
// "How are you?" (11) -> ceil(11/4)+4 = 3+4 = 7
|
||||
// raw = 19, with 1.2x = ~23
|
||||
assert!(tokens > 18 && tokens < 30, "Expected ~23 tokens, got {}", tokens);
|
||||
assert!(
|
||||
tokens > 18 && tokens < 30,
|
||||
"Expected ~23 tokens, got {}",
|
||||
tokens
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -585,14 +778,15 @@ mod tests {
|
||||
tool_result_trim_chars: 50,
|
||||
..Default::default()
|
||||
};
|
||||
let compressor = ContextCompressor::with_config(mock_provider(), 100_000, config, test_memory_manager());
|
||||
let compressor =
|
||||
ContextCompressor::with_config(mock_provider(), 100_000, config, test_memory_manager());
|
||||
|
||||
let mut messages = vec![
|
||||
ChatMessage::user("Hello"),
|
||||
ChatMessage::tool("call1", "bash", &"x".repeat(200)),
|
||||
];
|
||||
|
||||
let modified = compressor.fast_trim_tool_results(&mut messages);
|
||||
let modified = compressor.fast_trim_tool_results(&mut messages, 2);
|
||||
assert_eq!(modified, 1);
|
||||
assert!(messages[1].content.len() < 100);
|
||||
}
|
||||
@ -619,14 +813,18 @@ mod tests {
|
||||
max_passes: 0,
|
||||
..Default::default()
|
||||
};
|
||||
let compressor = ContextCompressor::with_config(mock_provider(), 200, config, mm);
|
||||
let mut compressor = ContextCompressor::with_config(mock_provider(), 200, config, mm);
|
||||
|
||||
let messages = vec![
|
||||
ChatMessage::user("Hi"),
|
||||
ChatMessage::tool("call1", "bash", &"x".repeat(3000)),
|
||||
];
|
||||
|
||||
let result = compressor.compress_if_needed(messages).await.unwrap().history;
|
||||
let result = compressor
|
||||
.compress_if_needed(messages)
|
||||
.await
|
||||
.unwrap()
|
||||
.history;
|
||||
|
||||
let tool_msg = result.iter().find(|m| m.role == "tool").unwrap();
|
||||
assert!(
|
||||
@ -650,18 +848,19 @@ mod tests {
|
||||
// - B2B (L275): last user message lost when it is the final history message
|
||||
//
|
||||
// context_window=200 → threshold=100. Large tool outputs force LLM summarization.
|
||||
let tmp = std::env::temp_dir().join(format!("picobot_ctx_boundary_{}.db", std::process::id()));
|
||||
let tmp =
|
||||
std::env::temp_dir().join(format!("picobot_ctx_boundary_{}.db", std::process::id()));
|
||||
let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap());
|
||||
let mm = Arc::new(MemoryManager::new(storage, "test".into(), "test".into()));
|
||||
|
||||
let config = ContextCompressionConfig {
|
||||
tool_result_trim_chars: 2000,
|
||||
protect_first_n: 1, // system/protected → B2A: first user (after skip) duplicated
|
||||
protect_first_n: 1, // system/protected → B2A: first user (after skip) duplicated
|
||||
protect_last_n: 2,
|
||||
max_passes: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let compressor = ContextCompressor::with_config(mock_summarizer(), 200, config, mm);
|
||||
let mut compressor = ContextCompressor::with_config(mock_summarizer(), 200, config, mm);
|
||||
|
||||
// History: 9 messages, last message is user Q4.
|
||||
// user_indices (skip 1) = [1, 3, 6, 8]
|
||||
@ -670,25 +869,43 @@ mod tests {
|
||||
let big = "x".repeat(3000);
|
||||
let messages = vec![
|
||||
ChatMessage::system("You are a helper."), // 0: protected
|
||||
ChatMessage::user("Q1"), // 1: first user
|
||||
ChatMessage::tool("t1", "bash", &big), // 2
|
||||
ChatMessage::user("Q2"), // 3
|
||||
ChatMessage::assistant("thinking"), // 4
|
||||
ChatMessage::tool("t2", "bash", &big), // 5
|
||||
ChatMessage::user("Q3"), // 6
|
||||
ChatMessage::assistant("thinking"), // 7
|
||||
ChatMessage::user("Q4"), // 8: LAST, is user → B2B triggers
|
||||
ChatMessage::user("Q1"), // 1: first user
|
||||
ChatMessage::tool("t1", "bash", &big), // 2
|
||||
ChatMessage::user("Q2"), // 3
|
||||
ChatMessage::assistant("thinking"), // 4
|
||||
ChatMessage::tool("t2", "bash", &big), // 5
|
||||
ChatMessage::user("Q3"), // 6
|
||||
ChatMessage::assistant("thinking"), // 7
|
||||
ChatMessage::user("Q4"), // 8: LAST, is user → B2B triggers
|
||||
];
|
||||
|
||||
let result = compressor.compress_if_needed(messages).await.unwrap().history;
|
||||
let result = compressor
|
||||
.compress_if_needed(messages)
|
||||
.await
|
||||
.unwrap()
|
||||
.history;
|
||||
|
||||
// B2A: "Q1" must appear exactly once
|
||||
let q1_count = result.iter().filter(|m| m.role == "user" && m.content == "Q1").count();
|
||||
assert_eq!(q1_count, 1, "Q1 should appear exactly once, got {}", q1_count);
|
||||
let q1_count = result
|
||||
.iter()
|
||||
.filter(|m| m.role == "user" && m.content == "Q1")
|
||||
.count();
|
||||
assert_eq!(
|
||||
q1_count, 1,
|
||||
"Q1 should appear exactly once, got {}",
|
||||
q1_count
|
||||
);
|
||||
|
||||
// B2B: "Q4" must NOT be lost
|
||||
let q4_count = result.iter().filter(|m| m.role == "user" && m.content == "Q4").count();
|
||||
assert_eq!(q4_count, 1, "Q4 should appear exactly once (not lost), got {}", q4_count);
|
||||
let q4_count = result
|
||||
.iter()
|
||||
.filter(|m| m.role == "user" && m.content == "Q4")
|
||||
.count();
|
||||
assert_eq!(
|
||||
q4_count, 1,
|
||||
"Q4 should appear exactly once (not lost), got {}",
|
||||
q4_count
|
||||
);
|
||||
|
||||
let _ = std::fs::remove_file(&tmp);
|
||||
}
|
||||
@ -702,16 +919,16 @@ mod tests {
|
||||
let mm = Arc::new(MemoryManager::new(storage, "test".into(), "test".into()));
|
||||
|
||||
let config = ContextCompressionConfig {
|
||||
tool_result_trim_chars: 500, // trim reduces but not enough
|
||||
tool_result_trim_chars: 500, // trim reduces but not enough
|
||||
protect_first_n: 1,
|
||||
protect_last_n: 2,
|
||||
max_passes: 0, // no LLM summarization → will exceed danger
|
||||
max_passes: 0, // no LLM summarization → will exceed danger
|
||||
..Default::default()
|
||||
};
|
||||
// context_window=100, danger_threshold=90.
|
||||
// Each trimmed tool (~500 chars): ceil(500/4)+4 = 129 raw. 3 tools = 387.
|
||||
// Plus users (~5 each) + system (~15) = ~417 raw * 1.2 = 500 > 90.
|
||||
let compressor = ContextCompressor::with_config(mock_provider(), 100, config, mm);
|
||||
let mut compressor = ContextCompressor::with_config(mock_provider(), 100, config, mm);
|
||||
|
||||
let big = "x".repeat(3000);
|
||||
let messages = vec![
|
||||
@ -724,13 +941,23 @@ mod tests {
|
||||
ChatMessage::tool("t3", "bash", &big),
|
||||
];
|
||||
|
||||
let result = compressor.compress_if_needed(messages).await.unwrap().history;
|
||||
let result = compressor
|
||||
.compress_if_needed(messages)
|
||||
.await
|
||||
.unwrap()
|
||||
.history;
|
||||
|
||||
// After hard truncation: head (1) + trunc_note (1) + tail (2) = 4 messages
|
||||
assert!(result.len() < 7, "expected truncation reduction, got {} messages", result.len());
|
||||
assert!(
|
||||
result.len() < 7,
|
||||
"expected truncation reduction, got {} messages",
|
||||
result.len()
|
||||
);
|
||||
|
||||
// Truncation notice should be present
|
||||
let has_notice = result.iter().any(|m| m.content.contains("Context truncation"));
|
||||
let has_notice = result
|
||||
.iter()
|
||||
.any(|m| m.content.contains("Context truncation"));
|
||||
assert!(has_notice, "hard truncation notice missing");
|
||||
|
||||
let _ = std::fs::remove_file(&tmp);
|
||||
@ -745,9 +972,9 @@ mod tests {
|
||||
let mut messages = vec![
|
||||
ChatMessage::user("Q1"),
|
||||
ChatMessage::user("[Context Summary]\n\nsummary of previous turn"),
|
||||
ChatMessage::tool("tc1", "bash", "orphan result"), // orphan — tc1 never declared
|
||||
ChatMessage::assistant("done"), // declares tc2
|
||||
ChatMessage::tool("tc2", "bash", "legitimate result"), // legit
|
||||
ChatMessage::tool("tc1", "bash", "orphan result"), // orphan — tc1 never declared
|
||||
ChatMessage::assistant("done"), // declares tc2
|
||||
ChatMessage::tool("tc2", "bash", "legitimate result"), // legit
|
||||
];
|
||||
// Set tool_call_id on tool messages and tool_calls on assistant
|
||||
messages[2].tool_call_id = Some("tc1".into());
|
||||
@ -762,8 +989,16 @@ mod tests {
|
||||
|
||||
// orphan should be removed; legitimate should stay
|
||||
assert_eq!(messages.len(), 4);
|
||||
assert!(messages.iter().all(|m| m.tool_call_id != Some("tc1".into())));
|
||||
assert!(messages.iter().any(|m| m.tool_call_id == Some("tc2".into())));
|
||||
assert!(
|
||||
messages
|
||||
.iter()
|
||||
.all(|m| m.tool_call_id != Some("tc1".into()))
|
||||
);
|
||||
assert!(
|
||||
messages
|
||||
.iter()
|
||||
.any(|m| m.tool_call_id == Some("tc2".into()))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@ -49,7 +49,7 @@ impl MediaHandler for ImageHandler {
|
||||
}
|
||||
|
||||
fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> {
|
||||
use base64::{engine::general_purpose::STANDARD, Engine as _};
|
||||
use base64::{Engine as _, engine::general_purpose::STANDARD};
|
||||
|
||||
let mut file = std::fs::File::open(path)?;
|
||||
let mut buffer = Vec::new();
|
||||
|
||||
@ -1,8 +1,16 @@
|
||||
pub mod agent_loop;
|
||||
pub mod context_compressor;
|
||||
pub mod media_handler;
|
||||
pub mod sub_agent;
|
||||
pub mod system_prompt;
|
||||
|
||||
pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult};
|
||||
pub use agent_loop::{AgentError, AgentLoop, AgentProcessResult};
|
||||
pub use context_compressor::{ContextCompressor, estimate_tokens};
|
||||
pub use system_prompt::{build_system_prompt, PromptContext, PromptSection, SystemPromptBuilder};
|
||||
pub use sub_agent::{
|
||||
DelegateContext, ExecutionMode, SubAgentConfig, SubAgentError, SubAgentManager, SubAgentResult,
|
||||
TaskNotification, TaskStatus,
|
||||
};
|
||||
pub use system_prompt::{
|
||||
PromptContext, PromptSection, SystemPromptBuilder, build_sub_agent_system_prompt,
|
||||
build_system_prompt,
|
||||
};
|
||||
|
||||
623
src/agent/sub_agent.rs
Normal file
623
src/agent/sub_agent.rs
Normal file
@ -0,0 +1,623 @@
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::agent::AgentError;
|
||||
use crate::agent::AgentLoop;
|
||||
use crate::agent::system_prompt::build_sub_agent_system_prompt;
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::providers::{LLMProvider, create_provider};
|
||||
use crate::skills::SkillsLoader;
|
||||
use crate::tools::ToolRegistry;
|
||||
|
||||
tokio::task_local! {
|
||||
pub(crate) static DELEGATE_CONTEXT: DelegateContext;
|
||||
}
|
||||
|
||||
/// Read the delegate context from the current task. Returns an error if not set.
|
||||
pub fn get_delegate_context() -> Result<DelegateContext, String> {
|
||||
DELEGATE_CONTEXT
|
||||
.try_with(|ctx| ctx.clone())
|
||||
.map_err(|_| "DELEGATE_CONTEXT not set".to_string())
|
||||
}
|
||||
|
||||
const DEFAULT_MAX_ITERATIONS: usize = 99;
|
||||
const DEFAULT_TIMEOUT_SECS: u64 = 3600;
|
||||
const MAX_INLINE_RESULT_CHARS: usize = 8000;
|
||||
|
||||
const DEFAULT_READONLY_TOOLS: &[&str] = &[
|
||||
"file_read",
|
||||
"file_search",
|
||||
"content_search",
|
||||
"web_fetch",
|
||||
"http_request",
|
||||
"calculator",
|
||||
];
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SubAgentConfig {
|
||||
pub prompt: String,
|
||||
pub mode: ExecutionMode,
|
||||
pub allowed_tools: Option<Vec<String>>,
|
||||
pub max_iterations: Option<usize>,
|
||||
pub timeout_secs: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum ExecutionMode {
|
||||
Inline,
|
||||
Background,
|
||||
Parallel,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SubAgentResult {
|
||||
pub task_id: String,
|
||||
pub content: String,
|
||||
pub content_truncated: bool,
|
||||
pub status: TaskStatus,
|
||||
pub tool_calls_count: usize,
|
||||
pub iterations: usize,
|
||||
pub duration_ms: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum TaskStatus {
|
||||
Completed,
|
||||
Failed(String),
|
||||
Cancelled,
|
||||
TimedOut,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TaskNotification {
|
||||
pub task_id: String,
|
||||
pub session_id: String,
|
||||
pub channel: String,
|
||||
pub chat_id: String,
|
||||
pub status: TaskStatus,
|
||||
pub result_summary: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DelegateContext {
|
||||
pub session_id: String,
|
||||
pub channel: String,
|
||||
pub chat_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum SubAgentError {
|
||||
TooManyTasks(usize),
|
||||
ProviderCreation(String),
|
||||
Storage(String),
|
||||
Other(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for SubAgentError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::TooManyTasks(max) => write!(f, "后台任务已达上限({}),请稍后重试", max),
|
||||
Self::ProviderCreation(e) => write!(f, "provider creation failed: {}", e),
|
||||
Self::Storage(e) => write!(f, "storage error: {}", e),
|
||||
Self::Other(e) => write!(f, "{}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for SubAgentError {}
|
||||
|
||||
pub struct SubAgentManager {
|
||||
provider_config: LLMProviderConfig,
|
||||
full_tools: Arc<ToolRegistry>,
|
||||
storage: Option<Arc<crate::storage::Storage>>,
|
||||
active_tasks: Arc<DashMap<String, CancellationToken>>,
|
||||
notify_tx: tokio::sync::mpsc::UnboundedSender<TaskNotification>,
|
||||
max_concurrent_background_tasks: usize,
|
||||
skills_loader: Option<Arc<SkillsLoader>>,
|
||||
}
|
||||
|
||||
impl SubAgentManager {
|
||||
pub fn new(
|
||||
provider_config: LLMProviderConfig,
|
||||
full_tools: Arc<ToolRegistry>,
|
||||
storage: Option<Arc<crate::storage::Storage>>,
|
||||
notify_tx: tokio::sync::mpsc::UnboundedSender<TaskNotification>,
|
||||
max_concurrent_background_tasks: usize,
|
||||
skills_loader: Option<Arc<SkillsLoader>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
provider_config,
|
||||
full_tools,
|
||||
storage,
|
||||
active_tasks: Arc::new(DashMap::new()),
|
||||
notify_tx,
|
||||
max_concurrent_background_tasks,
|
||||
skills_loader,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn filter_tools(&self, allowed: &Option<Vec<String>>) -> Arc<ToolRegistry> {
|
||||
let allowed_set: HashSet<&str> = match allowed {
|
||||
Some(list) => list.iter().map(|s| s.as_str()).collect(),
|
||||
None => DEFAULT_READONLY_TOOLS.iter().copied().collect(),
|
||||
};
|
||||
let filtered = ToolRegistry::new();
|
||||
for (name, tool) in self.full_tools.iter() {
|
||||
if allowed_set.contains(name.as_str()) && name != "delegate" {
|
||||
filtered.register_raw(name, tool);
|
||||
}
|
||||
}
|
||||
Arc::new(filtered)
|
||||
}
|
||||
|
||||
fn get_skills_prompt(&self, tools: &ToolRegistry) -> Option<String> {
|
||||
let has_get_skill = tools.iter().iter().any(|(name, _)| name == "get_skill");
|
||||
if has_get_skill {
|
||||
if let Some(ref loader) = self.skills_loader {
|
||||
let prompt = loader.build_skills_prompt();
|
||||
if !prompt.is_empty() {
|
||||
return Some(prompt);
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub fn build_sub_agent(
|
||||
&self,
|
||||
config: &SubAgentConfig,
|
||||
tools: Arc<ToolRegistry>,
|
||||
) -> Result<AgentLoop, AgentError> {
|
||||
let mut provider = create_provider(self.provider_config.clone())
|
||||
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
||||
if let Some(ref s) = self.storage {
|
||||
provider.set_storage(s.clone());
|
||||
}
|
||||
let provider: Arc<dyn LLMProvider> = Arc::from(provider);
|
||||
|
||||
let max_iterations = config.max_iterations.unwrap_or(DEFAULT_MAX_ITERATIONS);
|
||||
let workspace_dir = self.provider_config.workspace_dir.clone();
|
||||
let model_name = self.provider_config.model_id.clone();
|
||||
let input_types = self.provider_config.input_types.clone();
|
||||
|
||||
let agent = AgentLoop::with_provider_and_tools(
|
||||
provider,
|
||||
tools,
|
||||
max_iterations,
|
||||
model_name,
|
||||
workspace_dir,
|
||||
input_types,
|
||||
)
|
||||
.with_context_window(self.provider_config.token_limit);
|
||||
|
||||
Ok(agent)
|
||||
}
|
||||
|
||||
pub async fn run_inline(
|
||||
&self,
|
||||
config: SubAgentConfig,
|
||||
) -> Result<SubAgentResult, SubAgentError> {
|
||||
let task_id = generate_task_id();
|
||||
let tools = self.filter_tools(&config.allowed_tools);
|
||||
let timeout_secs = config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
|
||||
let timeout_human = format_duration(timeout_secs);
|
||||
let http_get_only = config.allowed_tools.is_none()
|
||||
|| config
|
||||
.allowed_tools
|
||||
.as_ref()
|
||||
.is_some_and(|v| v.iter().any(|t| t == "http_request"));
|
||||
let skills_prompt = self.get_skills_prompt(&tools);
|
||||
let system_prompt = build_sub_agent_system_prompt(
|
||||
&config.prompt,
|
||||
&timeout_human,
|
||||
&tools,
|
||||
&self.provider_config.workspace_dir,
|
||||
&self.provider_config.model_id,
|
||||
skills_prompt,
|
||||
http_get_only,
|
||||
);
|
||||
|
||||
let agent = self
|
||||
.build_sub_agent(&config, tools)
|
||||
.map_err(|e| SubAgentError::ProviderCreation(e.to_string()))?;
|
||||
|
||||
let history = vec![
|
||||
ChatMessage::system(system_prompt),
|
||||
ChatMessage::user(&config.prompt),
|
||||
];
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
let result = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(timeout_secs),
|
||||
agent.process(history),
|
||||
)
|
||||
.await;
|
||||
|
||||
let duration_ms = start.elapsed().as_millis() as u64;
|
||||
|
||||
match result {
|
||||
Ok(Ok(agent_result)) => {
|
||||
let (content, truncated) =
|
||||
truncate_sub_agent_result(&agent_result.final_response.content);
|
||||
let tool_calls_count = agent_result
|
||||
.emitted_messages
|
||||
.iter()
|
||||
.filter(|m| m.tool_calls.is_some())
|
||||
.count();
|
||||
let iterations = agent_result
|
||||
.emitted_messages
|
||||
.iter()
|
||||
.filter(|m| m.role == "assistant" && m.tool_calls.is_some())
|
||||
.count();
|
||||
Ok(SubAgentResult {
|
||||
task_id,
|
||||
content,
|
||||
content_truncated: truncated,
|
||||
status: TaskStatus::Completed,
|
||||
tool_calls_count,
|
||||
iterations,
|
||||
duration_ms,
|
||||
})
|
||||
}
|
||||
Ok(Err(e)) => Ok(SubAgentResult {
|
||||
task_id,
|
||||
content: String::new(),
|
||||
content_truncated: false,
|
||||
status: TaskStatus::Failed(e.to_string()),
|
||||
tool_calls_count: 0,
|
||||
iterations: 0,
|
||||
duration_ms,
|
||||
}),
|
||||
Err(_elapsed) => Ok(SubAgentResult {
|
||||
task_id,
|
||||
content: String::new(),
|
||||
content_truncated: false,
|
||||
status: TaskStatus::TimedOut,
|
||||
tool_calls_count: 0,
|
||||
iterations: 0,
|
||||
duration_ms,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_parallel(
|
||||
&self,
|
||||
configs: Vec<SubAgentConfig>,
|
||||
) -> Result<Vec<SubAgentResult>, SubAgentError> {
|
||||
let futures: Vec<_> = configs
|
||||
.into_iter()
|
||||
.map(|config| {
|
||||
let mgr = self; // &self borrow, all tasks share the same manager
|
||||
async move { mgr.run_inline(config).await }
|
||||
})
|
||||
.collect();
|
||||
|
||||
let results = futures_util::future::join_all(futures).await;
|
||||
Ok(results.into_iter().collect::<Result<Vec<_>, _>>()?)
|
||||
}
|
||||
|
||||
pub async fn run_background(
|
||||
&self,
|
||||
config: SubAgentConfig,
|
||||
ctx: DelegateContext,
|
||||
) -> Result<String, SubAgentError> {
|
||||
if self.active_tasks.len() >= self.max_concurrent_background_tasks {
|
||||
return Err(SubAgentError::TooManyTasks(
|
||||
self.max_concurrent_background_tasks,
|
||||
));
|
||||
}
|
||||
|
||||
let task_id = generate_task_id();
|
||||
let cancel_token = CancellationToken::new();
|
||||
|
||||
self.active_tasks
|
||||
.insert(task_id.clone(), cancel_token.clone());
|
||||
|
||||
// Write DB: pending
|
||||
if let Some(ref storage) = self.storage {
|
||||
let allowed_tools_json = config
|
||||
.allowed_tools
|
||||
.as_ref()
|
||||
.and_then(|v| serde_json::to_string(v).ok());
|
||||
let record = crate::storage::BackgroundTask {
|
||||
id: task_id.clone(),
|
||||
session_id: ctx.session_id.clone(),
|
||||
channel: ctx.channel.clone(),
|
||||
chat_id: ctx.chat_id.clone(),
|
||||
prompt: config.prompt.clone(),
|
||||
allowed_tools: allowed_tools_json,
|
||||
status: "pending".to_string(),
|
||||
result: None,
|
||||
error: None,
|
||||
tool_calls_count: 0,
|
||||
iterations: 0,
|
||||
started_at: None,
|
||||
finished_at: None,
|
||||
created_at: chrono::Utc::now().timestamp_millis(),
|
||||
};
|
||||
storage
|
||||
.create_background_task(&record)
|
||||
.await
|
||||
.map_err(|e| SubAgentError::Storage(e.to_string()))?;
|
||||
}
|
||||
|
||||
let tools = self.filter_tools(&config.allowed_tools);
|
||||
let timeout_secs = config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
|
||||
let timeout_human = format_duration(timeout_secs);
|
||||
let http_get_only = config.allowed_tools.is_none()
|
||||
|| config
|
||||
.allowed_tools
|
||||
.as_ref()
|
||||
.is_some_and(|v| v.iter().any(|t| t == "http_request"));
|
||||
let skills_prompt = self.get_skills_prompt(&tools);
|
||||
let system_prompt = build_sub_agent_system_prompt(
|
||||
&config.prompt,
|
||||
&timeout_human,
|
||||
&tools,
|
||||
&self.provider_config.workspace_dir,
|
||||
&self.provider_config.model_id,
|
||||
skills_prompt,
|
||||
http_get_only,
|
||||
);
|
||||
let provider_config = self.provider_config.clone();
|
||||
let storage = self.storage.clone();
|
||||
let notify_tx = self.notify_tx.clone();
|
||||
let active_tasks = Arc::clone(&self.active_tasks);
|
||||
|
||||
let tid = task_id.clone();
|
||||
let sess_id = ctx.session_id.clone();
|
||||
let ch = ctx.channel.clone();
|
||||
let cid = ctx.chat_id.clone();
|
||||
let prompt = config.prompt.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let started_at = chrono::Utc::now().timestamp_millis();
|
||||
|
||||
// Update DB: running
|
||||
if let Some(ref s) = storage {
|
||||
let _ = s
|
||||
.update_background_task_status(
|
||||
&tid,
|
||||
"running",
|
||||
None,
|
||||
None,
|
||||
Some(started_at),
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
let mut provider = create_provider(provider_config.clone()).ok();
|
||||
if let Some(ref mut p) = provider {
|
||||
if let Some(ref s) = storage {
|
||||
p.set_storage(s.clone());
|
||||
}
|
||||
}
|
||||
let provider_result: Option<Arc<dyn LLMProvider>> = provider.map(|p| Arc::from(p));
|
||||
|
||||
let result = match provider_result {
|
||||
Some(provider) => {
|
||||
let agent = AgentLoop::with_provider_and_tools(
|
||||
provider,
|
||||
tools,
|
||||
DEFAULT_MAX_ITERATIONS,
|
||||
provider_config.model_id.clone(),
|
||||
provider_config.workspace_dir.clone(),
|
||||
provider_config.input_types.clone(),
|
||||
)
|
||||
.with_context_window(provider_config.token_limit);
|
||||
|
||||
let history = vec![
|
||||
ChatMessage::system(system_prompt),
|
||||
ChatMessage::user(&prompt),
|
||||
];
|
||||
|
||||
tokio::select! {
|
||||
r = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(timeout_secs),
|
||||
agent.process(history),
|
||||
) => {
|
||||
match r {
|
||||
Ok(Ok(agent_result)) => SubAgentResult {
|
||||
task_id: tid.clone(),
|
||||
content: agent_result.final_response.content,
|
||||
content_truncated: false,
|
||||
status: TaskStatus::Completed,
|
||||
tool_calls_count: 0,
|
||||
iterations: 0,
|
||||
duration_ms: 0,
|
||||
},
|
||||
Ok(Err(e)) => SubAgentResult {
|
||||
task_id: tid.clone(),
|
||||
content: String::new(),
|
||||
content_truncated: false,
|
||||
status: TaskStatus::Failed(e.to_string()),
|
||||
tool_calls_count: 0,
|
||||
iterations: 0,
|
||||
duration_ms: 0,
|
||||
},
|
||||
Err(_) => SubAgentResult {
|
||||
task_id: tid.clone(),
|
||||
content: String::new(),
|
||||
content_truncated: false,
|
||||
status: TaskStatus::TimedOut,
|
||||
tool_calls_count: 0,
|
||||
iterations: 0,
|
||||
duration_ms: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
_ = cancel_token.cancelled() => SubAgentResult {
|
||||
task_id: tid.clone(),
|
||||
content: String::new(),
|
||||
content_truncated: false,
|
||||
status: TaskStatus::Cancelled,
|
||||
tool_calls_count: 0,
|
||||
iterations: 0,
|
||||
duration_ms: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
None => SubAgentResult {
|
||||
task_id: tid.clone(),
|
||||
content: String::new(),
|
||||
content_truncated: false,
|
||||
status: TaskStatus::Failed("provider creation failed".into()),
|
||||
tool_calls_count: 0,
|
||||
iterations: 0,
|
||||
duration_ms: 0,
|
||||
},
|
||||
};
|
||||
|
||||
let finished_at = chrono::Utc::now().timestamp_millis();
|
||||
let duration_ms = (finished_at - started_at) as u64;
|
||||
|
||||
let (status_str, error_val) = match &result.status {
|
||||
TaskStatus::Completed => ("completed".to_string(), None),
|
||||
TaskStatus::Failed(e) => ("failed".to_string(), Some(e.clone())),
|
||||
TaskStatus::Cancelled => ("cancelled".to_string(), None),
|
||||
TaskStatus::TimedOut => ("failed".to_string(), Some("timeout".to_string())),
|
||||
};
|
||||
|
||||
if let Some(ref s) = storage {
|
||||
let _ = s
|
||||
.update_background_task_status(
|
||||
&tid,
|
||||
&status_str,
|
||||
Some(&result.content),
|
||||
error_val.as_deref(),
|
||||
Some(started_at),
|
||||
Some(finished_at),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
let _ = notify_tx.send(TaskNotification {
|
||||
task_id: tid.clone(),
|
||||
session_id: sess_id,
|
||||
channel: ch,
|
||||
chat_id: cid,
|
||||
status: result.status,
|
||||
result_summary: summarize_for_notification(&result.content, duration_ms),
|
||||
});
|
||||
|
||||
active_tasks.remove(&tid);
|
||||
});
|
||||
|
||||
Ok(task_id)
|
||||
}
|
||||
|
||||
pub async fn cancel_task(&self, task_id: &str) -> Result<bool, SubAgentError> {
|
||||
if let Some((_, token)) = self.active_tasks.remove(task_id) {
|
||||
token.cancel();
|
||||
if let Some(ref s) = self.storage {
|
||||
s.update_background_task_status(
|
||||
task_id,
|
||||
"cancelled",
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
Some(chrono::Utc::now().timestamp_millis()),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| SubAgentError::Storage(e.to_string()))?;
|
||||
}
|
||||
Ok(true)
|
||||
} else if let Some(ref s) = self.storage {
|
||||
match s.get_background_task(task_id).await {
|
||||
Ok(task) => match task.status.as_str() {
|
||||
"pending" | "running" => {
|
||||
tracing::warn!(task_id, "task in DB but not in active_tasks");
|
||||
Ok(false)
|
||||
}
|
||||
_ => Ok(false),
|
||||
},
|
||||
Err(_) => Ok(false),
|
||||
}
|
||||
} else {
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn check_task(&self, task_id: &str) -> Option<crate::storage::BackgroundTask> {
|
||||
if let Some(ref s) = self.storage {
|
||||
s.get_background_task(task_id).await.ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn list_tasks(&self, session_id: &str) -> Vec<crate::storage::BackgroundTask> {
|
||||
if let Some(ref s) = self.storage {
|
||||
s.list_background_tasks(session_id)
|
||||
.await
|
||||
.unwrap_or_default()
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn cancel_by_session(&self, session_id: &str) {
|
||||
// Cancel all running tasks for a session by checking DB
|
||||
if let Some(ref s) = self.storage {
|
||||
if let Ok(tasks) = s.list_background_tasks(session_id).await {
|
||||
for task in &tasks {
|
||||
if task.status == "pending" || task.status == "running" {
|
||||
let _ = self.cancel_task(&task.id).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn active_task_count(&self) -> usize {
|
||||
self.active_tasks.len()
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_task_id() -> String {
|
||||
Uuid::new_v4().to_string()[..8].to_string()
|
||||
}
|
||||
|
||||
fn format_duration(seconds: u64) -> String {
|
||||
if seconds < 60 {
|
||||
format!("{}s", seconds)
|
||||
} else if seconds < 3600 {
|
||||
format!("{}m", seconds / 60)
|
||||
} else {
|
||||
format!("{}h", seconds / 3600)
|
||||
}
|
||||
}
|
||||
|
||||
fn truncate_sub_agent_result(content: &str) -> (String, bool) {
|
||||
if content.len() <= MAX_INLINE_RESULT_CHARS {
|
||||
(content.to_string(), false)
|
||||
} else {
|
||||
let truncate_at = content.floor_char_boundary(MAX_INLINE_RESULT_CHARS);
|
||||
(
|
||||
format!(
|
||||
"{}\n\n[... 结果已截断,共 {} 字符,完整结果请使用 check_task 查看 ...]",
|
||||
&content[..truncate_at],
|
||||
content.len()
|
||||
),
|
||||
true,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn summarize_for_notification(content: &str, _duration_ms: u64) -> String {
|
||||
const MAX_SUMMARY_BYTES: usize = 500;
|
||||
if content.len() <= MAX_SUMMARY_BYTES {
|
||||
content.to_string()
|
||||
} else {
|
||||
let truncate_at = content.floor_char_boundary(MAX_SUMMARY_BYTES);
|
||||
format!("{}...", &content[..truncate_at])
|
||||
}
|
||||
}
|
||||
@ -3,11 +3,7 @@
|
||||
//! This module provides a modular framework for building system prompts
|
||||
//! using the SystemPromptBuilder pattern.
|
||||
//!
|
||||
//! Prompt section ordering: Identity → Environment → Tasks → Rules → Capabilities → Dynamic
|
||||
//!
|
||||
//! Configuration files loaded from ~/.picobot/:
|
||||
//! - AGENTS.md — agent identity and behavior
|
||||
//! - USER.md — user preferences and profile
|
||||
//! Prompt section ordering: Identity → Environment → Tasks → Rules → Capabilities → Dynamic → Delegation
|
||||
|
||||
use crate::tools::ToolRegistry;
|
||||
use std::path::Path;
|
||||
@ -55,10 +51,35 @@ impl SystemPromptBuilder {
|
||||
Box::new(CrossChannelSection),
|
||||
Box::new(MemorySection),
|
||||
Box::new(HistorySection),
|
||||
Box::new(DelegationSection),
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a builder with sub-agent specific sections.
|
||||
pub fn with_sub_agent_defaults(
|
||||
task: &str,
|
||||
timeout: &str,
|
||||
skills_prompt: Option<String>,
|
||||
http_get_only: bool,
|
||||
) -> Self {
|
||||
let mut sections: Vec<Box<dyn PromptSection>> = vec![
|
||||
Box::new(SubAgentIdentitySection {
|
||||
task: task.to_string(),
|
||||
timeout: timeout.to_string(),
|
||||
}),
|
||||
Box::new(ToolHonestySection),
|
||||
Box::new(SafetySection),
|
||||
Box::new(SubAgentToolsSection { http_get_only }),
|
||||
Box::new(WorkspaceSection),
|
||||
Box::new(DateTimeSection),
|
||||
];
|
||||
if let Some(sp) = skills_prompt {
|
||||
sections.push(Box::new(SubAgentSkillsSection { skills_prompt: sp }));
|
||||
}
|
||||
Self { sections }
|
||||
}
|
||||
|
||||
/// Add a custom section to the builder.
|
||||
pub fn add_section(mut self, section: Box<dyn PromptSection>) -> Self {
|
||||
self.sections.push(section);
|
||||
@ -175,10 +196,10 @@ impl PromptSection for UserProfileSection {
|
||||
if let Some(user_config_dir) = get_user_config_dir()
|
||||
&& let Some(content) =
|
||||
load_file_from_dir(&user_config_dir, "USER.md", BOOTSTRAP_MAX_CHARS)
|
||||
{
|
||||
output.push_str(&content);
|
||||
return output;
|
||||
}
|
||||
{
|
||||
output.push_str(&content);
|
||||
return output;
|
||||
}
|
||||
|
||||
// No USER.md found, return empty
|
||||
String::new()
|
||||
@ -199,10 +220,10 @@ impl PromptSection for AgentProfileSection {
|
||||
if let Some(user_config_dir) = get_user_config_dir()
|
||||
&& let Some(content) =
|
||||
load_file_from_dir(&user_config_dir, "AGENTS.md", BOOTSTRAP_MAX_CHARS)
|
||||
{
|
||||
output.push_str(&content);
|
||||
return output;
|
||||
}
|
||||
{
|
||||
output.push_str(&content);
|
||||
return output;
|
||||
}
|
||||
|
||||
String::new()
|
||||
}
|
||||
@ -353,6 +374,120 @@ impl PromptSection for HistorySection {
|
||||
}
|
||||
}
|
||||
|
||||
/// Sub-agent delegation principles.
|
||||
pub struct DelegationSection;
|
||||
|
||||
impl PromptSection for DelegationSection {
|
||||
fn name(&self) -> &str {
|
||||
"delegation"
|
||||
}
|
||||
|
||||
fn build(&self, _ctx: &PromptContext<'_>) -> String {
|
||||
"## 子 Agent 委托原则\n\n\
|
||||
当任务复杂需要拆解时,使用 delegate 工具创建子 Agent:\n\
|
||||
\n\
|
||||
### 何时委托\n\
|
||||
- 多个独立子任务可以并行处理时(使用 mode=\"parallel\")\n\
|
||||
- 长时间运行的任务需要后台执行时(使用 mode=\"background\")\n\
|
||||
- 需要以不同权限(受限工具集)执行时\n\
|
||||
\n\
|
||||
### 工具分配原则\n\
|
||||
- **最小权限**:只给子 Agent 完成其任务所需的最少工具\n\
|
||||
- **只读优先**:如果可以只用 file_read、file_search、web_fetch 完成,不要给写权限(bash、file_write、file_edit)\n\
|
||||
- **禁止递归**:永远不要把 delegate 工具分配给子 Agent\n\
|
||||
- **明确边界**:每个子 Agent 只负责一个清晰、独立的子任务\n\
|
||||
\n\
|
||||
### Skill 分配原则\n\
|
||||
- 如果子任务的领域有对应的 skill,在 allowed_tools 中加入 get_skill\n\
|
||||
- 在任务 prompt 中明确告诉子 Agent 使用 get_skill 加载哪个技能\n\
|
||||
- 例如:\"使用 get_skill action='get' skill_name='pdf' 加载 PDF 处理技能后完成任务\"\n\
|
||||
\n\
|
||||
### 任务描述\n\
|
||||
- 任务 prompt 要清晰、具体、有明确输出要求\n\
|
||||
- 如需额外约束,直接写在 prompt 中(例如:\"跳过 .tmp 文件\")\n\
|
||||
- 明确说明期望的输出格式\n\
|
||||
\n\
|
||||
### 并行模式\n\
|
||||
- 多个无依赖的子任务使用 mode=\"parallel\",任务定义在 tasks 数组中\n\
|
||||
- 并行任务之间不应有数据依赖\n\
|
||||
- 并行任务数建议不超过 5 个\n\
|
||||
\n\
|
||||
### 后台模式\n\
|
||||
- 预计执行时间超过 30s 的任务使用 mode=\"background\"\n\
|
||||
- 后台任务有全局并发上限,如果失败提示用户稍后重试".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
// === Sub-Agent Prompt Sections ===
|
||||
|
||||
/// Sub-agent identity and task instructions.
|
||||
pub struct SubAgentIdentitySection {
|
||||
pub task: String,
|
||||
pub timeout: String,
|
||||
}
|
||||
|
||||
impl PromptSection for SubAgentIdentitySection {
|
||||
fn name(&self) -> &str {
|
||||
"sub_agent_identity"
|
||||
}
|
||||
|
||||
fn build(&self, _ctx: &PromptContext<'_>) -> String {
|
||||
format!(
|
||||
"## 子 Agent\n\n\
|
||||
你是主 Agent 派出的子 Agent,负责完成一个具体任务。你的最终回复将汇报给主 Agent。\n\
|
||||
\n\
|
||||
## 任务\n\n\
|
||||
{}\n\
|
||||
\n\
|
||||
## 规则\n\
|
||||
- 只专注于上述任务,不要探索无关话题\n\
|
||||
- 只在必要时使用工具\n\
|
||||
- 不要使用 delegate 工具(禁止递归委托)\n\
|
||||
- 如果任务无法完成,清楚说明原因\n\
|
||||
- 只返回最终结果,不要描述过程\n\
|
||||
- 超时:{},接近时限时返回部分结果",
|
||||
self.task, self.timeout,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Sub-agent available tools description.
|
||||
pub struct SubAgentToolsSection {
|
||||
pub http_get_only: bool,
|
||||
}
|
||||
|
||||
impl PromptSection for SubAgentToolsSection {
|
||||
fn name(&self) -> &str {
|
||||
"sub_agent_tools"
|
||||
}
|
||||
|
||||
fn build(&self, ctx: &PromptContext<'_>) -> String {
|
||||
let mut s = String::from("## 可用工具\n\n");
|
||||
s.push_str(&ctx.tools.describe_for_prompt());
|
||||
if self.http_get_only {
|
||||
s.push_str(
|
||||
"\n\n**注意**:使用 http_request 时只允许 GET 方法,禁止 POST、PUT、DELETE 等。",
|
||||
);
|
||||
}
|
||||
s
|
||||
}
|
||||
}
|
||||
|
||||
/// Sub-agent skills information, injected when get_skill tool is available.
|
||||
pub struct SubAgentSkillsSection {
|
||||
pub skills_prompt: String,
|
||||
}
|
||||
|
||||
impl PromptSection for SubAgentSkillsSection {
|
||||
fn name(&self) -> &str {
|
||||
"sub_agent_skills"
|
||||
}
|
||||
|
||||
fn build(&self, _ctx: &PromptContext<'_>) -> String {
|
||||
self.skills_prompt.clone()
|
||||
}
|
||||
}
|
||||
|
||||
// === Helper Functions ===
|
||||
|
||||
/// Get user config directory (~/.picobot/).
|
||||
@ -409,6 +544,28 @@ pub fn build_system_prompt(
|
||||
SystemPromptBuilder::with_defaults().build(&ctx)
|
||||
}
|
||||
|
||||
/// Build a system prompt for a sub-agent with all relevant operational sections.
|
||||
pub fn build_sub_agent_system_prompt(
|
||||
task: &str,
|
||||
timeout_human: &str,
|
||||
tools: &ToolRegistry,
|
||||
workspace_dir: &Path,
|
||||
model_name: &str,
|
||||
skills_prompt: Option<String>,
|
||||
http_get_only: bool,
|
||||
) -> String {
|
||||
let ctx = PromptContext {
|
||||
workspace_dir,
|
||||
model_name,
|
||||
tools,
|
||||
session_id: None,
|
||||
memory_context: None,
|
||||
has_compressed_history: false,
|
||||
};
|
||||
SystemPromptBuilder::with_sub_agent_defaults(task, timeout_human, skills_prompt, http_get_only)
|
||||
.build(&ctx)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::bus::{MessageBus, OutboundMessage};
|
||||
use crate::channels::base::{Channel, ChannelError};
|
||||
use crate::channels::ChannelManager;
|
||||
use crate::channels::base::{Channel, ChannelError};
|
||||
|
||||
/// OutboundDispatcher consumes outbound messages from the MessageBus
|
||||
/// and dispatches them to the appropriate Channel
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
use std::collections::HashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::providers::ToolCall;
|
||||
|
||||
@ -23,7 +23,9 @@ pub struct ImageUrlBlock {
|
||||
|
||||
impl ContentBlock {
|
||||
pub fn text(content: impl Into<String>) -> Self {
|
||||
Self::Text { text: content.into() }
|
||||
Self::Text {
|
||||
text: content.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn image_url(url: impl Into<String>) -> Self {
|
||||
@ -49,10 +51,10 @@ pub struct MediaRef {
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MediaItem {
|
||||
pub path: String, // Local file path
|
||||
pub media_type: String, // "image", "audio", "file", "video"
|
||||
pub path: String, // Local file path
|
||||
pub media_type: String, // "image", "audio", "file", "video"
|
||||
pub mime_type: Option<String>,
|
||||
pub original_key: Option<String>, // Feishu file_key for download
|
||||
pub original_key: Option<String>, // Feishu file_key for download
|
||||
}
|
||||
|
||||
impl MediaItem {
|
||||
@ -161,7 +163,10 @@ impl ChatMessage {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn assistant_with_tool_calls(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
|
||||
pub fn assistant_with_tool_calls(
|
||||
content: impl Into<String>,
|
||||
tool_calls: Vec<ToolCall>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
role: "assistant".to_string(),
|
||||
@ -206,7 +211,11 @@ impl ChatMessage {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
pub fn tool(
|
||||
tool_call_id: impl Into<String>,
|
||||
tool_name: impl Into<String>,
|
||||
content: impl Into<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
role: "tool".to_string(),
|
||||
|
||||
@ -2,10 +2,13 @@ pub mod dispatcher;
|
||||
pub mod message;
|
||||
|
||||
pub use dispatcher::OutboundDispatcher;
|
||||
pub use message::{ChatMessage, ContentBlock, ControlMessage, InboundMessage, MediaItem, MediaRef, MessageSource, OutboundMessage, SourceKind};
|
||||
pub use message::{
|
||||
ChatMessage, ContentBlock, ControlMessage, InboundMessage, MediaItem, MediaRef, MessageSource,
|
||||
OutboundMessage, SourceKind,
|
||||
};
|
||||
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
|
||||
// ============================================================================
|
||||
// MessageBus - Async message queue for Channel <-> Agent communication
|
||||
@ -49,7 +52,8 @@ impl MessageBus {
|
||||
|
||||
/// Consume an inbound message (Agent -> Bus)
|
||||
pub async fn consume_inbound(&self) -> InboundMessage {
|
||||
let msg = self.inbound_rx
|
||||
let msg = self
|
||||
.inbound_rx
|
||||
.lock()
|
||||
.await
|
||||
.recv()
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
use std::sync::Arc;
|
||||
use async_trait::async_trait;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
|
||||
use crate::bus::{ControlMessage, InboundMessage, MessageBus, OutboundMessage};
|
||||
use crate::protocol::{SlashCommandInfo, WsInbound, WsOutbound, parse_inbound};
|
||||
use crate::session::{SessionCommand, SessionEvent, UnifiedSessionId};
|
||||
use crate::protocol::{parse_inbound, WsInbound, WsOutbound, SlashCommandInfo};
|
||||
|
||||
use super::base::{Channel, ChannelError};
|
||||
|
||||
@ -14,6 +14,7 @@ use super::base::{Channel, ChannelError};
|
||||
|
||||
pub(crate) struct Client {
|
||||
sender: mpsc::Sender<WsOutbound>,
|
||||
chat_id: String,
|
||||
current_session_id: Mutex<Option<String>>,
|
||||
}
|
||||
|
||||
@ -41,23 +42,28 @@ impl CliChatChannel {
|
||||
}
|
||||
|
||||
/// Register a new client connection, returns (session_id, client)
|
||||
pub(crate) async fn register_client(&self, sender: mpsc::Sender<WsOutbound>) -> (String, Arc<Client>) {
|
||||
// Generate connection ID (used as chat_id) - use short ID
|
||||
let connection_id = crate::util::short_id();
|
||||
pub(crate) async fn register_client(
|
||||
&self,
|
||||
sender: mpsc::Sender<WsOutbound>,
|
||||
) -> (String, Arc<Client>) {
|
||||
// Each WebSocket connection gets a stable chat scope. All user input and
|
||||
// dialog controls for this client stay inside that scope unless the
|
||||
// protocol explicitly carries a full session id.
|
||||
let chat_id = crate::util::short_id();
|
||||
|
||||
let client = Arc::new(Client {
|
||||
sender,
|
||||
chat_id: chat_id.clone(),
|
||||
current_session_id: Mutex::new(None),
|
||||
});
|
||||
self.clients.lock().await.push(client.clone());
|
||||
|
||||
// Create initial session via control message
|
||||
let session_id = match self.create_session_via_control(&connection_id, None).await {
|
||||
Ok(id) => id,
|
||||
let session_id = match self.create_session_via_control(&chat_id, None).await {
|
||||
Ok((id, _title)) => id,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to create initial session");
|
||||
// Fall back to old format for backward compatibility
|
||||
connection_id.clone()
|
||||
UnifiedSessionId::new("cli_chat", &chat_id, &crate::util::short_id()).to_string()
|
||||
}
|
||||
};
|
||||
|
||||
@ -73,21 +79,19 @@ impl CliChatChannel {
|
||||
/// Handle an inbound message from a client
|
||||
pub(crate) async fn handle_inbound(&self, client: Arc<Client>, raw_msg: &str) {
|
||||
match parse_inbound(raw_msg) {
|
||||
Ok(inbound) => {
|
||||
match self.handle_ws_inbound(client.clone(), inbound).await {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "Failed to handle inbound message");
|
||||
let _ = client
|
||||
.sender
|
||||
.send(WsOutbound::Error {
|
||||
code: "INTERNAL_ERROR".to_string(),
|
||||
message: e.to_string(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
Ok(inbound) => match self.handle_ws_inbound(client.clone(), inbound).await {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "Failed to handle inbound message");
|
||||
let _ = client
|
||||
.sender
|
||||
.send(WsOutbound::Error {
|
||||
code: "INTERNAL_ERROR".to_string(),
|
||||
message: e.to_string(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "Failed to parse inbound message");
|
||||
let _ = client
|
||||
@ -101,22 +105,30 @@ impl CliChatChannel {
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_ws_inbound(&self, client: Arc<Client>, inbound: WsInbound) -> Result<(), ChannelError> {
|
||||
async fn handle_ws_inbound(
|
||||
&self,
|
||||
client: Arc<Client>,
|
||||
inbound: WsInbound,
|
||||
) -> Result<(), ChannelError> {
|
||||
let bus = {
|
||||
let guard = self.bus.lock().unwrap();
|
||||
guard.clone().ok_or_else(|| ChannelError::Other("Channel not started".to_string()))?
|
||||
guard
|
||||
.clone()
|
||||
.ok_or_else(|| ChannelError::Other("Channel not started".to_string()))?
|
||||
};
|
||||
|
||||
let mut current_session_guard = client.current_session_id.lock().await;
|
||||
|
||||
match inbound {
|
||||
WsInbound::UserInput { content, chat_id, .. } => {
|
||||
WsInbound::UserInput {
|
||||
content, chat_id, ..
|
||||
} => {
|
||||
// All messages (including slash commands) go through the normal inbound flow
|
||||
// SessionManager handles session creation/reuse internally
|
||||
let msg = InboundMessage {
|
||||
channel: self.name().to_string(),
|
||||
sender_id: "cli".to_string(),
|
||||
chat_id: chat_id.unwrap_or_else(crate::util::short_id),
|
||||
chat_id: chat_id.unwrap_or_else(|| client.chat_id.clone()),
|
||||
content,
|
||||
timestamp: crate::bus::message::current_timestamp(),
|
||||
media: Vec::new(),
|
||||
@ -125,19 +137,56 @@ impl CliChatChannel {
|
||||
};
|
||||
bus.publish_inbound(msg).await?;
|
||||
}
|
||||
WsInbound::ClearHistory { chat_id, session_id } => {
|
||||
let target = session_id
|
||||
.or(chat_id)
|
||||
.or(current_session_guard.clone())
|
||||
.ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
|
||||
|
||||
WsInbound::ClearHistory {
|
||||
chat_id,
|
||||
session_id,
|
||||
} => {
|
||||
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
||||
let session_id = UnifiedSessionId::parse(&target)
|
||||
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
|
||||
let session_id = if let Some(session_id) = session_id {
|
||||
UnifiedSessionId::parse(&session_id).ok_or_else(|| {
|
||||
ChannelError::Other("Invalid session ID format".to_string())
|
||||
})?
|
||||
} else if let Some(chat_id) = chat_id {
|
||||
let (current_tx, mut current_rx) = mpsc::channel(1);
|
||||
bus.publish_control(ControlMessage {
|
||||
op: SessionCommand::GetCurrentDialog {
|
||||
channel: "cli_chat".to_string(),
|
||||
chat_id,
|
||||
},
|
||||
reply_tx: current_tx,
|
||||
})
|
||||
.await?;
|
||||
match current_rx.recv().await {
|
||||
Some(Ok(SessionEvent::CurrentDialog {
|
||||
session_id: Some(session_id),
|
||||
})) => session_id,
|
||||
Some(Ok(SessionEvent::CurrentDialog { session_id: None })) => {
|
||||
return Err(ChannelError::Other("No active session".to_string()));
|
||||
}
|
||||
Some(Ok(_)) => {
|
||||
return Err(ChannelError::Other(
|
||||
"Unexpected response type".to_string(),
|
||||
));
|
||||
}
|
||||
Some(Err(e)) => return Err(e),
|
||||
None => {
|
||||
return Err(ChannelError::Other("Control channel closed".to_string()));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let target = current_session_guard
|
||||
.clone()
|
||||
.ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
|
||||
UnifiedSessionId::parse(&target).ok_or_else(|| {
|
||||
ChannelError::Other("Invalid session ID format".to_string())
|
||||
})?
|
||||
};
|
||||
let target = session_id.to_string();
|
||||
bus.publish_control(ControlMessage {
|
||||
op: SessionCommand::ClearHistory { session_id },
|
||||
reply_tx,
|
||||
}).await?;
|
||||
})
|
||||
.await?;
|
||||
|
||||
match reply_rx.recv().await {
|
||||
Some(Ok(SessionEvent::HistoryCleared { .. })) => {
|
||||
@ -158,24 +207,21 @@ impl CliChatChannel {
|
||||
}
|
||||
}
|
||||
WsInbound::CreateSession { title } => {
|
||||
// Use current session's chat_id if available, otherwise generate new one
|
||||
let chat_id = current_session_guard.clone()
|
||||
.unwrap_or_else(crate::util::short_id);
|
||||
let new_id = self.create_session_via_control(&chat_id, title.as_deref()).await?;
|
||||
let (new_id, created_title) = self
|
||||
.create_session_via_control(&client.chat_id, title.as_deref())
|
||||
.await?;
|
||||
*current_session_guard = Some(new_id.clone());
|
||||
let _ = client
|
||||
.sender
|
||||
.send(WsOutbound::SessionCreated {
|
||||
session_id: new_id,
|
||||
title: title.unwrap_or_default(),
|
||||
title: created_title,
|
||||
})
|
||||
.await;
|
||||
}
|
||||
WsInbound::ListSessions { include_archived } => {
|
||||
// List dialogs for the current chat
|
||||
let chat_id = current_session_guard.clone()
|
||||
.unwrap_or_else(|| "".to_string());
|
||||
let chat_id_for_response = chat_id.clone();
|
||||
let chat_id = client.chat_id.clone();
|
||||
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
||||
bus.publish_control(ControlMessage {
|
||||
op: SessionCommand::ListDialogs {
|
||||
@ -184,13 +230,18 @@ impl CliChatChannel {
|
||||
include_archived,
|
||||
},
|
||||
reply_tx,
|
||||
}).await?;
|
||||
})
|
||||
.await?;
|
||||
|
||||
match reply_rx.recv().await {
|
||||
Some(Ok(SessionEvent::DialogList { dialogs, current_dialog_id })) => {
|
||||
Some(Ok(SessionEvent::DialogList {
|
||||
dialogs,
|
||||
current_dialog_id,
|
||||
})) => {
|
||||
// Convert DialogInfo to SessionSummary for backward compatibility
|
||||
let sessions: Vec<crate::protocol::SessionSummary> = dialogs.into_iter().map(|d| {
|
||||
crate::protocol::SessionSummary {
|
||||
let sessions: Vec<crate::protocol::SessionSummary> = dialogs
|
||||
.into_iter()
|
||||
.map(|d| crate::protocol::SessionSummary {
|
||||
session_id: d.session_id.to_string(),
|
||||
title: d.title,
|
||||
channel_name: d.session_id.channel.clone(),
|
||||
@ -198,11 +249,14 @@ impl CliChatChannel {
|
||||
message_count: d.message_count,
|
||||
last_active_at: d.last_active_at,
|
||||
archived_at: d.archived_at,
|
||||
}
|
||||
}).collect();
|
||||
})
|
||||
.collect();
|
||||
let current_session_id = current_dialog_id.map(|did| {
|
||||
UnifiedSessionId::new("cli_chat", chat_id_for_response.clone(), did).to_string()
|
||||
UnifiedSessionId::new("cli_chat", &client.chat_id, &did).to_string()
|
||||
});
|
||||
if let Some(ref session_id) = current_session_id {
|
||||
*current_session_guard = Some(session_id.clone());
|
||||
}
|
||||
let _ = client
|
||||
.sender
|
||||
.send(WsOutbound::SessionList {
|
||||
@ -223,39 +277,35 @@ impl CliChatChannel {
|
||||
}
|
||||
}
|
||||
WsInbound::LoadSession { session_id } => {
|
||||
// LoadSession: parse the session_id and get current dialog info
|
||||
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
||||
let unified_id = UnifiedSessionId::parse(&session_id)
|
||||
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
|
||||
if unified_id.channel != "cli_chat" || unified_id.chat_id != client.chat_id {
|
||||
return Err(ChannelError::Other(
|
||||
"Session does not belong to this client".to_string(),
|
||||
));
|
||||
}
|
||||
bus.publish_control(ControlMessage {
|
||||
op: SessionCommand::GetCurrentDialog {
|
||||
op: SessionCommand::SwitchDialog {
|
||||
channel: unified_id.channel.clone(),
|
||||
chat_id: unified_id.chat_id.clone(),
|
||||
dialog_id: unified_id.dialog_id.clone(),
|
||||
},
|
||||
reply_tx,
|
||||
}).await?;
|
||||
})
|
||||
.await?;
|
||||
|
||||
match reply_rx.recv().await {
|
||||
Some(Ok(SessionEvent::CurrentDialog { session_id: current_session_id_opt })) => {
|
||||
if let Some(current_session_id) = current_session_id_opt {
|
||||
*current_session_guard = Some(current_session_id.to_string());
|
||||
let _ = client
|
||||
.sender
|
||||
.send(WsOutbound::SessionLoaded {
|
||||
session_id: current_session_id.to_string(),
|
||||
title: "Session".to_string(), // TODO: get actual title
|
||||
message_count: 0, // TODO: get actual count
|
||||
})
|
||||
.await;
|
||||
} else {
|
||||
let _ = client
|
||||
.sender
|
||||
.send(WsOutbound::Error {
|
||||
code: "NO_CURRENT_DIALOG".to_string(),
|
||||
message: "No current dialog".to_string(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
Some(Ok(SessionEvent::DialogSwitched { session_id })) => {
|
||||
*current_session_guard = Some(session_id.to_string());
|
||||
let _ = client
|
||||
.sender
|
||||
.send(WsOutbound::SessionLoaded {
|
||||
session_id: session_id.to_string(),
|
||||
title: "Session".to_string(),
|
||||
message_count: 0,
|
||||
})
|
||||
.await;
|
||||
}
|
||||
Some(Ok(_)) => {
|
||||
// Unexpected response type
|
||||
@ -275,23 +325,30 @@ impl CliChatChannel {
|
||||
}
|
||||
}
|
||||
WsInbound::RenameSession { session_id, title } => {
|
||||
let target = session_id.or(current_session_guard.clone()).ok_or_else(|| {
|
||||
ChannelError::Other("No active session".to_string())
|
||||
})?;
|
||||
let target = session_id
|
||||
.or(current_session_guard.clone())
|
||||
.ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
|
||||
|
||||
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
||||
let unified_id = UnifiedSessionId::parse(&target)
|
||||
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
|
||||
bus.publish_control(ControlMessage {
|
||||
op: SessionCommand::RenameDialog { session_id: unified_id, title: title.clone() },
|
||||
op: SessionCommand::RenameDialog {
|
||||
session_id: unified_id,
|
||||
title: title.clone(),
|
||||
},
|
||||
reply_tx,
|
||||
}).await?;
|
||||
})
|
||||
.await?;
|
||||
|
||||
match reply_rx.recv().await {
|
||||
Some(Ok(SessionEvent::DialogRenamed { session_id, title })) => {
|
||||
let _ = client
|
||||
.sender
|
||||
.send(WsOutbound::SessionRenamed { session_id: session_id.to_string(), title })
|
||||
.send(WsOutbound::SessionRenamed {
|
||||
session_id: session_id.to_string(),
|
||||
title,
|
||||
})
|
||||
.await;
|
||||
}
|
||||
Some(Ok(_)) => {
|
||||
@ -306,24 +363,43 @@ impl CliChatChannel {
|
||||
}
|
||||
}
|
||||
WsInbound::ArchiveSession { session_id } => {
|
||||
let target = session_id.or(current_session_guard.clone()).ok_or_else(|| {
|
||||
ChannelError::Other("No active session".to_string())
|
||||
})?;
|
||||
let target = session_id
|
||||
.or(current_session_guard.clone())
|
||||
.ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
|
||||
let was_current = current_session_guard.as_deref() == Some(&target);
|
||||
|
||||
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
||||
let unified_id = UnifiedSessionId::parse(&target)
|
||||
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
|
||||
bus.publish_control(ControlMessage {
|
||||
op: SessionCommand::ArchiveDialog { session_id: unified_id },
|
||||
op: SessionCommand::ArchiveDialog {
|
||||
session_id: unified_id,
|
||||
},
|
||||
reply_tx,
|
||||
}).await?;
|
||||
})
|
||||
.await?;
|
||||
|
||||
match reply_rx.recv().await {
|
||||
Some(Ok(SessionEvent::DialogArchived { session_id })) => {
|
||||
let _ = client
|
||||
.sender
|
||||
.send(WsOutbound::SessionArchived { session_id: session_id.to_string() })
|
||||
.send(WsOutbound::SessionArchived {
|
||||
session_id: session_id.to_string(),
|
||||
})
|
||||
.await;
|
||||
if was_current {
|
||||
let (new_id, title) = self
|
||||
.create_session_via_control(&client.chat_id, None)
|
||||
.await?;
|
||||
*current_session_guard = Some(new_id.clone());
|
||||
let _ = client
|
||||
.sender
|
||||
.send(WsOutbound::SessionCreated {
|
||||
session_id: new_id,
|
||||
title,
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
Some(Ok(_)) => {
|
||||
// Unexpected response type
|
||||
@ -337,35 +413,42 @@ impl CliChatChannel {
|
||||
}
|
||||
}
|
||||
WsInbound::DeleteSession { session_id } => {
|
||||
let target = session_id.or(current_session_guard.clone()).ok_or_else(|| {
|
||||
ChannelError::Other("No active session".to_string())
|
||||
})?;
|
||||
let target = session_id
|
||||
.or(current_session_guard.clone())
|
||||
.ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
|
||||
|
||||
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
||||
let unified_id = UnifiedSessionId::parse(&target)
|
||||
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
|
||||
bus.publish_control(ControlMessage {
|
||||
op: SessionCommand::DeleteDialog { session_id: unified_id },
|
||||
op: SessionCommand::DeleteDialog {
|
||||
session_id: unified_id,
|
||||
},
|
||||
reply_tx,
|
||||
}).await?;
|
||||
})
|
||||
.await?;
|
||||
|
||||
match reply_rx.recv().await {
|
||||
Some(Ok(SessionEvent::DialogDeleted { session_id })) => {
|
||||
let _ = client
|
||||
.sender
|
||||
.send(WsOutbound::SessionDeleted { session_id: session_id.to_string() })
|
||||
.send(WsOutbound::SessionDeleted {
|
||||
session_id: session_id.to_string(),
|
||||
})
|
||||
.await;
|
||||
|
||||
// If deleting current session, create a new one
|
||||
if current_session_guard.as_deref() == Some(&target) {
|
||||
drop(reply_rx);
|
||||
if let Ok(new_id) = self.create_session_via_control(&target, None).await {
|
||||
if let Ok((new_id, title)) =
|
||||
self.create_session_via_control(&client.chat_id, None).await
|
||||
{
|
||||
*current_session_guard = Some(new_id.clone());
|
||||
let _ = client
|
||||
.sender
|
||||
.send(WsOutbound::SessionCreated {
|
||||
session_id: new_id,
|
||||
title: String::new(),
|
||||
title,
|
||||
})
|
||||
.await;
|
||||
}
|
||||
@ -388,32 +471,45 @@ impl CliChatChannel {
|
||||
bus.publish_control(ControlMessage {
|
||||
op: SessionCommand::GetSlashCommands {
|
||||
channel: "cli_chat".to_string(),
|
||||
chat_id: "".to_string(),
|
||||
chat_id: client.chat_id.clone(),
|
||||
},
|
||||
reply_tx,
|
||||
}).await?;
|
||||
})
|
||||
.await?;
|
||||
|
||||
if let Some(result) = reply_rx.recv().await {
|
||||
match result {
|
||||
Ok(SessionEvent::SlashCommandsList { commands }) => {
|
||||
// Convert to SlashCommand to SlashCommandInfo
|
||||
let command_infos: Vec<SlashCommandInfo> = commands.into_iter().map(|cmd| {
|
||||
SlashCommandInfo {
|
||||
let command_infos: Vec<SlashCommandInfo> = commands
|
||||
.into_iter()
|
||||
.map(|cmd| SlashCommandInfo {
|
||||
name: cmd.name.to_string(),
|
||||
description: cmd.description.to_string(),
|
||||
aliases: cmd.aliases.iter().map(|&a| a.to_string()).collect(),
|
||||
}
|
||||
}).collect();
|
||||
let _ = client.sender.send(WsOutbound::SlashCommandsList { commands: command_infos }).await;
|
||||
})
|
||||
.collect();
|
||||
let _ = client
|
||||
.sender
|
||||
.send(WsOutbound::SlashCommandsList {
|
||||
commands: command_infos,
|
||||
})
|
||||
.await;
|
||||
}
|
||||
Ok(SessionEvent::Error { code, message }) => {
|
||||
let _ = client.sender.send(WsOutbound::Error { code, message }).await;
|
||||
let _ = client
|
||||
.sender
|
||||
.send(WsOutbound::Error { code, message })
|
||||
.await;
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = client.sender.send(WsOutbound::Error {
|
||||
code: "GET_COMMANDS_ERROR".to_string(),
|
||||
message: e.to_string()
|
||||
}).await;
|
||||
let _ = client
|
||||
.sender
|
||||
.send(WsOutbound::Error {
|
||||
code: "GET_COMMANDS_ERROR".to_string(),
|
||||
message: e.to_string(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
@ -427,29 +523,34 @@ impl CliChatChannel {
|
||||
}
|
||||
|
||||
/// Create a session via control message and return the session_id
|
||||
async fn create_session_via_control(&self, connection_id: &str, title: Option<&str>) -> Result<String, ChannelError> {
|
||||
async fn create_session_via_control(
|
||||
&self,
|
||||
chat_id: &str,
|
||||
title: Option<&str>,
|
||||
) -> Result<(String, String), ChannelError> {
|
||||
let bus = {
|
||||
let guard = self.bus.lock().unwrap();
|
||||
guard.clone().ok_or_else(|| ChannelError::Other("Channel not started".to_string()))?
|
||||
guard
|
||||
.clone()
|
||||
.ok_or_else(|| ChannelError::Other("Channel not started".to_string()))?
|
||||
};
|
||||
|
||||
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
||||
bus.publish_control(ControlMessage {
|
||||
op: SessionCommand::CreateDialog {
|
||||
channel: "cli_chat".to_string(),
|
||||
chat_id: connection_id.to_string(),
|
||||
chat_id: chat_id.to_string(),
|
||||
title: title.map(String::from),
|
||||
},
|
||||
reply_tx,
|
||||
}).await?;
|
||||
})
|
||||
.await?;
|
||||
|
||||
match reply_rx.recv().await {
|
||||
Some(Ok(SessionEvent::DialogCreated { session_id, .. })) => {
|
||||
Ok(session_id.to_string())
|
||||
}
|
||||
Some(Ok(_)) => {
|
||||
Err(ChannelError::Other("Unexpected response type".to_string()))
|
||||
Some(Ok(SessionEvent::DialogCreated { session_id, title })) => {
|
||||
Ok((session_id.to_string(), title))
|
||||
}
|
||||
Some(Ok(_)) => Err(ChannelError::Other("Unexpected response type".to_string())),
|
||||
Some(Err(e)) => Err(e),
|
||||
None => Err(ChannelError::Other("Control channel closed".to_string())),
|
||||
}
|
||||
@ -479,7 +580,11 @@ impl Channel for CliChatChannel {
|
||||
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
|
||||
let clients = self.clients.lock().await.clone();
|
||||
for client in clients {
|
||||
let outbound = if msg.metadata.get("_type").map(|v| v.as_str()) == Some("notification") {
|
||||
if client.chat_id != msg.chat_id {
|
||||
continue;
|
||||
}
|
||||
let outbound = if msg.metadata.get("_type").map(|v| v.as_str()) == Some("notification")
|
||||
{
|
||||
WsOutbound::SystemNotification {
|
||||
content: msg.content.clone(),
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -24,7 +24,10 @@ impl ChannelManager {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_bus(cli_chat_channel: Arc<crate::channels::CliChatChannel>, bus: Arc<MessageBus>) -> Self {
|
||||
pub fn with_bus(
|
||||
cli_chat_channel: Arc<crate::channels::CliChatChannel>,
|
||||
bus: Arc<MessageBus>,
|
||||
) -> Self {
|
||||
Self {
|
||||
channels: Arc::new(RwLock::new(HashMap::new())),
|
||||
cli_chat_channel,
|
||||
@ -39,7 +42,10 @@ impl ChannelManager {
|
||||
|
||||
/// Register a channel with the manager
|
||||
pub async fn register_channel(&self, name: &str, channel: Arc<dyn Channel + Send + Sync>) {
|
||||
self.channels.write().await.insert(name.to_string(), channel);
|
||||
self.channels
|
||||
.write()
|
||||
.await
|
||||
.insert(name.to_string(), channel);
|
||||
}
|
||||
|
||||
/// Get CLI chat channel
|
||||
@ -56,14 +62,19 @@ impl ChannelManager {
|
||||
// Initialize Feishu channel if enabled
|
||||
if let Some(feishu_config) = config.channels.get("feishu") {
|
||||
if feishu_config.enabled {
|
||||
let channel = FeishuChannel::new(feishu_config.clone(), &workspace_dir)
|
||||
.map_err(|e| ChannelError::Other(format!("Failed to create Feishu channel: {}", e)))?;
|
||||
let channel =
|
||||
FeishuChannel::new(feishu_config.clone(), &workspace_dir).map_err(|e| {
|
||||
ChannelError::Other(format!("Failed to create Feishu channel: {}", e))
|
||||
})?;
|
||||
|
||||
self.channels
|
||||
.write()
|
||||
.await
|
||||
.insert("feishu".to_string(), Arc::new(channel));
|
||||
tracing::info!("Feishu channel registered (media_dir: {}/media/feishu)", workspace_dir.display());
|
||||
tracing::info!(
|
||||
"Feishu channel registered (media_dir: {}/media/feishu)",
|
||||
workspace_dir.display()
|
||||
);
|
||||
} else {
|
||||
tracing::info!("Feishu channel disabled in config");
|
||||
}
|
||||
@ -118,7 +129,10 @@ impl ChannelManager {
|
||||
if let Some(channel) = self.get_channel(channel_name).await {
|
||||
channel.send(msg).await
|
||||
} else {
|
||||
Err(ChannelError::Other(format!("Channel not found: {}", channel_name)))
|
||||
Err(ChannelError::Other(format!(
|
||||
"Channel not found: {}",
|
||||
channel_name
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
pub mod base;
|
||||
pub mod feishu;
|
||||
pub mod cli_chat;
|
||||
pub mod feishu;
|
||||
pub mod manager;
|
||||
pub mod slash_command;
|
||||
|
||||
pub use base::{Channel, ChannelError};
|
||||
pub use manager::ChannelManager;
|
||||
pub use feishu::FeishuChannel;
|
||||
pub use cli_chat::CliChatChannel;
|
||||
pub use slash_command::{parse_slash_command, command_matches};
|
||||
pub use feishu::FeishuChannel;
|
||||
pub use manager::ChannelManager;
|
||||
pub use slash_command::{command_matches, parse_slash_command};
|
||||
|
||||
@ -16,7 +16,9 @@ pub fn parse_slash_command(content: &str) -> Option<(&str, &str)> {
|
||||
/// 检查内容是否匹配指定命令
|
||||
pub fn command_matches(content: &str, aliases: &[&str]) -> bool {
|
||||
let trimmed = content.trim();
|
||||
aliases.iter().any(|&alias| trimmed == alias || trimmed.starts_with(&format!("{} ", alias)))
|
||||
aliases
|
||||
.iter()
|
||||
.any(|&alias| trimmed == alias || trimmed.starts_with(&format!("{} ", alias)))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@ -27,7 +29,10 @@ mod tests {
|
||||
fn test_parse_slash_command() {
|
||||
assert_eq!(parse_slash_command("/reset"), Some(("reset", "")));
|
||||
assert_eq!(parse_slash_command("/reset arg"), Some(("reset", "arg")));
|
||||
assert_eq!(parse_slash_command("/new hello world"), Some(("new", "hello world")));
|
||||
assert_eq!(
|
||||
parse_slash_command("/new hello world"),
|
||||
Some(("new", "hello world"))
|
||||
);
|
||||
assert_eq!(parse_slash_command("/??"), Some(("??", "")));
|
||||
assert_eq!(parse_slash_command("/? arg"), Some(("?", "arg")));
|
||||
assert_eq!(parse_slash_command("/?"), Some(("?", "")));
|
||||
|
||||
@ -8,10 +8,10 @@ use crate::client::tui::ui::render_ui;
|
||||
use crossterm::{
|
||||
event::{self, Event},
|
||||
execute,
|
||||
terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen},
|
||||
terminal::{EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, enable_raw_mode},
|
||||
};
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use ratatui::{prelude::CrosstermBackend, Terminal};
|
||||
use ratatui::{Terminal, prelude::CrosstermBackend};
|
||||
use std::io;
|
||||
use tokio_tungstenite::{connect_async, tungstenite::Message};
|
||||
|
||||
@ -104,7 +104,10 @@ async fn handle_ws_message(app: &mut App, outbound: WsOutbound) {
|
||||
WsOutbound::SessionCreated { session_id, .. } => {
|
||||
app.set_current_session(Some(session_id));
|
||||
}
|
||||
WsOutbound::SessionList { sessions, current_session_id } => {
|
||||
WsOutbound::SessionList {
|
||||
sessions,
|
||||
current_session_id,
|
||||
} => {
|
||||
app.set_sessions(sessions);
|
||||
if let Some(id) = current_session_id {
|
||||
app.set_current_session(Some(id));
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
use crate::client::tui::app::{App, MessageRole};
|
||||
use ratatui::{
|
||||
Frame,
|
||||
layout::Rect,
|
||||
style::{Color, Modifier, Style},
|
||||
text::Line,
|
||||
widgets::{Block, Borders, List, ListItem},
|
||||
Frame,
|
||||
};
|
||||
|
||||
pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
use crate::client::tui::app::App;
|
||||
use ratatui::{
|
||||
Frame,
|
||||
layout::Rect,
|
||||
style::{Color, Modifier, Style},
|
||||
text::{Line, Span},
|
||||
widgets::{Block, Borders, List, ListItem},
|
||||
Frame,
|
||||
};
|
||||
|
||||
pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
use ratatui::{
|
||||
Frame,
|
||||
layout::Rect,
|
||||
style::{Color, Modifier, Style},
|
||||
widgets::{Block, Borders, Clear, List, ListItem},
|
||||
Frame,
|
||||
};
|
||||
|
||||
pub fn render(f: &mut Frame, area: Rect) {
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
use crate::client::tui::app::App;
|
||||
use ratatui::{
|
||||
Frame,
|
||||
layout::Rect,
|
||||
style::{Color, Style},
|
||||
widgets::{Block, Borders, Paragraph},
|
||||
Frame,
|
||||
};
|
||||
|
||||
pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
use crate::client::tui::app::App;
|
||||
use ratatui::{
|
||||
Frame,
|
||||
layout::Rect,
|
||||
style::{Color, Modifier, Style},
|
||||
widgets::{Block, Borders, List, ListItem},
|
||||
Frame,
|
||||
};
|
||||
|
||||
pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
||||
@ -11,9 +11,7 @@ pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
||||
.sessions
|
||||
.iter()
|
||||
.map(|session| {
|
||||
let is_current = app
|
||||
.current_session_id
|
||||
.as_ref() == Some(&session.session_id);
|
||||
let is_current = app.current_session_id.as_ref() == Some(&session.session_id);
|
||||
let archived = session.archived_at.is_some();
|
||||
|
||||
let mut content = if is_current {
|
||||
|
||||
@ -1,15 +1,18 @@
|
||||
use crate::client::tui::app::App;
|
||||
use ratatui::{
|
||||
Frame,
|
||||
layout::Rect,
|
||||
style::{Color, Modifier, Style},
|
||||
widgets::{Block, Borders, Paragraph},
|
||||
Frame,
|
||||
};
|
||||
|
||||
pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
||||
let (title, style) = if app.pending_quit {
|
||||
let msg = if let Some(session_id) = &app.current_session_id {
|
||||
format!("PicoBot | Session: {} | Press Ctrl+C again to quit", session_id)
|
||||
format!(
|
||||
"PicoBot | Session: {} | Press Ctrl+C again to quit",
|
||||
session_id
|
||||
)
|
||||
} else {
|
||||
"PicoBot | Press Ctrl+C again to quit".to_string()
|
||||
};
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
use crate::client::tui::app::{App, MessageRole};
|
||||
use crate::protocol::serialize_inbound;
|
||||
use crate::protocol::WsInbound;
|
||||
use crate::protocol::serialize_inbound;
|
||||
use crossterm::event::{KeyCode, KeyEvent};
|
||||
use futures_util::SinkExt;
|
||||
|
||||
@ -48,7 +48,10 @@ pub async fn handle_key_event(app: &mut App, key: KeyEvent) {
|
||||
|
||||
async fn handle_normal_input(app: &mut App, key: KeyEvent) {
|
||||
// Handle Ctrl+C for quit (double press to exit)
|
||||
let is_ctrl_c = key.code == KeyCode::Char('c') && key.modifiers.contains(crossterm::event::KeyModifiers::CONTROL);
|
||||
let is_ctrl_c = key.code == KeyCode::Char('c')
|
||||
&& key
|
||||
.modifiers
|
||||
.contains(crossterm::event::KeyModifiers::CONTROL);
|
||||
if is_ctrl_c {
|
||||
if app.handle_ctrl_c_for_quit() {
|
||||
return;
|
||||
@ -63,9 +66,11 @@ async fn handle_normal_input(app: &mut App, key: KeyEvent) {
|
||||
}
|
||||
KeyCode::Char(c) => {
|
||||
app.input_insert_char(c);
|
||||
|
||||
|
||||
// Show command menu when input starts with /
|
||||
if !app.show_command_menu && (app.input == "/" || (app.input.len() > 1 && app.input.starts_with('/'))) {
|
||||
if !app.show_command_menu
|
||||
&& (app.input == "/" || (app.input.len() > 1 && app.input.starts_with('/')))
|
||||
{
|
||||
app.show_command_menu = true;
|
||||
app.selected_command_idx = 0;
|
||||
} else if app.show_command_menu && !app.input.starts_with('/') {
|
||||
@ -74,7 +79,7 @@ async fn handle_normal_input(app: &mut App, key: KeyEvent) {
|
||||
}
|
||||
KeyCode::Backspace => {
|
||||
app.input_delete_char();
|
||||
|
||||
|
||||
// Hide menu if input no longer starts with /
|
||||
if app.show_command_menu && !app.input.starts_with('/') {
|
||||
app.show_command_menu = false;
|
||||
@ -121,7 +126,9 @@ async fn process_input(app: &mut App, input: String) {
|
||||
sender_id: None,
|
||||
};
|
||||
if let Ok(text) = serialize_inbound(&inbound) {
|
||||
let _ = sender.send(tokio_tungstenite::tungstenite::Message::Text(text.into())).await;
|
||||
let _ = sender
|
||||
.send(tokio_tungstenite::tungstenite::Message::Text(text.into()))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
use crate::client::tui::app::App;
|
||||
use crate::client::tui::components::*;
|
||||
use ratatui::{
|
||||
layout::{Constraint, Direction, Layout, Rect},
|
||||
Frame,
|
||||
layout::{Constraint, Direction, Layout, Rect},
|
||||
};
|
||||
|
||||
pub fn render_ui(f: &mut Frame, app: &App) {
|
||||
|
||||
@ -152,10 +152,26 @@ pub struct GatewayConfig {
|
||||
pub cleanup_interval_minutes: Option<u64>,
|
||||
#[serde(default, rename = "session_db_path")]
|
||||
pub session_db_path: Option<String>,
|
||||
#[serde(default, rename = "max_concurrent_background_tasks")]
|
||||
pub max_concurrent_background_tasks: usize,
|
||||
#[serde(default)]
|
||||
pub scheduler: Option<SchedulerConfig>,
|
||||
}
|
||||
|
||||
impl Default for GatewayConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
host: default_gateway_host(),
|
||||
port: default_gateway_port(),
|
||||
session_ttl_hours: None,
|
||||
cleanup_interval_minutes: None,
|
||||
session_db_path: None,
|
||||
max_concurrent_background_tasks: 10,
|
||||
scheduler: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SchedulerConfig {
|
||||
/// Whether the scheduler is enabled
|
||||
@ -209,19 +225,6 @@ fn default_gateway_url() -> String {
|
||||
"ws://127.0.0.1:19876/ws".to_string()
|
||||
}
|
||||
|
||||
impl Default for GatewayConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
host: default_gateway_host(),
|
||||
port: default_gateway_port(),
|
||||
session_ttl_hours: None,
|
||||
cleanup_interval_minutes: None,
|
||||
session_db_path: None,
|
||||
scheduler: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ClientConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@ -270,12 +273,16 @@ impl Default for MemoryConfig {
|
||||
impl MemoryConfig {
|
||||
/// Resolve consolidation provider name, falling back to the main agent's provider.
|
||||
pub fn resolve_consolidation_provider(&self, default: &str) -> String {
|
||||
self.consolidation_provider.clone().unwrap_or_else(|| default.to_string())
|
||||
self.consolidation_provider
|
||||
.clone()
|
||||
.unwrap_or_else(|| default.to_string())
|
||||
}
|
||||
|
||||
/// Resolve consolidation model name, falling back to the main agent's model.
|
||||
pub fn resolve_consolidation_model(&self, default: &str) -> String {
|
||||
self.consolidation_model.clone().unwrap_or_else(|| default.to_string())
|
||||
self.consolidation_model
|
||||
.clone()
|
||||
.unwrap_or_else(|| default.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
@ -363,10 +370,18 @@ impl Default for BrowserConfig {
|
||||
}
|
||||
}
|
||||
|
||||
fn default_recall_limit() -> usize { 5 }
|
||||
fn default_idle_consolidation_minutes() -> u64 { 10 }
|
||||
fn default_timeline_retention_days() -> u64 { 90 }
|
||||
fn default_max_failures_before_degrade() -> usize { 3 }
|
||||
fn default_recall_limit() -> usize {
|
||||
5
|
||||
}
|
||||
fn default_idle_consolidation_minutes() -> u64 {
|
||||
10
|
||||
}
|
||||
fn default_timeline_retention_days() -> u64 {
|
||||
90
|
||||
}
|
||||
fn default_max_failures_before_degrade() -> usize {
|
||||
3
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LLMProviderConfig {
|
||||
@ -466,7 +481,11 @@ pub enum ConfigError {
|
||||
impl std::fmt::Display for ConfigError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ConfigError::ConfigNotFound(path) => write!(f, "Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json", path),
|
||||
ConfigError::ConfigNotFound(path) => write!(
|
||||
f,
|
||||
"Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json",
|
||||
path
|
||||
),
|
||||
ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name),
|
||||
ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name),
|
||||
ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name),
|
||||
|
||||
@ -1,19 +1,19 @@
|
||||
pub mod http;
|
||||
pub mod ws;
|
||||
|
||||
use axum::{Router, routing};
|
||||
use std::sync::Arc;
|
||||
use axum::{routing, Router};
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
use crate::bus::{ControlMessage, MessageBus, OutboundDispatcher};
|
||||
use crate::channels::{ChannelManager, CliChatChannel};
|
||||
use crate::channels::base::{Channel, ChannelError};
|
||||
use crate::config::{Config, expand_path, ensure_workspace_dir};
|
||||
use crate::channels::{ChannelManager, CliChatChannel};
|
||||
use crate::config::{Config, ensure_workspace_dir, expand_path};
|
||||
use crate::logging;
|
||||
use crate::mcp;
|
||||
use crate::memory::MemoryManager;
|
||||
use crate::session::SessionManager;
|
||||
use crate::scheduler::Scheduler;
|
||||
use crate::session::SessionManager;
|
||||
|
||||
pub struct GatewayState {
|
||||
pub config: Config,
|
||||
@ -32,8 +32,13 @@ impl GatewayState {
|
||||
let workspace_path = ensure_workspace_dir(&workspace_path)?;
|
||||
|
||||
// Switch current working directory to workspace
|
||||
std::env::set_current_dir(&workspace_path)
|
||||
.map_err(|e| format!("Failed to switch to workspace directory {}: {}", workspace_path.display(), e))?;
|
||||
std::env::set_current_dir(&workspace_path).map_err(|e| {
|
||||
format!(
|
||||
"Failed to switch to workspace directory {}: {}",
|
||||
workspace_path.display(),
|
||||
e
|
||||
)
|
||||
})?;
|
||||
|
||||
tracing::info!("Using workspace directory: {}", workspace_path.display());
|
||||
|
||||
@ -52,8 +57,9 @@ impl GatewayState {
|
||||
workspace_path.join("picobot.db")
|
||||
};
|
||||
let storage = Arc::new(
|
||||
crate::storage::Storage::new(&db_path).await
|
||||
.map_err(|e| format!("failed to initialize session storage: {}", e))?
|
||||
crate::storage::Storage::new(&db_path)
|
||||
.await
|
||||
.map_err(|e| format!("failed to initialize session storage: {}", e))?,
|
||||
);
|
||||
tracing::info!("Session storage: {}", db_path.display());
|
||||
|
||||
@ -91,13 +97,16 @@ impl GatewayState {
|
||||
bus.clone(),
|
||||
memory_manager,
|
||||
browser_config,
|
||||
config.gateway.max_concurrent_background_tasks,
|
||||
)?;
|
||||
let session_manager = Arc::new(session_manager);
|
||||
|
||||
// Create ChannelManager and init channels
|
||||
let cli_chat_channel = Arc::new(CliChatChannel::new());
|
||||
let channel_manager = ChannelManager::with_bus(cli_chat_channel, bus);
|
||||
channel_manager.init(&config, workspace_path.clone()).await
|
||||
channel_manager
|
||||
.init(&config, workspace_path.clone())
|
||||
.await
|
||||
.map_err(|e| format!("Failed to init channels: {}", e))?;
|
||||
|
||||
// Register send_message tool with available channel names
|
||||
@ -106,9 +115,12 @@ impl GatewayState {
|
||||
session_manager.register_outbound_tool(available_channels);
|
||||
|
||||
// Register chat_manager tool
|
||||
session_manager.tools().register(
|
||||
crate::tools::ChatManagerTool::new(storage.clone(), valid_channels.clone()),
|
||||
);
|
||||
session_manager
|
||||
.tools()
|
||||
.register(crate::tools::ChatManagerTool::new(
|
||||
storage.clone(),
|
||||
valid_channels.clone(),
|
||||
));
|
||||
|
||||
// Initialize MCP servers — connect and register discovered tools
|
||||
if !config.mcp.servers.is_empty() {
|
||||
@ -129,24 +141,27 @@ impl GatewayState {
|
||||
let scheduler_config = config.gateway.scheduler.clone().unwrap_or_default();
|
||||
if scheduler_config.enabled {
|
||||
// Register cron tools
|
||||
session_manager.tools().register(
|
||||
crate::tools::cron::CronAddTool::new(storage.clone(), valid_channels),
|
||||
);
|
||||
session_manager.tools().register(
|
||||
crate::tools::cron::CronListTool::new(storage.clone()),
|
||||
);
|
||||
session_manager.tools().register(
|
||||
crate::tools::cron::CronRemoveTool::new(storage.clone()),
|
||||
);
|
||||
session_manager.tools().register(
|
||||
crate::tools::cron::CronEnableTool::new(storage.clone()),
|
||||
);
|
||||
session_manager.tools().register(
|
||||
crate::tools::cron::CronDisableTool::new(storage.clone()),
|
||||
);
|
||||
session_manager.tools().register(
|
||||
crate::tools::cron::CronUpdateTool::new(storage.clone()),
|
||||
);
|
||||
session_manager
|
||||
.tools()
|
||||
.register(crate::tools::cron::CronAddTool::new(
|
||||
storage.clone(),
|
||||
valid_channels,
|
||||
));
|
||||
session_manager
|
||||
.tools()
|
||||
.register(crate::tools::cron::CronListTool::new(storage.clone()));
|
||||
session_manager
|
||||
.tools()
|
||||
.register(crate::tools::cron::CronRemoveTool::new(storage.clone()));
|
||||
session_manager
|
||||
.tools()
|
||||
.register(crate::tools::cron::CronEnableTool::new(storage.clone()));
|
||||
session_manager
|
||||
.tools()
|
||||
.register(crate::tools::cron::CronDisableTool::new(storage.clone()));
|
||||
session_manager
|
||||
.tools()
|
||||
.register(crate::tools::cron::CronUpdateTool::new(storage.clone()));
|
||||
tracing::info!("Cron tools registered");
|
||||
}
|
||||
|
||||
@ -267,71 +282,103 @@ impl GatewayState {
|
||||
}
|
||||
|
||||
/// Handle control messages (session management operations)
|
||||
async fn handle_control_message(
|
||||
session_manager: &SessionManager,
|
||||
msg: ControlMessage,
|
||||
) {
|
||||
async fn handle_control_message(session_manager: &SessionManager, msg: ControlMessage) {
|
||||
use crate::session::{SessionCommand::*, SessionEvent};
|
||||
|
||||
let reply_tx = msg.reply_tx;
|
||||
let result: Result<SessionEvent, ChannelError> = match msg.op {
|
||||
CreateDialog { channel, chat_id, title } => {
|
||||
session_manager.create_dialog(&channel, &chat_id, title.as_deref()).await
|
||||
.map(|(session_id, title)| SessionEvent::DialogCreated { session_id, title })
|
||||
.map_err(|e| ChannelError::Other(e.to_string()))
|
||||
}
|
||||
ListDialogs { channel, chat_id, include_archived } => {
|
||||
session_manager.list_dialogs(&channel, &chat_id, include_archived).await
|
||||
.map(|(dialogs, current_dialog_id)| SessionEvent::DialogList { dialogs, current_dialog_id })
|
||||
.map_err(|e| ChannelError::Other(e.to_string()))
|
||||
}
|
||||
GetCurrentDialog { channel, chat_id } => {
|
||||
session_manager.get_current_dialog(&channel, &chat_id).await
|
||||
.map(|session_id| SessionEvent::CurrentDialog { session_id })
|
||||
.map_err(|e| ChannelError::Other(e.to_string()))
|
||||
}
|
||||
SwitchDialog { channel, chat_id, dialog_id } => {
|
||||
session_manager.switch_dialog(&channel, &chat_id, &dialog_id).await
|
||||
.map(|session_id| SessionEvent::DialogSwitched { session_id })
|
||||
.map_err(|e| ChannelError::Other(e.to_string()))
|
||||
}
|
||||
RenameDialog { session_id, title } => {
|
||||
session_manager.rename_dialog(&session_id, &title).await
|
||||
.map(|()| SessionEvent::DialogRenamed { session_id, title })
|
||||
.map_err(|e| ChannelError::Other(e.to_string()))
|
||||
}
|
||||
ArchiveDialog { session_id } => {
|
||||
session_manager.archive_dialog(&session_id)
|
||||
.map(|()| SessionEvent::DialogArchived { session_id })
|
||||
.map_err(|e| ChannelError::Other(e.to_string()))
|
||||
}
|
||||
DeleteDialog { session_id } => {
|
||||
session_manager.delete_dialog(&session_id).await
|
||||
.map(|()| SessionEvent::DialogDeleted { session_id })
|
||||
.map_err(|e| ChannelError::Other(e.to_string()))
|
||||
}
|
||||
ClearHistory { session_id } => {
|
||||
session_manager.clear_dialog_history(&session_id)
|
||||
.map(|()| SessionEvent::HistoryCleared { session_id })
|
||||
.map_err(|e| ChannelError::Other(e.to_string()))
|
||||
}
|
||||
GetSlashCommands { channel: _, chat_id: _ } => {
|
||||
CreateDialog {
|
||||
channel,
|
||||
chat_id,
|
||||
title,
|
||||
} => session_manager
|
||||
.create_dialog(&channel, &chat_id, title.as_deref())
|
||||
.await
|
||||
.map(|(session_id, title)| SessionEvent::DialogCreated { session_id, title })
|
||||
.map_err(|e| ChannelError::Other(e.to_string())),
|
||||
ListDialogs {
|
||||
channel,
|
||||
chat_id,
|
||||
include_archived,
|
||||
} => session_manager
|
||||
.list_dialogs(&channel, &chat_id, include_archived)
|
||||
.await
|
||||
.map(|(dialogs, current_dialog_id)| SessionEvent::DialogList {
|
||||
dialogs,
|
||||
current_dialog_id,
|
||||
})
|
||||
.map_err(|e| ChannelError::Other(e.to_string())),
|
||||
GetCurrentDialog { channel, chat_id } => session_manager
|
||||
.get_current_dialog(&channel, &chat_id)
|
||||
.await
|
||||
.map(|session_id| SessionEvent::CurrentDialog { session_id })
|
||||
.map_err(|e| ChannelError::Other(e.to_string())),
|
||||
SwitchDialog {
|
||||
channel,
|
||||
chat_id,
|
||||
dialog_id,
|
||||
} => session_manager
|
||||
.switch_dialog(&channel, &chat_id, &dialog_id)
|
||||
.await
|
||||
.map(|session_id| SessionEvent::DialogSwitched { session_id })
|
||||
.map_err(|e| ChannelError::Other(e.to_string())),
|
||||
RenameDialog { session_id, title } => session_manager
|
||||
.rename_dialog(&session_id, &title)
|
||||
.await
|
||||
.map(|()| SessionEvent::DialogRenamed { session_id, title })
|
||||
.map_err(|e| ChannelError::Other(e.to_string())),
|
||||
ArchiveDialog { session_id } => session_manager
|
||||
.archive_dialog(&session_id)
|
||||
.await
|
||||
.map(|()| SessionEvent::DialogArchived { session_id })
|
||||
.map_err(|e| ChannelError::Other(e.to_string())),
|
||||
DeleteDialog { session_id } => session_manager
|
||||
.delete_dialog(&session_id)
|
||||
.await
|
||||
.map(|()| SessionEvent::DialogDeleted { session_id })
|
||||
.map_err(|e| ChannelError::Other(e.to_string())),
|
||||
ClearHistory { session_id } => session_manager
|
||||
.clear_dialog_history(&session_id)
|
||||
.await
|
||||
.map(|()| SessionEvent::HistoryCleared { session_id })
|
||||
.map_err(|e| ChannelError::Other(e.to_string())),
|
||||
GetSlashCommands {
|
||||
channel: _,
|
||||
chat_id: _,
|
||||
} => {
|
||||
let commands = session_manager.get_slash_commands().to_vec();
|
||||
Ok(SessionEvent::SlashCommandsList { commands })
|
||||
}
|
||||
ExecuteSlashCommand { command, args, channel, chat_id, current_session_id } => {
|
||||
session_manager.execute_slash_command(&command, args.as_deref(), &channel, &chat_id, current_session_id.as_ref())
|
||||
.await
|
||||
.map(|(new_id, msg)| SessionEvent::SlashCommandExecuted { new_session_id: new_id, message: msg })
|
||||
.map_err(|e| ChannelError::Other(e.to_string()))
|
||||
}
|
||||
ExecuteSlashCommand {
|
||||
command,
|
||||
args,
|
||||
channel,
|
||||
chat_id,
|
||||
current_session_id,
|
||||
} => session_manager
|
||||
.execute_slash_command(
|
||||
&command,
|
||||
args.as_deref(),
|
||||
&channel,
|
||||
&chat_id,
|
||||
current_session_id.as_ref(),
|
||||
)
|
||||
.await
|
||||
.map(|(new_id, msg)| SessionEvent::SlashCommandExecuted {
|
||||
new_session_id: new_id,
|
||||
message: msg,
|
||||
})
|
||||
.map_err(|e| ChannelError::Other(e.to_string())),
|
||||
};
|
||||
|
||||
let _ = reply_tx.send(result).await;
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn std::error::Error>> {
|
||||
pub async fn run(
|
||||
host: Option<String>,
|
||||
port: Option<u16>,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Initialize logging
|
||||
logging::init_logging();
|
||||
tracing::info!("Starting PicoBot Gateway");
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
use std::sync::Arc;
|
||||
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage};
|
||||
use super::GatewayState;
|
||||
use crate::protocol::WsOutbound;
|
||||
use crate::protocol::serialize_outbound;
|
||||
use axum::extract::State;
|
||||
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
|
||||
use axum::response::Response;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use crate::protocol::serialize_outbound;
|
||||
use crate::protocol::WsOutbound;
|
||||
use super::GatewayState;
|
||||
|
||||
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
|
||||
ws.on_upgrade(|socket| async move {
|
||||
@ -25,9 +25,11 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
||||
let (session_id, client) = cli_chat_channel.register_client(sender.clone()).await;
|
||||
|
||||
// Send session established message
|
||||
let _ = sender.send(WsOutbound::SessionEstablished {
|
||||
session_id: session_id.clone(),
|
||||
}).await;
|
||||
let _ = sender
|
||||
.send(WsOutbound::SessionEstablished {
|
||||
session_id: session_id.clone(),
|
||||
})
|
||||
.await;
|
||||
|
||||
tracing::info!(session_id = %session_id, "CLI session established");
|
||||
|
||||
@ -37,9 +39,10 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
||||
tokio::spawn(async move {
|
||||
while let Some(msg) = receiver.recv().await {
|
||||
if let Ok(text) = serialize_outbound(&msg)
|
||||
&& ws_sender.send(WsMessage::Text(text.into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
&& ws_sender.send(WsMessage::Text(text.into())).await.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
14
src/lib.rs
14
src/lib.rs
@ -1,17 +1,17 @@
|
||||
pub mod config;
|
||||
pub mod providers;
|
||||
pub mod bus;
|
||||
pub mod agent;
|
||||
pub mod gateway;
|
||||
pub mod session;
|
||||
pub mod client;
|
||||
pub mod protocol;
|
||||
pub mod bus;
|
||||
pub mod channels;
|
||||
pub mod client;
|
||||
pub mod config;
|
||||
pub mod gateway;
|
||||
pub mod logging;
|
||||
pub mod mcp;
|
||||
pub mod memory;
|
||||
pub mod observability;
|
||||
pub mod protocol;
|
||||
pub mod providers;
|
||||
pub mod scheduler;
|
||||
pub mod session;
|
||||
pub mod skills;
|
||||
pub mod storage;
|
||||
pub mod tools;
|
||||
|
||||
@ -1,11 +1,7 @@
|
||||
use std::path::PathBuf;
|
||||
use tracing_appender::rolling::{RollingFileAppender, Rotation};
|
||||
use tracing_subscriber::{
|
||||
fmt,
|
||||
layer::SubscriberExt,
|
||||
util::SubscriberInitExt,
|
||||
fmt::time::LocalTime,
|
||||
EnvFilter,
|
||||
EnvFilter, fmt, fmt::time::LocalTime, layer::SubscriberExt, util::SubscriberInitExt,
|
||||
};
|
||||
|
||||
/// Get the default log directory path: ~/.picobot/logs
|
||||
@ -27,20 +23,20 @@ pub fn init_logging() {
|
||||
|
||||
// Create log directory if it doesn't exist
|
||||
if !log_dir.exists()
|
||||
&& let Err(e) = std::fs::create_dir_all(&log_dir) {
|
||||
eprintln!("Warning: Failed to create log directory {}: {}", log_dir.display(), e);
|
||||
}
|
||||
&& let Err(e) = std::fs::create_dir_all(&log_dir)
|
||||
{
|
||||
eprintln!(
|
||||
"Warning: Failed to create log directory {}: {}",
|
||||
log_dir.display(),
|
||||
e
|
||||
);
|
||||
}
|
||||
|
||||
// Create file appender with daily rotation
|
||||
let file_appender = RollingFileAppender::new(
|
||||
Rotation::DAILY,
|
||||
&log_dir,
|
||||
"picobot.log",
|
||||
);
|
||||
let file_appender = RollingFileAppender::new(Rotation::DAILY, &log_dir, "picobot.log");
|
||||
|
||||
// Build subscriber with both console and file output
|
||||
let env_filter = EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| EnvFilter::new("info"));
|
||||
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
|
||||
|
||||
let file_layer = fmt::layer()
|
||||
.with_writer(file_appender)
|
||||
@ -66,8 +62,7 @@ pub fn init_logging() {
|
||||
|
||||
/// Initialize logging without file output (console only)
|
||||
pub fn init_logging_console_only() {
|
||||
let env_filter = EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| EnvFilter::new("info"));
|
||||
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
|
||||
|
||||
let console_layer = fmt::layer()
|
||||
.with_timer(LocalTime::rfc_3339())
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
use clap::{Parser, CommandFactory};
|
||||
use clap::{CommandFactory, Parser};
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(name = "picobot")]
|
||||
#[command(about = "A CLI chatbot", long_about = None)]
|
||||
#[command(version = "1.1.0")]
|
||||
enum Command {
|
||||
/// Connect to gateway
|
||||
Chat {
|
||||
|
||||
@ -92,24 +92,19 @@ fn extract_text(result: &rmcp::model::CallToolResult) -> String {
|
||||
parts.push(text.text.clone());
|
||||
}
|
||||
RawContent::Image(image) => {
|
||||
parts.push(format!(
|
||||
"[image: {}]",
|
||||
image.mime_type,
|
||||
));
|
||||
parts.push(format!("[image: {}]", image.mime_type,));
|
||||
}
|
||||
RawContent::Resource(resource) => {
|
||||
match &resource.resource {
|
||||
rmcp::model::ResourceContents::TextResourceContents { text, .. } => {
|
||||
parts.push(format!(
|
||||
"[resource text: {}]",
|
||||
text.chars().take(200).collect::<String>(),
|
||||
));
|
||||
}
|
||||
rmcp::model::ResourceContents::BlobResourceContents { uri, .. } => {
|
||||
parts.push(format!("[resource blob: {}]", uri));
|
||||
}
|
||||
RawContent::Resource(resource) => match &resource.resource {
|
||||
rmcp::model::ResourceContents::TextResourceContents { text, .. } => {
|
||||
parts.push(format!(
|
||||
"[resource text: {}]",
|
||||
text.chars().take(200).collect::<String>(),
|
||||
));
|
||||
}
|
||||
}
|
||||
rmcp::model::ResourceContents::BlobResourceContents { uri, .. } => {
|
||||
parts.push(format!("[resource blob: {}]", uri));
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
parts.push("[unsupported content]".to_string());
|
||||
}
|
||||
@ -225,8 +220,8 @@ async fn connect_server(config: &McpServerConfig) -> anyhow::Result<McpConnectio
|
||||
cmd.env(k, v);
|
||||
}
|
||||
|
||||
let service = ()
|
||||
.serve(
|
||||
let service =
|
||||
().serve(
|
||||
TokioChildProcess::new(cmd).context("failed to create stdio MCP transport")?,
|
||||
)
|
||||
.await
|
||||
@ -261,14 +256,14 @@ async fn connect_server(config: &McpServerConfig) -> anyhow::Result<McpConnectio
|
||||
} else {
|
||||
StreamableHttpClientTransport::from_config(
|
||||
StreamableHttpClientTransportConfig::with_uri(url.to_string())
|
||||
.custom_headers(headers_map)
|
||||
.custom_headers(headers_map),
|
||||
)
|
||||
};
|
||||
|
||||
let service = ()
|
||||
.serve(transport)
|
||||
.await
|
||||
.context("failed to connect to HTTP/SSE MCP server")?;
|
||||
let service =
|
||||
().serve(transport)
|
||||
.await
|
||||
.context("failed to connect to HTTP/SSE MCP server")?;
|
||||
|
||||
let peer = service.peer().clone();
|
||||
|
||||
|
||||
@ -102,7 +102,11 @@ mod tests {
|
||||
let dir = tempdir().unwrap();
|
||||
let db_path = dir.path().join("test.db");
|
||||
let storage = Arc::new(Storage::new(&db_path).await.unwrap());
|
||||
let mm = Arc::new(MemoryManager::new(storage, "default".into(), "default".into()));
|
||||
let mm = Arc::new(MemoryManager::new(
|
||||
storage,
|
||||
"default".into(),
|
||||
"default".into(),
|
||||
));
|
||||
(mm, dir)
|
||||
}
|
||||
|
||||
@ -131,15 +135,9 @@ mod tests {
|
||||
async fn test_upsert_overwrites() {
|
||||
let (mm, _dir) = setup_memory_manager().await;
|
||||
|
||||
mm.store(
|
||||
"dup_key",
|
||||
"original",
|
||||
MemoryCategory::Knowledge,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
mm.store("dup_key", "original", MemoryCategory::Knowledge, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mm.store(
|
||||
"dup_key",
|
||||
"updated",
|
||||
@ -247,7 +245,12 @@ mod tests {
|
||||
|
||||
// Recall scoped to session A — should get only tl_a
|
||||
let scoped = mm
|
||||
.recall("summary", 10, Some(MemoryCategory::Timeline), Some("chan:chat:dialog_a"))
|
||||
.recall(
|
||||
"summary",
|
||||
10,
|
||||
Some(MemoryCategory::Timeline),
|
||||
Some("chan:chat:dialog_a"),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(scoped.len(), 1);
|
||||
|
||||
@ -20,10 +20,7 @@ pub enum ObserverEvent {
|
||||
success: bool,
|
||||
},
|
||||
/// Emitted when the agent starts processing.
|
||||
AgentStart {
|
||||
provider: String,
|
||||
model: String,
|
||||
},
|
||||
AgentStart { provider: String, model: String },
|
||||
/// Emitted when the agent finishes processing.
|
||||
AgentEnd {
|
||||
provider: String,
|
||||
@ -94,7 +91,11 @@ impl ToolExecutionOutcome {
|
||||
}
|
||||
|
||||
/// Create a failed outcome with duration.
|
||||
pub fn failure_with_duration(output: String, error_reason: Option<String>, duration: Duration) -> Self {
|
||||
pub fn failure_with_duration(
|
||||
output: String,
|
||||
error_reason: Option<String>,
|
||||
duration: Duration,
|
||||
) -> Self {
|
||||
Self {
|
||||
output,
|
||||
success: false,
|
||||
|
||||
@ -4,23 +4,24 @@ use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::bus::message::ContentBlock;
|
||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
||||
use super::traits::Usage;
|
||||
use std::sync::Arc;
|
||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
||||
use crate::bus::message::ContentBlock;
|
||||
use crate::storage::Storage;
|
||||
use std::sync::Arc;
|
||||
|
||||
const LLM_REQUEST_TIMEOUT_SECS: u64 = 300;
|
||||
|
||||
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> {
|
||||
blocks.iter().map(|b| match b {
|
||||
ContentBlock::Text { text } => {
|
||||
serde_json::json!({ "type": "text", "text": text })
|
||||
}
|
||||
ContentBlock::ImageUrl { image_url } => {
|
||||
convert_image_url_to_anthropic(&image_url.url)
|
||||
}
|
||||
}).collect()
|
||||
blocks
|
||||
.iter()
|
||||
.map(|b| match b {
|
||||
ContentBlock::Text { text } => {
|
||||
serde_json::json!({ "type": "text", "text": text })
|
||||
}
|
||||
ContentBlock::ImageUrl { image_url } => convert_image_url_to_anthropic(&image_url.url),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn convert_image_url_to_anthropic(url: &str) -> serde_json::Value {
|
||||
@ -197,8 +198,13 @@ impl LLMProvider for AnthropicProvider {
|
||||
};
|
||||
let content = if let Some(ref tc_id) = m.tool_call_id {
|
||||
// Tool result: wrap as tool_result content block
|
||||
let output = m.content.iter()
|
||||
.filter_map(|b| match b { ContentBlock::Text { text } => Some(text.as_str()), _ => None })
|
||||
let output = m
|
||||
.content
|
||||
.iter()
|
||||
.filter_map(|b| match b {
|
||||
ContentBlock::Text { text } => Some(text.as_str()),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("");
|
||||
vec![serde_json::json!({
|
||||
@ -244,19 +250,18 @@ impl LLMProvider for AnthropicProvider {
|
||||
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
|
||||
tracing::debug!(req_body = %req_body_str, "LLM request");
|
||||
|
||||
let resp = req_builder.json(&body).send().await
|
||||
.inspect_err(|e| {
|
||||
let is_timeout = e.is_timeout();
|
||||
tracing::error!(
|
||||
provider = %self.name,
|
||||
model = %self.model_id,
|
||||
url = %url,
|
||||
timeout = is_timeout,
|
||||
error = %e,
|
||||
elapsed_ms = %start.elapsed().as_millis(),
|
||||
"LLM API request failed"
|
||||
);
|
||||
})?;
|
||||
let resp = req_builder.json(&body).send().await.inspect_err(|e| {
|
||||
let is_timeout = e.is_timeout();
|
||||
tracing::error!(
|
||||
provider = %self.name,
|
||||
model = %self.model_id,
|
||||
url = %url,
|
||||
timeout = is_timeout,
|
||||
error = %e,
|
||||
elapsed_ms = %start.elapsed().as_millis(),
|
||||
"LLM API request failed"
|
||||
);
|
||||
})?;
|
||||
|
||||
let status = resp.status();
|
||||
let body_text = resp.text().await?;
|
||||
@ -281,32 +286,38 @@ impl LLMProvider for AnthropicProvider {
|
||||
"LLM API returned error"
|
||||
);
|
||||
if let Some(ref storage) = self.storage {
|
||||
let _ = storage.append_llm_call(
|
||||
&self.name, &self.model_id, &req_body_str,
|
||||
Some(&body_text), Some(&error_msg),
|
||||
start.elapsed().as_millis() as u64,
|
||||
).await;
|
||||
let _ = storage
|
||||
.append_llm_call(
|
||||
&self.name,
|
||||
&self.model_id,
|
||||
&req_body_str,
|
||||
Some(&body_text),
|
||||
Some(&error_msg),
|
||||
start.elapsed().as_millis() as u64,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
return Err(format!("API error ({}): {}", status.as_u16(), error_msg).into());
|
||||
}
|
||||
|
||||
let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text)
|
||||
.map_err(|e| {
|
||||
let err_msg = format!("decode error: {} | body: {}", e, &body_text);
|
||||
if let Some(ref storage) = self.storage {
|
||||
let name = self.name.clone();
|
||||
let model = self.model_id.clone();
|
||||
let req = req_body_str.clone();
|
||||
let resp_body = body_text.clone();
|
||||
let dur = start.elapsed().as_millis() as u64;
|
||||
let err = err_msg.clone();
|
||||
let s = storage.clone();
|
||||
tokio::spawn(async move {
|
||||
let _ = s.append_llm_call(&name, &model, &req, Some(&resp_body), Some(&err), dur).await;
|
||||
});
|
||||
}
|
||||
err_msg
|
||||
})?;
|
||||
let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text).map_err(|e| {
|
||||
let err_msg = format!("decode error: {} | body: {}", e, &body_text);
|
||||
if let Some(ref storage) = self.storage {
|
||||
let name = self.name.clone();
|
||||
let model = self.model_id.clone();
|
||||
let req = req_body_str.clone();
|
||||
let resp_body = body_text.clone();
|
||||
let dur = start.elapsed().as_millis() as u64;
|
||||
let err = err_msg.clone();
|
||||
let s = storage.clone();
|
||||
tokio::spawn(async move {
|
||||
let _ = s
|
||||
.append_llm_call(&name, &model, &req, Some(&resp_body), Some(&err), dur)
|
||||
.await;
|
||||
});
|
||||
}
|
||||
err_msg
|
||||
})?;
|
||||
|
||||
let mut content = String::new();
|
||||
let mut reasoning = None;
|
||||
@ -343,21 +354,35 @@ impl LLMProvider for AnthropicProvider {
|
||||
reasoning_content: reasoning,
|
||||
tool_calls,
|
||||
usage: Usage {
|
||||
prompt_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens).unwrap_or(0),
|
||||
completion_tokens: anthropic_resp.usage.as_ref().map(|u| u.output_tokens).unwrap_or(0),
|
||||
total_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens + u.output_tokens).unwrap_or(0),
|
||||
prompt_tokens: anthropic_resp
|
||||
.usage
|
||||
.as_ref()
|
||||
.map(|u| u.input_tokens)
|
||||
.unwrap_or(0),
|
||||
completion_tokens: anthropic_resp
|
||||
.usage
|
||||
.as_ref()
|
||||
.map(|u| u.output_tokens)
|
||||
.unwrap_or(0),
|
||||
total_tokens: anthropic_resp
|
||||
.usage
|
||||
.as_ref()
|
||||
.map(|u| u.input_tokens + u.output_tokens)
|
||||
.unwrap_or(0),
|
||||
},
|
||||
};
|
||||
|
||||
if let Some(ref storage) = self.storage {
|
||||
let _ = storage.append_llm_call(
|
||||
&self.name,
|
||||
&self.model_id,
|
||||
&req_body_str,
|
||||
Some(&body_text),
|
||||
None,
|
||||
start.elapsed().as_millis() as u64,
|
||||
).await;
|
||||
let _ = storage
|
||||
.append_llm_call(
|
||||
&self.name,
|
||||
&self.model_id,
|
||||
&req_body_str,
|
||||
Some(&body_text),
|
||||
None,
|
||||
start.elapsed().as_millis() as u64,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
|
||||
@ -1,12 +1,15 @@
|
||||
pub mod traits;
|
||||
pub mod openai;
|
||||
pub mod anthropic;
|
||||
pub mod openai;
|
||||
pub mod traits;
|
||||
|
||||
pub use self::openai::OpenAIProvider;
|
||||
pub use self::anthropic::AnthropicProvider;
|
||||
pub use self::openai::OpenAIProvider;
|
||||
|
||||
use crate::config::LLMProviderConfig;
|
||||
pub use traits::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall, ToolFunction, Usage};
|
||||
pub use traits::{
|
||||
ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall,
|
||||
ToolFunction, Usage,
|
||||
};
|
||||
|
||||
pub fn create_provider(config: LLMProviderConfig) -> Result<Box<dyn LLMProvider>, ProviderError> {
|
||||
match config.provider_type.as_str() {
|
||||
|
||||
@ -1,29 +1,35 @@
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use serde_json::{Value, json};
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::bus::message::ContentBlock;
|
||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
||||
use super::traits::Usage;
|
||||
use std::sync::Arc;
|
||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
||||
use crate::bus::message::ContentBlock;
|
||||
use crate::storage::Storage;
|
||||
use std::sync::Arc;
|
||||
|
||||
const LLM_REQUEST_TIMEOUT_SECS: u64 = 300;
|
||||
|
||||
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
|
||||
if blocks.len() == 1
|
||||
&& let ContentBlock::Text { text } = &blocks[0] {
|
||||
return Value::String(text.clone());
|
||||
}
|
||||
Value::Array(blocks.iter().map(|b| match b {
|
||||
ContentBlock::Text { text } => json!({ "type": "text", "text": text }),
|
||||
ContentBlock::ImageUrl { image_url } => {
|
||||
json!({ "type": "image_url", "image_url": { "url": image_url.url } })
|
||||
}
|
||||
}).collect())
|
||||
&& let ContentBlock::Text { text } = &blocks[0]
|
||||
{
|
||||
return Value::String(text.clone());
|
||||
}
|
||||
Value::Array(
|
||||
blocks
|
||||
.iter()
|
||||
.map(|b| match b {
|
||||
ContentBlock::Text { text } => json!({ "type": "text", "text": text }),
|
||||
ContentBlock::ImageUrl { image_url } => {
|
||||
json!({ "type": "image_url", "image_url": { "url": image_url.url } })
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
pub struct OpenAIProvider {
|
||||
@ -201,10 +207,14 @@ impl LLMProvider for OpenAIProvider {
|
||||
if let Some(content) = msg.get("content").and_then(|c| c.as_array()) {
|
||||
for (j, item) in content.iter().enumerate() {
|
||||
if item.get("type").and_then(|t| t.as_str()) == Some("image_url")
|
||||
&& let Some(url_str) = item.get("image_url").and_then(|u| u.get("url")).and_then(|v| v.as_str()) {
|
||||
let prefix: String = url_str.chars().take(20).collect();
|
||||
tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)");
|
||||
}
|
||||
&& let Some(url_str) = item
|
||||
.get("image_url")
|
||||
.and_then(|u| u.get("url"))
|
||||
.and_then(|v| v.as_str())
|
||||
{
|
||||
let prefix: String = url_str.chars().take(20).collect();
|
||||
tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -224,19 +234,18 @@ impl LLMProvider for OpenAIProvider {
|
||||
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
|
||||
tracing::debug!(req_body = %req_body_str, "LLM request");
|
||||
|
||||
let resp = req_builder.json(&body).send().await
|
||||
.inspect_err(|e| {
|
||||
let is_timeout = e.is_timeout();
|
||||
tracing::error!(
|
||||
provider = %self.name,
|
||||
model = %self.model_id,
|
||||
url = %url,
|
||||
timeout = is_timeout,
|
||||
error = %e,
|
||||
elapsed_ms = %start.elapsed().as_millis(),
|
||||
"LLM API request failed"
|
||||
);
|
||||
})?;
|
||||
let resp = req_builder.json(&body).send().await.inspect_err(|e| {
|
||||
let is_timeout = e.is_timeout();
|
||||
tracing::error!(
|
||||
provider = %self.name,
|
||||
model = %self.model_id,
|
||||
url = %url,
|
||||
timeout = is_timeout,
|
||||
error = %e,
|
||||
elapsed_ms = %start.elapsed().as_millis(),
|
||||
"LLM API request failed"
|
||||
);
|
||||
})?;
|
||||
|
||||
let status = resp.status();
|
||||
let text = resp.text().await?;
|
||||
@ -253,37 +262,48 @@ impl LLMProvider for OpenAIProvider {
|
||||
"LLM API returned error"
|
||||
);
|
||||
if let Some(ref storage) = self.storage
|
||||
&& let Err(e) = storage.append_llm_call(
|
||||
&self.name, &self.model_id, &req_body_str,
|
||||
Some(&text), Some(&error),
|
||||
start.elapsed().as_millis() as u64,
|
||||
).await {
|
||||
tracing::warn!("failed to persist LLM call: {}", e);
|
||||
}
|
||||
&& let Err(e) = storage
|
||||
.append_llm_call(
|
||||
&self.name,
|
||||
&self.model_id,
|
||||
&req_body_str,
|
||||
Some(&text),
|
||||
Some(&error),
|
||||
start.elapsed().as_millis() as u64,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::warn!("failed to persist LLM call: {}", e);
|
||||
}
|
||||
return Err(error.into());
|
||||
}
|
||||
|
||||
let openai_resp: OpenAIResponse = serde_json::from_str(&text)
|
||||
.map_err(|e| {
|
||||
let err_msg = format!("decode error: {} | body: {}", e, &text);
|
||||
if let Some(ref storage) = self.storage {
|
||||
let name = self.name.clone();
|
||||
let model = self.model_id.clone();
|
||||
let req = req_body_str.clone();
|
||||
let resp = text.clone();
|
||||
let dur = start.elapsed().as_millis() as u64;
|
||||
let err = err_msg.clone();
|
||||
let s = storage.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = s.append_llm_call(&name, &model, &req, Some(&resp), Some(&err), dur).await {
|
||||
tracing::warn!("failed to persist LLM call (decode error): {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
err_msg
|
||||
})?;
|
||||
let openai_resp: OpenAIResponse = serde_json::from_str(&text).map_err(|e| {
|
||||
let err_msg = format!("decode error: {} | body: {}", e, &text);
|
||||
if let Some(ref storage) = self.storage {
|
||||
let name = self.name.clone();
|
||||
let model = self.model_id.clone();
|
||||
let req = req_body_str.clone();
|
||||
let resp = text.clone();
|
||||
let dur = start.elapsed().as_millis() as u64;
|
||||
let err = err_msg.clone();
|
||||
let s = storage.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = s
|
||||
.append_llm_call(&name, &model, &req, Some(&resp), Some(&err), dur)
|
||||
.await
|
||||
{
|
||||
tracing::warn!("failed to persist LLM call (decode error): {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
err_msg
|
||||
})?;
|
||||
|
||||
let first_choice = openai_resp.choices.into_iter().next()
|
||||
let first_choice = openai_resp
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or("no choices in response")?;
|
||||
|
||||
let content = first_choice
|
||||
@ -300,7 +320,8 @@ impl LLMProvider for OpenAIProvider {
|
||||
.map(|tc| ToolCall {
|
||||
id: tc.id.clone(),
|
||||
name: tc.function.name.clone(),
|
||||
arguments: serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null),
|
||||
arguments: serde_json::from_str(&tc.function.arguments)
|
||||
.unwrap_or(serde_json::Value::Null),
|
||||
})
|
||||
.collect();
|
||||
|
||||
@ -318,13 +339,19 @@ impl LLMProvider for OpenAIProvider {
|
||||
};
|
||||
|
||||
if let Some(ref storage) = self.storage
|
||||
&& let Err(e) = storage.append_llm_call(
|
||||
&self.name, &self.model_id, &req_body_str,
|
||||
Some(&text), None,
|
||||
start.elapsed().as_millis() as u64,
|
||||
).await {
|
||||
tracing::warn!("failed to persist LLM call: {}", e);
|
||||
}
|
||||
&& let Err(e) = storage
|
||||
.append_llm_call(
|
||||
&self.name,
|
||||
&self.model_id,
|
||||
&req_body_str,
|
||||
Some(&text),
|
||||
None,
|
||||
start.elapsed().as_millis() as u64,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::warn!("failed to persist LLM call: {}", e);
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
@ -386,6 +413,9 @@ mod tests {
|
||||
assert_eq!(tool_calls[0]["id"], "call_1");
|
||||
assert_eq!(tool_calls[0]["type"], "function");
|
||||
assert_eq!(tool_calls[0]["function"]["name"], "calculator");
|
||||
assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}");
|
||||
assert_eq!(
|
||||
tool_calls[0]["function"]["arguments"],
|
||||
"{\"expression\":\"1+1\"}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
use crate::bus::message::ContentBlock;
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::bus::message::ContentBlock;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
@ -61,7 +61,11 @@ impl Message {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
pub fn tool(
|
||||
tool_call_id: impl Into<String>,
|
||||
tool_name: impl Into<String>,
|
||||
content: impl Into<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
role: "tool".to_string(),
|
||||
content: vec![ContentBlock::text(content)],
|
||||
|
||||
@ -5,11 +5,11 @@ use std::time::Instant;
|
||||
use tokio::time;
|
||||
|
||||
use crate::config::SchedulerConfig;
|
||||
use crate::session::session::HandleResult;
|
||||
use crate::session::SessionManager;
|
||||
use crate::session::session::HandleResult;
|
||||
use crate::storage::JobRun;
|
||||
use crate::storage::ScheduledJob;
|
||||
use crate::storage::Storage;
|
||||
use crate::storage::JobRun;
|
||||
|
||||
pub use types::Schedule;
|
||||
|
||||
@ -89,7 +89,11 @@ impl Scheduler {
|
||||
|
||||
let now = now_ms();
|
||||
|
||||
let due = match self.storage.due_scheduled_jobs(now, self.config.max_concurrent).await {
|
||||
let due = match self
|
||||
.storage
|
||||
.due_scheduled_jobs(now, self.config.max_concurrent)
|
||||
.await
|
||||
{
|
||||
Ok(jobs) => jobs,
|
||||
Err(e) => {
|
||||
tracing::error!("scheduler: failed to query due jobs: {}", e);
|
||||
@ -107,7 +111,11 @@ impl Scheduler {
|
||||
let start = Instant::now();
|
||||
let started_at = now_ms();
|
||||
|
||||
if let Err(e) = self.storage.touch_scheduled_job_last_run(&job.id, started_at).await {
|
||||
if let Err(e) = self
|
||||
.storage
|
||||
.touch_scheduled_job_last_run(&job.id, started_at)
|
||||
.await
|
||||
{
|
||||
tracing::error!(job_id = %job.id, "scheduler: failed to touch last_run_at: {}", e);
|
||||
continue;
|
||||
}
|
||||
@ -135,7 +143,10 @@ impl Scheduler {
|
||||
match result {
|
||||
Ok(HandleResult::AgentResponse(output)) => {
|
||||
let output_truncated = if output.len() > 8000 {
|
||||
format!("{}...[truncated]", &output[..output.ceil_char_boundary(8000)])
|
||||
format!(
|
||||
"{}...[truncated]",
|
||||
&output[..output.ceil_char_boundary(8000)]
|
||||
)
|
||||
} else {
|
||||
output.clone()
|
||||
};
|
||||
@ -155,7 +166,11 @@ impl Scheduler {
|
||||
tracing::error!(job_id = %job.id, "scheduler: failed to record run: {}", e);
|
||||
}
|
||||
|
||||
if let Err(e) = self.storage.set_scheduled_job_last_status(&job.id, "ok", None).await {
|
||||
if let Err(e) = self
|
||||
.storage
|
||||
.set_scheduled_job_last_status(&job.id, "ok", None)
|
||||
.await
|
||||
{
|
||||
tracing::error!(job_id = %job.id, "scheduler: failed to set last_status: {}", e);
|
||||
}
|
||||
|
||||
@ -199,9 +214,11 @@ impl Scheduler {
|
||||
tracing::error!(job_id = %job.id, "scheduler: failed to record error run: {}", e2);
|
||||
}
|
||||
|
||||
if let Err(e2) = self.storage.set_scheduled_job_last_status(
|
||||
&job.id, "error", Some(&error_str),
|
||||
).await {
|
||||
if let Err(e2) = self
|
||||
.storage
|
||||
.set_scheduled_job_last_status(&job.id, "error", Some(&error_str))
|
||||
.await
|
||||
{
|
||||
tracing::error!(job_id = %job.id, "scheduler: failed to set error status: {}", e2);
|
||||
}
|
||||
|
||||
@ -231,17 +248,23 @@ impl Scheduler {
|
||||
self.storage.remove_scheduled_job(&job.id).await?;
|
||||
tracing::info!(job_id = %job.id, "scheduler: one-shot job deleted after run");
|
||||
} else {
|
||||
self.storage.set_scheduled_job_enabled(&job.id, false).await?;
|
||||
self.storage
|
||||
.set_scheduled_job_enabled(&job.id, false)
|
||||
.await?;
|
||||
tracing::info!(job_id = %job.id, "scheduler: one-shot job disabled after run");
|
||||
}
|
||||
}
|
||||
Schedule::Every { .. } | Schedule::Cron { .. } => {
|
||||
if let Some(next) = next_run_for_schedule(&job.schedule, now) {
|
||||
self.storage.set_scheduled_job_next_run(&job.id, next).await?;
|
||||
self.storage
|
||||
.set_scheduled_job_next_run(&job.id, next)
|
||||
.await?;
|
||||
tracing::info!(job_id = %job.id, next_run_at = %next, "scheduler: job rescheduled");
|
||||
} else {
|
||||
tracing::error!(job_id = %job.id, "scheduler: could not compute next run -- disabling job");
|
||||
self.storage.set_scheduled_job_enabled(&job.id, false).await?;
|
||||
self.storage
|
||||
.set_scheduled_job_enabled(&job.id, false)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -22,32 +22,20 @@ pub enum SessionCommand {
|
||||
dialog_id: String,
|
||||
},
|
||||
/// Get the current dialog for a chat
|
||||
GetCurrentDialog {
|
||||
channel: String,
|
||||
chat_id: String,
|
||||
},
|
||||
GetCurrentDialog { channel: String, chat_id: String },
|
||||
/// Rename a dialog
|
||||
RenameDialog {
|
||||
session_id: UnifiedSessionId,
|
||||
title: String,
|
||||
},
|
||||
/// Archive a dialog
|
||||
ArchiveDialog {
|
||||
session_id: UnifiedSessionId,
|
||||
},
|
||||
ArchiveDialog { session_id: UnifiedSessionId },
|
||||
/// Delete a dialog
|
||||
DeleteDialog {
|
||||
session_id: UnifiedSessionId,
|
||||
},
|
||||
DeleteDialog { session_id: UnifiedSessionId },
|
||||
/// Clear dialog history
|
||||
ClearHistory {
|
||||
session_id: UnifiedSessionId,
|
||||
},
|
||||
ClearHistory { session_id: UnifiedSessionId },
|
||||
/// Get list of available slash commands
|
||||
GetSlashCommands {
|
||||
channel: String,
|
||||
chat_id: String,
|
||||
},
|
||||
GetSlashCommands { channel: String, chat_id: String },
|
||||
/// Execute a slash command
|
||||
ExecuteSlashCommand {
|
||||
command: String,
|
||||
@ -60,7 +48,11 @@ pub enum SessionCommand {
|
||||
|
||||
impl SessionCommand {
|
||||
/// Create a CreateDialog command
|
||||
pub fn create_dialog(channel: impl Into<String>, chat_id: impl Into<String>, title: Option<String>) -> Self {
|
||||
pub fn create_dialog(
|
||||
channel: impl Into<String>,
|
||||
chat_id: impl Into<String>,
|
||||
title: Option<String>,
|
||||
) -> Self {
|
||||
Self::CreateDialog {
|
||||
channel: channel.into(),
|
||||
chat_id: chat_id.into(),
|
||||
@ -69,7 +61,11 @@ impl SessionCommand {
|
||||
}
|
||||
|
||||
/// Create a ListDialogs command
|
||||
pub fn list_dialogs(channel: impl Into<String>, chat_id: impl Into<String>, include_archived: bool) -> Self {
|
||||
pub fn list_dialogs(
|
||||
channel: impl Into<String>,
|
||||
chat_id: impl Into<String>,
|
||||
include_archived: bool,
|
||||
) -> Self {
|
||||
Self::ListDialogs {
|
||||
channel: channel.into(),
|
||||
chat_id: chat_id.into(),
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
use super::session_id::UnifiedSessionId;
|
||||
use super::session::SlashCommand;
|
||||
use super::session_id::UnifiedSessionId;
|
||||
|
||||
/// Dialog information returned by SessionManager
|
||||
#[derive(Debug, Clone)]
|
||||
@ -30,30 +30,20 @@ pub enum SessionEvent {
|
||||
session_id: Option<UnifiedSessionId>,
|
||||
},
|
||||
/// Dialog switched successfully
|
||||
DialogSwitched {
|
||||
session_id: UnifiedSessionId,
|
||||
},
|
||||
DialogSwitched { session_id: UnifiedSessionId },
|
||||
/// Dialog renamed
|
||||
DialogRenamed {
|
||||
session_id: UnifiedSessionId,
|
||||
title: String,
|
||||
},
|
||||
/// Dialog archived
|
||||
DialogArchived {
|
||||
session_id: UnifiedSessionId,
|
||||
},
|
||||
DialogArchived { session_id: UnifiedSessionId },
|
||||
/// Dialog deleted
|
||||
DialogDeleted {
|
||||
session_id: UnifiedSessionId,
|
||||
},
|
||||
DialogDeleted { session_id: UnifiedSessionId },
|
||||
/// Dialog history cleared
|
||||
HistoryCleared {
|
||||
session_id: UnifiedSessionId,
|
||||
},
|
||||
HistoryCleared { session_id: UnifiedSessionId },
|
||||
/// List of available slash commands
|
||||
SlashCommandsList {
|
||||
commands: Vec<SlashCommand>,
|
||||
},
|
||||
SlashCommandsList { commands: Vec<SlashCommand> },
|
||||
/// Slash command executed successfully
|
||||
SlashCommandExecuted {
|
||||
new_session_id: Option<UnifiedSessionId>,
|
||||
@ -70,8 +60,5 @@ pub enum SessionEvent {
|
||||
message_count: usize,
|
||||
},
|
||||
/// Error occurred
|
||||
Error {
|
||||
code: String,
|
||||
message: String,
|
||||
},
|
||||
Error { code: String, message: String },
|
||||
}
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
pub mod error;
|
||||
pub mod commands;
|
||||
pub mod error;
|
||||
pub mod events;
|
||||
pub mod session;
|
||||
pub mod session_id;
|
||||
|
||||
pub use error::SessionError;
|
||||
pub use commands::SessionCommand;
|
||||
pub use events::{SessionEvent, DialogInfo};
|
||||
pub use session::{Session, SessionManager, SlashCommand, SLASH_COMMANDS};
|
||||
pub use error::SessionError;
|
||||
pub use events::{DialogInfo, SessionEvent};
|
||||
pub use session::{SLASH_COMMANDS, Session, SessionManager, SlashCommand};
|
||||
pub use session_id::UnifiedSessionId;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -8,7 +8,6 @@
|
||||
///
|
||||
/// For simple cases where only one dialog exists per chat:
|
||||
/// - `dialog_id` defaults to `"default"`
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub const DEFAULT_DIALOG_ID: &str = "default";
|
||||
@ -22,7 +21,11 @@ pub struct UnifiedSessionId {
|
||||
|
||||
impl UnifiedSessionId {
|
||||
/// Create a new UnifiedSessionId
|
||||
pub fn new(channel: impl Into<String>, chat_id: impl Into<String>, dialog_id: impl Into<String>) -> Self {
|
||||
pub fn new(
|
||||
channel: impl Into<String>,
|
||||
chat_id: impl Into<String>,
|
||||
dialog_id: impl Into<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
channel: channel.into(),
|
||||
chat_id: chat_id.into(),
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
use std::path::Path;
|
||||
|
||||
use super::embedded::{EmbeddedSkill, EMBEDDED_SKILLS};
|
||||
use super::embedded::{EMBEDDED_SKILLS, EmbeddedSkill};
|
||||
|
||||
pub fn install_builtin_skills(target_dir: &Path) {
|
||||
for skill in EMBEDDED_SKILLS {
|
||||
@ -22,8 +22,7 @@ pub fn install_builtin_skills(target_dir: &Path) {
|
||||
}
|
||||
|
||||
fn install_one(skill: &EmbeddedSkill, target_dir: &Path) -> Result<(), String> {
|
||||
let decompressed = zstd::decode_all(skill.data)
|
||||
.map_err(|e| format!("zstd decode: {}", e))?;
|
||||
let decompressed = zstd::decode_all(skill.data).map_err(|e| format!("zstd decode: {}", e))?;
|
||||
|
||||
let mut archive = tar::Archive::new(decompressed.as_slice());
|
||||
archive
|
||||
|
||||
@ -120,7 +120,11 @@ impl SkillsLoader {
|
||||
let count = loaded.len();
|
||||
let mut replaced = 0usize;
|
||||
for skill in loaded {
|
||||
if let Some(existing) = state.loaded_skills.iter_mut().find(|s| s.name == skill.name) {
|
||||
if let Some(existing) = state
|
||||
.loaded_skills
|
||||
.iter_mut()
|
||||
.find(|s| s.name == skill.name)
|
||||
{
|
||||
*existing = skill;
|
||||
replaced += 1;
|
||||
} else {
|
||||
@ -138,33 +142,42 @@ impl SkillsLoader {
|
||||
|
||||
// Load from workspace skills dir (highest priority) — replace same-name skills
|
||||
if let Some(ref ws_dir) = self.workspace_skills_dir
|
||||
&& ws_dir.exists() {
|
||||
let loaded = self.load_skills_from_dir(ws_dir);
|
||||
let count = loaded.len();
|
||||
let mut replaced = 0usize;
|
||||
for skill in loaded {
|
||||
if let Some(existing) = state.loaded_skills.iter_mut().find(|s| s.name == skill.name) {
|
||||
*existing = skill;
|
||||
replaced += 1;
|
||||
} else {
|
||||
state.loaded_skills.push(skill);
|
||||
}
|
||||
&& ws_dir.exists()
|
||||
{
|
||||
let loaded = self.load_skills_from_dir(ws_dir);
|
||||
let count = loaded.len();
|
||||
let mut replaced = 0usize;
|
||||
for skill in loaded {
|
||||
if let Some(existing) = state
|
||||
.loaded_skills
|
||||
.iter_mut()
|
||||
.find(|s| s.name == skill.name)
|
||||
{
|
||||
*existing = skill;
|
||||
replaced += 1;
|
||||
} else {
|
||||
state.loaded_skills.push(skill);
|
||||
}
|
||||
tracing::debug!(
|
||||
dir = %ws_dir.display(),
|
||||
count = count,
|
||||
replaced = replaced,
|
||||
"Loaded skills from workspace directory"
|
||||
);
|
||||
state.last_workspace_mtime = Self::get_dir_mtime(ws_dir);
|
||||
}
|
||||
tracing::debug!(
|
||||
dir = %ws_dir.display(),
|
||||
count = count,
|
||||
replaced = replaced,
|
||||
"Loaded skills from workspace directory"
|
||||
);
|
||||
state.last_workspace_mtime = Self::get_dir_mtime(ws_dir);
|
||||
}
|
||||
|
||||
state.last_load_time = SystemTime::now();
|
||||
|
||||
if state.loaded_skills.is_empty() {
|
||||
tracing::debug!("No skills found in any skills directory");
|
||||
} else {
|
||||
tracing::info!(count = state.loaded_skills.len(), "Loaded {} skills total", state.loaded_skills.len());
|
||||
tracing::info!(
|
||||
count = state.loaded_skills.len(),
|
||||
"Loaded {} skills total",
|
||||
state.loaded_skills.len()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -215,18 +228,20 @@ impl SkillsLoader {
|
||||
let mut max_mtime = None;
|
||||
|
||||
if let Ok(metadata) = std::fs::metadata(dir)
|
||||
&& let Ok(mtime) = metadata.modified() {
|
||||
max_mtime = Some(mtime);
|
||||
}
|
||||
&& let Ok(mtime) = metadata.modified()
|
||||
{
|
||||
max_mtime = Some(mtime);
|
||||
}
|
||||
|
||||
if let Ok(entries) = std::fs::read_dir(dir) {
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
if let Ok(metadata) = std::fs::metadata(&path)
|
||||
&& let Ok(mtime) = metadata.modified()
|
||||
&& max_mtime.is_none_or(|current| mtime > current) {
|
||||
max_mtime = Some(mtime);
|
||||
}
|
||||
&& max_mtime.is_none_or(|current| mtime > current)
|
||||
{
|
||||
max_mtime = Some(mtime);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -244,7 +259,12 @@ impl SkillsLoader {
|
||||
pub fn get_always_skills(&self) -> Vec<Skill> {
|
||||
self.reload_if_changed();
|
||||
let state = self.state.lock().unwrap();
|
||||
state.loaded_skills.iter().filter(|s| s.always).cloned().collect()
|
||||
state
|
||||
.loaded_skills
|
||||
.iter()
|
||||
.filter(|s| s.always)
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get a specific skill by name (checks for changes first)
|
||||
@ -258,7 +278,8 @@ impl SkillsLoader {
|
||||
pub fn list_skills(&self) -> Vec<(String, String)> {
|
||||
self.reload_if_changed();
|
||||
let state = self.state.lock().unwrap();
|
||||
state.loaded_skills
|
||||
state
|
||||
.loaded_skills
|
||||
.iter()
|
||||
.map(|s| (s.name.clone(), s.description.clone()))
|
||||
.collect()
|
||||
@ -279,15 +300,21 @@ impl SkillsLoader {
|
||||
prompt.push_str("### 目录说明\n\n");
|
||||
prompt.push_str("- `~/.agents/skills/` — 外部共享 skill 目录(第三方、系统级 skill)\n");
|
||||
prompt.push_str("- `~/.picobot/skills/` — 安装 skill 的默认目录\n");
|
||||
prompt.push_str("- `{workspace}/skills/` — 工作目录下的 skill,picobot 自行创建的 skill 存放于此\n\n");
|
||||
prompt.push_str("安装或创建 skill 时请按上述目录规范存放,创建skill时不要和已有skill同名。\n\n");
|
||||
prompt.push_str(
|
||||
"- `{workspace}/skills/` — 工作目录下的 skill,picobot 自行创建的 skill 存放于此\n\n",
|
||||
);
|
||||
prompt.push_str(
|
||||
"安装或创建 skill 时请按上述目录规范存放,创建skill时不要和已有skill同名。\n\n",
|
||||
);
|
||||
|
||||
// Always skills summary
|
||||
let always_skills: Vec<_> = state.loaded_skills.iter().filter(|s| s.always).collect();
|
||||
if !always_skills.is_empty() {
|
||||
prompt.push_str("### 常用技能\n\n");
|
||||
for skill in &always_skills {
|
||||
let path_str = skill.path.as_ref()
|
||||
let path_str = skill
|
||||
.path
|
||||
.as_ref()
|
||||
.map(|p| p.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|| "—".to_string());
|
||||
prompt.push_str(&format!(
|
||||
@ -300,8 +327,12 @@ impl SkillsLoader {
|
||||
|
||||
// Usage instructions
|
||||
prompt.push_str("### 使用方法\n\n");
|
||||
prompt.push_str("- 使用 `get_skill` 工具 action=\"list\" 列出所有可用 skill 及其名称、简介、路径\n");
|
||||
prompt.push_str("- 使用 `get_skill` 工具 action=\"get\" 并提供 `skill_name` 获取指定 skill 完整内容\n");
|
||||
prompt.push_str(
|
||||
"- 使用 `get_skill` 工具 action=\"list\" 列出所有可用 skill 及其名称、简介、路径\n",
|
||||
);
|
||||
prompt.push_str(
|
||||
"- 使用 `get_skill` 工具 action=\"get\" 并提供 `skill_name` 获取指定 skill 完整内容\n",
|
||||
);
|
||||
prompt.push_str("- 当用户询问已安装的 skill 相关信息时,需重新调用 get_skill 工具查询最新内容,避免 skill 已变更导致信息过时\n");
|
||||
|
||||
// Always skills full content
|
||||
@ -338,25 +369,23 @@ impl SkillsLoader {
|
||||
}
|
||||
|
||||
match std::fs::read_to_string(&skill_file) {
|
||||
Ok(content) => {
|
||||
match self.parse_skill(&path, &content) {
|
||||
Some(skill) => {
|
||||
tracing::debug!(
|
||||
skill = %skill.name,
|
||||
path = %skill_file.display(),
|
||||
always = skill.always,
|
||||
"Loaded skill"
|
||||
);
|
||||
skills.push(skill);
|
||||
}
|
||||
None => {
|
||||
tracing::warn!(
|
||||
path = %skill_file.display(),
|
||||
"Failed to parse skill"
|
||||
);
|
||||
}
|
||||
Ok(content) => match self.parse_skill(&path, &content) {
|
||||
Some(skill) => {
|
||||
tracing::debug!(
|
||||
skill = %skill.name,
|
||||
path = %skill_file.display(),
|
||||
always = skill.always,
|
||||
"Loaded skill"
|
||||
);
|
||||
skills.push(skill);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
tracing::warn!(
|
||||
path = %skill_file.display(),
|
||||
"Failed to parse skill"
|
||||
);
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
path = %skill_file.display(),
|
||||
@ -447,7 +476,6 @@ impl Default for SkillsLoader {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Extract first non-empty, non-heading line as description
|
||||
fn extract_description(content: &str) -> String {
|
||||
content
|
||||
|
||||
19
src/storage/background_task.rs
Normal file
19
src/storage/background_task.rs
Normal file
@ -0,0 +1,19 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BackgroundTask {
|
||||
pub id: String,
|
||||
pub session_id: String,
|
||||
pub channel: String,
|
||||
pub chat_id: String,
|
||||
pub prompt: String,
|
||||
pub allowed_tools: Option<String>,
|
||||
pub status: String,
|
||||
pub result: Option<String>,
|
||||
pub error: Option<String>,
|
||||
pub tool_calls_count: i64,
|
||||
pub iterations: i64,
|
||||
pub started_at: Option<i64>,
|
||||
pub finished_at: Option<i64>,
|
||||
pub created_at: i64,
|
||||
}
|
||||
@ -241,12 +241,11 @@ impl super::Storage {
|
||||
let cutoff = chrono::Utc::now() - chrono::Duration::days(retention_days as i64);
|
||||
let cutoff_str = cutoff.to_rfc3339();
|
||||
|
||||
let result = sqlx::query(
|
||||
"DELETE FROM memories WHERE category = 'timeline' AND created_at < ?",
|
||||
)
|
||||
.bind(&cutoff_str)
|
||||
.execute(self.pool())
|
||||
.await?;
|
||||
let result =
|
||||
sqlx::query("DELETE FROM memories WHERE category = 'timeline' AND created_at < ?")
|
||||
.bind(&cutoff_str)
|
||||
.execute(self.pool())
|
||||
.await?;
|
||||
|
||||
Ok(result.rows_affected())
|
||||
}
|
||||
@ -276,9 +275,7 @@ impl super::Storage {
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_memory_rows(
|
||||
rows: &[sqlx::sqlite::SqliteRow],
|
||||
) -> Result<Vec<MemoryEntry>, StorageError> {
|
||||
fn parse_memory_rows(rows: &[sqlx::sqlite::SqliteRow]) -> Result<Vec<MemoryEntry>, StorageError> {
|
||||
rows.iter()
|
||||
.map(|row| {
|
||||
Ok(MemoryEntry {
|
||||
|
||||
@ -1,15 +1,17 @@
|
||||
pub mod background_task;
|
||||
pub mod error;
|
||||
pub mod memory;
|
||||
pub mod message;
|
||||
pub mod scheduler;
|
||||
pub mod session;
|
||||
|
||||
pub use background_task::BackgroundTask;
|
||||
pub use error::StorageError;
|
||||
pub use scheduler::{JobRun, ScheduledJob};
|
||||
|
||||
use sqlx::{Pool, Row, Sqlite, SqlitePool};
|
||||
use tokio::time::{sleep, Duration};
|
||||
use std::path::Path;
|
||||
use tokio::time::{Duration, sleep};
|
||||
|
||||
pub struct Storage {
|
||||
pub(crate) pool: Pool<Sqlite>,
|
||||
@ -40,6 +42,7 @@ impl Storage {
|
||||
last_active_at INTEGER NOT NULL,
|
||||
message_count INTEGER DEFAULT 0,
|
||||
routing_info TEXT,
|
||||
archived_at INTEGER,
|
||||
deleted_at INTEGER,
|
||||
last_consolidated_at INTEGER,
|
||||
last_compressed_message_at INTEGER,
|
||||
@ -90,20 +93,58 @@ impl Storage {
|
||||
.await?;
|
||||
|
||||
// Migration: add source column if upgrading from older schema
|
||||
sqlx::query(
|
||||
r#"ALTER TABLE messages ADD COLUMN source TEXT"#,
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.ok();
|
||||
sqlx::query(r#"ALTER TABLE messages ADD COLUMN source TEXT"#)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.ok();
|
||||
|
||||
// Migration: add reasoning_content column if upgrading from older schema
|
||||
sqlx::query(r#"ALTER TABLE messages ADD COLUMN reasoning_content TEXT"#)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.ok();
|
||||
|
||||
// Background tasks table — for async sub-agent tasks.
|
||||
// Note: No FOREIGN KEY on session_id because sessions use soft delete (deleted_at IS NULL).
|
||||
// Session and task association is maintained at the application level.
|
||||
sqlx::query(
|
||||
r#"ALTER TABLE messages ADD COLUMN reasoning_content TEXT"#,
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS background_tasks (
|
||||
id TEXT PRIMARY KEY,
|
||||
session_id TEXT NOT NULL,
|
||||
channel TEXT NOT NULL,
|
||||
chat_id TEXT NOT NULL,
|
||||
prompt TEXT NOT NULL,
|
||||
allowed_tools TEXT,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
result TEXT,
|
||||
error TEXT,
|
||||
tool_calls_count INTEGER DEFAULT 0,
|
||||
iterations INTEGER DEFAULT 0,
|
||||
started_at INTEGER,
|
||||
finished_at INTEGER,
|
||||
created_at INTEGER NOT NULL
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.ok();
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE INDEX IF NOT EXISTS idx_bg_tasks_session ON background_tasks(session_id)
|
||||
"#,
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE INDEX IF NOT EXISTS idx_bg_tasks_status ON background_tasks(status)
|
||||
"#,
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
@ -172,11 +213,19 @@ impl Storage {
|
||||
.await?;
|
||||
|
||||
// Rebuild FTS5 index for any existing records
|
||||
sqlx::query("INSERT INTO memory_fts(memory_fts) VALUES ('rebuild')")
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
// Migration: add last_consolidated_at column if not exists
|
||||
sqlx::query(
|
||||
"INSERT INTO memory_fts(memory_fts) VALUES ('rebuild')",
|
||||
r#"
|
||||
ALTER TABLE sessions ADD COLUMN archived_at INTEGER
|
||||
"#,
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
.await
|
||||
.ok();
|
||||
|
||||
// Migration: add last_consolidated_at column if not exists
|
||||
sqlx::query(
|
||||
@ -216,7 +265,10 @@ impl Storage {
|
||||
.await?;
|
||||
|
||||
if let Err(e) = Self::init_scheduler_schema(&self.pool).await {
|
||||
tracing::warn!("Failed to init scheduler schema (tables may already exist): {}", e);
|
||||
tracing::warn!(
|
||||
"Failed to init scheduler schema (tables may already exist): {}",
|
||||
e
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@ -330,16 +382,20 @@ impl Storage {
|
||||
&self.pool
|
||||
}
|
||||
|
||||
pub async fn upsert_session(&self, meta: &crate::storage::session::SessionMeta) -> Result<(), StorageError> {
|
||||
pub async fn upsert_session(
|
||||
&self,
|
||||
meta: &crate::storage::session::SessionMeta,
|
||||
) -> Result<(), StorageError> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO sessions (id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO sessions (id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
title = excluded.title,
|
||||
last_active_at = excluded.last_active_at,
|
||||
message_count = excluded.message_count,
|
||||
routing_info = excluded.routing_info,
|
||||
archived_at = excluded.archived_at,
|
||||
deleted_at = excluded.deleted_at,
|
||||
last_consolidated_at = excluded.last_consolidated_at,
|
||||
last_compressed_message_at = excluded.last_compressed_message_at
|
||||
@ -354,6 +410,7 @@ impl Storage {
|
||||
.bind(meta.last_active_at)
|
||||
.bind(meta.message_count)
|
||||
.bind(&meta.routing_info)
|
||||
.bind(meta.archived_at)
|
||||
.bind(meta.deleted_at)
|
||||
.bind(meta.last_consolidated_at)
|
||||
.bind(meta.last_compressed_message_at)
|
||||
@ -363,10 +420,13 @@ impl Storage {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_session(&self, id: &str) -> Result<crate::storage::session::SessionMeta, StorageError> {
|
||||
pub async fn get_session(
|
||||
&self,
|
||||
id: &str,
|
||||
) -> Result<crate::storage::session::SessionMeta, StorageError> {
|
||||
let row = sqlx::query(
|
||||
r#"
|
||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at
|
||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at
|
||||
FROM sessions WHERE id = ? AND deleted_at IS NULL
|
||||
"#,
|
||||
)
|
||||
@ -385,6 +445,7 @@ impl Storage {
|
||||
last_active_at: row.get("last_active_at"),
|
||||
message_count: row.get("message_count"),
|
||||
routing_info: row.get("routing_info"),
|
||||
archived_at: row.get("archived_at"),
|
||||
deleted_at: row.get("deleted_at"),
|
||||
last_consolidated_at: row.get("last_consolidated_at"),
|
||||
last_compressed_message_at: row.get("last_compressed_message_at"),
|
||||
@ -396,18 +457,21 @@ impl Storage {
|
||||
channel: &str,
|
||||
chat_id: &str,
|
||||
limit: i64,
|
||||
include_archived: bool,
|
||||
) -> Result<Vec<crate::storage::session::SessionMeta>, StorageError> {
|
||||
let rows = sqlx::query(
|
||||
r#"
|
||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at
|
||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at
|
||||
FROM sessions
|
||||
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL
|
||||
AND (? OR archived_at IS NULL)
|
||||
ORDER BY last_active_at DESC
|
||||
LIMIT ?
|
||||
"#,
|
||||
)
|
||||
.bind(channel)
|
||||
.bind(chat_id)
|
||||
.bind(include_archived)
|
||||
.bind(limit)
|
||||
.fetch_all(self.pool())
|
||||
.await?;
|
||||
@ -424,6 +488,7 @@ impl Storage {
|
||||
last_active_at: row.get("last_active_at"),
|
||||
message_count: row.get("message_count"),
|
||||
routing_info: row.get("routing_info"),
|
||||
archived_at: row.get("archived_at"),
|
||||
deleted_at: row.get("deleted_at"),
|
||||
last_consolidated_at: row.get("last_consolidated_at"),
|
||||
last_compressed_message_at: row.get("last_compressed_message_at"),
|
||||
@ -454,13 +519,22 @@ impl Storage {
|
||||
|
||||
pub async fn soft_delete_session(&self, id: &str) -> Result<(), StorageError> {
|
||||
let now = chrono::Utc::now().timestamp_millis();
|
||||
sqlx::query(
|
||||
r#"UPDATE sessions SET deleted_at = ? WHERE id = ?"#,
|
||||
)
|
||||
.bind(now)
|
||||
.bind(id)
|
||||
.execute(self.pool())
|
||||
.await?;
|
||||
sqlx::query(r#"UPDATE sessions SET deleted_at = ? WHERE id = ?"#)
|
||||
.bind(now)
|
||||
.bind(id)
|
||||
.execute(self.pool())
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn archive_session(&self, id: &str) -> Result<(), StorageError> {
|
||||
let now = chrono::Utc::now().timestamp_millis();
|
||||
sqlx::query(r#"UPDATE sessions SET archived_at = ? WHERE id = ? AND deleted_at IS NULL"#)
|
||||
.bind(now)
|
||||
.bind(id)
|
||||
.execute(self.pool())
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -472,9 +546,9 @@ impl Storage {
|
||||
) -> Result<Option<crate::storage::session::SessionMeta>, StorageError> {
|
||||
let row = sqlx::query(
|
||||
r#"
|
||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at
|
||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at
|
||||
FROM sessions
|
||||
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL
|
||||
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL AND archived_at IS NULL
|
||||
ORDER BY last_active_at DESC
|
||||
LIMIT 1
|
||||
"#,
|
||||
@ -495,6 +569,7 @@ impl Storage {
|
||||
last_active_at: row.get("last_active_at"),
|
||||
message_count: row.get("message_count"),
|
||||
routing_info: row.get("routing_info"),
|
||||
archived_at: row.get("archived_at"),
|
||||
deleted_at: row.get("deleted_at"),
|
||||
last_consolidated_at: row.get("last_consolidated_at"),
|
||||
last_compressed_message_at: row.get("last_compressed_message_at"),
|
||||
@ -503,7 +578,11 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn append_message(&self, session_id: &str, msg: &crate::storage::message::MessageMeta) -> Result<i64, StorageError> {
|
||||
pub async fn append_message(
|
||||
&self,
|
||||
session_id: &str,
|
||||
msg: &crate::storage::message::MessageMeta,
|
||||
) -> Result<i64, StorageError> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO messages (id, session_id, seq, role, content, reasoning_content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at)
|
||||
@ -630,16 +709,15 @@ impl Storage {
|
||||
offset: i64,
|
||||
limit: i64,
|
||||
) -> Result<(Vec<crate::storage::session::SessionMeta>, i64), StorageError> {
|
||||
let count_row = sqlx::query(
|
||||
"SELECT COUNT(*) as total FROM sessions WHERE deleted_at IS NULL",
|
||||
)
|
||||
.fetch_one(self.pool())
|
||||
.await?;
|
||||
let count_row =
|
||||
sqlx::query("SELECT COUNT(*) as total FROM sessions WHERE deleted_at IS NULL")
|
||||
.fetch_one(self.pool())
|
||||
.await?;
|
||||
let total: i64 = count_row.get("total");
|
||||
|
||||
let rows = sqlx::query(
|
||||
r#"
|
||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at
|
||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at
|
||||
FROM sessions
|
||||
WHERE deleted_at IS NULL
|
||||
ORDER BY last_active_at DESC
|
||||
@ -663,6 +741,7 @@ impl Storage {
|
||||
last_active_at: row.get("last_active_at"),
|
||||
message_count: row.get("message_count"),
|
||||
routing_info: row.get("routing_info"),
|
||||
archived_at: row.get("archived_at"),
|
||||
deleted_at: row.get("deleted_at"),
|
||||
last_consolidated_at: row.get("last_consolidated_at"),
|
||||
last_compressed_message_at: row.get("last_compressed_message_at"),
|
||||
@ -728,7 +807,10 @@ impl Storage {
|
||||
where_extra.push_str(" AND created_at > ?");
|
||||
}
|
||||
|
||||
let count_sql = format!("SELECT COUNT(*) as total FROM messages WHERE session_id = ?{}", where_extra);
|
||||
let count_sql = format!(
|
||||
"SELECT COUNT(*) as total FROM messages WHERE session_id = ?{}",
|
||||
where_extra
|
||||
);
|
||||
let select_sql = format!(
|
||||
r#"
|
||||
SELECT id, session_id, seq, role, content, reasoning_content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at
|
||||
@ -816,6 +898,148 @@ impl Storage {
|
||||
}
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
// ── Background Task CRUD ──
|
||||
|
||||
pub async fn create_background_task(
|
||||
&self,
|
||||
task: &crate::storage::background_task::BackgroundTask,
|
||||
) -> Result<(), StorageError> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO background_tasks (id, session_id, channel, chat_id, prompt, allowed_tools, status, result, error, tool_calls_count, iterations, started_at, finished_at, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"#,
|
||||
)
|
||||
.bind(&task.id)
|
||||
.bind(&task.session_id)
|
||||
.bind(&task.channel)
|
||||
.bind(&task.chat_id)
|
||||
.bind(&task.prompt)
|
||||
.bind(&task.allowed_tools)
|
||||
.bind(&task.status)
|
||||
.bind(&task.result)
|
||||
.bind(&task.error)
|
||||
.bind(task.tool_calls_count)
|
||||
.bind(task.iterations)
|
||||
.bind(task.started_at)
|
||||
.bind(task.finished_at)
|
||||
.bind(task.created_at)
|
||||
.execute(self.pool())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn update_background_task_status(
|
||||
&self,
|
||||
id: &str,
|
||||
status: &str,
|
||||
result: Option<&str>,
|
||||
error: Option<&str>,
|
||||
started_at: Option<i64>,
|
||||
finished_at: Option<i64>,
|
||||
) -> Result<(), StorageError> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE background_tasks
|
||||
SET status = ?, result = COALESCE(?, result), error = COALESCE(?, error),
|
||||
started_at = COALESCE(?, started_at), finished_at = COALESCE(?, finished_at)
|
||||
WHERE id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(status)
|
||||
.bind(result)
|
||||
.bind(error)
|
||||
.bind(started_at)
|
||||
.bind(finished_at)
|
||||
.bind(id)
|
||||
.execute(self.pool())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_background_task(
|
||||
&self,
|
||||
id: &str,
|
||||
) -> Result<crate::storage::background_task::BackgroundTask, StorageError> {
|
||||
let row = sqlx::query(
|
||||
r#"
|
||||
SELECT id, session_id, channel, chat_id, prompt, allowed_tools, status, result, error,
|
||||
tool_calls_count, iterations, started_at, finished_at, created_at
|
||||
FROM background_tasks WHERE id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_optional(self.pool())
|
||||
.await?
|
||||
.ok_or_else(|| StorageError::NotFound(id.to_string()))?;
|
||||
|
||||
Ok(crate::storage::background_task::BackgroundTask {
|
||||
id: row.get("id"),
|
||||
session_id: row.get("session_id"),
|
||||
channel: row.get("channel"),
|
||||
chat_id: row.get("chat_id"),
|
||||
prompt: row.get("prompt"),
|
||||
allowed_tools: row.get("allowed_tools"),
|
||||
status: row.get("status"),
|
||||
result: row.get("result"),
|
||||
error: row.get("error"),
|
||||
tool_calls_count: row.get("tool_calls_count"),
|
||||
iterations: row.get("iterations"),
|
||||
started_at: row.get("started_at"),
|
||||
finished_at: row.get("finished_at"),
|
||||
created_at: row.get("created_at"),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn list_background_tasks(
|
||||
&self,
|
||||
session_id: &str,
|
||||
) -> Result<Vec<crate::storage::background_task::BackgroundTask>, StorageError> {
|
||||
let rows = sqlx::query(
|
||||
r#"
|
||||
SELECT id, session_id, channel, chat_id, prompt, allowed_tools, status, result, error,
|
||||
tool_calls_count, iterations, started_at, finished_at, created_at
|
||||
FROM background_tasks
|
||||
WHERE session_id = ?
|
||||
ORDER BY created_at DESC
|
||||
"#,
|
||||
)
|
||||
.bind(session_id)
|
||||
.fetch_all(self.pool())
|
||||
.await?;
|
||||
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.map(|row| crate::storage::background_task::BackgroundTask {
|
||||
id: row.get("id"),
|
||||
session_id: row.get("session_id"),
|
||||
channel: row.get("channel"),
|
||||
chat_id: row.get("chat_id"),
|
||||
prompt: row.get("prompt"),
|
||||
allowed_tools: row.get("allowed_tools"),
|
||||
status: row.get("status"),
|
||||
result: row.get("result"),
|
||||
error: row.get("error"),
|
||||
tool_calls_count: row.get("tool_calls_count"),
|
||||
iterations: row.get("iterations"),
|
||||
started_at: row.get("started_at"),
|
||||
finished_at: row.get("finished_at"),
|
||||
created_at: row.get("created_at"),
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub async fn cleanup_old_tasks(&self, ttl_ms: i64) -> Result<usize, StorageError> {
|
||||
let cutoff = chrono::Utc::now().timestamp_millis() - ttl_ms;
|
||||
let result = sqlx::query(
|
||||
"DELETE FROM background_tasks WHERE status IN ('completed', 'failed', 'cancelled') AND finished_at IS NOT NULL AND finished_at < ?",
|
||||
)
|
||||
.bind(cutoff)
|
||||
.execute(self.pool())
|
||||
.await?;
|
||||
Ok(result.rows_affected() as usize)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@ -844,6 +1068,7 @@ mod tests {
|
||||
last_active_at: 1000,
|
||||
message_count: 0,
|
||||
routing_info: Some(r#"{"type":"cli"}"#.to_string()),
|
||||
archived_at: None,
|
||||
deleted_at: None,
|
||||
last_consolidated_at: None,
|
||||
last_compressed_message_at: None,
|
||||
@ -880,14 +1105,18 @@ mod tests {
|
||||
last_active_at: i as i64 * 1000,
|
||||
message_count: i,
|
||||
routing_info: None,
|
||||
archived_at: None,
|
||||
deleted_at: None,
|
||||
last_consolidated_at: None,
|
||||
last_compressed_message_at: None,
|
||||
last_consolidated_at: None,
|
||||
last_compressed_message_at: None,
|
||||
};
|
||||
storage.upsert_session(&meta).await.unwrap();
|
||||
}
|
||||
|
||||
let sessions = storage.list_sessions("cli_chat", "sid123", 10).await.unwrap();
|
||||
let sessions = storage
|
||||
.list_sessions("cli_chat", "sid123", 10, false)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(sessions.len(), 5);
|
||||
// 按 last_active_at DESC 排序
|
||||
assert_eq!(sessions[0].dialog_id, "dialog4");
|
||||
@ -907,6 +1136,7 @@ mod tests {
|
||||
last_active_at: 1000,
|
||||
message_count: 0,
|
||||
routing_info: None,
|
||||
archived_at: None,
|
||||
deleted_at: None,
|
||||
last_consolidated_at: None,
|
||||
last_compressed_message_at: None,
|
||||
@ -934,6 +1164,7 @@ mod tests {
|
||||
last_active_at: 1000,
|
||||
message_count: 0,
|
||||
routing_info: None,
|
||||
archived_at: None,
|
||||
deleted_at: None,
|
||||
last_consolidated_at: None,
|
||||
last_compressed_message_at: None,
|
||||
@ -955,7 +1186,10 @@ mod tests {
|
||||
created_at: 1000,
|
||||
};
|
||||
|
||||
let seq = storage.append_message(&session_meta.id, &msg).await.unwrap();
|
||||
let seq = storage
|
||||
.append_message(&session_meta.id, &msg)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(seq, 1);
|
||||
|
||||
let loaded = storage.load_messages(&session_meta.id, 0).await.unwrap();
|
||||
@ -977,6 +1211,7 @@ mod tests {
|
||||
last_active_at: 1000,
|
||||
message_count: 0,
|
||||
routing_info: None,
|
||||
archived_at: None,
|
||||
deleted_at: None,
|
||||
last_consolidated_at: None,
|
||||
last_compressed_message_at: None,
|
||||
|
||||
@ -165,7 +165,11 @@ impl crate::storage::Storage {
|
||||
}
|
||||
|
||||
/// Update next_run_at and last_run_at for a job.
|
||||
pub async fn set_scheduled_job_next_run(&self, id: &str, next_run_at: i64) -> anyhow::Result<()> {
|
||||
pub async fn set_scheduled_job_next_run(
|
||||
&self,
|
||||
id: &str,
|
||||
next_run_at: i64,
|
||||
) -> anyhow::Result<()> {
|
||||
let now = now_ms();
|
||||
sqlx::query(
|
||||
"UPDATE scheduled_jobs SET next_run_at = ?, last_run_at = ?, updated_at = ? WHERE id = ?",
|
||||
@ -331,7 +335,9 @@ mod tests {
|
||||
async fn setup_storage() -> Storage {
|
||||
let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
|
||||
let storage = Storage { pool };
|
||||
Storage::init_scheduler_schema(storage.pool()).await.unwrap();
|
||||
Storage::init_scheduler_schema(storage.pool())
|
||||
.await
|
||||
.unwrap();
|
||||
storage
|
||||
}
|
||||
|
||||
@ -450,7 +456,10 @@ mod tests {
|
||||
updated_at: t,
|
||||
};
|
||||
storage.add_scheduled_job(&job).await.unwrap();
|
||||
storage.set_scheduled_job_enabled("job-toggle", false).await.unwrap();
|
||||
storage
|
||||
.set_scheduled_job_enabled("job-toggle", false)
|
||||
.await
|
||||
.unwrap();
|
||||
let got = storage.get_scheduled_job("job-toggle").await.unwrap();
|
||||
assert!(!got.enabled);
|
||||
}
|
||||
@ -461,31 +470,55 @@ mod tests {
|
||||
let t = now();
|
||||
let jobs = vec![
|
||||
ScheduledJob {
|
||||
id: "due".into(), name: "due".into(),
|
||||
schedule: Schedule::At { at: t }, prompt: "1".into(),
|
||||
channel: "cli_chat".into(), chat_id: "c".into(),
|
||||
model: None, enabled: true, delete_after_run: false,
|
||||
next_run_at: t - 1000, last_run_at: None,
|
||||
last_status: None, last_error: None,
|
||||
created_at: t, updated_at: t,
|
||||
id: "due".into(),
|
||||
name: "due".into(),
|
||||
schedule: Schedule::At { at: t },
|
||||
prompt: "1".into(),
|
||||
channel: "cli_chat".into(),
|
||||
chat_id: "c".into(),
|
||||
model: None,
|
||||
enabled: true,
|
||||
delete_after_run: false,
|
||||
next_run_at: t - 1000,
|
||||
last_run_at: None,
|
||||
last_status: None,
|
||||
last_error: None,
|
||||
created_at: t,
|
||||
updated_at: t,
|
||||
},
|
||||
ScheduledJob {
|
||||
id: "future".into(), name: "future".into(),
|
||||
schedule: Schedule::At { at: t + 99999999 }, prompt: "2".into(),
|
||||
channel: "cli_chat".into(), chat_id: "c".into(),
|
||||
model: None, enabled: true, delete_after_run: false,
|
||||
next_run_at: t + 99999999, last_run_at: None,
|
||||
last_status: None, last_error: None,
|
||||
created_at: t, updated_at: t,
|
||||
id: "future".into(),
|
||||
name: "future".into(),
|
||||
schedule: Schedule::At { at: t + 99999999 },
|
||||
prompt: "2".into(),
|
||||
channel: "cli_chat".into(),
|
||||
chat_id: "c".into(),
|
||||
model: None,
|
||||
enabled: true,
|
||||
delete_after_run: false,
|
||||
next_run_at: t + 99999999,
|
||||
last_run_at: None,
|
||||
last_status: None,
|
||||
last_error: None,
|
||||
created_at: t,
|
||||
updated_at: t,
|
||||
},
|
||||
ScheduledJob {
|
||||
id: "disabled-due".into(), name: "disabled due".into(),
|
||||
schedule: Schedule::At { at: t }, prompt: "3".into(),
|
||||
channel: "cli_chat".into(), chat_id: "c".into(),
|
||||
model: None, enabled: false, delete_after_run: false,
|
||||
next_run_at: t - 1000, last_run_at: None,
|
||||
last_status: None, last_error: None,
|
||||
created_at: t, updated_at: t,
|
||||
id: "disabled-due".into(),
|
||||
name: "disabled due".into(),
|
||||
schedule: Schedule::At { at: t },
|
||||
prompt: "3".into(),
|
||||
channel: "cli_chat".into(),
|
||||
chat_id: "c".into(),
|
||||
model: None,
|
||||
enabled: false,
|
||||
delete_after_run: false,
|
||||
next_run_at: t - 1000,
|
||||
last_run_at: None,
|
||||
last_status: None,
|
||||
last_error: None,
|
||||
created_at: t,
|
||||
updated_at: t,
|
||||
},
|
||||
];
|
||||
for j in &jobs {
|
||||
@ -501,24 +534,39 @@ mod tests {
|
||||
let storage = setup_storage().await;
|
||||
let t = now();
|
||||
let job = ScheduledJob {
|
||||
id: "job-run".into(), name: "run test".into(),
|
||||
id: "job-run".into(),
|
||||
name: "run test".into(),
|
||||
schedule: Schedule::Every { every_ms: 1000 },
|
||||
prompt: "hi".into(), channel: "cli_chat".into(), chat_id: "c".into(),
|
||||
model: None, enabled: true, delete_after_run: false,
|
||||
next_run_at: t, last_run_at: None,
|
||||
last_status: None, last_error: None,
|
||||
created_at: t, updated_at: t,
|
||||
prompt: "hi".into(),
|
||||
channel: "cli_chat".into(),
|
||||
chat_id: "c".into(),
|
||||
model: None,
|
||||
enabled: true,
|
||||
delete_after_run: false,
|
||||
next_run_at: t,
|
||||
last_run_at: None,
|
||||
last_status: None,
|
||||
last_error: None,
|
||||
created_at: t,
|
||||
updated_at: t,
|
||||
};
|
||||
storage.add_scheduled_job(&job).await.unwrap();
|
||||
|
||||
let run = super::JobRun {
|
||||
id: 0, job_id: "job-run".into(),
|
||||
started_at: t, finished_at: t + 500,
|
||||
status: "ok".into(), output: Some("hello".into()),
|
||||
error: None, duration_ms: 500,
|
||||
id: 0,
|
||||
job_id: "job-run".into(),
|
||||
started_at: t,
|
||||
finished_at: t + 500,
|
||||
status: "ok".into(),
|
||||
output: Some("hello".into()),
|
||||
error: None,
|
||||
duration_ms: 500,
|
||||
};
|
||||
storage.record_scheduled_job_run(&run).await.unwrap();
|
||||
let runs = storage.list_scheduled_job_runs("job-run", 10).await.unwrap();
|
||||
let runs = storage
|
||||
.list_scheduled_job_runs("job-run", 10)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(runs.len(), 1);
|
||||
assert_eq!(runs[0].status, "ok");
|
||||
assert_eq!(runs[0].output.as_deref(), Some("hello"));
|
||||
@ -529,22 +577,34 @@ mod tests {
|
||||
let storage = setup_storage().await;
|
||||
let t = now();
|
||||
let job = ScheduledJob {
|
||||
id: "job-update".into(), name: "old name".into(),
|
||||
id: "job-update".into(),
|
||||
name: "old name".into(),
|
||||
schedule: Schedule::Every { every_ms: 1000 },
|
||||
prompt: "old prompt".into(), channel: "feishu".into(),
|
||||
chat_id: "oc_1".into(), model: None,
|
||||
enabled: true, delete_after_run: false,
|
||||
next_run_at: t, last_run_at: None,
|
||||
last_status: None, last_error: None,
|
||||
created_at: t, updated_at: t,
|
||||
prompt: "old prompt".into(),
|
||||
channel: "feishu".into(),
|
||||
chat_id: "oc_1".into(),
|
||||
model: None,
|
||||
enabled: true,
|
||||
delete_after_run: false,
|
||||
next_run_at: t,
|
||||
last_run_at: None,
|
||||
last_status: None,
|
||||
last_error: None,
|
||||
created_at: t,
|
||||
updated_at: t,
|
||||
};
|
||||
storage.add_scheduled_job(&job).await.unwrap();
|
||||
storage.update_scheduled_job(
|
||||
"job-update",
|
||||
Some("new prompt".into()),
|
||||
Some(Schedule::Every { every_ms: 60000 }),
|
||||
None, None, None,
|
||||
).await.unwrap();
|
||||
storage
|
||||
.update_scheduled_job(
|
||||
"job-update",
|
||||
Some("new prompt".into()),
|
||||
Some(Schedule::Every { every_ms: 60000 }),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let got = storage.get_scheduled_job("job-update").await.unwrap();
|
||||
assert_eq!(got.prompt, "new prompt");
|
||||
}
|
||||
|
||||
@ -11,6 +11,7 @@ pub struct SessionMeta {
|
||||
pub last_active_at: i64,
|
||||
pub message_count: i64,
|
||||
pub routing_info: Option<String>,
|
||||
pub archived_at: Option<i64>,
|
||||
pub deleted_at: Option<i64>,
|
||||
pub last_consolidated_at: Option<i64>,
|
||||
pub last_compressed_message_at: Option<i64>,
|
||||
|
||||
@ -167,10 +167,7 @@ impl Tool for BashTool {
|
||||
Err(_) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Command timed out after {} seconds",
|
||||
timeout_secs
|
||||
)),
|
||||
error: Some(format!("Command timed out after {} seconds", timeout_secs)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
@ -249,10 +246,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_pwd_command() {
|
||||
let tool = BashTool::new();
|
||||
let result = tool
|
||||
.execute(json!({ "command": "pwd" }))
|
||||
.await
|
||||
.unwrap();
|
||||
let result = tool.execute(json!({ "command": "pwd" })).await.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
}
|
||||
@ -260,7 +254,10 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_ls_command() {
|
||||
let tool = BashTool::new();
|
||||
let result = tool.execute(json!({ "command": "ls -la /tmp" })).await.unwrap();
|
||||
let result = tool
|
||||
.execute(json!({ "command": "ls -la /tmp" }))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
}
|
||||
|
||||
@ -5,7 +5,7 @@ use std::time::Duration;
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use base64::Engine;
|
||||
use fantoccini::actions::{InputSource, MouseActions, PointerAction, MOUSE_BUTTON_LEFT};
|
||||
use fantoccini::actions::{InputSource, MOUSE_BUTTON_LEFT, MouseActions, PointerAction};
|
||||
use fantoccini::key::Key;
|
||||
use fantoccini::{Client, ClientBuilder, Locator};
|
||||
use serde::{Deserialize, Serialize};
|
||||
@ -63,7 +63,9 @@ impl BrowserTool {
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum BrowserAction {
|
||||
Open { url: String },
|
||||
Open {
|
||||
url: String,
|
||||
},
|
||||
Snapshot {
|
||||
#[serde(default)]
|
||||
interactive_only: bool,
|
||||
@ -72,10 +74,20 @@ pub enum BrowserAction {
|
||||
#[serde(default)]
|
||||
depth: Option<i64>,
|
||||
},
|
||||
Click { selector: String },
|
||||
Fill { selector: String, value: String },
|
||||
Type { selector: Option<String>, text: String },
|
||||
GetText { selector: String },
|
||||
Click {
|
||||
selector: String,
|
||||
},
|
||||
Fill {
|
||||
selector: String,
|
||||
value: String,
|
||||
},
|
||||
Type {
|
||||
selector: Option<String>,
|
||||
text: String,
|
||||
},
|
||||
GetText {
|
||||
selector: String,
|
||||
},
|
||||
GetTitle,
|
||||
GetUrl,
|
||||
Screenshot {
|
||||
@ -84,7 +96,9 @@ pub enum BrowserAction {
|
||||
#[serde(default)]
|
||||
return_base64: bool,
|
||||
},
|
||||
Focus { selector: String },
|
||||
Focus {
|
||||
selector: String,
|
||||
},
|
||||
Wait {
|
||||
#[serde(default)]
|
||||
selector: Option<String>,
|
||||
@ -93,9 +107,16 @@ pub enum BrowserAction {
|
||||
#[serde(default)]
|
||||
text: Option<String>,
|
||||
},
|
||||
Press { key: String },
|
||||
Hover { selector: String },
|
||||
ClickAt { x: u32, y: u32 },
|
||||
Press {
|
||||
key: String,
|
||||
},
|
||||
Hover {
|
||||
selector: String,
|
||||
},
|
||||
ClickAt {
|
||||
x: u32,
|
||||
y: u32,
|
||||
},
|
||||
Scroll {
|
||||
direction: String,
|
||||
#[serde(default)]
|
||||
@ -120,13 +141,8 @@ fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result<Browse
|
||||
.get("interactive_only")
|
||||
.and_then(Value::as_bool)
|
||||
.unwrap_or(true),
|
||||
compact: args
|
||||
.get("compact")
|
||||
.and_then(Value::as_bool)
|
||||
.unwrap_or(true),
|
||||
depth: args
|
||||
.get("depth")
|
||||
.and_then(|v| v.as_i64()),
|
||||
compact: args.get("compact").and_then(Value::as_bool).unwrap_or(true),
|
||||
depth: args.get("depth").and_then(|v| v.as_i64()),
|
||||
}),
|
||||
"click" => {
|
||||
let selector = args
|
||||
@ -198,10 +214,7 @@ fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result<Browse
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from),
|
||||
ms: args.get("ms").and_then(|v| v.as_u64()),
|
||||
text: args
|
||||
.get("text")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from),
|
||||
text: args.get("text").and_then(|v| v.as_str()).map(String::from),
|
||||
}),
|
||||
"press" => {
|
||||
let key = args
|
||||
@ -239,11 +252,13 @@ fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result<Browse
|
||||
let x = args
|
||||
.get("x")
|
||||
.and_then(|v| v.as_u64())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'x' for click_at"))? as u32;
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'x' for click_at"))?
|
||||
as u32;
|
||||
let y = args
|
||||
.get("y")
|
||||
.and_then(|v| v.as_u64())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'y' for click_at"))? as u32;
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'y' for click_at"))?
|
||||
as u32;
|
||||
Ok(BrowserAction::ClickAt { x, y })
|
||||
}
|
||||
other => anyhow::bail!("Unsupported browser action: {}", other),
|
||||
@ -488,7 +503,11 @@ impl BrowserState {
|
||||
}
|
||||
Err(e) => return Err(e.into()),
|
||||
}
|
||||
tracing::debug!(action = "fill", output_len = value.len(), "Browser action completed");
|
||||
tracing::debug!(
|
||||
action = "fill",
|
||||
output_len = value.len(),
|
||||
"Browser action completed"
|
||||
);
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Filled {} with {}", selector, value),
|
||||
@ -573,7 +592,10 @@ impl BrowserState {
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
BrowserAction::Screenshot { path, return_base64 } => {
|
||||
BrowserAction::Screenshot {
|
||||
path,
|
||||
return_base64,
|
||||
} => {
|
||||
let client = self.active_client()?;
|
||||
let png = client.screenshot().await?;
|
||||
let save_path = path.unwrap_or_else(|| {
|
||||
@ -588,14 +610,25 @@ impl BrowserState {
|
||||
tokio::fs::write(&save_path, &png).await?;
|
||||
if return_base64 {
|
||||
let b64 = base64::engine::general_purpose::STANDARD.encode(&png);
|
||||
tracing::debug!(action = "screenshot", output_len = b64.len(), "Browser action completed");
|
||||
tracing::debug!(
|
||||
action = "screenshot",
|
||||
output_len = b64.len(),
|
||||
"Browser action completed"
|
||||
);
|
||||
return Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Screenshot saved to {}. Base64: data:image/png;base64,{}", save_path, b64),
|
||||
output: format!(
|
||||
"Screenshot saved to {}. Base64: data:image/png;base64,{}",
|
||||
save_path, b64
|
||||
),
|
||||
error: None,
|
||||
});
|
||||
}
|
||||
tracing::debug!(action = "screenshot", output_len = save_path.len(), "Browser action completed");
|
||||
tracing::debug!(
|
||||
action = "screenshot",
|
||||
output_len = save_path.len(),
|
||||
"Browser action completed"
|
||||
);
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Screenshot saved to {}", save_path),
|
||||
@ -611,18 +644,18 @@ impl BrowserState {
|
||||
vec![serde_json::to_value(el)?],
|
||||
)
|
||||
.await?;
|
||||
tracing::debug!(action = "focus", output_len = selector.len(), "Browser action completed");
|
||||
tracing::debug!(
|
||||
action = "focus",
|
||||
output_len = selector.len(),
|
||||
"Browser action completed"
|
||||
);
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Focused {}", selector),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
BrowserAction::Wait {
|
||||
selector,
|
||||
ms,
|
||||
text,
|
||||
} => {
|
||||
BrowserAction::Wait { selector, ms, text } => {
|
||||
if let Some(sel) = selector {
|
||||
let client = self.active_client()?;
|
||||
wait_for_selector(client, &sel).await?;
|
||||
@ -719,9 +752,21 @@ impl BrowserState {
|
||||
let id = info.get("id").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let el_type = info.get("type").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let text = info.get("text").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let id_str = if id.is_empty() { String::new() } else { format!("#{id}") };
|
||||
let type_str = if el_type.is_empty() { String::new() } else { format!("[type={el_type}]") };
|
||||
let text_str = if text.is_empty() { String::new() } else { format!(" ({text})") };
|
||||
let id_str = if id.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!("#{id}")
|
||||
};
|
||||
let type_str = if el_type.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!("[type={el_type}]")
|
||||
};
|
||||
let text_str = if text.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!(" ({text})")
|
||||
};
|
||||
format!("Clicked at ({x},{y}) on <{tag}{id_str}{type_str}>{text_str}")
|
||||
}
|
||||
None => format!("Clicked at ({}, {})", x, y),
|
||||
@ -1090,10 +1135,7 @@ fn css_attr_escape(input: &str) -> String {
|
||||
}
|
||||
|
||||
fn xpath_contains_text(text: &str) -> String {
|
||||
format!(
|
||||
"//*[contains(normalize-space(.), {})]",
|
||||
xpath_literal(text)
|
||||
)
|
||||
format!("//*[contains(normalize-space(.), {})]", xpath_literal(text))
|
||||
}
|
||||
|
||||
fn xpath_literal(input: &str) -> String {
|
||||
@ -1140,7 +1182,10 @@ fn webdriver_key(key: &str) -> String {
|
||||
"pagedown" => Key::PageDown.to_string(),
|
||||
"space" => " ".to_string(),
|
||||
other => {
|
||||
tracing::warn!("Unrecognized key '{}', this will have no effect (press only supports single named keys)", other);
|
||||
tracing::warn!(
|
||||
"Unrecognized key '{}', this will have no effect (press only supports single named keys)",
|
||||
other
|
||||
);
|
||||
other.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
@ -659,10 +659,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_evaluate_missing_expression() {
|
||||
let tool = CalculatorTool::new();
|
||||
let result = tool
|
||||
.execute(json!({"function": "evaluate"}))
|
||||
.await
|
||||
.unwrap();
|
||||
let result = tool.execute(json!({"function": "evaluate"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
}
|
||||
|
||||
|
||||
@ -126,7 +126,10 @@ impl ChatManagerTool {
|
||||
let start_num = offset + 1;
|
||||
let end_num = offset + sessions.len() as i64;
|
||||
|
||||
let mut output = format!("全部会话 (共 {} 个,第 {}-{} 个):\n", total, start_num, end_num);
|
||||
let mut output = format!(
|
||||
"全部会话 (共 {} 个,第 {}-{} 个):\n",
|
||||
total, start_num, end_num
|
||||
);
|
||||
|
||||
for s in &sessions {
|
||||
let ago = format_duration_ago(now_ms - s.last_active_at);
|
||||
@ -300,9 +303,10 @@ mod tests {
|
||||
last_active_at: now - i * 3600_000,
|
||||
message_count: i * 5,
|
||||
routing_info: None,
|
||||
archived_at: None,
|
||||
deleted_at: None,
|
||||
last_consolidated_at: None,
|
||||
last_compressed_message_at: None,
|
||||
last_consolidated_at: None,
|
||||
last_compressed_message_at: None,
|
||||
};
|
||||
storage.upsert_session(&meta).await.unwrap();
|
||||
}
|
||||
@ -335,6 +339,7 @@ mod tests {
|
||||
last_active_at: now,
|
||||
message_count: 3,
|
||||
routing_info: None,
|
||||
archived_at: None,
|
||||
deleted_at: None,
|
||||
last_consolidated_at: None,
|
||||
last_compressed_message_at: None,
|
||||
@ -346,7 +351,11 @@ mod tests {
|
||||
id: format!("msg{}", i),
|
||||
session_id: session_id.to_string(),
|
||||
seq: i as i64 + 1,
|
||||
role: if i == 0 { "user".to_string() } else { "assistant".to_string() },
|
||||
role: if i == 0 {
|
||||
"user".to_string()
|
||||
} else {
|
||||
"assistant".to_string()
|
||||
},
|
||||
content: format!("消息内容 {}", i),
|
||||
reasoning_content: None,
|
||||
media_refs: None,
|
||||
@ -392,6 +401,7 @@ mod tests {
|
||||
last_active_at: now,
|
||||
message_count: 5,
|
||||
routing_info: None,
|
||||
archived_at: None,
|
||||
deleted_at: None,
|
||||
last_consolidated_at: None,
|
||||
last_compressed_message_at: None,
|
||||
@ -403,7 +413,11 @@ mod tests {
|
||||
id: format!("msg{}", i),
|
||||
session_id: session_id.to_string(),
|
||||
seq: i as i64 + 1,
|
||||
role: if i % 2 == 0 { "user".to_string() } else { "assistant".to_string() },
|
||||
role: if i % 2 == 0 {
|
||||
"user".to_string()
|
||||
} else {
|
||||
"assistant".to_string()
|
||||
},
|
||||
content: format!("消息内容 {}", i),
|
||||
reasoning_content: None,
|
||||
media_refs: None,
|
||||
@ -447,6 +461,7 @@ mod tests {
|
||||
last_active_at: now,
|
||||
message_count: 5,
|
||||
routing_info: None,
|
||||
archived_at: None,
|
||||
deleted_at: None,
|
||||
last_consolidated_at: None,
|
||||
last_compressed_message_at: None,
|
||||
@ -492,10 +507,7 @@ mod tests {
|
||||
let (storage, _dir) = create_test_storage().await;
|
||||
let tool = ChatManagerTool::new(storage, vec![]);
|
||||
|
||||
let result = tool
|
||||
.execute(json!({ "action": "unknown" }))
|
||||
.await
|
||||
.unwrap();
|
||||
let result = tool.execute(json!({ "action": "unknown" })).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("Unknown action"));
|
||||
}
|
||||
|
||||
@ -31,10 +31,7 @@ impl ContentSearchTool {
|
||||
for (i, line) in lines.iter().enumerate() {
|
||||
if output.len() + line.len() + 1 > MAX_OUTPUT_CHARS {
|
||||
let omitted = lines.len() - i;
|
||||
output.push_str(&format!(
|
||||
"\n... ({} matches omitted) ...",
|
||||
omitted
|
||||
));
|
||||
output.push_str(&format!("\n... ({} matches omitted) ...", omitted));
|
||||
break;
|
||||
}
|
||||
if !output.is_empty() {
|
||||
@ -113,18 +110,40 @@ impl Tool for ContentSearchTool {
|
||||
|
||||
let dir = self.resolve_dir(args.get("dir").and_then(|v| v.as_str()));
|
||||
let file_pattern = args.get("file_pattern").and_then(|v| v.as_str());
|
||||
let case_sensitive = args.get("case_sensitive").and_then(|v| v.as_bool()).unwrap_or(false);
|
||||
let context_lines = args.get("context_lines").and_then(|v| v.as_u64()).unwrap_or(0) as usize;
|
||||
let max_results = args.get("max_results").and_then(|v| v.as_u64()).unwrap_or(MAX_RESULTS as u64) as usize;
|
||||
let case_sensitive = args
|
||||
.get("case_sensitive")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
let context_lines = args
|
||||
.get("context_lines")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0) as usize;
|
||||
let max_results = args
|
||||
.get("max_results")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(MAX_RESULTS as u64) as usize;
|
||||
|
||||
let result = self.run_search(pattern, &dir, file_pattern, case_sensitive, context_lines, max_results).await;
|
||||
let result = self
|
||||
.run_search(
|
||||
pattern,
|
||||
&dir,
|
||||
file_pattern,
|
||||
case_sensitive,
|
||||
context_lines,
|
||||
max_results,
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(lines) => {
|
||||
let count = lines.len();
|
||||
let mut output = self.truncate_output(&lines);
|
||||
output.push_str(&format!("\n\n---\n共 {} 条匹配", count));
|
||||
Ok(ToolResult { success: true, output, error: None })
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
@ -146,22 +165,52 @@ impl ContentSearchTool {
|
||||
max_results: usize,
|
||||
) -> anyhow::Result<Vec<String>> {
|
||||
if which::which("rg").is_ok() {
|
||||
match self.search_with_rg(pattern, dir, file_pattern, case_sensitive, context_lines, max_results).await {
|
||||
match self
|
||||
.search_with_rg(
|
||||
pattern,
|
||||
dir,
|
||||
file_pattern,
|
||||
case_sensitive,
|
||||
context_lines,
|
||||
max_results,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(lines) => return Ok(lines),
|
||||
Err(e) => tracing::warn!("rg failed: {}, falling back", e),
|
||||
}
|
||||
}
|
||||
|
||||
if which::which("grep").is_ok() {
|
||||
match self.search_with_grep(pattern, dir, file_pattern, case_sensitive, context_lines, max_results).await {
|
||||
match self
|
||||
.search_with_grep(
|
||||
pattern,
|
||||
dir,
|
||||
file_pattern,
|
||||
case_sensitive,
|
||||
context_lines,
|
||||
max_results,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(lines) if !lines.is_empty() => return Ok(lines),
|
||||
Ok(_) => {},
|
||||
Ok(_) => {}
|
||||
Err(e) => tracing::warn!("grep failed: {}, falling back", e),
|
||||
}
|
||||
}
|
||||
|
||||
tracing::warn!("No rg/grep available, using built-in content search (much slower). Install ripgrep for better performance.");
|
||||
self.search_with_rust(pattern, dir, file_pattern, case_sensitive, context_lines, max_results).await
|
||||
tracing::warn!(
|
||||
"No rg/grep available, using built-in content search (much slower). Install ripgrep for better performance."
|
||||
);
|
||||
self.search_with_rust(
|
||||
pattern,
|
||||
dir,
|
||||
file_pattern,
|
||||
case_sensitive,
|
||||
context_lines,
|
||||
max_results,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn search_with_rg(
|
||||
@ -176,8 +225,10 @@ impl ContentSearchTool {
|
||||
let mut cmd = Command::new("rg");
|
||||
cmd.arg("-n")
|
||||
.arg("--no-heading")
|
||||
.arg("--color").arg("never")
|
||||
.arg("--max-count").arg(max_results.to_string())
|
||||
.arg("--color")
|
||||
.arg("never")
|
||||
.arg("--max-count")
|
||||
.arg(max_results.to_string())
|
||||
.arg(pattern)
|
||||
.arg(dir)
|
||||
.stdout(Stdio::piped())
|
||||
@ -193,12 +244,9 @@ impl ContentSearchTool {
|
||||
cmd.arg("--glob").arg(fp);
|
||||
}
|
||||
|
||||
let output = timeout(
|
||||
std::time::Duration::from_secs(TIMEOUT_SECS),
|
||||
cmd.output(),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("rg timed out after {}s", TIMEOUT_SECS))??;
|
||||
let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("rg timed out after {}s", TIMEOUT_SECS))??;
|
||||
|
||||
if !output.status.success() && output.status.code() != Some(1) {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
@ -206,7 +254,8 @@ impl ContentSearchTool {
|
||||
}
|
||||
|
||||
let text = String::from_utf8_lossy(&output.stdout);
|
||||
let lines: Vec<String> = text.lines()
|
||||
let lines: Vec<String> = text
|
||||
.lines()
|
||||
.take(max_results)
|
||||
.map(|l| l.to_string())
|
||||
.collect();
|
||||
@ -242,15 +291,13 @@ impl ContentSearchTool {
|
||||
cmd.arg("--include").arg(fp);
|
||||
}
|
||||
|
||||
let output = timeout(
|
||||
std::time::Duration::from_secs(TIMEOUT_SECS),
|
||||
cmd.output(),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("grep timed out after {}s", TIMEOUT_SECS))??;
|
||||
let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("grep timed out after {}s", TIMEOUT_SECS))??;
|
||||
|
||||
let text = String::from_utf8_lossy(&output.stdout);
|
||||
let lines: Vec<String> = text.lines()
|
||||
let lines: Vec<String> = text
|
||||
.lines()
|
||||
.take(max_results)
|
||||
.map(|l| l.to_string())
|
||||
.collect();
|
||||
@ -280,7 +327,9 @@ impl ContentSearchTool {
|
||||
if case_sensitive {
|
||||
regex::Regex::new(&re_str)
|
||||
} else {
|
||||
regex::RegexBuilder::new(&re_str).case_insensitive(true).build()
|
||||
regex::RegexBuilder::new(&re_str)
|
||||
.case_insensitive(true)
|
||||
.build()
|
||||
}
|
||||
});
|
||||
|
||||
@ -291,7 +340,14 @@ impl ContentSearchTool {
|
||||
};
|
||||
|
||||
let mut results = Vec::new();
|
||||
grep_dir(Path::new(dir), Path::new(dir), &re, file_re.as_ref(), &mut results, max_results)?;
|
||||
grep_dir(
|
||||
Path::new(dir),
|
||||
Path::new(dir),
|
||||
&re,
|
||||
file_re.as_ref(),
|
||||
&mut results,
|
||||
max_results,
|
||||
)?;
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
@ -350,16 +406,19 @@ fn grep_dir(
|
||||
|
||||
if path.is_dir() {
|
||||
if let Some(name) = rel.file_name().and_then(|n| n.to_str())
|
||||
&& name.starts_with('.') && name.len() > 1 {
|
||||
continue;
|
||||
}
|
||||
&& name.starts_with('.')
|
||||
&& name.len() > 1
|
||||
{
|
||||
continue;
|
||||
}
|
||||
grep_dir(base, &path, re, file_re, results, max)?;
|
||||
} else if path.is_file() {
|
||||
if let Some(file_re) = file_re
|
||||
&& let Some(name) = rel.file_name().and_then(|n| n.to_str())
|
||||
&& !file_re.is_match(name) {
|
||||
continue;
|
||||
}
|
||||
&& !file_re.is_match(name)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Ok(content) = std::fs::read_to_string(&path) {
|
||||
for (line_num, line) in content.lines().enumerate() {
|
||||
@ -391,8 +450,16 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_content_search_rust_fallback() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
fs::write(dir.path().join("main.rs"), "fn main() {\n let x = 42;\n println!(\"hello\");\n}").unwrap();
|
||||
fs::write(dir.path().join("lib.rs"), "pub fn foo() -> u32 {\n let y = 42;\n y\n}").unwrap();
|
||||
fs::write(
|
||||
dir.path().join("main.rs"),
|
||||
"fn main() {\n let x = 42;\n println!(\"hello\");\n}",
|
||||
)
|
||||
.unwrap();
|
||||
fs::write(
|
||||
dir.path().join("lib.rs"),
|
||||
"pub fn foo() -> u32 {\n let y = 42;\n y\n}",
|
||||
)
|
||||
.unwrap();
|
||||
fs::write(dir.path().join("README.md"), "# Project\nHello world").unwrap();
|
||||
|
||||
let tool = ContentSearchTool::new();
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::{json, Value};
|
||||
use serde_json::{Value, json};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::scheduler::{next_run_for_schedule, Schedule};
|
||||
use crate::scheduler::{Schedule, next_run_for_schedule};
|
||||
use crate::storage::{ScheduledJob, Storage};
|
||||
use crate::tools::traits::{Tool, ToolResult};
|
||||
|
||||
@ -229,10 +229,7 @@ impl Tool for CronListTool {
|
||||
}
|
||||
|
||||
async fn execute(&self, args: Value) -> anyhow::Result<ToolResult> {
|
||||
let filter = args
|
||||
.get("status")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("all");
|
||||
let filter = args.get("status").and_then(|v| v.as_str()).unwrap_or("all");
|
||||
let jobs = self.storage.list_scheduled_jobs().await?;
|
||||
|
||||
let filtered: Vec<&ScheduledJob> = match filter {
|
||||
@ -397,7 +394,9 @@ impl Tool for CronEnableTool {
|
||||
.map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?;
|
||||
|
||||
let next = next_run_for_schedule(&job.schedule, now_ms());
|
||||
self.storage.set_scheduled_job_enabled(&job_id, true).await?;
|
||||
self.storage
|
||||
.set_scheduled_job_enabled(&job_id, true)
|
||||
.await?;
|
||||
if let Some(n) = next {
|
||||
self.storage.set_scheduled_job_next_run(&job_id, n).await?;
|
||||
}
|
||||
@ -464,7 +463,9 @@ impl Tool for CronDisableTool {
|
||||
.get_scheduled_job(&job_id)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?;
|
||||
self.storage.set_scheduled_job_enabled(&job_id, false).await?;
|
||||
self.storage
|
||||
.set_scheduled_job_enabled(&job_id, false)
|
||||
.await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
@ -580,7 +581,9 @@ impl Tool for CronUpdateTool {
|
||||
if args.get("schedule").is_some() {
|
||||
let job = self.storage.get_scheduled_job(&job_id).await?;
|
||||
if let Some(next) = next_run_for_schedule(&job.schedule, now_ms()) {
|
||||
self.storage.set_scheduled_job_next_run(&job_id, next).await?;
|
||||
self.storage
|
||||
.set_scheduled_job_next_run(&job_id, next)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
@ -765,9 +768,7 @@ mod tests {
|
||||
let job = ScheduledJob {
|
||||
id: "job-update-tool".into(),
|
||||
name: "old".into(),
|
||||
schedule: Schedule::Every {
|
||||
every_ms: 3600000,
|
||||
},
|
||||
schedule: Schedule::Every { every_ms: 3600000 },
|
||||
prompt: "old prompt".into(),
|
||||
channel: "feishu".into(),
|
||||
chat_id: "oc_1".into(),
|
||||
|
||||
390
src/tools/delegate.rs
Normal file
390
src/tools/delegate.rs
Normal file
@ -0,0 +1,390 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::agent::{ExecutionMode, SubAgentConfig, SubAgentManager, TaskStatus};
|
||||
use crate::tools::traits::{Tool, ToolResult};
|
||||
|
||||
pub struct DelegateTool {
|
||||
sub_agent_manager: Arc<SubAgentManager>,
|
||||
}
|
||||
|
||||
impl DelegateTool {
|
||||
pub fn new(sub_agent_manager: Arc<SubAgentManager>) -> Self {
|
||||
Self { sub_agent_manager }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for DelegateTool {
|
||||
fn name(&self) -> &str {
|
||||
"delegate"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"子任务委托工具。创建子 Agent 处理独立任务,支持三种模式:\
|
||||
inline (阻塞返回结果)、background (异步执行,完成后通知)、\
|
||||
parallel (多个子 Agent 并发执行,聚合结果)。\
|
||||
也可用于查询、取消和列出后台任务。"
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["run", "check_task", "cancel_task", "list_tasks"],
|
||||
"description": "操作类型: run 创建子Agent执行任务, check_task 查询后台任务, cancel_task 取消后台任务, list_tasks 列出后台任务"
|
||||
},
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "子任务描述(可内含额外约束,如:跳过 .tmp 文件)。action=run 时必填"
|
||||
},
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"enum": ["inline", "background", "parallel"],
|
||||
"description": "执行模式: inline=阻塞返回结果, background=异步执行+通知, parallel=多子Agent并发。默认 inline"
|
||||
},
|
||||
"allowed_tools": {
|
||||
"type": "array",
|
||||
"items": { "type": "string" },
|
||||
"description": "允许子Agent使用的工具列表。不填使用默认只读集: file_read,file_search,content_search,web_fetch,http_request,calculator"
|
||||
},
|
||||
"max_iterations": {
|
||||
"type": "integer",
|
||||
"description": "最大迭代次数,默认 99"
|
||||
},
|
||||
"timeout_secs": {
|
||||
"type": "integer",
|
||||
"description": "超时秒数,默认 3600(1小时)"
|
||||
},
|
||||
"tasks": {
|
||||
"type": "array",
|
||||
"description": "并行模式下的多个子任务(仅 mode=parallel 时使用)",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": { "type": "string", "description": "子任务描述" },
|
||||
"allowed_tools": {
|
||||
"type": "array",
|
||||
"items": { "type": "string" },
|
||||
"description": "该子任务的工具列表"
|
||||
}
|
||||
},
|
||||
"required": ["prompt"]
|
||||
}
|
||||
},
|
||||
"task_id": {
|
||||
"type": "string",
|
||||
"description": "后台任务ID(action=check_task/cancel_task 时必填)"
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
})
|
||||
}
|
||||
|
||||
fn read_only(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let action = args["action"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing required parameter: action"))?;
|
||||
|
||||
match action {
|
||||
"run" => self.handle_run(&args).await,
|
||||
"check_task" => self.handle_check_task(&args).await,
|
||||
"cancel_task" => self.handle_cancel_task(&args).await,
|
||||
"list_tasks" => self.handle_list_tasks(&args).await,
|
||||
_ => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Unknown action: {}. Supported: run, check_task, cancel_task, list_tasks",
|
||||
action
|
||||
)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DelegateTool {
|
||||
fn parse_config_from_args(&self, args: &serde_json::Value) -> anyhow::Result<SubAgentConfig> {
|
||||
let prompt = args["prompt"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing required parameter: prompt"))?
|
||||
.to_string();
|
||||
|
||||
let allowed_tools: Option<Vec<String>> = args["allowed_tools"].as_array().map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(|s| s.to_string()))
|
||||
.collect()
|
||||
});
|
||||
|
||||
let max_iterations = args["max_iterations"].as_u64().map(|v| v as usize);
|
||||
let timeout_secs = args["timeout_secs"].as_u64();
|
||||
|
||||
Ok(SubAgentConfig {
|
||||
prompt,
|
||||
mode: ExecutionMode::Inline,
|
||||
allowed_tools,
|
||||
max_iterations,
|
||||
timeout_secs,
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_run(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let mode_str = args["mode"].as_str().unwrap_or("inline");
|
||||
let mode = match mode_str {
|
||||
"inline" => ExecutionMode::Inline,
|
||||
"background" => ExecutionMode::Background,
|
||||
"parallel" => ExecutionMode::Parallel,
|
||||
_ => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"unknown mode: {}. Supported: inline, background, parallel",
|
||||
mode_str
|
||||
)),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
match mode {
|
||||
ExecutionMode::Inline => {
|
||||
let config = self.parse_config_from_args(args)?;
|
||||
let result = self
|
||||
.sub_agent_manager
|
||||
.run_inline(config)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
|
||||
match result.status {
|
||||
TaskStatus::Completed => Ok(ToolResult {
|
||||
success: true,
|
||||
output: result.content,
|
||||
error: None,
|
||||
}),
|
||||
TaskStatus::Failed(err) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: result.content,
|
||||
error: Some(err),
|
||||
}),
|
||||
TaskStatus::TimedOut => Ok(ToolResult {
|
||||
success: false,
|
||||
output: result.content,
|
||||
error: Some("sub-agent timed out".into()),
|
||||
}),
|
||||
TaskStatus::Cancelled => Ok(ToolResult {
|
||||
success: false,
|
||||
output: result.content,
|
||||
error: Some("sub-agent cancelled".into()),
|
||||
}),
|
||||
}
|
||||
}
|
||||
ExecutionMode::Background => {
|
||||
let config = self.parse_config_from_args(args)?;
|
||||
let ctx = crate::agent::sub_agent::get_delegate_context().map_err(|_| {
|
||||
anyhow::anyhow!("delegate context not available: not in an agent worker")
|
||||
})?;
|
||||
|
||||
let task_id = self
|
||||
.sub_agent_manager
|
||||
.run_background(config, ctx)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("后台任务已启动。\ntask_id: {}", task_id),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
ExecutionMode::Parallel => {
|
||||
let tasks = args["tasks"]
|
||||
.as_array()
|
||||
.ok_or_else(|| anyhow::anyhow!("parallel mode requires 'tasks' array"))?;
|
||||
|
||||
let mut configs = Vec::new();
|
||||
for task in tasks {
|
||||
let prompt = task["prompt"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("each parallel task requires 'prompt'"))?
|
||||
.to_string();
|
||||
let allowed_tools: Option<Vec<String>> =
|
||||
task["allowed_tools"].as_array().map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(|s| s.to_string()))
|
||||
.collect()
|
||||
});
|
||||
|
||||
configs.push(SubAgentConfig {
|
||||
prompt,
|
||||
mode: ExecutionMode::Inline,
|
||||
allowed_tools,
|
||||
max_iterations: args["max_iterations"].as_u64().map(|v| v as usize),
|
||||
timeout_secs: args["timeout_secs"].as_u64(),
|
||||
});
|
||||
}
|
||||
|
||||
let has_args_allowed = args["allowed_tools"].as_array().is_some();
|
||||
for c in &mut configs {
|
||||
if c.allowed_tools.is_none() && has_args_allowed {
|
||||
c.allowed_tools = args["allowed_tools"].as_array().map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(|s| s.to_string()))
|
||||
.collect()
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let results = self
|
||||
.sub_agent_manager
|
||||
.run_parallel(configs)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
|
||||
let mut output = String::new();
|
||||
for (i, r) in results.iter().enumerate() {
|
||||
let status_icon = match r.status {
|
||||
TaskStatus::Completed => "✅",
|
||||
TaskStatus::Failed(_) => "❌",
|
||||
TaskStatus::TimedOut => "⏱️ 超时",
|
||||
TaskStatus::Cancelled => "🚫 已取消",
|
||||
};
|
||||
output.push_str(&format!("[task_{}] {}\n", i + 1, status_icon));
|
||||
if !r.content.is_empty() {
|
||||
output.push_str(&r.content);
|
||||
output.push_str("\n\n");
|
||||
}
|
||||
if let TaskStatus::Failed(ref err) = r.status {
|
||||
output.push_str(&format!("错误: {}\n\n", err));
|
||||
}
|
||||
}
|
||||
|
||||
let all_success = results
|
||||
.iter()
|
||||
.all(|r| matches!(r.status, TaskStatus::Completed));
|
||||
Ok(ToolResult {
|
||||
success: all_success,
|
||||
output: output.trim().to_string(),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_check_task(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let task_id = args["task_id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing required parameter: task_id"))?;
|
||||
|
||||
match self.sub_agent_manager.check_task(task_id).await {
|
||||
Some(task) => {
|
||||
let status_icon = match task.status.as_str() {
|
||||
"completed" => "✅ 已完成",
|
||||
"failed" => "❌ 失败",
|
||||
"cancelled" => "🚫 已取消",
|
||||
"running" => "🔄 运行中",
|
||||
"pending" => "⏳ 等待中",
|
||||
_ => task.status.as_str(),
|
||||
};
|
||||
let mut output = format!(
|
||||
"任务 ID: {}\n状态: {}\n任务: {}",
|
||||
task.id, status_icon, task.prompt
|
||||
);
|
||||
if let Some(ref result) = task.result {
|
||||
output.push_str(&format!("\n\n结果:\n{}", result));
|
||||
}
|
||||
if let Some(ref error) = task.error {
|
||||
output.push_str(&format!("\n错误: {}", error));
|
||||
}
|
||||
if let Some(started) = task.started_at {
|
||||
if let Some(finished) = task.finished_at {
|
||||
let duration = (finished - started) as f64 / 1000.0;
|
||||
output.push_str(&format!("\n耗时: {:.1}s", duration));
|
||||
}
|
||||
}
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
None => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("task not found: {}", task_id)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_cancel_task(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let task_id = args["task_id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing required parameter: task_id"))?;
|
||||
|
||||
match self.sub_agent_manager.cancel_task(task_id).await {
|
||||
Ok(true) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("后台任务 {} 已取消", task_id),
|
||||
error: None,
|
||||
}),
|
||||
Ok(false) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("无法取消任务 {}(可能已完成或不存在)", task_id)),
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("取消失败: {}", e)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_list_tasks(&self, _args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let ctx = crate::agent::sub_agent::get_delegate_context()
|
||||
.map_err(|_| anyhow::anyhow!("delegate context not available"))?;
|
||||
let tasks = self.sub_agent_manager.list_tasks(&ctx.session_id).await;
|
||||
|
||||
if tasks.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: true,
|
||||
output: "没有后台任务".to_string(),
|
||||
error: None,
|
||||
});
|
||||
}
|
||||
|
||||
let mut output = String::from("后台任务列表:\n\n");
|
||||
for task in &tasks {
|
||||
let status_icon = match task.status.as_str() {
|
||||
"completed" => "✅",
|
||||
"failed" => "❌",
|
||||
"cancelled" => "🚫",
|
||||
"running" => "🔄",
|
||||
"pending" => "⏳",
|
||||
_ => "❓",
|
||||
};
|
||||
output.push_str(&format!(
|
||||
"{} {} - {} - {} (created: {})\n",
|
||||
status_icon,
|
||||
&task.id[..std::cmp::min(8, task.id.len())],
|
||||
task.prompt.chars().take(60).collect::<String>(),
|
||||
task.status,
|
||||
task.created_at,
|
||||
));
|
||||
}
|
||||
output.push_str(&format!("\n共 {} 个任务", tasks.len()));
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -243,8 +243,8 @@ impl Tool for FileEditTool {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::NamedTempFile;
|
||||
use std::io::Write;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_edit_simple() {
|
||||
|
||||
@ -181,10 +181,7 @@ impl Tool for FileReadTool {
|
||||
}
|
||||
result = lines[..end_idx].join("\n");
|
||||
let truncated = original_len - result.len();
|
||||
result.push_str(&format!(
|
||||
"\n\n... ({} chars truncated) ...",
|
||||
truncated
|
||||
));
|
||||
result.push_str(&format!("\n\n... ({} chars truncated) ...", truncated));
|
||||
}
|
||||
|
||||
if end < total {
|
||||
@ -196,10 +193,7 @@ impl Tool for FileReadTool {
|
||||
end + 1
|
||||
));
|
||||
} else {
|
||||
result.push_str(&format!(
|
||||
"\n\n(End of file — {} lines total)",
|
||||
total
|
||||
));
|
||||
result.push_str(&format!("\n\n(End of file — {} lines total)", total));
|
||||
}
|
||||
|
||||
if let Some(label) = encoding_label {
|
||||
@ -214,7 +208,7 @@ impl Tool for FileReadTool {
|
||||
}
|
||||
None => {
|
||||
// Truly binary file — base64 encode
|
||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||
use base64::{Engine, engine::general_purpose::STANDARD};
|
||||
let encoded = STANDARD.encode(&bytes);
|
||||
let mime = mime_guess::from_path(&resolved)
|
||||
.first_or_octet_stream()
|
||||
@ -278,8 +272,8 @@ fn decode_text(bytes: &[u8]) -> (Option<String>, Option<&'static str>) {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::NamedTempFile;
|
||||
use std::io::Write;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_read_simple_file() {
|
||||
@ -338,10 +332,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_is_directory() {
|
||||
let tool = FileReadTool::new();
|
||||
let result = tool
|
||||
.execute(json!({ "path": "." }))
|
||||
.await
|
||||
.unwrap();
|
||||
let result = tool.execute(json!({ "path": "." })).await.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("Not a file"));
|
||||
|
||||
@ -101,17 +101,29 @@ impl Tool for FileSearchTool {
|
||||
};
|
||||
|
||||
let dir = self.resolve_dir(args.get("dir").and_then(|v| v.as_str()));
|
||||
let case_sensitive = args.get("case_sensitive").and_then(|v| v.as_bool()).unwrap_or(true);
|
||||
let max_results = args.get("max_results").and_then(|v| v.as_u64()).unwrap_or(MAX_RESULTS as u64) as usize;
|
||||
let case_sensitive = args
|
||||
.get("case_sensitive")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(true);
|
||||
let max_results = args
|
||||
.get("max_results")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(MAX_RESULTS as u64) as usize;
|
||||
|
||||
let result = self.run_search(pattern, &dir, case_sensitive, max_results).await;
|
||||
let result = self
|
||||
.run_search(pattern, &dir, case_sensitive, max_results)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(lines) => {
|
||||
let count = lines.len();
|
||||
let mut output = self.truncate_output(&lines);
|
||||
output.push_str(&format!("\n\n---\n共 {} 个文件", count));
|
||||
Ok(ToolResult { success: true, output, error: None })
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
@ -139,9 +151,12 @@ impl FileSearchTool {
|
||||
};
|
||||
|
||||
if !fd_cmd.is_empty() {
|
||||
match self.search_with_fd(pattern, dir, case_sensitive, max_results, fd_cmd).await {
|
||||
match self
|
||||
.search_with_fd(pattern, dir, case_sensitive, max_results, fd_cmd)
|
||||
.await
|
||||
{
|
||||
Ok(lines) if !lines.is_empty() => return Ok(lines),
|
||||
Ok(_) => {},
|
||||
Ok(_) => {}
|
||||
Err(e) => tracing::warn!("{} failed: {}, falling back", fd_cmd, e),
|
||||
}
|
||||
}
|
||||
@ -149,13 +164,14 @@ impl FileSearchTool {
|
||||
if which::which("find").is_ok() {
|
||||
match self.search_with_find(pattern, dir, max_results).await {
|
||||
Ok(lines) if !lines.is_empty() => return Ok(lines),
|
||||
Ok(_) => {},
|
||||
Ok(_) => {}
|
||||
Err(e) => tracing::warn!("find failed: {}, falling back", e),
|
||||
}
|
||||
}
|
||||
|
||||
tracing::warn!("No fd/find available, using built-in file search (slower)");
|
||||
self.search_with_rust(pattern, dir, case_sensitive, max_results).await
|
||||
self.search_with_rust(pattern, dir, case_sensitive, max_results)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn search_with_fd(
|
||||
@ -167,11 +183,15 @@ impl FileSearchTool {
|
||||
fd_cmd: &str,
|
||||
) -> anyhow::Result<Vec<String>> {
|
||||
let mut cmd = Command::new(fd_cmd);
|
||||
cmd.arg("--search-path").arg(dir)
|
||||
.arg("--glob").arg(pattern)
|
||||
.arg("--color").arg("never")
|
||||
cmd.arg("--search-path")
|
||||
.arg(dir)
|
||||
.arg("--glob")
|
||||
.arg(pattern)
|
||||
.arg("--color")
|
||||
.arg("never")
|
||||
.arg("--strip-cwd-prefix")
|
||||
.arg("--max-results").arg(max_results.to_string())
|
||||
.arg("--max-results")
|
||||
.arg(max_results.to_string())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped());
|
||||
|
||||
@ -179,12 +199,9 @@ impl FileSearchTool {
|
||||
cmd.arg("--ignore-case");
|
||||
}
|
||||
|
||||
let output = timeout(
|
||||
std::time::Duration::from_secs(TIMEOUT_SECS),
|
||||
cmd.output(),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("fd timed out after {}s", TIMEOUT_SECS))??;
|
||||
let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("fd timed out after {}s", TIMEOUT_SECS))??;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
@ -192,7 +209,8 @@ impl FileSearchTool {
|
||||
}
|
||||
|
||||
let text = String::from_utf8_lossy(&output.stdout);
|
||||
let lines: Vec<String> = text.lines()
|
||||
let lines: Vec<String> = text
|
||||
.lines()
|
||||
.filter(|l| !l.is_empty())
|
||||
.map(|l| l.to_string())
|
||||
.collect();
|
||||
@ -215,15 +233,13 @@ impl FileSearchTool {
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::null());
|
||||
|
||||
let output = timeout(
|
||||
std::time::Duration::from_secs(TIMEOUT_SECS),
|
||||
cmd.output(),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("find timed out after {}s", TIMEOUT_SECS))??;
|
||||
let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("find timed out after {}s", TIMEOUT_SECS))??;
|
||||
|
||||
let text = String::from_utf8_lossy(&output.stdout);
|
||||
let mut lines: Vec<String> = text.lines()
|
||||
let mut lines: Vec<String> = text
|
||||
.lines()
|
||||
.filter(|l| !l.is_empty())
|
||||
.map(|l| {
|
||||
let p = Path::new(l);
|
||||
@ -254,7 +270,13 @@ impl FileSearchTool {
|
||||
.map_err(|e| anyhow::anyhow!("Invalid glob pattern '{}': {}", pattern, e))?;
|
||||
|
||||
let mut results = Vec::new();
|
||||
walk_dir(Path::new(dir), Path::new(dir), &re, &mut results, max_results)?;
|
||||
walk_dir(
|
||||
Path::new(dir),
|
||||
Path::new(dir),
|
||||
&re,
|
||||
&mut results,
|
||||
max_results,
|
||||
)?;
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
@ -311,15 +333,18 @@ fn walk_dir(
|
||||
|
||||
if path.is_dir() {
|
||||
if let Some(name) = rel.file_name().and_then(|n| n.to_str())
|
||||
&& name.starts_with('.') && name.len() > 1 {
|
||||
continue;
|
||||
}
|
||||
&& name.starts_with('.')
|
||||
&& name.len() > 1
|
||||
{
|
||||
continue;
|
||||
}
|
||||
walk_dir(base, &path, re, results, max)?;
|
||||
} else if path.is_file() {
|
||||
if let Some(name) = rel.file_name().and_then(|n| n.to_str())
|
||||
&& re.is_match(name) {
|
||||
results.push(rel.to_string_lossy().to_string());
|
||||
}
|
||||
&& re.is_match(name)
|
||||
{
|
||||
results.push(rel.to_string_lossy().to_string());
|
||||
}
|
||||
if results.len() >= max {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
@ -90,13 +90,14 @@ impl Tool for FileWriteTool {
|
||||
// Create parent directories if needed
|
||||
if let Some(parent) = resolved.parent()
|
||||
&& !parent.exists()
|
||||
&& let Err(e) = std::fs::create_dir_all(parent) {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Failed to create parent directory: {}", e)),
|
||||
});
|
||||
}
|
||||
&& let Err(e) = std::fs::create_dir_all(parent)
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Failed to create parent directory: {}", e)),
|
||||
});
|
||||
}
|
||||
|
||||
match std::fs::write(&resolved, content) {
|
||||
Ok(_) => Ok(ToolResult {
|
||||
@ -168,10 +169,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_write_missing_path() {
|
||||
let tool = FileWriteTool::new();
|
||||
let result = tool
|
||||
.execute(json!({ "content": "Hello" }))
|
||||
.await
|
||||
.unwrap();
|
||||
let result = tool.execute(json!({ "content": "Hello" })).await.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("path"));
|
||||
|
||||
@ -129,7 +129,9 @@ impl GetSkillTool {
|
||||
let mut output = format!("可用 skill (共 {} 个):\n", skills.len());
|
||||
for s in &skills {
|
||||
let always_mark = if s.always { " [常驻]" } else { "" };
|
||||
let path_str = s.path.as_ref()
|
||||
let path_str = s
|
||||
.path
|
||||
.as_ref()
|
||||
.map(|p| p.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|| "—".to_string());
|
||||
output.push_str(&format!(
|
||||
@ -148,10 +150,10 @@ impl GetSkillTool {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::tempdir;
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_existing_skill() {
|
||||
|
||||
@ -50,10 +50,7 @@ impl HttpRequestTool {
|
||||
}
|
||||
|
||||
if !host_matches_allowlist(&host, &self.allowed_domains) {
|
||||
return Err(format!(
|
||||
"Host '{}' is not in allowed_domains",
|
||||
host
|
||||
));
|
||||
return Err(format!("Host '{}' is not in allowed_domains", host));
|
||||
}
|
||||
|
||||
Ok(url.to_string())
|
||||
@ -80,11 +77,10 @@ impl HttpRequestTool {
|
||||
for (key, value) in obj {
|
||||
if let Some(str_val) = value.as_str()
|
||||
&& let Ok(name) = reqwest::header::HeaderName::from_bytes(key.as_bytes())
|
||||
&& let Ok(val) =
|
||||
reqwest::header::HeaderValue::from_str(str_val)
|
||||
{
|
||||
header_map.insert(name, val);
|
||||
}
|
||||
&& let Ok(val) = reqwest::header::HeaderValue::from_str(str_val)
|
||||
{
|
||||
header_map.insert(name, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -191,7 +187,9 @@ fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool {
|
||||
|
||||
allowed_domains.iter().any(|domain| {
|
||||
host == domain
|
||||
|| host.strip_suffix(domain).is_some_and(|prefix| prefix.ends_with('.'))
|
||||
|| host
|
||||
.strip_suffix(domain)
|
||||
.is_some_and(|prefix| prefix.ends_with('.'))
|
||||
})
|
||||
}
|
||||
|
||||
@ -202,7 +200,11 @@ fn is_private_host(host: &str) -> bool {
|
||||
}
|
||||
|
||||
// Check .local TLD
|
||||
if host.rsplit('.').next().is_some_and(|label| label == "local") {
|
||||
if host
|
||||
.rsplit('.')
|
||||
.next()
|
||||
.is_some_and(|label| label == "local")
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -224,9 +226,7 @@ fn is_private_ip(ip: &std::net::IpAddr) -> bool {
|
||||
|| v4.is_broadcast()
|
||||
|| v4.is_multicast()
|
||||
}
|
||||
std::net::IpAddr::V6(v6) => {
|
||||
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
|
||||
}
|
||||
std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -278,10 +278,7 @@ impl Tool for HttpRequestTool {
|
||||
}
|
||||
};
|
||||
|
||||
let method_str = args
|
||||
.get("method")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("GET");
|
||||
let method_str = args.get("method").and_then(|v| v.as_str()).unwrap_or("GET");
|
||||
|
||||
let headers_val = args.get("headers").cloned().unwrap_or(json!({}));
|
||||
let body = args.get("body").and_then(|v| v.as_str());
|
||||
|
||||
@ -151,10 +151,19 @@ impl Tool for MemoryRecallTool {
|
||||
.and_then(|v| v.as_i64())
|
||||
.unwrap_or(chrono::Utc::now().timestamp_millis());
|
||||
self.memory
|
||||
.recall_by_time(since, until, Some(query), limit, Some(MemoryCategory::Knowledge), None)
|
||||
.recall_by_time(
|
||||
since,
|
||||
until,
|
||||
Some(query),
|
||||
limit,
|
||||
Some(MemoryCategory::Knowledge),
|
||||
None,
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
self.memory.recall(query, limit, Some(MemoryCategory::Knowledge), None).await?
|
||||
self.memory
|
||||
.recall(query, limit, Some(MemoryCategory::Knowledge), None)
|
||||
.await?
|
||||
};
|
||||
|
||||
if entries.is_empty() {
|
||||
@ -168,7 +177,11 @@ impl Tool for MemoryRecallTool {
|
||||
let formatted = entries
|
||||
.iter()
|
||||
.map(|e| {
|
||||
let session = e.session_id.as_deref().map(|s| format!(" [session: {}]", s)).unwrap_or_default();
|
||||
let session = e
|
||||
.session_id
|
||||
.as_deref()
|
||||
.map(|s| format!(" [session: {}]", s))
|
||||
.unwrap_or_default();
|
||||
format!(
|
||||
"- {} [{}]{} [importance: {:.1}]: {}",
|
||||
e.key,
|
||||
@ -264,10 +277,19 @@ impl Tool for TimelineRecallTool {
|
||||
.and_then(|v| v.as_i64())
|
||||
.unwrap_or(chrono::Utc::now().timestamp_millis());
|
||||
self.memory
|
||||
.recall_by_time(since, until, Some(query), limit, Some(MemoryCategory::Timeline), session_id)
|
||||
.recall_by_time(
|
||||
since,
|
||||
until,
|
||||
Some(query),
|
||||
limit,
|
||||
Some(MemoryCategory::Timeline),
|
||||
session_id,
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
self.memory.recall(query, limit, Some(MemoryCategory::Timeline), session_id).await?
|
||||
self.memory
|
||||
.recall(query, limit, Some(MemoryCategory::Timeline), session_id)
|
||||
.await?
|
||||
};
|
||||
|
||||
if entries.is_empty() {
|
||||
@ -281,7 +303,11 @@ impl Tool for TimelineRecallTool {
|
||||
let formatted = entries
|
||||
.iter()
|
||||
.map(|e| {
|
||||
let session = e.session_id.as_deref().map(|s| format!(" [session: {}]", s)).unwrap_or_default();
|
||||
let session = e
|
||||
.session_id
|
||||
.as_deref()
|
||||
.map(|s| format!(" [session: {}]", s))
|
||||
.unwrap_or_default();
|
||||
format!(
|
||||
"- {} [{}]{} [importance: {:.1}]: {}",
|
||||
e.key,
|
||||
|
||||
@ -4,6 +4,7 @@ pub mod calculator;
|
||||
pub mod chat_manager;
|
||||
pub mod content_search;
|
||||
pub mod cron;
|
||||
pub mod delegate;
|
||||
pub mod file_edit;
|
||||
pub mod file_read;
|
||||
pub mod file_search;
|
||||
@ -23,6 +24,7 @@ pub use browser::BrowserTool;
|
||||
pub use calculator::CalculatorTool;
|
||||
pub use chat_manager::ChatManagerTool;
|
||||
pub use content_search::ContentSearchTool;
|
||||
pub use delegate::DelegateTool;
|
||||
pub use file_edit::FileEditTool;
|
||||
pub use file_read::FileReadTool;
|
||||
pub use file_search::FileSearchTool;
|
||||
@ -35,10 +37,11 @@ pub use send_message::SendMessageTool;
|
||||
pub use traits::{OutboundMessenger, Tool, ToolResult};
|
||||
pub use web_fetch::WebFetchTool;
|
||||
|
||||
use std::sync::Arc;
|
||||
use crate::agent::SubAgentManager;
|
||||
use crate::config::BrowserConfig;
|
||||
use crate::memory::MemoryManager;
|
||||
use crate::skills::SkillsLoader;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Create the base tool registry (without send_message).
|
||||
/// `send_message` tool is registered later via `SessionManager::register_outbound_tool()`
|
||||
@ -46,6 +49,7 @@ use crate::skills::SkillsLoader;
|
||||
pub fn create_default_tools(
|
||||
skills_loader: Arc<SkillsLoader>,
|
||||
memory: Arc<MemoryManager>,
|
||||
sub_agent_manager: Option<Arc<SubAgentManager>>,
|
||||
browser_config: Option<&BrowserConfig>,
|
||||
) -> ToolRegistry {
|
||||
let registry = ToolRegistry::new();
|
||||
@ -76,5 +80,9 @@ pub fn create_default_tools(
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(mgr) = sub_agent_manager {
|
||||
registry.register(DelegateTool::new(mgr));
|
||||
}
|
||||
|
||||
registry
|
||||
}
|
||||
|
||||
608
src/tools/pty.rs
Normal file
608
src/tools/pty.rs
Normal file
@ -0,0 +1,608 @@
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::io::Write;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Instant;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::tools::traits::{Tool, ToolResult};
|
||||
|
||||
const MAX_OUTPUT_LINES: usize = 50_000;
|
||||
const MAX_CHARS_PER_LINE: usize = 2_000;
|
||||
const MAX_SESSIONS: usize = 10;
|
||||
|
||||
fn guard_command(command: &str) -> Option<String> {
|
||||
let deny_patterns: &[&str] = &[
|
||||
r"\brm\s+-[rf]{1,2}\b",
|
||||
r"\bdel\s+/[fq]\b",
|
||||
r"\brmdir\s+/s\b",
|
||||
r":\(\)\s*\{.*\};\s*:",
|
||||
];
|
||||
let lower = command.to_lowercase();
|
||||
for pattern in deny_patterns {
|
||||
if regex::Regex::new(pattern)
|
||||
.ok()
|
||||
.map(|re| re.is_match(&lower))
|
||||
.unwrap_or(false)
|
||||
{
|
||||
return Some(format!(
|
||||
"Command blocked by safety guard (dangerous pattern: {})",
|
||||
pattern
|
||||
));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn unescape_control_chars(s: &str) -> String {
|
||||
let mut result = String::with_capacity(s.len());
|
||||
let mut chars = s.chars().peekable();
|
||||
while let Some(c) = chars.next() {
|
||||
if c == '\\' {
|
||||
match chars.next() {
|
||||
Some('n') => result.push('\n'),
|
||||
Some('r') => result.push('\r'),
|
||||
Some('t') => result.push('\t'),
|
||||
Some('x') => {
|
||||
let hex: String = chars.by_ref().take(2).collect();
|
||||
if hex.len() == 2 {
|
||||
if let Ok(byte) = u8::from_str_radix(&hex, 16) {
|
||||
result.push(byte as char);
|
||||
} else {
|
||||
result.push_str(&format!("\\x{}", hex));
|
||||
}
|
||||
} else {
|
||||
result.push_str(&format!("\\x{}", hex));
|
||||
}
|
||||
}
|
||||
Some('\\') => result.push('\\'),
|
||||
Some(other) => {
|
||||
result.push('\\');
|
||||
result.push(other);
|
||||
}
|
||||
None => result.push('\\'),
|
||||
}
|
||||
} else {
|
||||
result.push(c);
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn truncate_command(cmd: &str, max_len: usize) -> String {
|
||||
let cmd = cmd.trim();
|
||||
let first_arg = cmd.split_whitespace().next().unwrap_or(cmd);
|
||||
if first_arg.len() > max_len {
|
||||
format!("{}...", &first_arg[..max_len])
|
||||
} else {
|
||||
first_arg.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn status_str(status: &SessionStatus) -> &str {
|
||||
match status {
|
||||
SessionStatus::Running => "running",
|
||||
SessionStatus::Exited(_) => "exited",
|
||||
SessionStatus::Killed => "killed",
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum SessionStatus {
|
||||
Running,
|
||||
Exited(i32),
|
||||
Killed,
|
||||
}
|
||||
|
||||
struct PtySession {
|
||||
#[allow(dead_code)]
|
||||
id: String,
|
||||
#[allow(dead_code)]
|
||||
command: String,
|
||||
#[allow(dead_code)]
|
||||
started_at: Instant,
|
||||
status: SessionStatus,
|
||||
child: Arc<Mutex<Option<Box<dyn portable_pty::Child + Send + Sync>>>>,
|
||||
writer: Arc<Mutex<Option<Box<dyn Write + Send>>>>,
|
||||
output_buffer: VecDeque<String>,
|
||||
output_total_lines: usize,
|
||||
}
|
||||
|
||||
impl PtySession {
|
||||
fn new(
|
||||
id: String,
|
||||
command: String,
|
||||
child: Box<dyn portable_pty::Child + Send + Sync>,
|
||||
writer: Box<dyn Write + Send>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id,
|
||||
command,
|
||||
started_at: Instant::now(),
|
||||
status: SessionStatus::Running,
|
||||
child: Arc::new(Mutex::new(Some(child))),
|
||||
writer: Arc::new(Mutex::new(Some(writer))),
|
||||
output_buffer: VecDeque::new(),
|
||||
output_total_lines: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn push_line(&mut self, line: String) {
|
||||
let line = if line.len() > MAX_CHARS_PER_LINE {
|
||||
format!("{}...<truncated>", &line[..MAX_CHARS_PER_LINE])
|
||||
} else {
|
||||
line
|
||||
};
|
||||
self.output_total_lines += 1;
|
||||
if self.output_buffer.len() >= MAX_OUTPUT_LINES {
|
||||
self.output_buffer.pop_front();
|
||||
}
|
||||
self.output_buffer.push_back(line);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PtyManager {
|
||||
sessions: Mutex<HashMap<String, Arc<Mutex<PtySession>>>>,
|
||||
}
|
||||
|
||||
impl PtyManager {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
sessions: Mutex::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cleanup_all(&self) {
|
||||
let sessions: Vec<Arc<Mutex<PtySession>>> = {
|
||||
let mut guard = self.sessions.lock().unwrap();
|
||||
guard.drain().map(|(_, s)| s).collect()
|
||||
};
|
||||
for session in sessions {
|
||||
let mut guard = session.lock().unwrap();
|
||||
let mut child_guard = guard.child.lock().unwrap();
|
||||
if let Some(ref mut child) = *child_guard {
|
||||
let _ = child.kill();
|
||||
}
|
||||
*child_guard = None;
|
||||
guard.status = SessionStatus::Killed;
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn(&self, command: &str) -> Result<String, String> {
|
||||
if let Some(reason) = guard_command(command) {
|
||||
return Err(reason);
|
||||
}
|
||||
|
||||
let mut sessions = self.sessions.lock().unwrap();
|
||||
if sessions.len() >= MAX_SESSIONS {
|
||||
return Err(format!(
|
||||
"Max sessions ({}) reached, kill some sessions first",
|
||||
MAX_SESSIONS
|
||||
));
|
||||
}
|
||||
|
||||
let session_id = format!("pty_{}", crate::util::short_id());
|
||||
let cwd = std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from("."));
|
||||
|
||||
let pty_system = portable_pty::native_pty_system();
|
||||
let pty_pair = pty_system
|
||||
.openpty(portable_pty::PtySize {
|
||||
rows: 24,
|
||||
cols: 80,
|
||||
pixel_width: 0,
|
||||
pixel_height: 0,
|
||||
})
|
||||
.map_err(|e| format!("Failed to open PTY: {}", e))?;
|
||||
|
||||
let mut cmd = portable_pty::CommandBuilder::new("bash");
|
||||
cmd.args(&["-c", command]);
|
||||
cmd.cwd(cwd);
|
||||
|
||||
let child = pty_pair
|
||||
.slave
|
||||
.spawn_command(cmd)
|
||||
.map_err(|e| format!("Failed to spawn: {}", e))?;
|
||||
|
||||
let pid = child.process_id().unwrap_or(0);
|
||||
|
||||
let writer = pty_pair
|
||||
.master
|
||||
.take_writer()
|
||||
.map_err(|e| format!("Failed to take writer: {}", e))?;
|
||||
|
||||
let reader = pty_pair
|
||||
.master
|
||||
.try_clone_reader()
|
||||
.map_err(|e| format!("Failed to clone reader: {}", e))?;
|
||||
|
||||
let session_id_clone = session_id.clone();
|
||||
let session = PtySession::new(session_id_clone, command.to_string(), child, writer);
|
||||
let session = Arc::new(Mutex::new(session));
|
||||
sessions.insert(session_id.clone(), session.clone());
|
||||
|
||||
let session_for_reader = session.clone();
|
||||
let child_for_reader = session_for_reader.lock().unwrap().child.clone();
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let mut reader = reader;
|
||||
let mut buf = [0u8; 4096];
|
||||
let mut partial = String::new();
|
||||
loop {
|
||||
use std::io::Read;
|
||||
match reader.read(&mut buf) {
|
||||
Ok(0) => break,
|
||||
Ok(n) => {
|
||||
let s = String::from_utf8_lossy(&buf[..n]);
|
||||
partial.push_str(&s);
|
||||
let lines: Vec<&str> = partial.split('\n').collect();
|
||||
if lines.len() <= 1 {
|
||||
continue;
|
||||
}
|
||||
let complete = lines.len() - 1;
|
||||
let mut guard = session_for_reader.lock().unwrap();
|
||||
for line in lines[..complete].iter() {
|
||||
guard.push_line(line.to_string());
|
||||
}
|
||||
partial = lines[complete].to_string();
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
let mut guard = session_for_reader.lock().unwrap();
|
||||
if !partial.is_empty() {
|
||||
guard.push_line(partial);
|
||||
}
|
||||
let exit_code = {
|
||||
let mut cg = child_for_reader.lock().unwrap();
|
||||
if let Some(ref mut c) = *cg {
|
||||
c.wait().map(|s| s.exit_code() as i32).ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
guard.status = SessionStatus::Exited(exit_code.unwrap_or(-1));
|
||||
});
|
||||
|
||||
Ok(format!("session_id: {}, pid: {}", session_id, pid))
|
||||
}
|
||||
|
||||
fn write(&self, session_id: &str, data: &str) -> Result<String, String> {
|
||||
let unescaped = unescape_control_chars(data);
|
||||
let byte_count = unescaped.len();
|
||||
|
||||
let sessions = self.sessions.lock().unwrap();
|
||||
let session = sessions
|
||||
.get(session_id)
|
||||
.ok_or_else(|| format!("Session not found: {}", session_id))?;
|
||||
let mut guard = session.lock().unwrap();
|
||||
if guard.status != SessionStatus::Running {
|
||||
return Err("Session is not running".to_string());
|
||||
}
|
||||
|
||||
let writer = guard.writer.clone();
|
||||
drop(guard);
|
||||
drop(sessions);
|
||||
|
||||
let mut writer_guard = writer.lock().unwrap();
|
||||
match *writer_guard {
|
||||
Some(ref mut w) => {
|
||||
w.write_all(unescaped.as_bytes())
|
||||
.map_err(|e| format!("Write error: {}", e))?;
|
||||
w.flush().map_err(|e| format!("Flush error: {}", e))?;
|
||||
}
|
||||
None => return Err("Writer not available".to_string()),
|
||||
}
|
||||
|
||||
Ok(format!("OK, wrote {} bytes", byte_count))
|
||||
}
|
||||
|
||||
fn read(
|
||||
&self,
|
||||
session_id: &str,
|
||||
offset: usize,
|
||||
limit: usize,
|
||||
) -> Result<String, String> {
|
||||
let sessions = self.sessions.lock().unwrap();
|
||||
let session = sessions
|
||||
.get(session_id)
|
||||
.ok_or_else(|| format!("Session not found: {}", session_id))?;
|
||||
let guard = session.lock().unwrap();
|
||||
|
||||
let total = guard.output_total_lines;
|
||||
let buffer_len = guard.output_buffer.len();
|
||||
let start = 0_usize.max(offset);
|
||||
let skip_old = total.saturating_sub(buffer_len);
|
||||
let view_start = start.saturating_sub(skip_old);
|
||||
|
||||
let lines: Vec<String> = guard
|
||||
.output_buffer
|
||||
.iter()
|
||||
.skip(view_start)
|
||||
.take(limit)
|
||||
.enumerate()
|
||||
.map(|(i, line)| format!("{}: {}", start + i, line))
|
||||
.collect();
|
||||
|
||||
let displayed = start + lines.len();
|
||||
let has_more = displayed < total;
|
||||
|
||||
let mut output = format!(
|
||||
"# Lines {}-{} (共 {} 行{})\n",
|
||||
start,
|
||||
displayed.saturating_sub(1),
|
||||
total,
|
||||
if has_more { ",还有更多" } else { "" }
|
||||
);
|
||||
output.push_str(&lines.join("\n"));
|
||||
if has_more {
|
||||
output.push_str(&format!(
|
||||
"\n[还有 {} 行未显示,用 offset={} 继续读取]",
|
||||
total.saturating_sub(displayed),
|
||||
displayed
|
||||
));
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn kill(&self, session_id: &str) -> Result<String, String> {
|
||||
let mut sessions = self.sessions.lock().unwrap();
|
||||
let session = sessions
|
||||
.get(session_id)
|
||||
.ok_or_else(|| format!("Session not found: {}", session_id))?;
|
||||
let mut guard = session.lock().unwrap();
|
||||
|
||||
let mut child_guard = guard.child.lock().unwrap();
|
||||
if let Some(ref mut child) = *child_guard {
|
||||
let _ = child.kill();
|
||||
let _ = child.wait();
|
||||
}
|
||||
*child_guard = None;
|
||||
guard.status = SessionStatus::Killed;
|
||||
drop(child_guard);
|
||||
drop(guard);
|
||||
sessions.remove(session_id);
|
||||
|
||||
Ok(format!("Session {} killed", session_id))
|
||||
}
|
||||
|
||||
fn list(&self) -> String {
|
||||
let sessions = self.sessions.lock().unwrap();
|
||||
if sessions.is_empty() {
|
||||
return "No active PTY sessions".to_string();
|
||||
}
|
||||
|
||||
let mut lines: Vec<String> = sessions
|
||||
.iter()
|
||||
.map(|(id, session)| {
|
||||
let guard = session.lock().unwrap();
|
||||
let age = Instant::now().duration_since(guard.started_at);
|
||||
let age_str = if age.as_secs() < 60 {
|
||||
format!("{}s ago", age.as_secs())
|
||||
} else {
|
||||
format!("{}m ago", age.as_secs() / 60)
|
||||
};
|
||||
format!(
|
||||
"{:<14} {:<12} {:<10} {:<6} lines {}",
|
||||
id,
|
||||
truncate_command(&guard.command, 10),
|
||||
status_str(&guard.status),
|
||||
guard.output_total_lines,
|
||||
age_str,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
lines.sort();
|
||||
lines.join("\n")
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for PtyManager {
|
||||
fn drop(&mut self) {
|
||||
self.cleanup_all();
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PtyTool {
|
||||
pty_manager: Arc<PtyManager>,
|
||||
}
|
||||
|
||||
impl PtyTool {
|
||||
pub fn new(pty_manager: Arc<PtyManager>) -> Self {
|
||||
Self { pty_manager }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for PtyTool {
|
||||
fn name(&self) -> &str {
|
||||
"pty"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"管理持久伪终端(PTY)会话。用于交互式程序、长运行服务、多步骤命令等需要保持终端状态的场景。支持操作: spawn(创建)/write(写入输入)/read(读取输出)/kill(终止)/list(列出所有会话)。"
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["spawn", "write", "read", "kill", "list"],
|
||||
"description": "操作类型: spawn=创建会话, write=写入输入, read=读取输出, kill=终止会话, list=列出所有会话"
|
||||
},
|
||||
"session_id": {
|
||||
"type": "string",
|
||||
"description": "会话ID (write/read/kill 需要)"
|
||||
},
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "要执行的命令 (spawn 需要)"
|
||||
},
|
||||
"data": {
|
||||
"type": "string",
|
||||
"description": "写入终端的数据,支持转义序列: \\n(换行) \\x03(Ctrl+C) \\x04(Ctrl+D) \\x1a(Ctrl+Z) (write 需要)"
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "输出读取起始行号,从 0 开始 (read 可选,默认 0)",
|
||||
"minimum": 0
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "读取的最大行数 (read 可选,默认 500)",
|
||||
"minimum": 1
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
})
|
||||
}
|
||||
|
||||
fn exclusive(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let action = match args.get("action").and_then(|v| v.as_str()) {
|
||||
Some(a) => a,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Missing required parameter: action".to_string()),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
match action {
|
||||
"spawn" => {
|
||||
let command = match args.get("command").and_then(|v| v.as_str()) {
|
||||
Some(c) => c,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Missing required parameter: command".to_string()),
|
||||
});
|
||||
}
|
||||
};
|
||||
match self.pty_manager.spawn(command) {
|
||||
Ok(output) => Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(e),
|
||||
}),
|
||||
}
|
||||
}
|
||||
"write" => {
|
||||
let session_id = match args.get("session_id").and_then(|v| v.as_str()) {
|
||||
Some(id) => id,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Missing required parameter: session_id".to_string()),
|
||||
});
|
||||
}
|
||||
};
|
||||
let data = match args.get("data").and_then(|v| v.as_str()) {
|
||||
Some(d) => d,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Missing required parameter: data".to_string()),
|
||||
});
|
||||
}
|
||||
};
|
||||
match self.pty_manager.write(session_id, data) {
|
||||
Ok(output) => Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(e),
|
||||
}),
|
||||
}
|
||||
}
|
||||
"read" => {
|
||||
let session_id = match args.get("session_id").and_then(|v| v.as_str()) {
|
||||
Some(id) => id,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Missing required parameter: session_id".to_string()),
|
||||
});
|
||||
}
|
||||
};
|
||||
let offset = args
|
||||
.get("offset")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0) as usize;
|
||||
let limit = args
|
||||
.get("limit")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(500) as usize;
|
||||
match self.pty_manager.read(session_id, offset, limit) {
|
||||
Ok(output) => Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(e),
|
||||
}),
|
||||
}
|
||||
}
|
||||
"kill" => {
|
||||
let session_id = match args.get("session_id").and_then(|v| v.as_str()) {
|
||||
Some(id) => id,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Missing required parameter: session_id".to_string()),
|
||||
});
|
||||
}
|
||||
};
|
||||
match self.pty_manager.kill(session_id) {
|
||||
Ok(output) => Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(e),
|
||||
}),
|
||||
}
|
||||
}
|
||||
"list" => {
|
||||
let output = self.pty_manager.list();
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
_ => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Unknown action: {}", action)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -17,7 +17,15 @@ impl ToolRegistry {
|
||||
}
|
||||
|
||||
pub fn register<T: ToolTrait + 'static>(&self, tool: T) {
|
||||
self.tools.lock().unwrap().insert(tool.name().to_string(), Arc::new(tool));
|
||||
self.tools
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(tool.name().to_string(), Arc::new(tool));
|
||||
}
|
||||
|
||||
/// Register an existing Arc-wrapped tool by name
|
||||
pub fn register_raw(&self, name: String, tool: Arc<dyn ToolTrait>) {
|
||||
self.tools.lock().unwrap().insert(name, tool);
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Option<Arc<dyn ToolTrait>> {
|
||||
@ -62,6 +70,17 @@ impl ToolRegistry {
|
||||
.map(|(k, v)| (k.clone(), v.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// 生成工具列表描述,用于子 Agent 系统提示词
|
||||
pub fn describe_for_prompt(&self) -> String {
|
||||
let mut entries: Vec<String> = self
|
||||
.iter()
|
||||
.into_iter()
|
||||
.map(|(name, tool)| format!("- {}: {}", name, tool.description()))
|
||||
.collect();
|
||||
entries.sort();
|
||||
entries.join("\n")
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ToolRegistry {
|
||||
|
||||
@ -115,9 +115,11 @@ impl SchemaCleanr {
|
||||
}
|
||||
|
||||
if let Some(Value::String(t)) = obj.get("type")
|
||||
&& t == "object" && !obj.contains_key("properties") {
|
||||
tracing::warn!("Object schema without 'properties' field may cause issues");
|
||||
}
|
||||
&& t == "object"
|
||||
&& !obj.contains_key("properties")
|
||||
{
|
||||
tracing::warn!("Object schema without 'properties' field may cause issues");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -173,9 +175,10 @@ impl SchemaCleanr {
|
||||
|
||||
// Handle anyOf/oneOf simplification
|
||||
if (obj.contains_key("anyOf") || obj.contains_key("oneOf"))
|
||||
&& let Some(simplified) = Self::try_simplify_union(&obj, defs, strategy, ref_stack) {
|
||||
return simplified;
|
||||
}
|
||||
&& let Some(simplified) = Self::try_simplify_union(&obj, defs, strategy, ref_stack)
|
||||
{
|
||||
return simplified;
|
||||
}
|
||||
|
||||
// Build cleaned object
|
||||
let mut cleaned = Map::new();
|
||||
@ -243,12 +246,13 @@ impl SchemaCleanr {
|
||||
}
|
||||
|
||||
if let Some(def_name) = Self::parse_local_ref(ref_value)
|
||||
&& let Some(definition) = defs.get(def_name.as_str()) {
|
||||
ref_stack.insert(ref_value.to_string());
|
||||
let cleaned = Self::clean_with_defs(definition.clone(), defs, strategy, ref_stack);
|
||||
ref_stack.remove(ref_value);
|
||||
return Self::preserve_meta(obj, cleaned);
|
||||
}
|
||||
&& let Some(definition) = defs.get(def_name.as_str())
|
||||
{
|
||||
ref_stack.insert(ref_value.to_string());
|
||||
let cleaned = Self::clean_with_defs(definition.clone(), defs, strategy, ref_stack);
|
||||
ref_stack.remove(ref_value);
|
||||
return Self::preserve_meta(obj, cleaned);
|
||||
}
|
||||
|
||||
tracing::warn!("Cannot resolve $ref: {}", ref_value);
|
||||
Self::preserve_meta(obj, Value::Object(Map::new()))
|
||||
@ -340,13 +344,16 @@ impl SchemaCleanr {
|
||||
return true;
|
||||
}
|
||||
if let Some(Value::Array(arr)) = obj.get("enum")
|
||||
&& arr.len() == 1 && matches!(arr[0], Value::Null) {
|
||||
return true;
|
||||
}
|
||||
&& arr.len() == 1
|
||||
&& matches!(arr[0], Value::Null)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
if let Some(Value::String(t)) = obj.get("type")
|
||||
&& t == "null" {
|
||||
return true;
|
||||
}
|
||||
&& t == "null"
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
@ -403,7 +410,10 @@ impl SchemaCleanr {
|
||||
|
||||
match non_null.len() {
|
||||
0 => Value::String("null".to_string()),
|
||||
1 => non_null.into_iter().next().unwrap_or(Value::String("null".to_string())),
|
||||
1 => non_null
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap_or(Value::String("null".to_string())),
|
||||
_ => Value::Array(non_null),
|
||||
}
|
||||
} else {
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
use std::sync::Arc;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use mime_guess::mime;
|
||||
@ -31,14 +31,20 @@ fn parse_target_chat_id(raw: &str) -> Result<(&str, &str, Option<&str>), String>
|
||||
match parts.len() {
|
||||
2 => {
|
||||
if parts[0].is_empty() || parts[1].is_empty() {
|
||||
Err(format!("Invalid target_chat_id format '{}': channel and chat_id must not be empty", raw))
|
||||
Err(format!(
|
||||
"Invalid target_chat_id format '{}': channel and chat_id must not be empty",
|
||||
raw
|
||||
))
|
||||
} else {
|
||||
Ok((parts[0], parts[1], None))
|
||||
}
|
||||
}
|
||||
3 => {
|
||||
if parts[0].is_empty() || parts[1].is_empty() || parts[2].is_empty() {
|
||||
Err(format!("Invalid target_chat_id format '{}': all three parts must not be empty", raw))
|
||||
Err(format!(
|
||||
"Invalid target_chat_id format '{}': all three parts must not be empty",
|
||||
raw
|
||||
))
|
||||
} else {
|
||||
Ok((parts[0], parts[1], Some(parts[2])))
|
||||
}
|
||||
@ -98,8 +104,8 @@ target_chat_id 支持两种格式:<channel>:<chat_id>(发送到该聊天下
|
||||
.ok_or_else(|| anyhow::anyhow!("missing content"))?;
|
||||
|
||||
// 1. Parse target_chat_id
|
||||
let (channel, chat_id, dialog_id) = parse_target_chat_id(raw_id)
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
let (channel, chat_id, dialog_id) =
|
||||
parse_target_chat_id(raw_id).map_err(|e| anyhow::anyhow!(e))?;
|
||||
|
||||
// 2. Validate channel
|
||||
if !self.available_channels.contains(channel) {
|
||||
@ -109,7 +115,11 @@ target_chat_id 支持两种格式:<channel>:<chat_id>(发送到该聊天下
|
||||
error: Some(format!(
|
||||
"Channel '{}' is not available. Available channels: {}",
|
||||
channel,
|
||||
self.available_channels.iter().cloned().collect::<Vec<_>>().join(", ")
|
||||
self.available_channels
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
)),
|
||||
});
|
||||
}
|
||||
@ -129,7 +139,8 @@ target_chat_id 支持两种格式:<channel>:<chat_id>(发送到该聊天下
|
||||
let media = parse_files_arg(&args);
|
||||
|
||||
// 4. Send via messenger
|
||||
match self.messenger
|
||||
match self
|
||||
.messenger
|
||||
.send_message(channel, chat_id, dialog_id, content, source, media)
|
||||
.await
|
||||
{
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
use async_trait::async_trait;
|
||||
use crate::bus::{MediaItem, MessageSource};
|
||||
use async_trait::async_trait;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ToolResult {
|
||||
|
||||
@ -239,7 +239,11 @@ fn is_private_host(host: &str) -> bool {
|
||||
return true;
|
||||
}
|
||||
|
||||
if host.rsplit('.').next().is_some_and(|label| label == "local") {
|
||||
if host
|
||||
.rsplit('.')
|
||||
.next()
|
||||
.is_some_and(|label| label == "local")
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -248,7 +252,9 @@ fn is_private_host(host: &str) -> bool {
|
||||
std::net::IpAddr::V4(v4) => {
|
||||
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified()
|
||||
}
|
||||
std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(),
|
||||
std::net::IpAddr::V6(v6) => {
|
||||
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
use std::collections::HashMap;
|
||||
use picobot::providers::{create_provider, ChatCompletionRequest, Message};
|
||||
use picobot::config::{Config, LLMProviderConfig};
|
||||
use picobot::providers::{ChatCompletionRequest, Message, create_provider};
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn load_config() -> Option<LLMProviderConfig> {
|
||||
dotenv::from_filename("tests/test.env").ok()?;
|
||||
@ -42,8 +42,7 @@ fn create_request(content: &str) -> ChatCompletionRequest {
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_openai_simple_completion() {
|
||||
let config = load_config()
|
||||
.expect("Please configure tests/test.env with valid API keys");
|
||||
let config = load_config().expect("Please configure tests/test.env with valid API keys");
|
||||
|
||||
let provider = create_provider(config).expect("Failed to create provider");
|
||||
let response = provider.chat(create_request("Say 'ok'")).await.unwrap();
|
||||
@ -57,8 +56,7 @@ async fn test_openai_simple_completion() {
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_openai_conversation() {
|
||||
let config = load_config()
|
||||
.expect("Please configure tests/test.env with valid API keys");
|
||||
let config = load_config().expect("Please configure tests/test.env with valid API keys");
|
||||
|
||||
let provider = create_provider(config).expect("Failed to create provider");
|
||||
|
||||
@ -82,7 +80,9 @@ async fn test_openai_conversation() {
|
||||
async fn test_config_load() {
|
||||
// Test that config.json can be loaded and provider config created
|
||||
let config = Config::load("config.json").expect("Failed to load config.json");
|
||||
let provider_config = config.get_provider_config("default").expect("Failed to get provider config");
|
||||
let provider_config = config
|
||||
.get_provider_config("default")
|
||||
.expect("Failed to get provider config");
|
||||
|
||||
assert_eq!(provider_config.provider_type, "openai");
|
||||
assert_eq!(provider_config.name, "aliyun");
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
use picobot::providers::{ChatCompletionRequest, Message};
|
||||
use picobot::protocol::{SessionSummary, WsInbound, WsOutbound};
|
||||
use picobot::providers::{ChatCompletionRequest, Message};
|
||||
|
||||
/// Test that message with special characters is properly escaped
|
||||
#[test]
|
||||
@ -19,7 +19,9 @@ fn test_message_special_characters() {
|
||||
#[test]
|
||||
fn test_multiline_system_prompt() {
|
||||
let messages = vec![
|
||||
Message::system("You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate"),
|
||||
Message::system(
|
||||
"You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate",
|
||||
),
|
||||
Message::user("Hi"),
|
||||
];
|
||||
|
||||
@ -33,10 +35,7 @@ fn test_multiline_system_prompt() {
|
||||
#[test]
|
||||
fn test_chat_request_serialization() {
|
||||
let request = ChatCompletionRequest {
|
||||
messages: vec![
|
||||
Message::system("You are helpful"),
|
||||
Message::user("Hello"),
|
||||
],
|
||||
messages: vec![Message::system("You are helpful"), Message::user("Hello")],
|
||||
temperature: Some(0.7),
|
||||
max_tokens: Some(100),
|
||||
tools: None,
|
||||
|
||||
@ -41,7 +41,7 @@ async fn test_scheduler_types_roundtrip() {
|
||||
/// Verify that next_run_for_schedule produces valid future timestamps.
|
||||
#[test]
|
||||
fn test_next_run_always_future() {
|
||||
use picobot::scheduler::{next_run_for_schedule, Schedule};
|
||||
use picobot::scheduler::{Schedule, next_run_for_schedule};
|
||||
|
||||
let now = 1700000000000_i64;
|
||||
|
||||
@ -56,6 +56,10 @@ fn test_next_run_always_future() {
|
||||
for s in &schedules {
|
||||
let next = next_run_for_schedule(s, now);
|
||||
assert!(next.is_some(), "expected next run for {:?}", s);
|
||||
assert!(next.unwrap() > now, "next run should be after now for {:?}", s);
|
||||
assert!(
|
||||
next.unwrap() > now,
|
||||
"next run should be after now for {:?}",
|
||||
s
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
use std::collections::HashMap;
|
||||
use picobot::providers::{create_provider, ChatCompletionRequest, Message, Tool, ToolFunction};
|
||||
use picobot::config::LLMProviderConfig;
|
||||
use picobot::providers::{ChatCompletionRequest, Message, Tool, ToolFunction, create_provider};
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn load_openai_config() -> Option<LLMProviderConfig> {
|
||||
dotenv::from_filename("tests/test.env").ok()?;
|
||||
@ -53,8 +53,7 @@ fn make_weather_tool() -> Tool {
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_openai_tool_call() {
|
||||
let config = load_openai_config()
|
||||
.expect("Please configure tests/test.env with valid API keys");
|
||||
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
|
||||
|
||||
let provider = create_provider(config).expect("Failed to create provider");
|
||||
|
||||
@ -68,7 +67,11 @@ async fn test_openai_tool_call() {
|
||||
let response = provider.chat(request).await.unwrap();
|
||||
|
||||
// Should have tool calls
|
||||
assert!(!response.tool_calls.is_empty(), "Expected tool call, got: {}", response.content);
|
||||
assert!(
|
||||
!response.tool_calls.is_empty(),
|
||||
"Expected tool call, got: {}",
|
||||
response.content
|
||||
);
|
||||
|
||||
let tool_call = &response.tool_calls[0];
|
||||
assert_eq!(tool_call.name, "get_weather");
|
||||
@ -78,8 +81,7 @@ async fn test_openai_tool_call() {
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_openai_tool_call_with_manual_execution() {
|
||||
let config = load_openai_config()
|
||||
.expect("Please configure tests/test.env with valid API keys");
|
||||
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
|
||||
|
||||
let provider = create_provider(config).expect("Failed to create provider");
|
||||
|
||||
@ -92,8 +94,7 @@ async fn test_openai_tool_call_with_manual_execution() {
|
||||
};
|
||||
|
||||
let response1 = provider.chat(request1).await.unwrap();
|
||||
let tool_call = response1.tool_calls.first()
|
||||
.expect("Expected tool call");
|
||||
let tool_call = response1.tool_calls.first().expect("Expected tool call");
|
||||
assert_eq!(tool_call.name, "get_weather");
|
||||
|
||||
// Second request with tool result
|
||||
@ -116,8 +117,7 @@ async fn test_openai_tool_call_with_manual_execution() {
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_openai_no_tool_when_not_provided() {
|
||||
let config = load_openai_config()
|
||||
.expect("Please configure tests/test.env with valid API keys");
|
||||
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
|
||||
|
||||
let provider = create_provider(config).expect("Failed to create provider");
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user