Compare commits
No commits in common. "8f4ee79d8d6f880824df1770c001a6554fe9d719" and "a77c02682605b6a901f821694293d18867868534" have entirely different histories.
8f4ee79d8d
...
a77c026826
@ -1,28 +0,0 @@
|
|||||||
# Git
|
|
||||||
.git
|
|
||||||
.gitignore
|
|
||||||
|
|
||||||
# Build artifacts
|
|
||||||
target/
|
|
||||||
!target/release/picobot
|
|
||||||
|
|
||||||
# IDE
|
|
||||||
.vscode/
|
|
||||||
.idea/
|
|
||||||
*.swp
|
|
||||||
*.swo
|
|
||||||
|
|
||||||
# Docs and references
|
|
||||||
docs/
|
|
||||||
reference/
|
|
||||||
|
|
||||||
# Test files
|
|
||||||
tests/
|
|
||||||
|
|
||||||
# Misc
|
|
||||||
*.md
|
|
||||||
*.txt
|
|
||||||
.opencode/
|
|
||||||
CLAUDE.md
|
|
||||||
AGENTS.md
|
|
||||||
ARCHITECTURE_REVIEW.md
|
|
||||||
2
.gitignore
vendored
2
.gitignore
vendored
@ -1,8 +1,6 @@
|
|||||||
/target
|
/target
|
||||||
docker_build/
|
|
||||||
reference/**
|
reference/**
|
||||||
.env
|
.env
|
||||||
*.env
|
*.env
|
||||||
Cargo.lock
|
Cargo.lock
|
||||||
.worktrees/
|
.worktrees/
|
||||||
design
|
|
||||||
|
|||||||
128
ARCHITECTURE_REVIEW.md
Normal file
128
ARCHITECTURE_REVIEW.md
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
# 架构审查报告
|
||||||
|
|
||||||
|
> 生成时间: 2026-04-26
|
||||||
|
> 更新时间: 2026-04-26
|
||||||
|
|
||||||
|
## 审查摘要
|
||||||
|
|
||||||
|
本报告识别了当前代码库中的架构不合理、冗余和无效代码的问题。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 问题清单
|
||||||
|
|
||||||
|
### 已修复
|
||||||
|
|
||||||
|
#### ✅ #1 OutboundDispatcher 重复维护 Channel 注册表
|
||||||
|
|
||||||
|
**修复方案**: `OutboundDispatcher` 现在从 `ChannelManager` 获取 channels,而不是自己维护一份注册表。
|
||||||
|
|
||||||
|
**修改文件**:
|
||||||
|
- `src/bus/dispatcher.rs` - 移除 `channels` 字段,改用 `ChannelManager`
|
||||||
|
- `src/channels/manager.rs` - 添加 `register_channel` 方法
|
||||||
|
- `src/gateway/mod.rs` - 简化 dispatcher 初始化
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
#### ✅ #2 CliChatChannel 持有独立的 SessionStore
|
||||||
|
|
||||||
|
**修复方案**: `CliChatChannel` 的 `SessionStore` 通过依赖注入从 `ChannelManager` 获取,而不是独立持有。
|
||||||
|
|
||||||
|
**修改文件**:
|
||||||
|
- `src/channels/cli_chat.rs` - 添加 `set_store()` 方法
|
||||||
|
- `src/channels/manager.rs` - 添加 `cli_chat_channel` 字段
|
||||||
|
- `src/gateway/mod.rs` - 重构 channel 初始化流程
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
#### ✅ #3 MessageBus 被创建两次引用
|
||||||
|
|
||||||
|
**修复方案**: 移除 `GatewayState.bus` 字段,直接使用 `channel_manager.bus()`。
|
||||||
|
|
||||||
|
**修改文件**:
|
||||||
|
- `src/gateway/mod.rs` - 移除冗余的 `bus` 字段
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
#### ✅ #4 GatewayState 同时持有 channel_manager 和 cli_chat_channel
|
||||||
|
|
||||||
|
**修复方案**: `cli_chat_channel` 只通过 `ChannelManager` 管理,`GatewayState` 不再单独持有。
|
||||||
|
|
||||||
|
**修改文件**:
|
||||||
|
- `src/gateway/mod.rs` - 移除 `cli_chat_channel` 字段,添加 `cli_chat_channel()` getter 方法
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 高优先级(待修复)
|
||||||
|
|
||||||
|
#### ❌ Session 每次重建都创建新的 LLM Provider
|
||||||
|
|
||||||
|
**文件**: `src/gateway/session.rs:349-361`
|
||||||
|
|
||||||
|
**问题**: 每当 session TTL 过期(默认4小时),就会销毁并重建 session,同时创建新的 LLM provider 连接。
|
||||||
|
|
||||||
|
**建议**: Provider 应该池化复用,不随 session 销毁而重建。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
#### ❌ CliChatChannel::send 广播给所有客户端
|
||||||
|
|
||||||
|
**文件**: `src/channels/cli_chat.rs:279-289`
|
||||||
|
|
||||||
|
**问题**: `OutboundMessage` 有 `chat_id` 字段用于路由,但实现广播给所有客户端,而不是只发给对应 chat_id 的客户端。
|
||||||
|
|
||||||
|
**建议**: 根据 `chat_id` 过滤客户端,只发送给对应的客户端。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 中优先级(待修复)
|
||||||
|
|
||||||
|
#### ❌ default_tools() 每次调用创建新 ToolRegistry
|
||||||
|
|
||||||
|
**文件**: `src/gateway/session.rs:212-227`
|
||||||
|
|
||||||
|
**建议**: 如果工具列表是只读的,直接 clone Arc;如果需要修改,需要澄清设计意图。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 低优先级(待修复)
|
||||||
|
|
||||||
|
#### ❌ FeishuChannel::new 接收未使用的 provider_config
|
||||||
|
|
||||||
|
**文件**: `src/channels/feishu.rs:175-178`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
#### ❌ OutboundDispatcher::send_with_retry 永不执行的 unreachable
|
||||||
|
|
||||||
|
**文件**: `src/bus/dispatcher.rs:81`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
#### ❌ Channel trait 的 `is_running` 使用 std::sync::Mutex
|
||||||
|
|
||||||
|
**文件**: `src/channels/base.rs:38` vs `src/channels/cli_chat.rs:265-267`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
#### ❌ LoopDetector 硬编码在 AgentLoop 中
|
||||||
|
|
||||||
|
**文件**: `src/agent/agent_loop.rs:88-172`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
#### ❌ InboundMessage 和 OutboundMessage 结构重复
|
||||||
|
|
||||||
|
**文件**: `src/bus/message.rs`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 问题统计
|
||||||
|
|
||||||
|
| 状态 | 优先级 | 数量 |
|
||||||
|
|------|--------|------|
|
||||||
|
| ✅ 已修复 | - | 4 |
|
||||||
|
| ❌ 待修复 | 高 | 2 |
|
||||||
|
| ❌ 待修复 | 中 | 1 |
|
||||||
|
| ❌ 待修复 | 低 | 5 |
|
||||||
|
| **总计** | - | **12** |
|
||||||
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "picobot"
|
name = "picobot"
|
||||||
version = "1.1.0"
|
version = "0.1.0"
|
||||||
edition = "2024"
|
edition = "2024"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
@ -12,8 +12,6 @@ serde_json = "1.0"
|
|||||||
async-trait = "0.1"
|
async-trait = "0.1"
|
||||||
thiserror = "2.0.18"
|
thiserror = "2.0.18"
|
||||||
tokio = { version = "1.52", features = ["full"] }
|
tokio = { version = "1.52", features = ["full"] }
|
||||||
tokio-util = { version = "0.7", features = ["rt"] }
|
|
||||||
dashmap = "6.1"
|
|
||||||
uuid = { version = "1.23", features = ["v4"] }
|
uuid = { version = "1.23", features = ["v4"] }
|
||||||
axum = { version = "0.8", features = ["ws"] }
|
axum = { version = "0.8", features = ["ws"] }
|
||||||
tokio-tungstenite = { version = "0.29.0", features = ["rustls-tls-webpki-roots", "rustls"] }
|
tokio-tungstenite = { version = "0.29.0", features = ["rustls-tls-webpki-roots", "rustls"] }
|
||||||
@ -51,7 +49,6 @@ encoding_rs = "0.8"
|
|||||||
zstd = "0.13"
|
zstd = "0.13"
|
||||||
tar = "0.4"
|
tar = "0.4"
|
||||||
fantoccini = { version = "0.22", default-features = false, features = ["rustls-tls"] }
|
fantoccini = { version = "0.22", default-features = false, features = ["rustls-tls"] }
|
||||||
portable-pty = "0.9"
|
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
zstd = "0.13"
|
zstd = "0.13"
|
||||||
|
|||||||
110
Dockerfile
110
Dockerfile
@ -1,110 +0,0 @@
|
|||||||
# =============================================================================
|
|
||||||
# PicoBot Docker Image
|
|
||||||
# =============================================================================
|
|
||||||
# Build binary on host:
|
|
||||||
# cargo build --release
|
|
||||||
#
|
|
||||||
# Build image:
|
|
||||||
# docker build -t picobot .
|
|
||||||
#
|
|
||||||
# Run gateway: docker run -d -v ~/.picobot:/app/.picobot -p 19876:19876 picobot gateway
|
|
||||||
# Run chat: docker run -it -v ~/.picobot:/app/.picobot picobot chat
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
FROM debian:trixie-slim
|
|
||||||
|
|
||||||
LABEL org.opencontainers.image.title="PicoBot"
|
|
||||||
LABEL org.opencontainers.image.description="AI agent gateway and chat client"
|
|
||||||
LABEL org.opencontainers.image.source="https://github.com/your-repo/picobot"
|
|
||||||
|
|
||||||
# Avoid interactive prompts
|
|
||||||
ENV DEBIAN_FRONTEND=noninteractive
|
|
||||||
|
|
||||||
# Configure domestic mirrors for pip, uv, npm (China)
|
|
||||||
ENV PIP_INDEX_URL=https://mirrors.aliyun.com/pypi/simple/
|
|
||||||
ENV UV_INDEX_URL=https://mirrors.aliyun.com/pypi/simple/
|
|
||||||
|
|
||||||
# Install base tools, Python, and uv in one layer to reduce duplication
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
|
||||||
ca-certificates \
|
|
||||||
tini \
|
|
||||||
curl \
|
|
||||||
gnupg \
|
|
||||||
git \
|
|
||||||
jq \
|
|
||||||
tree \
|
|
||||||
zip \
|
|
||||||
unzip \
|
|
||||||
sqlite3 \
|
|
||||||
openssh-client \
|
|
||||||
sshpass \
|
|
||||||
dnsutils \
|
|
||||||
poppler-utils \
|
|
||||||
fonts-wqy-zenhei \
|
|
||||||
fonts-wqy-microhei \
|
|
||||||
python3 \
|
|
||||||
python3-pip \
|
|
||||||
python3-venv \
|
|
||||||
&& rm -rf /var/lib/apt/lists/* \
|
|
||||||
&& pip3 install --no-cache-dir --break-system-packages uv
|
|
||||||
|
|
||||||
# Install Node.js and npx
|
|
||||||
RUN curl -fsSL https://deb.nodesource.com/setup_22.x | bash - \
|
|
||||||
&& apt-get install -y --no-install-recommends nodejs \
|
|
||||||
&& npm config set registry https://registry.npmmirror.com \
|
|
||||||
&& npm cache clean --force \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Install himalaya (CLI email client) from local file
|
|
||||||
COPY docker_build/himalaya.x86_64-linux.tgz /tmp/himalaya.tgz
|
|
||||||
RUN tar -xzf /tmp/himalaya.tgz -C /usr/local/bin \
|
|
||||||
&& chmod +x /usr/local/bin/himalaya \
|
|
||||||
&& rm -f /tmp/himalaya.tgz
|
|
||||||
|
|
||||||
# Install fd (alternative to find)
|
|
||||||
RUN curl -fsSL https://github.com/sharkdp/fd/releases/download/v9.0.0/fd-v9.0.0-x86_64-unknown-linux-gnu.tar.gz | \
|
|
||||||
tar -xz --strip-components=1 -C /usr/local/bin \
|
|
||||||
&& chmod +x /usr/local/bin/fd
|
|
||||||
|
|
||||||
# Install ripgrep (rg)
|
|
||||||
RUN curl -fsSL https://github.com/BurntSushi/ripgrep/releases/download/14.1.0/ripgrep-14.1.0-x86_64-unknown-linux-musl.tar.gz | \
|
|
||||||
tar -xz --strip-components=1 -C /usr/local/bin \
|
|
||||||
&& chmod +x /usr/local/bin/rg
|
|
||||||
|
|
||||||
# Install Chromium and chromedriver for browser automation
|
|
||||||
# Debian's chromium package is real (not a snap shim like Ubuntu 24.04)
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
|
||||||
chromium \
|
|
||||||
chromium-driver \
|
|
||||||
&& ln -sf /usr/bin/chromium /usr/local/bin/chrome \
|
|
||||||
&& ln -sf /usr/bin/chromedriver /usr/local/bin/chromedriver \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Create non-root user
|
|
||||||
RUN useradd -m -s /bin/bash app
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
# Copy pre-built binary from host
|
|
||||||
COPY target/release/picobot /app/picobot
|
|
||||||
|
|
||||||
# Copy config template
|
|
||||||
COPY resources/templates/config.example.json /app/config.json.example
|
|
||||||
|
|
||||||
# Create required directories
|
|
||||||
RUN mkdir -p /app/.picobot/workspace /app/.picobot/media /app/.picobot/tmp && \
|
|
||||||
chown -R app:app /app
|
|
||||||
|
|
||||||
USER app
|
|
||||||
ENV HOME=/app
|
|
||||||
|
|
||||||
# Environment variables for Chromium in containers
|
|
||||||
ENV CHROME_BIN=/usr/bin/chromium
|
|
||||||
ENV TMPDIR=/app/.picobot/tmp
|
|
||||||
|
|
||||||
ENTRYPOINT ["/app/picobot"]
|
|
||||||
CMD ["gateway"]
|
|
||||||
|
|
||||||
EXPOSE 19876
|
|
||||||
|
|
||||||
ENV RUST_LOG=info
|
|
||||||
501
README.md
501
README.md
@ -1,102 +1,143 @@
|
|||||||
# PicoBot
|
# PicoBot
|
||||||
|
|
||||||
PicoBot is a Rust-based personal AI assistant runtime. It runs a local gateway, connects chat channels such as the terminal TUI and Feishu/Lark, persists sessions in SQLite, and gives the agent a tool system for files, shell commands, web access, memory, scheduling, skills, MCP tools, and delegated sub-agents.
|
A multi-channel AI agent framework with a WebSocket gateway and TUI client, supporting OpenAI-compatible and Anthropic LLM providers, tool calling, session persistence, and cron-based scheduling.
|
||||||
|
|
||||||
## What It Does
|
## System Architecture
|
||||||
|
|
||||||
- Runs as a gateway server on `127.0.0.1:19876` by default.
|
```mermaid
|
||||||
- Provides a Ratatui terminal client over WebSocket.
|
graph TB
|
||||||
- Supports Feishu/Lark messages, reactions, file upload/download, and media references.
|
subgraph Clients
|
||||||
- Calls OpenAI-compatible providers and Anthropic Messages API providers.
|
TUI["🖥️ CLI Chat (TUI)"]
|
||||||
- Persists conversations, messages, memories, scheduled jobs, LLM call metadata, and background sub-agent tasks in SQLite.
|
FS["📱 Feishu/Lark"]
|
||||||
- Loads skills from workspace, user, and shared skill directories, with built-in skills installed on first use.
|
end
|
||||||
- Compresses long contexts and stores timeline summaries for later recall.
|
|
||||||
- Can register tools discovered from configured MCP servers.
|
|
||||||
|
|
||||||
## Architecture
|
subgraph Gateway["Gateway Server (127.0.0.1:19876)"]
|
||||||
|
HTTP["HTTP Endpoints<br/>GET /health<br/>GET /ws (WebSocket upgrade)"]
|
||||||
|
WS["WebSocket Handler"]
|
||||||
|
CD["ChannelManager"]
|
||||||
|
SP["SessionManager"]
|
||||||
|
AL["AgentLoop"]
|
||||||
|
end
|
||||||
|
|
||||||
```text
|
subgraph Bus["MessageBus"]
|
||||||
Channel -> MessageBus -> SessionManager -> AgentLoop -> LLM Provider
|
IB["Inbound Channel"]
|
||||||
| |
|
OB["Outbound Channel"]
|
||||||
| v
|
CC["Control Channel"]
|
||||||
| Tools
|
end
|
||||||
v
|
|
||||||
SQLite
|
|
||||||
|
|
||||||
Control messages -> SessionManager -> MessageBus -> OutboundDispatcher -> Channel
|
subgraph Storage
|
||||||
|
SQLite[("SQLite<br/>picobot.db")]
|
||||||
|
end
|
||||||
|
|
||||||
|
subgraph AI["AI Providers"]
|
||||||
|
OpenAI["OpenAI / DashScope"]
|
||||||
|
Anthropic["Anthropic Claude"]
|
||||||
|
end
|
||||||
|
|
||||||
|
TUI <-->|WebSocket| WS
|
||||||
|
FS <-->|Webhook| HTTP
|
||||||
|
|
||||||
|
CD -->|InboundMessage| IB
|
||||||
|
IB -->|DialogEvent| SP
|
||||||
|
CC -->|ControlMessage| SP
|
||||||
|
SP <--> AL
|
||||||
|
AL -->|API Call| OpenAI
|
||||||
|
AL -->|API Call| Anthropic
|
||||||
|
AL -->|Tool Call| Tools
|
||||||
|
SP -->|OutboundMessage| OB
|
||||||
|
OB --> CD
|
||||||
|
SP --> SQLite
|
||||||
|
Tools --> SQLite
|
||||||
|
|
||||||
|
subgraph Tools
|
||||||
|
Bash["Bash"]
|
||||||
|
FileIO["File Read/Write/Edit"]
|
||||||
|
Web["HTTP Request / Web Fetch"]
|
||||||
|
Calc["Calculator"]
|
||||||
|
Skill["Get Skill"]
|
||||||
|
Msg["Send Message"]
|
||||||
|
Cron["Cron Jobs"]
|
||||||
|
end
|
||||||
```
|
```
|
||||||
|
|
||||||
The main runtime boundary is:
|
### Core Data Flow
|
||||||
|
|
||||||
- `channels` only receive and send external messages.
|
```mermaid
|
||||||
- `bus` is an async queue, not a router.
|
sequenceDiagram
|
||||||
- `session` owns dialog lifecycle, persistence, memory recall, prompt assembly, compression, and task cancellation.
|
participant Channel as Channel<br/>(CLI/Feishu)
|
||||||
- `agent` runs the stateless LLM/tool loop.
|
participant Bus as MessageBus
|
||||||
- `providers` are HTTP clients for model APIs.
|
participant SM as SessionManager
|
||||||
- `tools` execute agent actions and return string results.
|
participant AL as AgentLoop
|
||||||
- `storage` owns SQLite schema and CRUD.
|
participant LLM as LLM Provider
|
||||||
- `scheduler` polls due jobs and feeds prompts back into sessions.
|
participant Tool as Tools
|
||||||
|
|
||||||
|
Channel->>Bus: InboundMessage (user input)
|
||||||
|
Bus->>SM: DialogEvent
|
||||||
|
SM->>SM: Load/Resolve Session
|
||||||
|
SM->>AL: Process (session state)
|
||||||
|
AL->>LLM: ChatCompletionRequest
|
||||||
|
LLM-->>AL: response / tool_calls
|
||||||
|
alt Tool Calls
|
||||||
|
AL->>Tool: execute tool
|
||||||
|
Tool-->>AL: result
|
||||||
|
AL->>LLM: continue with tool result
|
||||||
|
end
|
||||||
|
AL-->>SM: AgentProcessResult (text + token count)
|
||||||
|
SM->>SM: Persist to SQLite
|
||||||
|
SM->>Bus: OutboundMessage
|
||||||
|
Bus->>Channel: response to user
|
||||||
|
```
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
### Channels
|
### Multi-Channel Support
|
||||||
|
- **CLI Chat Client** — Full TUI with session management, Markdown rendering, slash commands
|
||||||
|
- **Feishu (Lark)** — Webhook-based integration with typing indicators and media support
|
||||||
|
|
||||||
- `cli_chat`: terminal TUI client connected through `/ws`.
|
### Multi-Provider LLM
|
||||||
- `feishu`: Feishu/Lark channel with configurable allow list, media directory, and reaction emoji.
|
- OpenAI-compatible API (GPT-4, DashScope, Volcengine, etc.)
|
||||||
|
- Anthropic Messages API (Claude)
|
||||||
|
- Cross-provider JSON Schema normalization for tool calling compatibility
|
||||||
|
|
||||||
### LLM Providers
|
### Session Management
|
||||||
|
- Multi-session conversations per channel/chat
|
||||||
|
- Create, switch, rename, archive, delete dialogs via slash commands or WebSocket
|
||||||
|
- SQLite-persisted session history with automatic TTL-based cleanup
|
||||||
|
- Context compression for long conversations approaching token limits
|
||||||
|
|
||||||
- OpenAI-compatible chat completions, including DashScope, Volcengine, and similar APIs.
|
### Tool System
|
||||||
- Anthropic Messages API.
|
| Tool | Description |
|
||||||
- Model-specific `input_type` metadata for text/image capability checks.
|
|------|-------------|
|
||||||
- JSON Schema cleanup for cross-provider tool compatibility.
|
| `bash` | Execute shell commands in workspace |
|
||||||
|
| `file_read` | Read file contents |
|
||||||
|
| `file_write` | Create/overwrite files |
|
||||||
|
| `file_edit` | Precise string substitution in files |
|
||||||
|
| `http_request` | Make HTTP API requests |
|
||||||
|
| `web_fetch` | Fetch and parse web pages |
|
||||||
|
| `calculator` | Evaluate mathematical expressions |
|
||||||
|
| `get_skill` | Load agent skills from local skill files |
|
||||||
|
| `send_message` | Send messages to other channels |
|
||||||
|
| `cron_add/list/remove/enable/disable/update` | Manage scheduled jobs |
|
||||||
|
|
||||||
### Sessions And Memory
|
### Scheduling
|
||||||
|
- Cron-based recurring jobs with optional timezone support
|
||||||
|
- One-shot (`at`) and interval (`every`) schedules
|
||||||
|
- Jobs trigger agent processing via specified channel/chat
|
||||||
|
|
||||||
- Session IDs use `<channel>:<chat_id>:<dialog_id>`.
|
### Skills System
|
||||||
- Each channel/chat can have multiple dialogs.
|
- Load Markdown skill files from `~/.picobot/skills` and `~/.agents/skills`
|
||||||
- Dialog operations include create, list, switch, rename, delete, compact, dump, info, and stop.
|
- Skills inject specialized system prompts for specific tasks
|
||||||
- Session history is persisted to SQLite and can be incrementally restored after compression.
|
- Automatic hot-reload on file changes
|
||||||
- Knowledge memories are recalled into the system prompt each turn.
|
|
||||||
- Timeline memories are produced by context compression and can be searched later.
|
|
||||||
|
|
||||||
### Tools
|
### Observability
|
||||||
|
- Observer pattern for agent and tool telemetry
|
||||||
Base tools registered for the agent:
|
- Events: `AgentStart`, `AgentEnd`, `ToolCallStart`, `ToolCall`
|
||||||
|
- Structured JSON logging with file rotation
|
||||||
| Tool | Purpose |
|
|
||||||
|------|---------|
|
|
||||||
| `calculator` | Math expressions and statistics |
|
|
||||||
| `file_read` / `file_write` / `file_edit` | Workspace file operations |
|
|
||||||
| `file_search` / `content_search` | File and content search |
|
|
||||||
| `bash` | Run shell commands in the workspace |
|
|
||||||
| `http_request` | HTTP API requests |
|
|
||||||
| `web_fetch` | Fetch and extract web page text |
|
|
||||||
| `get_skill` | List or load local skills |
|
|
||||||
| `memory_store` / `memory_recall` / `timeline_recall` / `memory_forget` | Long-term memory operations |
|
|
||||||
| `delegate` | Run inline, background, or parallel sub-agents |
|
|
||||||
| `send_message` | Send outbound messages to configured channels |
|
|
||||||
| `chat_manager` | Inspect sessions, channels, and stored messages |
|
|
||||||
| `cron_add/list/remove/enable/disable/update` | Manage scheduled jobs when scheduler is enabled |
|
|
||||||
| `browser` | Optional WebDriver browser automation when enabled |
|
|
||||||
| MCP tools | Dynamically registered from configured MCP servers |
|
|
||||||
|
|
||||||
### Skills
|
|
||||||
|
|
||||||
Skills are directories containing `SKILL.md`. Load priority is:
|
|
||||||
|
|
||||||
1. `{workspace}/skills`
|
|
||||||
2. `~/.picobot/skills`
|
|
||||||
3. `~/.agents/skills`
|
|
||||||
|
|
||||||
Same-name skills in higher-priority locations override lower-priority ones. Built-in skills from `resources/skills` are embedded into the binary and installed into `~/.picobot/skills` if missing.
|
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
### Prerequisites
|
### Prerequisites
|
||||||
|
- Rust nightly (edition 2024) — use `rustup` to install
|
||||||
- Rust toolchain with edition 2024 support.
|
|
||||||
- A configured LLM provider API key.
|
|
||||||
|
|
||||||
### Build
|
### Build
|
||||||
|
|
||||||
@ -106,9 +147,7 @@ cargo build
|
|||||||
|
|
||||||
### Configure
|
### Configure
|
||||||
|
|
||||||
PicoBot loads `~/.picobot/config.json` first, then falls back to `./config.json`. On gateway startup, a template is released to `~/.picobot/config.example.json` if it does not exist. The source template is [resources/templates/config.example.json](/home/xiaoxixi/code/PicoBot/resources/templates/config.example.json).
|
1. Create `config.json` (or `~/.picobot/config.json`):
|
||||||
|
|
||||||
Minimal example:
|
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
@ -116,16 +155,14 @@ Minimal example:
|
|||||||
"openai": {
|
"openai": {
|
||||||
"type": "openai",
|
"type": "openai",
|
||||||
"base_url": "https://api.openai.com/v1",
|
"base_url": "https://api.openai.com/v1",
|
||||||
"api_key": "<OPENAI_API_KEY>",
|
"api_key": "<OPENAI_API_KEY>"
|
||||||
"extra_headers": {}
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"models": {
|
"models": {
|
||||||
"gpt-4o": {
|
"gpt-4o": {
|
||||||
"model_id": "gpt-4o",
|
"model_id": "gpt-4o",
|
||||||
"temperature": 0.7,
|
"temperature": 0.7,
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096
|
||||||
"input_type": ["text", "image"]
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"agents": {
|
"agents": {
|
||||||
@ -135,157 +172,251 @@ Minimal example:
|
|||||||
"max_tool_iterations": 99,
|
"max_tool_iterations": 99,
|
||||||
"token_limit": 128000
|
"token_limit": 128000
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
"workspace_dir": "~/.picobot/workspace"
|
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
The `.env` file in the current directory is parsed by PicoBot itself. Values like `<OPENAI_API_KEY>` in JSON are replaced from the process environment after `.env` is loaded.
|
2. Set API keys via `.env` file (one `KEY=VALUE` per line):
|
||||||
|
|
||||||
|
```env
|
||||||
|
OPENAI_API_KEY=sk-xxxxx
|
||||||
|
```
|
||||||
|
|
||||||
### Run
|
### Run
|
||||||
|
|
||||||
|
**Start gateway server:**
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cargo run -- gateway
|
cargo run -- gateway
|
||||||
```
|
```
|
||||||
|
|
||||||
The gateway switches the process working directory to `workspace_dir` and stores `picobot.db` there by default.
|
Binds `127.0.0.1:19876` by default. Override with `--host` and `--port`.
|
||||||
|
|
||||||
In another terminal:
|
**Connect CLI client:**
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cargo run -- chat
|
cargo run -- chat
|
||||||
```
|
```
|
||||||
|
|
||||||
The client connects to `ws://127.0.0.1:19876/ws` by default. Override with `--gateway-url`.
|
Connects to `ws://127.0.0.1:19876/ws`. Override with `--gateway-url`.
|
||||||
|
|
||||||
## Configuration
|
## Configuration Reference
|
||||||
|
|
||||||
Top-level config fields:
|
Config load order: `~/.picobot/config.json` → `./config.json` (fallback).
|
||||||
|
|
||||||
| Field | Purpose |
|
### Full Config Structure
|
||||||
|-------|---------|
|
|
||||||
| `providers` | Named LLM provider configs |
|
|
||||||
| `models` | Named model configs |
|
|
||||||
| `agents` | Agent-to-provider/model binding |
|
|
||||||
| `gateway` | Bind address, session DB path, cleanup, scheduler, background task limits |
|
|
||||||
| `client` | Default WebSocket URL for the TUI client |
|
|
||||||
| `channels` | Channel configs, currently Feishu/Lark |
|
|
||||||
| `memory` | Recall and consolidation settings |
|
|
||||||
| `mcp` | MCP server configs |
|
|
||||||
| `browser` | Optional WebDriver browser tool config |
|
|
||||||
| `workspace_dir` | Workspace used for file tools, shell commands, DB default, and workspace skills |
|
|
||||||
|
|
||||||
Important defaults:
|
```mermaid
|
||||||
|
graph LR
|
||||||
|
Config["config.json"]
|
||||||
|
Config --> Providers["providers<br/>ProviderConfig{}"]
|
||||||
|
Config --> Models["models<br/>ModelConfig{}"]
|
||||||
|
Config --> Agents["agents<br/>AgentConfig{}"]
|
||||||
|
Config --> Gateway["gateway<br/>GatewayConfig"]
|
||||||
|
Config --> Client["client<br/>ClientConfig"]
|
||||||
|
Config --> Channels["channels<br/>ChannelConfig{}"]
|
||||||
|
Config --> Workspace["workspace_dir"]
|
||||||
|
|
||||||
| Key | Default |
|
Providers --> PT["type (openai / anthropic)<br/>base_url<br/>api_key<br/>extra_headers"]
|
||||||
|-----|---------|
|
Models --> MT["model_id<br/>temperature<br/>max_tokens"]
|
||||||
| `gateway.host` | `127.0.0.1` |
|
Agents --> AT["provider (ref)<br/>model (ref)<br/>max_tool_iterations<br/>token_limit"]
|
||||||
| `gateway.port` | `19876` |
|
Gateway --> GT["host / port<br/>session_db_path<br/>scheduler"]
|
||||||
| `gateway.max_concurrent_background_tasks` | `10` |
|
Channels --> CT["feishu: app_id, app_secret<br/>allow_from, agent, media_dir"]
|
||||||
| `gateway.scheduler.enabled` | `true` if `scheduler` is omitted and defaulted |
|
```
|
||||||
| `client.gateway_url` | `ws://127.0.0.1:19876/ws` |
|
|
||||||
| `memory.recall_limit` | `5` |
|
|
||||||
| `memory.timeline_retention_days` | `90` |
|
|
||||||
| `mcp.tool_timeout_secs` | `180` |
|
|
||||||
| `browser.enabled` | `false` |
|
|
||||||
|
|
||||||
MCP servers support `stdio`, `sse`, and `streamable-http` transports. Browser automation requires a compatible Chrome/Chromium and chromedriver/WebDriver endpoint.
|
### Environment Variables
|
||||||
|
|
||||||
|
The `.env` file in the working directory is loaded manually (not via dotenv crate). Placeholders in `config.json` written as `<VAR_NAME>` are substituted at load time.
|
||||||
|
|
||||||
|
### Gateway Config
|
||||||
|
|
||||||
|
| Key | Type | Default | Description |
|
||||||
|
|-----|------|---------|-------------|
|
||||||
|
| `host` | string | `127.0.0.1` | Bind address |
|
||||||
|
| `port` | u16 | `19876` | Listen port |
|
||||||
|
| `session_db_path` | string | workspace `picobot.db` | SQLite database path |
|
||||||
|
| `scheduler.enabled` | bool | `false` | Enable cron scheduler |
|
||||||
|
|
||||||
|
### Agent Config
|
||||||
|
|
||||||
|
| Key | Type | Default | Description |
|
||||||
|
|-----|------|---------|-------------|
|
||||||
|
| `provider` | string | — | Provider name (key in `providers`) |
|
||||||
|
| `model` | string | — | Model name (key in `models`) |
|
||||||
|
| `max_tool_iterations` | number | `99` | Max tool call iterations per turn |
|
||||||
|
| `token_limit` | number | `128000` | Context window token limit |
|
||||||
|
|
||||||
## Slash Commands
|
## Slash Commands
|
||||||
|
|
||||||
Available from CLI chat and channel text messages:
|
Available in CLI chat and Feishu:
|
||||||
|
|
||||||
| Command | Description |
|
| Command | Alias | Description |
|
||||||
|---------|-------------|
|
|---------|-------|-------------|
|
||||||
| `/new` | Create a new dialog |
|
| `/new` | `/刷新` | Create a new dialog |
|
||||||
| `/sessions` | List recent dialogs |
|
| `/list` | `/对话列表` | List all dialogs |
|
||||||
| `/switch <dialog_id>` | Switch dialog |
|
| `/switch <id>` | — | Switch to a dialog |
|
||||||
| `/rename <title>` | Rename current dialog |
|
| `/rename <title>` | — | Rename current dialog |
|
||||||
| `/delete` | Delete current dialog |
|
| `/archive` | — | Archive current dialog |
|
||||||
| `/compact` | Manually trigger context compression |
|
| `/delete` | — | Delete current dialog |
|
||||||
| `/info` | Show current dialog information |
|
| `/clear` | `/清空` | Clear current dialog history |
|
||||||
| `/dump` | Save current dialog as Markdown |
|
|
||||||
| `/?`, `/help` | Show help |
|
|
||||||
| `/mcp` | Show MCP server and tool status |
|
|
||||||
| `/stop` | Stop active tasks and clear queued messages |
|
|
||||||
|
|
||||||
## WebSocket API
|
## WebSocket Protocol
|
||||||
|
|
||||||
The gateway exposes:
|
The gateway exposes a WebSocket endpoint at `/ws`. Messages use typed JSON with a `type` discriminator field.
|
||||||
|
|
||||||
|
### Client → Server (WsInbound)
|
||||||
|
|
||||||
|
| Type | Fields |
|
||||||
|
|------|--------|
|
||||||
|
| `user_input` | `content`, `channel?`, `chat_id?`, `sender_id?` |
|
||||||
|
| `create_session` | `title?` |
|
||||||
|
| `list_sessions` | `include_archived` |
|
||||||
|
| `load_session` | `session_id` |
|
||||||
|
| `rename_session` | `session_id?`, `title` |
|
||||||
|
| `archive_session` | `session_id?` |
|
||||||
|
| `delete_session` | `session_id?` |
|
||||||
|
| `clear_history` | `chat_id?`, `session_id?` |
|
||||||
|
| `get_slash_commands` | — |
|
||||||
|
| `ping` | — |
|
||||||
|
|
||||||
|
### Server → Client (WsOutbound)
|
||||||
|
|
||||||
|
| Type | Fields |
|
||||||
|
|------|--------|
|
||||||
|
| `assistant_response` | `session_id`, `response`, `tokens_used?`, `tool_calls?` |
|
||||||
|
| `session_list` | `sessions[]` |
|
||||||
|
| `session_loaded` | `session_id`, `messages[]` |
|
||||||
|
| `session_created` | `session_id`, `title` |
|
||||||
|
| `session_renamed` | `session_id`, `title` |
|
||||||
|
| `session_archived` | `session_id` |
|
||||||
|
| `session_deleted` | `session_id` |
|
||||||
|
| `slash_commands` | `commands[]` |
|
||||||
|
| `error` | `message` |
|
||||||
|
| `pong` | — |
|
||||||
|
|
||||||
|
## HTTP Endpoints
|
||||||
|
|
||||||
| Method | Path | Description |
|
| Method | Path | Description |
|
||||||
|--------|------|-------------|
|
|--------|------|-------------|
|
||||||
| `GET` | `/health` | Returns service health and version |
|
| `GET` | `/health` | Health check — returns `{"status":"ok","version":"x.y.z"}` |
|
||||||
| `GET` | `/ws` | WebSocket upgrade for chat clients |
|
| `GET` | `/ws` | WebSocket upgrade for chat clients |
|
||||||
|
|
||||||
Inbound WebSocket message types:
|
|
||||||
|
|
||||||
| Type | Main fields |
|
|
||||||
|------|-------------|
|
|
||||||
| `user_input` | `content`, optional `channel`, `chat_id`, `sender_id` |
|
|
||||||
| `clear_history` | optional `chat_id`, `session_id` |
|
|
||||||
| `create_session` | optional `title` |
|
|
||||||
| `list_sessions` | `include_archived` |
|
|
||||||
| `load_session` | `session_id` |
|
|
||||||
| `rename_session` | optional `session_id`, `title` |
|
|
||||||
| `archive_session` | optional `session_id` |
|
|
||||||
| `delete_session` | optional `session_id` |
|
|
||||||
| `get_slash_commands` | none |
|
|
||||||
| `ping` | none |
|
|
||||||
|
|
||||||
Outbound WebSocket message types include `assistant_response`, `error`, `session_established`, `session_created`, `session_list`, `session_loaded`, `session_renamed`, `session_archived`, `session_deleted`, `history_cleared`, `slash_commands_list`, `pong`, `command_executed`, and `system_notification`.
|
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Unit tests
|
# Unit tests (no external dependencies)
|
||||||
cargo test --lib
|
cargo test --lib
|
||||||
|
|
||||||
# Integration tests require real API keys in tests/test.env
|
# Integration tests (require API keys)
|
||||||
cp tests/test.env.example tests/test.env
|
cp tests/test.env.example tests/test.env
|
||||||
|
# Fill in your API keys in tests/test.env
|
||||||
cargo test --test test_integration -- --ignored
|
cargo test --test test_integration -- --ignored
|
||||||
cargo test --test test_tool_calling -- --ignored
|
cargo test --test test_tool_calling -- --ignored
|
||||||
cargo test --test test_request_format -- --ignored
|
cargo test --test test_request_format -- --ignored
|
||||||
|
|
||||||
|
# Run all tests
|
||||||
|
cargo test -- --ignored
|
||||||
```
|
```
|
||||||
|
|
||||||
Integration tests are ignored by default because they make real provider calls.
|
Integration tests are `#[ignore]` by default because they make real API calls.
|
||||||
|
|
||||||
## Project Layout
|
## Project Structure
|
||||||
|
|
||||||
```text
|
```
|
||||||
src/
|
├── src/
|
||||||
agent/ LLM loop, context compression, system prompts, media handling, sub-agents
|
│ ├── main.rs # CLI entrypoint (clap-based subcommands)
|
||||||
bus/ Inbound, outbound, and control message queues
|
│ ├── lib.rs # Module declarations
|
||||||
channels/ CLI chat and Feishu/Lark integrations
|
│ ├── gateway/ # HTTP/WS server, GatewayState initialization
|
||||||
client/ Ratatui terminal UI
|
│ │ ├── mod.rs
|
||||||
config/ Config loading, env substitution, path expansion
|
│ │ ├── http.rs # Health endpoint
|
||||||
gateway/ Axum HTTP/WebSocket server and GatewayState wiring
|
│ │ └── ws.rs # WebSocket handler
|
||||||
mcp/ MCP client connections and tool wrappers
|
│ ├── client/ # TUI chat client
|
||||||
memory/ Memory manager and memory types
|
│ │ ├── mod.rs
|
||||||
observability/ Agent/tool telemetry observer interfaces
|
│ │ └── tui/ # Ratatui-based terminal UI
|
||||||
providers/ OpenAI-compatible and Anthropic clients
|
│ ├── channels/ # Channel integrations
|
||||||
scheduler/ Scheduled job runtime
|
│ │ ├── base.rs # Channel trait
|
||||||
session/ Session lifecycle, dialog commands, persistence integration
|
│ │ ├── cli_chat.rs # CLI WebSocket channel
|
||||||
skills/ Skill loading and embedded built-in skill installation
|
│ │ ├── feishu.rs # Feishu/Lark webhook channel
|
||||||
storage/ SQLite schema and CRUD
|
│ │ ├── manager.rs # ChannelManager
|
||||||
tools/ Agent tool implementations
|
│ │ └── slash_command.rs # Slash command parser
|
||||||
resources/
|
│ ├── bus/ # Async message bus
|
||||||
skills/ Built-in skills embedded at build time
|
│ │ ├── mod.rs # MessageBus (tokio mpsc channels)
|
||||||
templates/ Config, AGENTS.md, and USER.md templates released on first run
|
│ │ ├── message.rs # Message types
|
||||||
tests/ Unit and ignored integration tests
|
│ │ └── dispatcher.rs # OutboundDispatcher
|
||||||
reference/ Third-party reference code; do not modify as project source
|
│ ├── session/ # Session & dialog management
|
||||||
|
│ │ ├── mod.rs
|
||||||
|
│ │ ├── session.rs # Session, SessionManager
|
||||||
|
│ │ ├── session_id.rs # UnifiedSessionId
|
||||||
|
│ │ ├── commands.rs # SessionCommand enum
|
||||||
|
│ │ └── events.rs # SessionEvent, DialogInfo
|
||||||
|
│ ├── agent/ # LLM interaction loop
|
||||||
|
│ │ ├── mod.rs
|
||||||
|
│ │ ├── agent_loop.rs # AgentLoop (stateless)
|
||||||
|
│ │ ├── context_compressor.rs # Token estimation & summarization
|
||||||
|
│ │ └── system_prompt.rs # System prompt builder
|
||||||
|
│ ├── providers/ # LLM API clients
|
||||||
|
│ │ ├── mod.rs # Factory: create_provider()
|
||||||
|
│ │ ├── traits.rs # LLMProvider trait
|
||||||
|
│ │ ├── openai.rs # OpenAI-compatible client
|
||||||
|
│ │ └── anthropic.rs # Anthropic Messages API client
|
||||||
|
│ ├── tools/ # Agent tools
|
||||||
|
│ │ ├── mod.rs # create_default_tools()
|
||||||
|
│ │ ├── registry.rs # ToolRegistry
|
||||||
|
│ │ ├── traits.rs # Tool trait, ToolResult
|
||||||
|
│ │ ├── schema.rs # Cross-provider JSON Schema cleaner
|
||||||
|
│ │ ├── bash.rs # Shell command execution
|
||||||
|
│ │ ├── calculator.rs # Math expression evaluator
|
||||||
|
│ │ ├── chat_manager.rs # Session management tool
|
||||||
|
│ │ ├── cron.rs # Cron job management tools
|
||||||
|
│ │ ├── file_read.rs # File reader
|
||||||
|
│ │ ├── file_write.rs # File writer
|
||||||
|
│ │ ├── file_edit.rs # File editor (string substitution)
|
||||||
|
│ │ ├── get_skill.rs # Skill loader tool
|
||||||
|
│ │ ├── http_request.rs # HTTP request tool
|
||||||
|
│ │ ├── send_message.rs # Cross-channel messaging
|
||||||
|
│ │ └── web_fetch.rs # Web page fetcher
|
||||||
|
│ ├── skills/ # Skills loading from markdown files
|
||||||
|
│ │ └── mod.rs # SkillsLoader, Skill
|
||||||
|
│ ├── storage/ # SQLite persistence
|
||||||
|
│ │ ├── mod.rs # Storage, schema init
|
||||||
|
│ │ ├── session.rs # Session CRUD operations
|
||||||
|
│ │ ├── message.rs # Message persistence
|
||||||
|
│ │ ├── scheduler.rs # ScheduledJob, JobRun storage
|
||||||
|
│ │ └── error.rs # StorageError
|
||||||
|
│ ├── scheduler/ # Cron scheduler runtime
|
||||||
|
│ │ ├── mod.rs # Scheduler, next_run_for_schedule()
|
||||||
|
│ │ └── types.rs # Schedule enum (At/Every/Cron)
|
||||||
|
│ ├── observability/ # Telemetry observer pattern
|
||||||
|
│ │ └── mod.rs # Observer trait, ObserverEvent, MultiObserver
|
||||||
|
│ ├── protocol.rs # WebSocket message types (WsInbound/WsOutbound)
|
||||||
|
│ ├── config/ # Config loading & env substitution
|
||||||
|
│ │ └── mod.rs # Config, LLMProviderConfig, load_env_file()
|
||||||
|
│ └── logging.rs # Tracing subscriber init with file rotation
|
||||||
|
├── tests/
|
||||||
|
│ ├── test_integration.rs # LLM provider integration tests
|
||||||
|
│ ├── test_tool_calling.rs # Tool calling integration tests
|
||||||
|
│ ├── test_request_format.rs # Request format tests
|
||||||
|
│ ├── test_scheduler.rs # Scheduler unit tests
|
||||||
|
│ ├── test.env.example # Test environment template
|
||||||
|
│ └── test.env # Actual test keys (gitignored)
|
||||||
|
├── reference/ # Third-party reference code (do not modify)
|
||||||
|
├── resources/ # Assets embedded in binary
|
||||||
|
│ └── templates/ # Templates released to ~/.picobot/ on first run
|
||||||
|
├── config.example.json # Full config example
|
||||||
|
└── Cargo.toml
|
||||||
```
|
```
|
||||||
|
|
||||||
## Key Dependencies
|
## Key Dependencies
|
||||||
|
|
||||||
| Crate | Purpose |
|
| Crate | Purpose |
|
||||||
|-------|---------|
|
|-------|---------|
|
||||||
| `axum`, `tokio`, `tokio-tungstenite` | Gateway and WebSocket runtime |
|
| `axum` + `tokio-tungstenite` | HTTP server & WebSocket |
|
||||||
| `sqlx` | SQLite persistence |
|
| `sqlx` (SQLite) | Session/Message/Job persistence |
|
||||||
| `reqwest` | LLM and HTTP clients |
|
| `reqwest` (rustls) | LLM API & external HTTP calls |
|
||||||
| `ratatui`, `crossterm`, `termimad` | Terminal UI |
|
| `ratatui` + `crossterm` | Terminal UI |
|
||||||
| `rmcp` | MCP client support |
|
| `clap` | CLI argument parsing |
|
||||||
| `fantoccini` | Optional browser automation |
|
| `tracing` + `tracing-subscriber` | Structured logging |
|
||||||
| `cron`, `chrono-tz` | Scheduling |
|
| `cron` + `chrono-tz` | Cron schedule parsing |
|
||||||
| `jieba-rs` | Chinese tokenization for memory search |
|
| `meval` | Mathematical expression evaluation |
|
||||||
| `zstd`, `tar` | Embedded built-in skill packaging |
|
| `uuid` | Session/Dialog ID generation |
|
||||||
|
| `dirs` | Platform config directory resolution |
|
||||||
|
|||||||
@ -1,16 +0,0 @@
|
|||||||
services:
|
|
||||||
picobot:
|
|
||||||
image: picobot:latest
|
|
||||||
container_name: picobot
|
|
||||||
restart: unless-stopped
|
|
||||||
ports:
|
|
||||||
- "19876:19876"
|
|
||||||
volumes:
|
|
||||||
- ~/.picobot/config.json:/app/.picobot/config.json:ro
|
|
||||||
- picobot_data:/app/.picobot
|
|
||||||
environment:
|
|
||||||
- RUST_LOG=info
|
|
||||||
command: gateway
|
|
||||||
|
|
||||||
volumes:
|
|
||||||
picobot_data:
|
|
||||||
@ -1,359 +0,0 @@
|
|||||||
# PicoBot 代码质量分析报告
|
|
||||||
|
|
||||||
审查日期:2026-06-15
|
|
||||||
|
|
||||||
## 结论摘要
|
|
||||||
|
|
||||||
PicoBot 的总体架构方向是清晰的:Gateway 负责装配,Channel 只做收发,MessageBus 解耦输入输出,SessionManager 管理会话,AgentLoop 保持无状态并执行工具,Storage 统一持久化。这条主线是成立的,也已经具备较完整的 AI 助手运行时能力。
|
|
||||||
|
|
||||||
当前主要质量风险集中在三类:
|
|
||||||
|
|
||||||
1. 会话/CLI 路由语义不一致,导致多客户端隔离、加载会话、当前会话追踪不可靠。
|
|
||||||
2. 若干公开控制接口是空实现或弱实现,协议层暴露的能力和后端实际行为不匹配。
|
|
||||||
3. 工具和后台任务的资源边界偏弱,文件、shell、HTTP、长期任务在异常情况下容易突破预期的安全或稳定性边界。
|
|
||||||
|
|
||||||
如果只安排一轮修复,优先处理会话路由和控制接口。这些问题会直接影响用户看到的行为;工具安全和大模块拆分可以作为第二阶段。
|
|
||||||
|
|
||||||
## 修复状态
|
|
||||||
|
|
||||||
- 已修复:CLI 会话路由现在按每个 WebSocket client 的稳定 `chat_id` 隔离,普通输入、创建、列表、加载和 outbound 投递不再混用完整 `session_id` 与 `chat_id`。
|
|
||||||
- 已修复:Dialog 控制接口已补齐当前会话查询、列表 current 标记、归档、清空历史和 `/delete` 删除当前会话后新建的行为;`include_archived` 现在由 Storage 查询生效。
|
|
||||||
- 待处理:工具文件边界、Session 锁粒度、Bash 超时进程清理等仍是后续质量风险。
|
|
||||||
|
|
||||||
## 主要发现
|
|
||||||
|
|
||||||
### 已修复:CLI 会话路由会破坏会话连续性和多客户端隔离
|
|
||||||
|
|
||||||
位置:
|
|
||||||
|
|
||||||
- `src/channels/cli_chat.rs:113-126`
|
|
||||||
- `src/channels/cli_chat.rs:160-164`
|
|
||||||
- `src/channels/cli_chat.rs:225-249`
|
|
||||||
- `src/channels/cli_chat.rs:479-494`
|
|
||||||
- `src/session/session.rs:1305-1310`
|
|
||||||
|
|
||||||
问题:
|
|
||||||
|
|
||||||
`Client.current_session_id` 存的是完整 session id,但 CLI channel 在多个地方把它当作 `chat_id` 使用。普通用户输入如果没有显式传 `chat_id`,会在 `src/channels/cli_chat.rs:119` 生成新的短 ID,而不是复用当前 client 的 chat scope。`CreateSession` 又把当前完整 session id 当成新会话的 chat_id。`LoadSession` 解析了传入 session id,但随后调用 `GetCurrentDialog`,而后端 `get_current_dialog()` 固定返回 `None`。
|
|
||||||
|
|
||||||
同时,`send()` 会把所有 `OutboundMessage` 广播给所有 CLI WebSocket client,没有按 `msg.chat_id` 或 client 当前会话过滤。这意味着一个客户端的回复可能出现在另一个客户端里。
|
|
||||||
|
|
||||||
影响:
|
|
||||||
|
|
||||||
- CLI 多轮对话可能落入不同 chat scope。
|
|
||||||
- 创建/列出/加载会话得到的结果可能不符合 UI 预期。
|
|
||||||
- 多个 CLI 客户端同时连接时存在串话。
|
|
||||||
|
|
||||||
建议:
|
|
||||||
|
|
||||||
- 将 client 状态拆成 `chat_id` 和 `current_session_id`,不要混用。
|
|
||||||
- 注册 client 时生成稳定 `chat_id`,后续 `UserInput` 默认复用它。
|
|
||||||
- `send()` 按 `OutboundMessage.chat_id` 精确投递;必要时维护 `chat_id -> clients` 映射。
|
|
||||||
- `LoadSession` 应直接切换到指定 session,或通过 `SwitchDialog` 使用其中的 `dialog_id`。
|
|
||||||
- 为 CLI WebSocket 增加多客户端路由测试。
|
|
||||||
|
|
||||||
### 已修复:Dialog 控制接口与协议承诺不一致
|
|
||||||
|
|
||||||
位置:
|
|
||||||
|
|
||||||
- `src/session/session.rs:996-997`
|
|
||||||
- `src/session/session.rs:1305-1310`
|
|
||||||
- `src/session/session.rs:1329-1349`
|
|
||||||
- `src/session/session.rs:1378-1384`
|
|
||||||
- `src/channels/cli_chat.rs:128-158`
|
|
||||||
|
|
||||||
问题:
|
|
||||||
|
|
||||||
后端暴露了 create/list/load/rename/archive/delete/clear 等 dialog 操作,但部分行为是空实现或语义错位:
|
|
||||||
|
|
||||||
- `/delete` 只创建新 session,并没有删除当前 session。
|
|
||||||
- `get_current_dialog()` 固定返回 `Ok(None)`。
|
|
||||||
- `list_dialogs()` 忽略 `include_archived`,且总是返回 `current_dialog_id = None`。
|
|
||||||
- `archive_dialog()` 是空操作。
|
|
||||||
- `clear_dialog_history()` 直接返回不可用,但 WebSocket 协议仍暴露 `clear_history`。
|
|
||||||
|
|
||||||
影响:
|
|
||||||
|
|
||||||
用户通过 slash command 和 WebSocket 调用同一类能力时,会得到不一致结果。前端难以基于协议实现可靠状态同步。
|
|
||||||
|
|
||||||
建议:
|
|
||||||
|
|
||||||
- 明确“archive/clear 是否支持”。不支持就从协议和命令列表移除;支持就实现到底。
|
|
||||||
- `/delete` 应调用 `delete_dialog(current_session_id)`,再创建一个新的 current session。
|
|
||||||
- `get_current_dialog()` 应读取 `current_sessions[channel:chat_id]` 并解析为 `UnifiedSessionId`。
|
|
||||||
- `list_dialogs()` 返回真实 current dialog,并补上 archived 模型或移除 archived 参数。
|
|
||||||
|
|
||||||
### 高优先级:工具文件边界不符合“工作目录内工具”的架构约束
|
|
||||||
|
|
||||||
位置:
|
|
||||||
|
|
||||||
- `src/tools/mod.rs:56-62`
|
|
||||||
- `src/tools/path_utils.rs:3-23`
|
|
||||||
- `src/tools/bash.rs:146-185`
|
|
||||||
|
|
||||||
问题:
|
|
||||||
|
|
||||||
文件工具默认通过 `FileReadTool::new()`、`FileWriteTool::new()` 等注册,没有传入 workspace allowlist。`resolve_path()` 对绝对路径直接放行;即使传入 allowlist,也只是做 `Path::starts_with()` 的词法判断,没有 canonicalize,不能防御 `..`、符号链接等路径逃逸。
|
|
||||||
|
|
||||||
`bash` 默认工作目录是 `"."`,Gateway 启动时切到 workspace,这对相对路径有效,但 shell 命令仍然可以访问绝对路径。当前 denylist 只挡少数危险模式,不构成权限边界。
|
|
||||||
|
|
||||||
影响:
|
|
||||||
|
|
||||||
Agent 工具实际可以读写 workspace 外文件,和文档/架构里的“工作目录内操作”不一致。对于个人助手这可能是有意设计,但如果未来接入外部渠道、多用户或 MCP,风险会放大。
|
|
||||||
|
|
||||||
建议:
|
|
||||||
|
|
||||||
- 工具注册时传入 `workspace_dir`,默认所有文件工具限制在 workspace。
|
|
||||||
- `resolve_path()` 使用 `std::fs::canonicalize` 或 `path_absolutize` 风格逻辑,并处理目标文件不存在时的父目录 canonicalize。
|
|
||||||
- 写工具禁止跟随危险符号链接,或至少在文档中明确该能力是全文件系统权限。
|
|
||||||
- shell 工具如果保留,应在配置中显式开关,并区分本地可信模式和渠道暴露模式。
|
|
||||||
|
|
||||||
### 中高优先级:Session 锁内执行过多异步操作
|
|
||||||
|
|
||||||
位置:
|
|
||||||
|
|
||||||
- `src/session/session.rs:1001-1018`
|
|
||||||
- `src/session/session.rs:1604-1711`
|
|
||||||
|
|
||||||
问题:
|
|
||||||
|
|
||||||
`/compact` 在持有 session mutex 时执行压缩和持久化。agent worker 的 Phase 1 也在持有 session mutex 时执行用户消息落库、memory recall、上下文压缩、session meta 持久化和 agent 创建。其中 `compress_if_needed()` 可能触发 LLM 摘要,属于慢操作。
|
|
||||||
|
|
||||||
影响:
|
|
||||||
|
|
||||||
- 同一 session 的 slash command、stop、消息排队、状态查询会被慢操作阻塞。
|
|
||||||
- 当压缩或存储出现抖动时,用户感觉像“卡死”。
|
|
||||||
- 后续如果在这些慢操作里间接需要 session 状态,容易形成锁顺序问题。
|
|
||||||
|
|
||||||
建议:
|
|
||||||
|
|
||||||
- 锁内只做内存状态快照和必要的状态标记。
|
|
||||||
- 将 memory recall、压缩、LLM 摘要放到锁外执行。
|
|
||||||
- 锁外完成后重新加锁提交结果,并用 generation/version 检测期间是否被 `/stop` 或新任务替换。
|
|
||||||
|
|
||||||
### 中优先级:Bash 超时不会显式终止子进程
|
|
||||||
|
|
||||||
位置:
|
|
||||||
|
|
||||||
- `src/tools/bash.rs:150-174`
|
|
||||||
- `src/tools/bash.rs:180-207`
|
|
||||||
|
|
||||||
问题:
|
|
||||||
|
|
||||||
`timeout()` 包裹的是 `run_command()` future。超时后 future 被取消,但代码没有持有 child 句柄并显式 `kill()` / `wait()`。对于已经启动的长运行命令或子进程树,可能留下后台进程。
|
|
||||||
|
|
||||||
影响:
|
|
||||||
|
|
||||||
长任务、服务进程或卡住的 shell 命令会泄漏进程和资源,后续工具调用的行为也会变得不可预测。
|
|
||||||
|
|
||||||
建议:
|
|
||||||
|
|
||||||
- 使用 `tokio::process::Child` 的 `kill_on_drop(true)`。
|
|
||||||
- 超时分支显式 kill child 并 wait。
|
|
||||||
- 对 shell 子进程树使用进程组隔离,必要时杀整个进程组。
|
|
||||||
- 对需要持久进程的场景使用 PTY 工具,不混用 bash 的一次性语义。
|
|
||||||
|
|
||||||
### 中优先级:文件读取对大二进制文件没有输出上限
|
|
||||||
|
|
||||||
位置:
|
|
||||||
|
|
||||||
- `src/tools/file_read.rs:121-131`
|
|
||||||
- `src/tools/file_read.rs:214-229`
|
|
||||||
|
|
||||||
问题:
|
|
||||||
|
|
||||||
`file_read` 先 `std::fs::read()` 读取整个文件。文本路径有 `MAX_CHARS` 截断,但二进制路径会完整 base64 编码后返回,没有大小限制。
|
|
||||||
|
|
||||||
影响:
|
|
||||||
|
|
||||||
读取大文件会造成内存膨胀、响应膨胀、上下文污染,甚至拖垮进程。
|
|
||||||
|
|
||||||
建议:
|
|
||||||
|
|
||||||
- 先检查 metadata size,超过阈值直接返回提示。
|
|
||||||
- 二进制文件默认只返回 mime、大小和建议操作;需要内容时提供显式 `max_bytes` 参数。
|
|
||||||
- 对文本读取也改成流式按行读取,而不是整文件读入。
|
|
||||||
|
|
||||||
### 中优先级:HTTP 私网防护只检查字面 host,未做 DNS 解析校验
|
|
||||||
|
|
||||||
位置:
|
|
||||||
|
|
||||||
- `src/tools/http_request.rs:31-59`
|
|
||||||
|
|
||||||
问题:
|
|
||||||
|
|
||||||
`http_request` 阻止 localhost、私网 IP 字面量和 `.local`,但普通域名不会解析后检查最终 IP。DNS rebinding 或内网域名解析到私网地址时,当前校验拦不住。
|
|
||||||
|
|
||||||
影响:
|
|
||||||
|
|
||||||
如果该工具暴露给非完全可信输入,存在 SSRF 风险。
|
|
||||||
|
|
||||||
建议:
|
|
||||||
|
|
||||||
- 请求前解析域名,拒绝私网、loopback、link-local、multicast、unspecified 地址。
|
|
||||||
- 禁止或限制重定向,重定向后的每个 URL 重新校验。
|
|
||||||
- 对 `http_request` 和 `web_fetch` 复用同一套 URL 安全策略。
|
|
||||||
|
|
||||||
### 中优先级:后台任务和主循环缺少监督与优雅关闭
|
|
||||||
|
|
||||||
位置:
|
|
||||||
|
|
||||||
- `src/bus/mod.rs:51-99`
|
|
||||||
- `src/gateway/mod.rs:187-244`
|
|
||||||
- `src/gateway/mod.rs:247-266`
|
|
||||||
|
|
||||||
问题:
|
|
||||||
|
|
||||||
Gateway 中多个长期任务通过 `tokio::spawn` 启动后没有保存 JoinHandle,也没有统一 cancellation token。MessageBus 的 `consume_*()` 在 channel 关闭时使用 `expect()` panic。
|
|
||||||
|
|
||||||
影响:
|
|
||||||
|
|
||||||
- 某个后台 loop 异常退出后,Gateway 不一定能发现。
|
|
||||||
- 关闭流程只能 stop channel,无法系统性停止 scheduler、dispatcher、agent workers、notification publishers。
|
|
||||||
- bus channel 关闭时更像崩溃,而不是可恢复状态。
|
|
||||||
|
|
||||||
建议:
|
|
||||||
|
|
||||||
- 引入 runtime supervisor,保存 JoinHandle 并集中处理退出原因。
|
|
||||||
- 用 `CancellationToken` 贯穿 Gateway 子任务。
|
|
||||||
- `consume_*()` 返回 `Result<Option<T>>`,由调用方决定退出或重启。
|
|
||||||
|
|
||||||
### 中低优先级:Cron 计算函数没有按入参 `from` 计算 cron 下一次时间
|
|
||||||
|
|
||||||
位置:
|
|
||||||
|
|
||||||
- `src/scheduler/mod.rs:18-40`
|
|
||||||
|
|
||||||
问题:
|
|
||||||
|
|
||||||
`next_run_for_schedule(schedule, from)` 的注释说基于 `from` 计算,但 cron 分支创建了 `from_dt` 后没有传给 `cron_schedule`,实际使用的是 `upcoming(Utc)` 或 `upcoming(tz)` 的当前时间。
|
|
||||||
|
|
||||||
影响:
|
|
||||||
|
|
||||||
单元测试或补偿调度传入历史/未来时间时,结果不符合函数契约。线上 reschedule 当前使用 now,影响较小,但函数语义是错的。
|
|
||||||
|
|
||||||
建议:
|
|
||||||
|
|
||||||
- 使用 `cron_schedule.after(&from_dt).next()` 或等价 API。
|
|
||||||
- timezone 分支用 `from_dt.with_timezone(&tz)` 作为 after 起点。
|
|
||||||
- 增加固定时间输入的单元测试,避免受系统时间影响。
|
|
||||||
|
|
||||||
### 中低优先级:存在未接入或半接入代码,增加维护噪音
|
|
||||||
|
|
||||||
位置:
|
|
||||||
|
|
||||||
- `src/tools/pty.rs`
|
|
||||||
- `src/tools/mod.rs:1-20`
|
|
||||||
- `src/tools/mod.rs:49-88`
|
|
||||||
|
|
||||||
问题:
|
|
||||||
|
|
||||||
仓库里有完整 `pty.rs`,但 `tools/mod.rs` 没有声明 `pub mod pty`,`create_default_tools()` 也没有注册 PTY 工具。类似情况会让文档、计划和实现状态难以判断。
|
|
||||||
|
|
||||||
影响:
|
|
||||||
|
|
||||||
维护者会误以为功能已上线。未来改动容易遗漏测试和注册路径。
|
|
||||||
|
|
||||||
建议:
|
|
||||||
|
|
||||||
- 若 PTY 是要发布的功能:接入模块导出、注册、配置开关、测试和文档。
|
|
||||||
- 若暂不发布:移动到设计文档或 feature branch,避免主干保留死代码。
|
|
||||||
|
|
||||||
## 架构评价
|
|
||||||
|
|
||||||
### 做得好的地方
|
|
||||||
|
|
||||||
- 模块分层方向清楚:Channel、Bus、Session、Agent、Provider、Tool、Storage 边界基本可理解。
|
|
||||||
- AgentLoop 设计为无状态,历史由 SessionManager 管理,这一点利于恢复、压缩和测试。
|
|
||||||
- Provider 抽象简单直接,OpenAI-compatible 与 Anthropic 的差异被限制在 provider 层。
|
|
||||||
- Storage 集中初始化 schema,便于部署单二进制应用。
|
|
||||||
- Skill、memory、MCP、delegate 这几条扩展线已经形成统一的 ToolRegistry 接入点。
|
|
||||||
|
|
||||||
### 主要架构债务
|
|
||||||
|
|
||||||
- SessionManager 承担过多职责:会话生命周期、命令解析、memory recall、压缩、agent worker、任务取消、send_message 目标解析都在一个 2000 行文件内。
|
|
||||||
- Channel 和 Session 对 chat_id/session_id/dialog_id 的边界没有类型保护,导致 CLI 层混用字符串。
|
|
||||||
- Tool 权限模型不够显式:工具是否能访问全文件系统、是否能联网、是否能修改状态主要靠工具自身约定。
|
|
||||||
- 后台任务生命周期分散:gateway loop、agent worker、notification publisher、scheduler、sub-agent task 各自 spawn,缺少统一管理。
|
|
||||||
|
|
||||||
## 模块级分析
|
|
||||||
|
|
||||||
### gateway
|
|
||||||
|
|
||||||
`GatewayState::new()` 是清晰的装配中心:配置、workspace、storage、memory、bus、session manager、channels、MCP、scheduler 都在这里接线。问题是启动后任务监督不足,且 scheduler 默认 `unwrap_or_default()` 会在省略 `gateway.scheduler` 时启用调度器,这和“省略配置是否代表开启”需要产品层确认。
|
|
||||||
|
|
||||||
### channels
|
|
||||||
|
|
||||||
Feishu channel 功能较厚,单文件接近 2000 行,建议后续按 API client、message parsing、media handling、outbound rendering 拆分。CLI channel 目前是质量风险最高的 channel,核心问题是会话身份混用和广播投递。
|
|
||||||
|
|
||||||
### bus
|
|
||||||
|
|
||||||
MessageBus 简洁,但当前消费者 API 通过 mutex 包住 receiver 并 `expect()`,更像“单消费者内部队列”。这没问题,但应该把“只能有一个 consumer”写进类型/文档,并把关闭作为正常状态处理。
|
|
||||||
|
|
||||||
### session
|
|
||||||
|
|
||||||
这是系统核心,也是债务最集中的模块。建议把 `session.rs` 拆成:
|
|
||||||
|
|
||||||
- `manager.rs`:SessionManager 状态和 dialog 生命周期
|
|
||||||
- `worker.rs`:per-session agent worker 和 cancellation
|
|
||||||
- `commands.rs`:slash command 执行
|
|
||||||
- `outbound.rs`:OutboundMessenger 实现
|
|
||||||
- `restore.rs`:storage 恢复与 tool call chain repair
|
|
||||||
|
|
||||||
拆分之前,先补行为测试,尤其是 CLI/WS session lifecycle。
|
|
||||||
|
|
||||||
### agent
|
|
||||||
|
|
||||||
AgentLoop 的职责相对聚焦:请求模型、执行工具、回填 tool result、循环直到 final response。需要关注的是工具并发的语义:`read_only()` 目前是工具自己声明,副作用工具不能错标。LoopDetector 有帮助,但属于 runtime guard,不应替代工具层的资源限制。
|
|
||||||
|
|
||||||
### providers
|
|
||||||
|
|
||||||
Provider 层整体可维护。OpenAI/Anthropic 的请求构造逻辑可以继续保留在 provider 内。建议补充请求脱敏策略:当前 debug log 和 `llm_calls` 会持久化完整 request/response,可能包含用户隐私、API 返回内容和文件内容。
|
|
||||||
|
|
||||||
### tools
|
|
||||||
|
|
||||||
工具体系覆盖面很强,但需要明确权限模型。建议新增统一的 `ToolExecutionContext`,包含 workspace、channel、session_id、权限策略、网络策略、输出预算。现在很多策略散落在各工具构造函数里,默认值容易失控。
|
|
||||||
|
|
||||||
### storage
|
|
||||||
|
|
||||||
Storage schema 初始化实用,但迁移方式是“CREATE IF NOT EXISTS + ALTER IGNORE”,适合早期迭代,不适合长期演进。建议引入 schema version 表或 sqlx migrations,至少把每次迁移记录下来。
|
|
||||||
|
|
||||||
### skills
|
|
||||||
|
|
||||||
Skill 加载优先级清晰,内置 skill 打包也实用。需要注意 `SkillsLoader` 使用同步文件系统扫描和 `std::sync::Mutex`,在请求路径频繁 `reload_if_changed()` 时可能造成阻塞。短期可以接受,长期建议缓存刷新放到后台 watcher。
|
|
||||||
|
|
||||||
## 建议修复路线
|
|
||||||
|
|
||||||
### P0:先修会话正确性
|
|
||||||
|
|
||||||
1. 修正 CLI `chat_id/current_session_id` 数据模型。
|
|
||||||
2. 修正 CLI 出站按 client/chat_id 投递。
|
|
||||||
3. 实现 `get_current_dialog()`、`list_dialogs()` current 返回。
|
|
||||||
4. 修正 `/delete`、`clear_history`、`archive` 的真实行为或从协议移除。
|
|
||||||
5. 增加 WebSocket session lifecycle 测试。
|
|
||||||
|
|
||||||
### P1:收紧工具和资源边界
|
|
||||||
|
|
||||||
1. 文件工具默认限制 workspace,路径 canonicalize。
|
|
||||||
2. bash 超时杀进程,必要时引入进程组。
|
|
||||||
3. file_read 增加文件大小上限和二进制输出上限。
|
|
||||||
4. HTTP/web 工具增加 DNS 解析后的私网校验和重定向校验。
|
|
||||||
5. 明确高危工具的配置开关。
|
|
||||||
|
|
||||||
### P2:降低架构复杂度
|
|
||||||
|
|
||||||
1. 拆分 `session.rs`、`feishu.rs`、`storage/mod.rs`、`browser.rs`。
|
|
||||||
2. 引入任务 supervisor 和统一 shutdown token。
|
|
||||||
3. 引入正式数据库迁移。
|
|
||||||
4. 增加工具注册快照测试,避免死代码和文档漂移。
|
|
||||||
|
|
||||||
## 建议测试补充
|
|
||||||
|
|
||||||
- CLI 多客户端并发:两个 WebSocket client 同时发消息,互不串话。
|
|
||||||
- CLI 不传 chat_id 的连续对话:所有消息应进入同一 session。
|
|
||||||
- Load/switch/list/delete/clear 的完整 WebSocket 流程。
|
|
||||||
- `/delete` 后旧 session 软删除、新 session 成为 current。
|
|
||||||
- 文件路径逃逸:`../`、绝对路径、符号链接、workspace 前缀欺骗。
|
|
||||||
- bash timeout 后检查子进程不存在。
|
|
||||||
- cron `next_run_for_schedule()` 使用固定 `from` 的 deterministic 测试。
|
|
||||||
- HTTP 工具对 DNS 解析到 `127.0.0.1` / `10.0.0.0/8` 的域名拒绝测试。
|
|
||||||
40
docs/plans/2026-04-26-client-refactor-design.md
Normal file
40
docs/plans/2026-04-26-client-refactor-design.md
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
# 客户端代码整合设计
|
||||||
|
|
||||||
|
## 目标
|
||||||
|
|
||||||
|
将分散在 `src/cli/` 和 `src/client/` 的客户端代码整合到 `src/client/` 目录。
|
||||||
|
|
||||||
|
## 变更
|
||||||
|
|
||||||
|
### 目录结构
|
||||||
|
|
||||||
|
```
|
||||||
|
src/
|
||||||
|
├── client/ # 整合后的客户端模块
|
||||||
|
│ ├── mod.rs # 主程序入口 (run 函数)
|
||||||
|
│ ├── input.rs # InputHandler + InputCommand (从 cli/input.rs 合并)
|
||||||
|
│ └── channel.rs # CliChannel (从 cli/channel.rs 合并)
|
||||||
|
├── cli/ # 删除
|
||||||
|
└── protocol.rs # 保留
|
||||||
|
```
|
||||||
|
|
||||||
|
### 关键变更
|
||||||
|
|
||||||
|
| 变更 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `InputEvent::Message(String)` | 简化为只携带文本内容,不再使用 `ChatMessage` |
|
||||||
|
| `cli` 模块删除 | 代码合并到 `client` |
|
||||||
|
| 解耦 | `client` 不再依赖 `bus::ChatMessage` |
|
||||||
|
|
||||||
|
## 实施步骤
|
||||||
|
|
||||||
|
1. 创建 `src/client/input.rs` - 从 `cli/input.rs` 合并,修改 `InputEvent::Message` 为 `String`
|
||||||
|
2. 创建 `src/client/channel.rs` - 从 `cli/channel.rs` 直接复制
|
||||||
|
3. 更新 `src/client/mod.rs` - 更新 import
|
||||||
|
4. 更新 `src/lib.rs` - 删除 `pub mod cli;`
|
||||||
|
5. 删除 `src/cli/` 目录
|
||||||
|
|
||||||
|
## 验证
|
||||||
|
|
||||||
|
- `cargo build` 通过
|
||||||
|
- 功能保持不变
|
||||||
877
docs/plans/2026-04-28-phase1-storage-implementation.md
Normal file
877
docs/plans/2026-04-28-phase1-storage-implementation.md
Normal file
@ -0,0 +1,877 @@
|
|||||||
|
# Phase 1: Storage 基础 实现计划
|
||||||
|
|
||||||
|
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||||
|
|
||||||
|
**Goal:** 创建 `src/storage/` 模块,实现 SQLite 持久化层,为后续 Session 扩展提供 Storage 基础设施。
|
||||||
|
|
||||||
|
**Architecture:** 使用 `sqlx` + `sqlite`,通过 `SqlitePool` 实现异步连接池,所有 Storage 操作均为 async,在 `Storage` 内部管理连接池的生命周期。
|
||||||
|
|
||||||
|
**Tech Stack:** `sqlx` (sqlite, tokio), `serde`, `chrono` (时间戳), `tokio::time::sleep` (重试退避)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 1: 添加依赖
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `Cargo.toml:36` (在 `[dependencies]` 末尾添加)
|
||||||
|
|
||||||
|
**Step 1: 添加 sqlx + sqlite 依赖**
|
||||||
|
|
||||||
|
在 `Cargo.toml` 末尾添加:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
sqlx = { version = "0.8", features = ["sqlite", "tokio", "macros", "chrono"] }
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: 运行 cargo check 验证依赖**
|
||||||
|
|
||||||
|
Run: `cargo check 2>&1`
|
||||||
|
Expected: 无报错,依赖解析成功
|
||||||
|
|
||||||
|
**Step 3: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add Cargo.toml
|
||||||
|
git commit -m "deps: 添加 sqlx + sqlite 依赖"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 2: 创建 Storage Error 类型
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `src/storage/error.rs`
|
||||||
|
|
||||||
|
**Step 1: 编写 StorageError 枚举和测试**
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum StorageError {
|
||||||
|
#[error("session not found: {0}")]
|
||||||
|
NotFound(String),
|
||||||
|
|
||||||
|
#[error("session already exists: {0}")]
|
||||||
|
AlreadyExists(String),
|
||||||
|
|
||||||
|
#[error("database error: {0}")]
|
||||||
|
Database(#[from] sqlx::Error),
|
||||||
|
|
||||||
|
#[error("serialization error: {0}")]
|
||||||
|
Serialization(String),
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: 验证编译**
|
||||||
|
|
||||||
|
Run: `cargo build --lib 2>&1 | head -30`
|
||||||
|
Expected: 报错 "cannot find module `storage`"(因为模块未创建),这是预期的
|
||||||
|
|
||||||
|
**Step 3: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add src/storage/error.rs
|
||||||
|
git commit -m "feat(storage): 添加 StorageError 类型"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 3: 创建 Storage 模块骨架
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `src/storage/mod.rs`
|
||||||
|
- Create: `src/storage/session.rs`
|
||||||
|
- Create: `src/storage/message.rs`
|
||||||
|
|
||||||
|
**Step 1: 创建 `src/storage/mod.rs`**
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub mod error;
|
||||||
|
pub mod session;
|
||||||
|
pub mod message;
|
||||||
|
|
||||||
|
pub use error::StorageError;
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: 创建 `src/storage/session.rs`(空壳)**
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// Session CRUD 操作占位符
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: 创建 `src/storage/message.rs`(空壳)**
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// Message CRUD 操作占位符
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: 在 `src/lib.rs` 中添加 storage 模块**
|
||||||
|
|
||||||
|
在 `src/lib.rs` 末尾添加:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub mod storage;
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 5: 验证编译**
|
||||||
|
|
||||||
|
Run: `cargo build --lib 2>&1`
|
||||||
|
Expected: 编译成功(空壳模块)
|
||||||
|
|
||||||
|
**Step 6: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add src/storage/ src/lib.rs
|
||||||
|
git commit -m "feat(storage): 创建 storage 模块骨架"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 4: 实现 Storage 主结构
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `src/storage/mod.rs`
|
||||||
|
|
||||||
|
**Step 1: 编写 Storage 结构和初始化逻辑**
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use sqlx::{Pool, Sqlite, SqlitePool};
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
pub struct Storage {
|
||||||
|
pool: Pool<Sqlite>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Storage {
|
||||||
|
/// 打开或创建数据库
|
||||||
|
pub async fn new(db_path: &Path) -> Result<Self, StorageError> {
|
||||||
|
let database_url = format!("sqlite:{}?mode=rwc", db_path.display());
|
||||||
|
let pool = SqlitePool::connect(&database_url).await?;
|
||||||
|
|
||||||
|
let storage = Self { pool };
|
||||||
|
storage.init_schema().await?;
|
||||||
|
Ok(storage)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 初始化数据库 schema
|
||||||
|
async fn init_schema(&self) -> Result<(), StorageError> {
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
CREATE TABLE IF NOT EXISTS sessions (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
channel TEXT NOT NULL,
|
||||||
|
chat_id TEXT NOT NULL,
|
||||||
|
dialog_id TEXT NOT NULL,
|
||||||
|
title TEXT NOT NULL DEFAULT '新对话',
|
||||||
|
created_at INTEGER NOT NULL,
|
||||||
|
last_active_at INTEGER NOT NULL,
|
||||||
|
message_count INTEGER DEFAULT 0,
|
||||||
|
routing_info TEXT,
|
||||||
|
deleted_at INTEGER,
|
||||||
|
UNIQUE(channel, chat_id, dialog_id)
|
||||||
|
)
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_sessions_chat
|
||||||
|
ON sessions(channel, chat_id, deleted_at)
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
CREATE TABLE IF NOT EXISTS messages (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
session_id TEXT NOT NULL,
|
||||||
|
seq INTEGER NOT NULL,
|
||||||
|
role TEXT NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
media_refs TEXT,
|
||||||
|
tool_call_id TEXT,
|
||||||
|
tool_name TEXT,
|
||||||
|
tool_calls TEXT,
|
||||||
|
created_at INTEGER NOT NULL,
|
||||||
|
FOREIGN KEY (session_id) REFERENCES sessions(id) ON DELETE CASCADE
|
||||||
|
)
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_messages_session_seq
|
||||||
|
ON messages(session_id, seq)
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取连接池引用(供内部 CRUD 使用)
|
||||||
|
pub(crate) fn pool(&self) -> &Pool<Sqlite> {
|
||||||
|
&self.pool
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: 验证编译**
|
||||||
|
|
||||||
|
Run: `cargo build --lib 2>&1`
|
||||||
|
Expected: 编译成功
|
||||||
|
|
||||||
|
**Step 3: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add src/storage/mod.rs
|
||||||
|
git commit -m "feat(storage): 实现 Storage 主结构和初始化"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 5: 定义 SessionMeta 和 MessageMeta 数据结构
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `src/storage/session.rs`
|
||||||
|
- Modify: `src/storage/message.rs`
|
||||||
|
|
||||||
|
**Step 1: 在 `session.rs` 中定义 SessionMeta**
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct SessionMeta {
|
||||||
|
pub id: String,
|
||||||
|
pub channel: String,
|
||||||
|
pub chat_id: String,
|
||||||
|
pub dialog_id: String,
|
||||||
|
pub title: String,
|
||||||
|
pub created_at: i64,
|
||||||
|
pub last_active_at: i64,
|
||||||
|
pub message_count: i64,
|
||||||
|
pub routing_info: Option<String>,
|
||||||
|
pub deleted_at: Option<i64>,
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: 在 `message.rs` 中定义 MessageMeta**
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct MessageMeta {
|
||||||
|
pub id: String,
|
||||||
|
pub session_id: String,
|
||||||
|
pub seq: i64,
|
||||||
|
pub role: String,
|
||||||
|
pub content: String,
|
||||||
|
pub media_refs: Option<String>,
|
||||||
|
pub tool_call_id: Option<String>,
|
||||||
|
pub tool_name: Option<String>,
|
||||||
|
pub tool_calls: Option<String>,
|
||||||
|
pub created_at: i64,
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: 验证编译**
|
||||||
|
|
||||||
|
Run: `cargo build --lib 2>&1`
|
||||||
|
Expected: 编译成功
|
||||||
|
|
||||||
|
**Step 4: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add src/storage/session.rs src/storage/message.rs
|
||||||
|
git commit -m "feat(storage): 定义 SessionMeta 和 MessageMeta 数据结构"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 6: 实现 Session CRUD 操作
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `src/storage/session.rs`
|
||||||
|
|
||||||
|
**Step 1: 编写 upsert_session**
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use sqlx::Row;
|
||||||
|
use super::SessionMeta;
|
||||||
|
|
||||||
|
impl Storage {
|
||||||
|
pub async fn upsert_session(&self, meta: &SessionMeta) -> Result<(), StorageError> {
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
INSERT INTO sessions (id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT(id) DO UPDATE SET
|
||||||
|
title = excluded.title,
|
||||||
|
last_active_at = excluded.last_active_at,
|
||||||
|
message_count = excluded.message_count,
|
||||||
|
routing_info = excluded.routing_info,
|
||||||
|
deleted_at = excluded.deleted_at
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(&meta.id)
|
||||||
|
.bind(&meta.channel)
|
||||||
|
.bind(&meta.chat_id)
|
||||||
|
.bind(&meta.dialog_id)
|
||||||
|
.bind(&meta.title)
|
||||||
|
.bind(meta.created_at)
|
||||||
|
.bind(meta.last_active_at)
|
||||||
|
.bind(meta.message_count)
|
||||||
|
.bind(&meta.routing_info)
|
||||||
|
.bind(meta.deleted_at)
|
||||||
|
.execute(self.pool())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_session(&self, id: &str) -> Result<SessionMeta, StorageError> {
|
||||||
|
let row = sqlx::query(
|
||||||
|
r#"
|
||||||
|
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at
|
||||||
|
FROM sessions WHERE id = ? AND deleted_at IS NULL
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(id)
|
||||||
|
.fetch_optional(self.pool())
|
||||||
|
.await?
|
||||||
|
.ok_or_else(|| StorageError::NotFound(id.to_string()))?;
|
||||||
|
|
||||||
|
Ok(SessionMeta {
|
||||||
|
id: row.get("id"),
|
||||||
|
channel: row.get("channel"),
|
||||||
|
chat_id: row.get("chat_id"),
|
||||||
|
dialog_id: row.get("dialog_id"),
|
||||||
|
title: row.get("title"),
|
||||||
|
created_at: row.get("created_at"),
|
||||||
|
last_active_at: row.get("last_active_at"),
|
||||||
|
message_count: row.get("message_count"),
|
||||||
|
routing_info: row.get("routing_info"),
|
||||||
|
deleted_at: row.get("deleted_at"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn list_sessions(
|
||||||
|
&self,
|
||||||
|
channel: &str,
|
||||||
|
chat_id: &str,
|
||||||
|
limit: i64,
|
||||||
|
) -> Result<Vec<SessionMeta>, StorageError> {
|
||||||
|
let rows = sqlx::query(
|
||||||
|
r#"
|
||||||
|
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at
|
||||||
|
FROM sessions
|
||||||
|
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL
|
||||||
|
ORDER BY last_active_at DESC
|
||||||
|
LIMIT ?
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(channel)
|
||||||
|
.bind(chat_id)
|
||||||
|
.bind(limit)
|
||||||
|
.fetch_all(self.pool())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(rows
|
||||||
|
.into_iter()
|
||||||
|
.map(|row| SessionMeta {
|
||||||
|
id: row.get("id"),
|
||||||
|
channel: row.get("channel"),
|
||||||
|
chat_id: row.get("chat_id"),
|
||||||
|
dialog_id: row.get("dialog_id"),
|
||||||
|
title: row.get("title"),
|
||||||
|
created_at: row.get("created_at"),
|
||||||
|
last_active_at: row.get("last_active_at"),
|
||||||
|
message_count: row.get("message_count"),
|
||||||
|
routing_info: row.get("routing_info"),
|
||||||
|
deleted_at: row.get("deleted_at"),
|
||||||
|
})
|
||||||
|
.collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn touch_session(
|
||||||
|
&self,
|
||||||
|
id: &str,
|
||||||
|
message_count: i64,
|
||||||
|
last_active_at: i64,
|
||||||
|
) -> Result<(), StorageError> {
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
UPDATE sessions SET message_count = ?, last_active_at = ?
|
||||||
|
WHERE id = ?
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(message_count)
|
||||||
|
.bind(last_active_at)
|
||||||
|
.bind(id)
|
||||||
|
.execute(self.pool())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn soft_delete_session(&self, id: &str) -> Result<(), StorageError> {
|
||||||
|
let now = chrono::Utc::now().timestamp_millis();
|
||||||
|
sqlx::query(
|
||||||
|
r#"UPDATE sessions SET deleted_at = ? WHERE id = ?"#,
|
||||||
|
)
|
||||||
|
.bind(now)
|
||||||
|
.bind(id)
|
||||||
|
.execute(self.pool())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 查找 channel:chat_id 下最近活跃且未过期的 session
|
||||||
|
pub async fn find_active_session(
|
||||||
|
&self,
|
||||||
|
channel: &str,
|
||||||
|
chat_id: &str,
|
||||||
|
ttl_millis: i64,
|
||||||
|
) -> Result<Option<SessionMeta>, StorageError> {
|
||||||
|
let cutoff = chrono::Utc::now().timestamp_millis() - ttl_millis;
|
||||||
|
let row = sqlx::query(
|
||||||
|
r#"
|
||||||
|
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at
|
||||||
|
FROM sessions
|
||||||
|
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL AND last_active_at > ?
|
||||||
|
ORDER BY last_active_at DESC
|
||||||
|
LIMIT 1
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(channel)
|
||||||
|
.bind(chat_id)
|
||||||
|
.bind(cutoff)
|
||||||
|
.fetch_optional(self.pool())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
match row {
|
||||||
|
Some(row) => Ok(Some(SessionMeta {
|
||||||
|
id: row.get("id"),
|
||||||
|
channel: row.get("channel"),
|
||||||
|
chat_id: row.get("chat_id"),
|
||||||
|
dialog_id: row.get("dialog_id"),
|
||||||
|
title: row.get("title"),
|
||||||
|
created_at: row.get("created_at"),
|
||||||
|
last_active_at: row.get("last_active_at"),
|
||||||
|
message_count: row.get("message_count"),
|
||||||
|
routing_info: row.get("routing_info"),
|
||||||
|
deleted_at: row.get("deleted_at"),
|
||||||
|
})),
|
||||||
|
None => Ok(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
> 注意:`Storage` 的 CRUD 方法需要能访问 `pool()`,但目前 `pool()` 是 `pub(crate)`。在 `mod.rs` 中为 `session.rs` 实现 `Storage` 的 CRUD,所以同模块内可访问。
|
||||||
|
|
||||||
|
**Step 2: 验证编译**
|
||||||
|
|
||||||
|
Run: `cargo build --lib 2>&1`
|
||||||
|
Expected: 编译成功
|
||||||
|
|
||||||
|
**Step 3: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add src/storage/session.rs
|
||||||
|
git commit -m "feat(storage): 实现 Session CRUD 操作"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 7: 实现 Message CRUD 操作
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `src/storage/message.rs`
|
||||||
|
|
||||||
|
**Step 1: 编写 Message CRUD**
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use sqlx::Row;
|
||||||
|
use super::MessageMeta;
|
||||||
|
|
||||||
|
impl Storage {
|
||||||
|
pub async fn append_message(&self, session_id: &str, msg: &MessageMeta) -> Result<i64, StorageError> {
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
INSERT INTO messages (id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, created_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(&msg.id)
|
||||||
|
.bind(session_id)
|
||||||
|
.bind(msg.seq)
|
||||||
|
.bind(&msg.role)
|
||||||
|
.bind(&msg.content)
|
||||||
|
.bind(&msg.media_refs)
|
||||||
|
.bind(&msg.tool_call_id)
|
||||||
|
.bind(&msg.tool_name)
|
||||||
|
.bind(&msg.tool_calls)
|
||||||
|
.bind(msg.created_at)
|
||||||
|
.execute(self.pool())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(msg.seq)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn append_messages(
|
||||||
|
&self,
|
||||||
|
session_id: &str,
|
||||||
|
msgs: &[MessageMeta],
|
||||||
|
) -> Result<Vec<i64>, StorageError> {
|
||||||
|
let mut seqs = Vec::with_capacity(msgs.len());
|
||||||
|
for msg in msgs {
|
||||||
|
let seq = self.append_message(session_id, msg).await?;
|
||||||
|
seqs.push(seq);
|
||||||
|
}
|
||||||
|
Ok(seqs)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn load_messages(
|
||||||
|
&self,
|
||||||
|
session_id: &str,
|
||||||
|
from_seq: i64,
|
||||||
|
) -> Result<Vec<MessageMeta>, StorageError> {
|
||||||
|
let rows = sqlx::query(
|
||||||
|
r#"
|
||||||
|
SELECT id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, created_at
|
||||||
|
FROM messages
|
||||||
|
WHERE session_id = ? AND seq >= ?
|
||||||
|
ORDER BY seq ASC
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(session_id)
|
||||||
|
.bind(from_seq)
|
||||||
|
.fetch_all(self.pool())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(rows
|
||||||
|
.into_iter()
|
||||||
|
.map(|row| MessageMeta {
|
||||||
|
id: row.get("id"),
|
||||||
|
session_id: row.get("session_id"),
|
||||||
|
seq: row.get("seq"),
|
||||||
|
role: row.get("role"),
|
||||||
|
content: row.get("content"),
|
||||||
|
media_refs: row.get("media_refs"),
|
||||||
|
tool_call_id: row.get("tool_call_id"),
|
||||||
|
tool_name: row.get("tool_name"),
|
||||||
|
tool_calls: row.get("tool_calls"),
|
||||||
|
created_at: row.get("created_at"),
|
||||||
|
})
|
||||||
|
.collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> {
|
||||||
|
sqlx::query(r#"DELETE FROM messages WHERE session_id = ?"#)
|
||||||
|
.bind(session_id)
|
||||||
|
.execute(self.pool())
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
> 注意:同样在 `mod.rs` 中实现,这样 `Storage` 的方法对 `message.rs` 中的 impl 可见。
|
||||||
|
|
||||||
|
**Step 2: 验证编译**
|
||||||
|
|
||||||
|
Run: `cargo build --lib 2>&1`
|
||||||
|
Expected: 编译成功
|
||||||
|
|
||||||
|
**Step 3: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add src/storage/message.rs
|
||||||
|
git commit -m "feat(storage): 实现 Message CRUD 操作"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 8: 实现写入重试逻辑
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `src/storage/mod.rs`
|
||||||
|
|
||||||
|
**Step 1: 在 Storage 中添加带重试的 append_message**
|
||||||
|
|
||||||
|
在 `mod.rs` 的 `Storage` impl 块中添加:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use tokio::time::{sleep, Duration};
|
||||||
|
|
||||||
|
impl Storage {
|
||||||
|
/// 追加消息,带重试逻辑
|
||||||
|
/// 重试 3 次(100/200/300ms 退避),仍失败返回错误
|
||||||
|
pub async fn append_message_with_retry(
|
||||||
|
&self,
|
||||||
|
session_id: &str,
|
||||||
|
msg: &MessageMeta,
|
||||||
|
) -> Result<i64, StorageError> {
|
||||||
|
let delays = [100, 200, 300];
|
||||||
|
|
||||||
|
for (i, delay) in delays.iter().enumerate() {
|
||||||
|
match self.append_message(session_id, msg).await {
|
||||||
|
Ok(seq) => return Ok(seq),
|
||||||
|
Err(e) if i < delays.len() - 1 => {
|
||||||
|
sleep(Duration::from_millis(*delay)).await;
|
||||||
|
tracing::warn!("Storage write failed, retrying: {}", e);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("Storage write failed after retries: {}", e);
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
unreachable!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
> 注意:需要 `use sqlx::Row;` 在 `mod.rs` 中。
|
||||||
|
|
||||||
|
**Step 2: 验证编译**
|
||||||
|
|
||||||
|
Run: `cargo build --lib 2>&1`
|
||||||
|
Expected: 编译成功
|
||||||
|
|
||||||
|
**Step 3: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add src/storage/mod.rs
|
||||||
|
git commit -m "feat(storage): 添加写入重试逻辑"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 9: 编写 Storage 单元测试
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `src/storage/mod.rs`(添加测试模块)
|
||||||
|
|
||||||
|
**Step 1: 编写 Storage 集成测试**
|
||||||
|
|
||||||
|
在 `src/storage/mod.rs` 末尾添加:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use tempfile::tempdir;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
async fn create_test_storage() -> (Storage, impl Fn()) {
|
||||||
|
let dir = tempdir().unwrap();
|
||||||
|
let db_path = dir.path().join("test.db");
|
||||||
|
let storage = Storage::new(&db_path).await.unwrap();
|
||||||
|
|
||||||
|
let cleanup = || {
|
||||||
|
drop(dir);
|
||||||
|
};
|
||||||
|
|
||||||
|
(storage, cleanup)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_upsert_and_get_session() {
|
||||||
|
let (storage, cleanup) = create_test_storage().await;
|
||||||
|
defer { cleanup(); }
|
||||||
|
|
||||||
|
let meta = SessionMeta {
|
||||||
|
id: "cli_chat:sid123:dialog1".to_string(),
|
||||||
|
channel: "cli_chat".to_string(),
|
||||||
|
chat_id: "sid123".to_string(),
|
||||||
|
dialog_id: "dialog1".to_string(),
|
||||||
|
title: "测试会话".to_string(),
|
||||||
|
created_at: 1000,
|
||||||
|
last_active_at: 1000,
|
||||||
|
message_count: 0,
|
||||||
|
routing_info: Some(r#"{"type":"cli"}"#.to_string()),
|
||||||
|
deleted_at: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
storage.upsert_session(&meta).await.unwrap();
|
||||||
|
|
||||||
|
let loaded = storage.get_session(&meta.id).await.unwrap();
|
||||||
|
assert_eq!(loaded.title, "测试会话");
|
||||||
|
assert_eq!(loaded.channel, "cli_chat");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_get_nonexistent_session() {
|
||||||
|
let (storage, cleanup) = create_test_storage().await;
|
||||||
|
defer { cleanup(); }
|
||||||
|
|
||||||
|
let result = storage.get_session("nonexistent").await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
matches!(result.unwrap_err(), StorageError::NotFound(_));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_list_sessions() {
|
||||||
|
let (storage, cleanup) = create_test_storage().await;
|
||||||
|
defer { cleanup(); }
|
||||||
|
|
||||||
|
for i in 0..5 {
|
||||||
|
let meta = SessionMeta {
|
||||||
|
id: format!("cli_chat:sid123:dialog{}", i),
|
||||||
|
channel: "cli_chat".to_string(),
|
||||||
|
chat_id: "sid123".to_string(),
|
||||||
|
dialog_id: format!("dialog{}", i),
|
||||||
|
title: format!("会话{}", i),
|
||||||
|
created_at: i as i64 * 1000,
|
||||||
|
last_active_at: i as i64 * 1000,
|
||||||
|
message_count: i,
|
||||||
|
routing_info: None,
|
||||||
|
deleted_at: None,
|
||||||
|
};
|
||||||
|
storage.upsert_session(&meta).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let sessions = storage.list_sessions("cli_chat", "sid123", 10).await.unwrap();
|
||||||
|
assert_eq!(sessions.len(), 5);
|
||||||
|
// 按 last_active_at DESC 排序
|
||||||
|
assert_eq!(sessions[0].dialog_id, "dialog4");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_soft_delete() {
|
||||||
|
let (storage, cleanup) = create_test_storage().await;
|
||||||
|
defer { cleanup(); }
|
||||||
|
|
||||||
|
let meta = SessionMeta {
|
||||||
|
id: "cli_chat:sid123:dialog1".to_string(),
|
||||||
|
channel: "cli_chat".to_string(),
|
||||||
|
chat_id: "sid123".to_string(),
|
||||||
|
dialog_id: "dialog1".to_string(),
|
||||||
|
title: "测试".to_string(),
|
||||||
|
created_at: 1000,
|
||||||
|
last_active_at: 1000,
|
||||||
|
message_count: 0,
|
||||||
|
routing_info: None,
|
||||||
|
deleted_at: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
storage.upsert_session(&meta).await.unwrap();
|
||||||
|
storage.soft_delete_session(&meta.id).await.unwrap();
|
||||||
|
|
||||||
|
let result = storage.get_session(&meta.id).await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
matches!(result.unwrap_err(), StorageError::NotFound(_));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_append_and_load_messages() {
|
||||||
|
let (storage, cleanup) = create_test_storage().await;
|
||||||
|
defer { cleanup(); }
|
||||||
|
|
||||||
|
let session_meta = SessionMeta {
|
||||||
|
id: "cli_chat:sid123:dialog1".to_string(),
|
||||||
|
channel: "cli_chat".to_string(),
|
||||||
|
chat_id: "sid123".to_string(),
|
||||||
|
dialog_id: "dialog1".to_string(),
|
||||||
|
title: "测试".to_string(),
|
||||||
|
created_at: 1000,
|
||||||
|
last_active_at: 1000,
|
||||||
|
message_count: 0,
|
||||||
|
routing_info: None,
|
||||||
|
deleted_at: None,
|
||||||
|
};
|
||||||
|
storage.upsert_session(&session_meta).await.unwrap();
|
||||||
|
|
||||||
|
let msg = MessageMeta {
|
||||||
|
id: "msg1".to_string(),
|
||||||
|
session_id: session_meta.id.clone(),
|
||||||
|
seq: 1,
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "你好".to_string(),
|
||||||
|
media_refs: None,
|
||||||
|
tool_call_id: None,
|
||||||
|
tool_name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
created_at: 1000,
|
||||||
|
};
|
||||||
|
|
||||||
|
let seq = storage.append_message(&session_meta.id, &msg).await.unwrap();
|
||||||
|
assert_eq!(seq, 1);
|
||||||
|
|
||||||
|
let loaded = storage.load_messages(&session_meta.id, 0).await.unwrap();
|
||||||
|
assert_eq!(loaded.len(), 1);
|
||||||
|
assert_eq!(loaded[0].content, "你好");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_touch_session() {
|
||||||
|
let (storage, cleanup) = create_test_storage().await;
|
||||||
|
defer { cleanup(); }
|
||||||
|
|
||||||
|
let meta = SessionMeta {
|
||||||
|
id: "cli_chat:sid123:dialog1".to_string(),
|
||||||
|
channel: "cli_chat".to_string(),
|
||||||
|
chat_id: "sid123".to_string(),
|
||||||
|
dialog_id: "dialog1".to_string(),
|
||||||
|
title: "测试".to_string(),
|
||||||
|
created_at: 1000,
|
||||||
|
last_active_at: 1000,
|
||||||
|
message_count: 0,
|
||||||
|
routing_info: None,
|
||||||
|
deleted_at: None,
|
||||||
|
};
|
||||||
|
storage.upsert_session(&meta).await.unwrap();
|
||||||
|
|
||||||
|
storage.touch_session(&meta.id, 5, 2000).await.unwrap();
|
||||||
|
|
||||||
|
let loaded = storage.get_session(&meta.id).await.unwrap();
|
||||||
|
assert_eq!(loaded.message_count, 5);
|
||||||
|
assert_eq!(loaded.last_active_at, 2000);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
> 需要在 `Cargo.toml` 中添加 `tempfile` 依赖(已存在)。`defer` 宏可自己实现一个简单的:`fn defer<F: FnOnce()>(f: F) { f() }`
|
||||||
|
|
||||||
|
**Step 2: 运行测试**
|
||||||
|
|
||||||
|
Run: `cargo test storage::tests --lib 2>&1`
|
||||||
|
Expected: 所有 7 个测试 PASS
|
||||||
|
|
||||||
|
**Step 3: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add src/storage/mod.rs
|
||||||
|
git commit -m "test(storage): 编写 Storage 单元测试"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 汇总
|
||||||
|
|
||||||
|
| Task | 改动文件 | 关键交付物 |
|
||||||
|
|------|----------|-----------|
|
||||||
|
| 1 | `Cargo.toml` | sqlx 依赖 |
|
||||||
|
| 2 | `src/storage/error.rs` | StorageError |
|
||||||
|
| 3 | `src/storage/{mod.rs,session.rs,message.rs}`, `src/lib.rs` | 模块骨架 |
|
||||||
|
| 4 | `src/storage/mod.rs` | Storage 结构 + init_schema |
|
||||||
|
| 5 | `src/storage/session.rs`, `message.rs` | SessionMeta, MessageMeta |
|
||||||
|
| 6 | `src/storage/session.rs` | Session CRUD |
|
||||||
|
| 7 | `src/storage/message.rs` | Message CRUD |
|
||||||
|
| 8 | `src/storage/mod.rs` | append_message_with_retry |
|
||||||
|
| 9 | `src/storage/mod.rs` | 7 个单元测试 |
|
||||||
|
|
||||||
|
**Phase 1 完成后:** Storage 模块可独立使用,具备完整的持久化能力,可安全地集成到 Session 和 SessionManager 中。
|
||||||
278
docs/plans/2026-04-28-session-persistence-design.md
Normal file
278
docs/plans/2026-04-28-session-persistence-design.md
Normal file
@ -0,0 +1,278 @@
|
|||||||
|
# Session 持久化设计方案
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
为 PicoBot 添加 SQLite 持久化层,实现 Session 数据的持久化、完整 Dialog 生命周期管理、消息实时落盘、以及基于 TTL 的自动内存清理。
|
||||||
|
|
||||||
|
## 核心概念
|
||||||
|
|
||||||
|
```
|
||||||
|
UnifiedSessionId = {channel}:{chat_id}:{dialog_id}
|
||||||
|
Session = Dialog(两者等价,不再分层)
|
||||||
|
```
|
||||||
|
|
||||||
|
每个 Session 独立管理自己的消息历史、LLM 配置和路由信息。
|
||||||
|
|
||||||
|
## 数据库 Schema
|
||||||
|
|
||||||
|
### sessions 表
|
||||||
|
|
||||||
|
```sql
|
||||||
|
CREATE TABLE sessions (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
channel TEXT NOT NULL,
|
||||||
|
chat_id TEXT NOT NULL,
|
||||||
|
dialog_id TEXT NOT NULL,
|
||||||
|
title TEXT NOT NULL DEFAULT '新对话',
|
||||||
|
created_at INTEGER NOT NULL,
|
||||||
|
last_active_at INTEGER NOT NULL,
|
||||||
|
message_count INTEGER DEFAULT 0,
|
||||||
|
routing_info TEXT,
|
||||||
|
deleted_at INTEGER,
|
||||||
|
UNIQUE(channel, chat_id, dialog_id)
|
||||||
|
);
|
||||||
|
CREATE INDEX idx_sessions_chat ON sessions(channel, chat_id, deleted_at);
|
||||||
|
```
|
||||||
|
|
||||||
|
### messages 表
|
||||||
|
|
||||||
|
```sql
|
||||||
|
CREATE TABLE messages (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
session_id TEXT NOT NULL,
|
||||||
|
seq INTEGER NOT NULL,
|
||||||
|
role TEXT NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
media_refs TEXT,
|
||||||
|
tool_call_id TEXT,
|
||||||
|
tool_name TEXT,
|
||||||
|
tool_calls TEXT,
|
||||||
|
created_at INTEGER NOT NULL,
|
||||||
|
FOREIGN KEY (session_id) REFERENCES sessions(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
CREATE INDEX idx_messages_session_seq ON messages(session_id, seq);
|
||||||
|
```
|
||||||
|
|
||||||
|
## Storage API
|
||||||
|
|
||||||
|
### Session 操作
|
||||||
|
|
||||||
|
| 方法 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `new(db_path) -> Storage` | 打开/创建数据库 |
|
||||||
|
| `upsert_session(meta) -> Result<(), StorageError>` | 插入或更新 session 元数据 |
|
||||||
|
| `get_session(id) -> Result<SessionMeta, StorageError>` | 获取单个 session |
|
||||||
|
| `list_sessions(channel, chat_id, limit) -> Result<Vec<SessionMeta>>` | 最近 N 条 |
|
||||||
|
| `touch_session(id, message_count, last_active_at)` | 更新计数和最后活跃时间 |
|
||||||
|
| `soft_delete_session(id) -> Result<(), StorageError>` | 软删除 |
|
||||||
|
|
||||||
|
### Message 操作
|
||||||
|
|
||||||
|
| 方法 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `append_message(session_id, msg) -> Result<i64, StorageError>` | 追加单条消息,返回 seq |
|
||||||
|
| `append_messages(session_id, msgs) -> Result<Vec<i64>, StorageError>` | 批量追加 |
|
||||||
|
| `load_messages(session_id, from_seq) -> Result<Vec<MessageMeta>>` | 从指定 seq 加载 |
|
||||||
|
| `clear_messages(session_id) -> Result<(), StorageError>` | 清除消息(保留 session) |
|
||||||
|
|
||||||
|
### 写入失败处理
|
||||||
|
|
||||||
|
重试 3 次(100/200/300ms 退避),仍失败则发送系统通知告警。
|
||||||
|
|
||||||
|
## Session 结构
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub struct Session {
|
||||||
|
pub id: UnifiedSessionId,
|
||||||
|
pub title: String,
|
||||||
|
pub created_at: i64,
|
||||||
|
pub last_active_at: i64,
|
||||||
|
pub message_count: i64, // 用户消息计数
|
||||||
|
pub total_message_count: i64, // 含系统消息
|
||||||
|
|
||||||
|
messages: Vec<ChatMessage>, // 内存消息历史
|
||||||
|
seq_counter: i64, // 下一个消息的 seq
|
||||||
|
|
||||||
|
provider_config: LLMProviderConfig,
|
||||||
|
provider: Arc<dyn LLMProvider>,
|
||||||
|
tools: Arc<ToolRegistry>,
|
||||||
|
compressor: ContextCompressor,
|
||||||
|
user_tx: mpsc::Sender<WsOutbound>,
|
||||||
|
storage: Arc<Storage>, // 持久化 sink
|
||||||
|
routing_info: String, // JSON 路由信息
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 初始化流程
|
||||||
|
|
||||||
|
```
|
||||||
|
new() 或 from_storage()
|
||||||
|
↓
|
||||||
|
注入 storage 引用
|
||||||
|
↓
|
||||||
|
创建 provider, tools, compressor
|
||||||
|
↓
|
||||||
|
从 Storage 加载 messages(from_seq = 0)
|
||||||
|
↓
|
||||||
|
设置 seq_counter = messages.len() + 1
|
||||||
|
↓
|
||||||
|
返回 Session 实例
|
||||||
|
```
|
||||||
|
|
||||||
|
## handle_message 流程
|
||||||
|
|
||||||
|
```
|
||||||
|
handle_message(channel, chat_id, sender_id, content, media)
|
||||||
|
│
|
||||||
|
├── 1. 确定 dialog_id
|
||||||
|
│ │
|
||||||
|
│ ├── 显式传入 dialog_id → 使用
|
||||||
|
│ └── 无 dialog_id
|
||||||
|
│ ├── 查找 channel:chat_id 下最近活跃且未过期的 session
|
||||||
|
│ ├── 找到 → 使用该 session
|
||||||
|
│ └── 未找到 → 创建新 session(dialog_id = 新随机 ID)
|
||||||
|
│
|
||||||
|
├── 2. 获取或创建 Session
|
||||||
|
│ 有 → 更新 session_timestamps
|
||||||
|
│ 无 → 从 Storage 恢复 或 创建新 Session
|
||||||
|
│
|
||||||
|
├── 3. 追加用户消息并持久化
|
||||||
|
│ seq = seq_counter; seq_counter += 1
|
||||||
|
│ Storage.append_message()(失败重试 → 告警)
|
||||||
|
│ messages.push(user_msg)
|
||||||
|
│ message_count += 1
|
||||||
|
│
|
||||||
|
├── 4. 检查 title 自动生成
|
||||||
|
│ message_count == 10 且 title == 默认值 → LLM 生成 → 更新 title → 写回 Storage
|
||||||
|
│
|
||||||
|
├── 5. 注入 skills_prompt
|
||||||
|
│
|
||||||
|
├── 6. 新 session 注入欢迎消息(系统消息,不计入 message_count)
|
||||||
|
│
|
||||||
|
├── 7. 上下文压缩(如需要)
|
||||||
|
│
|
||||||
|
├── 8. 调用 AgentLoop
|
||||||
|
│
|
||||||
|
├── 9. 持久化 Agent 响应
|
||||||
|
│
|
||||||
|
└── 10. 返回响应
|
||||||
|
```
|
||||||
|
|
||||||
|
## Dialog 生命周期命令
|
||||||
|
|
||||||
|
| 命令 | 行为 |
|
||||||
|
|------|------|
|
||||||
|
| `/new [标题]` | 创建新 dialog(新随机 dialog_id),新建 Session |
|
||||||
|
| `/sessions` | 列出 channel:chat_id 下最近 10 条 session(按 last_active_at 倒序) |
|
||||||
|
| `/switch <dialog_id>` | 切换到指定 session(从 Storage 恢复或内存命中) |
|
||||||
|
| `/rename <新标题>` | 重命名当前 session |
|
||||||
|
| `/delete` | 软删除当前 session(内存移除 + Storage 标记 deleted_at) |
|
||||||
|
| `/info` | 显示当前 session 信息 |
|
||||||
|
| `/compact` | 手动触发上下文压缩 |
|
||||||
|
|
||||||
|
## 路由信息
|
||||||
|
|
||||||
|
每种 Channel 在创建 Session 时注入路由信息:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// CLI
|
||||||
|
routing_info = json!({"type": "cli", "ws_sender_id": "xxx"})
|
||||||
|
|
||||||
|
// Feishu
|
||||||
|
routing_info = json!({"type": "feishu", "open_conversation_id": "oc_xxx", "tenant_key": "xxx"})
|
||||||
|
```
|
||||||
|
|
||||||
|
## Title 自动生成
|
||||||
|
|
||||||
|
调用时机:
|
||||||
|
1. Session 首次创建时(初始 title = "新对话")
|
||||||
|
2. `message_count` 达到 10 且 title 仍为默认值时,自动更新
|
||||||
|
|
||||||
|
生成 Prompt:
|
||||||
|
```
|
||||||
|
给定以下对话历史,生成一个简短的会话标题(5-15 个中文字符),
|
||||||
|
概括这个对话的核心内容或用户的主要需求。只返回一个标题,不要解释。
|
||||||
|
|
||||||
|
历史:
|
||||||
|
{messages}
|
||||||
|
```
|
||||||
|
|
||||||
|
## TTL 清理
|
||||||
|
|
||||||
|
- 内存 session 超时 → 释放内存,Storage 记录保留
|
||||||
|
- 用户切换回该 session → 从 Storage 重新加载到内存
|
||||||
|
- Storage 中的 session 记录通过 `deleted_at` 软删除,不会物理删除
|
||||||
|
|
||||||
|
## 文件结构
|
||||||
|
|
||||||
|
```
|
||||||
|
src/
|
||||||
|
├── storage/
|
||||||
|
│ ├── mod.rs # Storage 主模块
|
||||||
|
│ ├── session.rs # Session CRUD
|
||||||
|
│ ├── message.rs # Message CRUD
|
||||||
|
│ └── error.rs # StorageError
|
||||||
|
│
|
||||||
|
└── session/
|
||||||
|
├── mod.rs # 导出 Session, SessionManager
|
||||||
|
├── session.rs # Session, SessionManager 实现
|
||||||
|
├── session_id.rs # UnifiedSessionId
|
||||||
|
├── commands.rs # SessionCommand
|
||||||
|
├── events.rs # SessionEvent, DialogInfo
|
||||||
|
└── error.rs # SessionError
|
||||||
|
```
|
||||||
|
|
||||||
|
## 实现顺序
|
||||||
|
|
||||||
|
### Phase 1: Storage 基础
|
||||||
|
1. 添加 `sqlx` + `sqlite` 依赖
|
||||||
|
2. 实现 `Storage` 结构(连接池、初始化)
|
||||||
|
3. Session CRUD + Message CRUD
|
||||||
|
4. 写入重试逻辑
|
||||||
|
5. 单元测试
|
||||||
|
|
||||||
|
### Phase 2: Session 扩展
|
||||||
|
1. 扩展 `Session` 结构(添加 storage、routing_info、计数字段、seq_counter)
|
||||||
|
2. `from_storage()` 恢复逻辑
|
||||||
|
3. `add_message` 持久化集成
|
||||||
|
4. `send_system_notification` 接口
|
||||||
|
5. Title 自动生成
|
||||||
|
|
||||||
|
### Phase 3: SessionManager 完善
|
||||||
|
1. 注入 `Arc<Storage>`
|
||||||
|
2. 实现 `list_dialogs()`
|
||||||
|
3. 实现 `switch_dialog()`
|
||||||
|
4. 实现 `delete_dialog()` / `rename_dialog()`
|
||||||
|
5. 后台 TTL 清理任务
|
||||||
|
6. 集成测试
|
||||||
|
|
||||||
|
### Phase 4: 斜杠命令
|
||||||
|
1. 实现 `/sessions`
|
||||||
|
2. 实现 `/switch`
|
||||||
|
3. 实现 `/rename`
|
||||||
|
4. 实现 `/delete`
|
||||||
|
5. 端到端测试
|
||||||
|
|
||||||
|
## 配置项
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"session": {
|
||||||
|
"ttl_hours": 24,
|
||||||
|
"cleanup_interval_minutes": 60,
|
||||||
|
"auto_title_after_n_messages": 10,
|
||||||
|
"storage_retry_delays_ms": [100, 200, 300]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 与现有代码的冲突点
|
||||||
|
|
||||||
|
| 冲突 | 处理方式 |
|
||||||
|
|------|----------|
|
||||||
|
| `DialogInfo` 有 `archived_at` | 删除该字段,改用 `deleted_at` |
|
||||||
|
| `SessionCommand::ArchiveDialog` | 删除 |
|
||||||
|
| `/new` 现有行为 | 改为创建新 session(新 dialog_id) |
|
||||||
|
| 现有 `Session` 无 storage/routing_info | 扩展结构,新增 `from_storage()` |
|
||||||
|
| `SessionManager` 需注入 `Arc<Storage>` | 扩展构造方法 |
|
||||||
|
| stub 方法 | 实现 |
|
||||||
226
docs/plans/2026-05-07-memory-system-design.md
Normal file
226
docs/plans/2026-05-07-memory-system-design.md
Normal file
@ -0,0 +1,226 @@
|
|||||||
|
# PicoBot Memory System Design
|
||||||
|
|
||||||
|
Date: 2026-05-07
|
||||||
|
|
||||||
|
## 1. Overview
|
||||||
|
|
||||||
|
Introduce a memory system that allows PicoBot agents to remember user preferences, project context, facts, and conversation history across sessions. The memory system is **unified with the existing context compression pipeline**: compression automatically produces `timeline` memory entries and advances a `last_consolidated_at` pointer to avoid redundant reprocessing.
|
||||||
|
|
||||||
|
### Design Principles
|
||||||
|
|
||||||
|
- **Compression is memory** (inspired by nanobot): when old messages are compressed, the summary is persisted — not discarded
|
||||||
|
- **FTS5 only** (no vector embeddings): keyword search via SQLite FTS5, sufficient for current scale
|
||||||
|
- **Extend existing infrastructure**: reuse `Storage` connection pool, `ContextCompressor`, `SystemPromptBuilder`
|
||||||
|
- **YAGNI**: no knowledge graph, no response cache, no namespace isolation, no audit trail
|
||||||
|
|
||||||
|
## 2. Core Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
ContextCompressor (existing) MemoryManager (new)
|
||||||
|
│ │
|
||||||
|
│ compress_if_needed() │ store / recall / forget
|
||||||
|
│ ├─ LLM summary → inject │
|
||||||
|
│ └─ store(timeline entry) ──────┘
|
||||||
|
│ └─ advance last_consolidated_at
|
||||||
|
│
|
||||||
|
SystemPromptBuilder ── recall(knowledge, limit=5) ──→ inject into system prompt
|
||||||
|
AgentLoop ── after_turn ──→ memory_store / memory_recall / memory_forget tools
|
||||||
|
```
|
||||||
|
|
||||||
|
## 3. Memory Categories
|
||||||
|
|
||||||
|
| Category | Purpose | Written By | Retrieved By |
|
||||||
|
|----------|---------|-----------|--------------|
|
||||||
|
| `knowledge` | Long-term facts, preferences, patterns, insights | Agent via `memory_store` tool | FTS5 → injected into system prompt every turn |
|
||||||
|
| `timeline` | Compressed conversation summaries | ContextCompressor automatically | FTS5 + time-range queries |
|
||||||
|
|
||||||
|
## 4. Storage Schema
|
||||||
|
|
||||||
|
### New table: `memories`
|
||||||
|
|
||||||
|
Added to the existing `Storage` initialization in `src/storage/mod.rs`:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
CREATE TABLE IF NOT EXISTS memories (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
key TEXT NOT NULL UNIQUE,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
category TEXT NOT NULL DEFAULT 'knowledge',
|
||||||
|
importance REAL NOT NULL DEFAULT 0.5,
|
||||||
|
session_id TEXT,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
updated_at TEXT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE VIRTUAL TABLE IF NOT EXISTS memory_fts USING fts5(
|
||||||
|
key,
|
||||||
|
content,
|
||||||
|
content=memories,
|
||||||
|
content_rowid=rowid
|
||||||
|
);
|
||||||
|
```
|
||||||
|
|
||||||
|
### Modified table: `sessions`
|
||||||
|
|
||||||
|
```sql
|
||||||
|
ALTER TABLE sessions ADD COLUMN last_consolidated_at INTEGER;
|
||||||
|
```
|
||||||
|
|
||||||
|
## 5. Unified Compression-Memory Pipeline
|
||||||
|
|
||||||
|
### Trigger Conditions
|
||||||
|
|
||||||
|
Compression/consolidation fires when **any** of these conditions is met:
|
||||||
|
|
||||||
|
| Condition | Value | Rationale |
|
||||||
|
|-----------|-------|-----------|
|
||||||
|
| Token budget exceeds 50% threshold | `context_window / 2` | Primary trigger — context is getting full |
|
||||||
|
| Accumulated N turns without consolidation | 3 (configurable) | Catch-up for short messages that don't hit token threshold |
|
||||||
|
| Session idle | 10 minutes (configurable) | Important for async channels like Feishu |
|
||||||
|
|
||||||
|
### Flow
|
||||||
|
|
||||||
|
```
|
||||||
|
compress_if_needed(history, session_id):
|
||||||
|
1. Read last_consolidated_at from session
|
||||||
|
→ Only compress messages after that timestamp
|
||||||
|
2. If no messages to compress → return history unchanged
|
||||||
|
3. FTS5 recall(user_input, limit=recall_limit, category=knowledge)
|
||||||
|
→ Inject relevant facts into system prompt
|
||||||
|
4. LLM summarization of old messages → [Context Summary]
|
||||||
|
→ Inject into current conversation
|
||||||
|
5. Store summary as timeline entry:
|
||||||
|
key: "ctx_{session_id}_{uuid}"
|
||||||
|
content: "[YYYY-MM-DD HH:MM] summary text..."
|
||||||
|
category: timeline
|
||||||
|
6. UPDATE sessions.last_consolidated_at = now()
|
||||||
|
7. Return compressed history
|
||||||
|
```
|
||||||
|
|
||||||
|
### timeline Entry Format
|
||||||
|
|
||||||
|
Each timeline entry follows nanobot's convention:
|
||||||
|
```
|
||||||
|
[2026-05-07 14:30] User asked about Rust async patterns. Discussed tokio::select!,
|
||||||
|
semaphore-based rate limiting, and backpressure strategies. No code was written.
|
||||||
|
```
|
||||||
|
|
||||||
|
This format is grep-friendly and human-readable.
|
||||||
|
|
||||||
|
## 6. Retrieval Strategy
|
||||||
|
|
||||||
|
### Automatic Retrieval (every turn)
|
||||||
|
|
||||||
|
`SystemPromptBuilder.build_system_prompt()` calls:
|
||||||
|
```rust
|
||||||
|
memory.recall(query=user_message, limit=recall_limit, category=knowledge)
|
||||||
|
```
|
||||||
|
|
||||||
|
Results sorted by FTS5 BM25 score, injected as:
|
||||||
|
```
|
||||||
|
## Memory Context
|
||||||
|
|
||||||
|
- user_prefers_rust: User prefers Rust for all backend projects
|
||||||
|
- project_picobot_stack: PicoBot uses Rust, axum, sqlx, ratatui, tokio
|
||||||
|
- user_workflow: User prefers TDD workflow with cargo test --lib
|
||||||
|
```
|
||||||
|
|
||||||
|
### Agent-Initiated Retrieval
|
||||||
|
|
||||||
|
Agent uses `memory_recall` tool with optional `category`, `since`, `until` parameters.
|
||||||
|
|
||||||
|
### Fallback
|
||||||
|
|
||||||
|
If FTS5 returns empty results, fallback to `LIKE '%keyword%'` on `key` and `content` columns.
|
||||||
|
|
||||||
|
## 7. Agent Tools
|
||||||
|
|
||||||
|
| Tool | Parameters | Description |
|
||||||
|
|------|-----------|-------------|
|
||||||
|
| `memory_store` | `key: str`, `content: str`, `category: str`, `importance?: f64` | Write or update a memory entry. Key is semantic identifier (e.g., "user_language_pref") |
|
||||||
|
| `memory_recall` | `query: str`, `category?: str`, `since?: i64`, `until?: i64`, `limit?: usize` | Search memories by keyword and optional filters |
|
||||||
|
| `memory_forget` | `key: str` | Delete a memory entry by key |
|
||||||
|
|
||||||
|
## 8. Error Handling & Degradation
|
||||||
|
|
||||||
|
| Scenario | Strategy |
|
||||||
|
|----------|----------|
|
||||||
|
| Consolidation LLM call fails | Log warning, increment failure counter, do NOT block main flow |
|
||||||
|
| Consecutive failures >= 3 | Degrade: append raw message dump to timeline with `[RAW]` prefix, reset counter |
|
||||||
|
| FTS5 recall returns empty | Fallback to `LIKE '%keyword%'` query |
|
||||||
|
| `memory.enabled = false` | ContextCompressor works normally, no memory writes |
|
||||||
|
| MemoryManager uninitialized | ContextCompressor works with feature-gated memory write path |
|
||||||
|
|
||||||
|
## 9. Configuration
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"memory": {
|
||||||
|
"enabled": true,
|
||||||
|
"consolidation_provider": "openai",
|
||||||
|
"consolidation_model": "gpt-4o-mini",
|
||||||
|
"recall_limit": 5,
|
||||||
|
"consolidation_turn_threshold": 3,
|
||||||
|
"idle_consolidation_minutes": 10,
|
||||||
|
"timeline_retention_days": 90,
|
||||||
|
"max_failures_before_degrade": 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
| Key | Type | Default | Description |
|
||||||
|
|-----|------|---------|-------------|
|
||||||
|
| `enabled` | bool | `false` | Master switch for memory system |
|
||||||
|
| `consolidation_provider` | string | — | Provider name for consolidation LLM calls |
|
||||||
|
| `consolidation_model` | string | — | Model name for consolidation |
|
||||||
|
| `recall_limit` | usize | `5` | Max knowledge entries injected into system prompt |
|
||||||
|
| `consolidation_turn_threshold` | usize | `3` | Turns before forced consolidation |
|
||||||
|
| `idle_consolidation_minutes` | u64 | `10` | Idle time before consolidation trigger |
|
||||||
|
| `timeline_retention_days` | u64 | `90` | Auto-cleanup age for timeline entries |
|
||||||
|
| `max_failures_before_degrade` | usize | `3` | Consecutive failures before raw archive fallback |
|
||||||
|
|
||||||
|
## 10. New Module Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
src/
|
||||||
|
├── memory/
|
||||||
|
│ ├── mod.rs # MemoryManager, MemoryConfig
|
||||||
|
│ ├── types.rs # MemoryEntry, MemoryCategory, ConsolidationResult
|
||||||
|
│ └── consolidation.rs # Consolidation prompt + LLM call logic
|
||||||
|
├── storage/
|
||||||
|
│ └── memory.rs # SQLite CRUD for memories table + FTS5
|
||||||
|
├── tools/
|
||||||
|
│ ├── memory_store.rs # memory_store tool
|
||||||
|
│ ├── memory_recall.rs # memory_recall tool
|
||||||
|
│ └── memory_forget.rs # memory_forget tool
|
||||||
|
```
|
||||||
|
|
||||||
|
## 11. Integration Points (Existing Files Modified)
|
||||||
|
|
||||||
|
| File | Change |
|
||||||
|
|------|--------|
|
||||||
|
| `src/lib.rs` | Add `pub mod memory;` |
|
||||||
|
| `src/config/mod.rs` | Add `MemoryConfig` struct and deserialization |
|
||||||
|
| `src/storage/mod.rs` | Add `pub mod memory;`, init `memories` table and FTS5 in `init_schema()` |
|
||||||
|
| `src/storage/session.rs` | Add `last_consolidated_at` column read/write |
|
||||||
|
| `src/session/session.rs` | Add `last_consolidated_at: Option<i64>` field to Session |
|
||||||
|
| `src/agent/context_compressor.rs` | Add `memory: Option<Arc<MemoryManager>>` field, write timeline on compress |
|
||||||
|
| `src/agent/system_prompt.rs` | Add `memory_context` section via `MemoryManager::recall()` |
|
||||||
|
| `src/agent/agent_loop.rs` | No changes (tools registered via ToolRegistry) |
|
||||||
|
| `src/tools/mod.rs` | Register `memory_store`, `memory_recall`, `memory_forget` in `create_default_tools()` |
|
||||||
|
| `src/gateway/mod.rs` | Initialize `MemoryManager` in `GatewayState::new()`, pass to ContextCompressor |
|
||||||
|
|
||||||
|
## 12. Implementation Order
|
||||||
|
|
||||||
|
| # | Task | Dependencies |
|
||||||
|
|---|------|-------------|
|
||||||
|
| 1 | Types: `MemoryEntry`, `MemoryCategory`, `ConsolidationResult` | — |
|
||||||
|
| 2 | Config: `MemoryConfig` + deserialization | — |
|
||||||
|
| 3 | Storage: `memories` table + FTS5 + CRUD + search | #1 |
|
||||||
|
| 4 | `MemoryManager` API | #1, #2, #3 |
|
||||||
|
| 5 | Session: `last_consolidated_at` field | — |
|
||||||
|
| 6 | `ContextCompressor` memory integration | #4, #5 |
|
||||||
|
| 7 | `SystemPromptBuilder` memory context injection | #4 |
|
||||||
|
| 8 | Agent tools: `memory_store`, `memory_recall`, `memory_forget` | #4 |
|
||||||
|
| 9 | `GatewayState` initialization wiring | #4, #5, #6 |
|
||||||
|
| 10 | Unit tests | #1-#9 |
|
||||||
1392
docs/plans/2026-05-07-memory-system-impl.md
Normal file
1392
docs/plans/2026-05-07-memory-system-impl.md
Normal file
File diff suppressed because it is too large
Load Diff
90
docs/plans/2026-05-10-incremental-session-recovery-design.md
Normal file
90
docs/plans/2026-05-10-incremental-session-recovery-design.md
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
# 启动增量恢复设计
|
||||||
|
|
||||||
|
## 问题
|
||||||
|
|
||||||
|
PicoBot 重启后,`Session::from_storage()` 全量加载 `messages` 表,恢复的 history 可能直接超出上下文窗口,首次 LLM 调用即触发压缩,浪费 token。
|
||||||
|
|
||||||
|
## 设计
|
||||||
|
|
||||||
|
### 核心思路
|
||||||
|
|
||||||
|
用 `last_compressed_message_at` 标记最后压缩时刻。恢复时:
|
||||||
|
- 加载该标记之后的原始消息
|
||||||
|
- 用该 session 的 Timeline 条目替代已压缩部分
|
||||||
|
- `seq_counter` 统一从 SQLite 查 `MAX(seq) + 1`
|
||||||
|
|
||||||
|
```
|
||||||
|
messages 表 memories(timeline)
|
||||||
|
┌──────────────────────────┐ ┌───────────────────────────┐
|
||||||
|
│ created_at = T1..T5 │ ← 跳过 │ session = feishu:oc:dialog │
|
||||||
|
│ (压缩已覆盖,用Timeline替代)│ │ created_at 降序 │
|
||||||
|
├──────────────────────────┤ ├───────────────────────────┤
|
||||||
|
│ created_at > T6 │ ← 加载 │ 只取最近 3 条 │
|
||||||
|
└──────────────────────────┘ └───────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### 数据变更
|
||||||
|
|
||||||
|
**`sessions` 表加列:**
|
||||||
|
```sql
|
||||||
|
last_compressed_message_at INTEGER
|
||||||
|
```
|
||||||
|
|
||||||
|
**`SessionMeta` / `Session` 加字段:** `last_compressed_message_at: Option<i64>`
|
||||||
|
|
||||||
|
### Storage 层新增方法
|
||||||
|
|
||||||
|
| 方法 | SQL |
|
||||||
|
|------|-----|
|
||||||
|
| `get_max_message_seq(session_id)` | `SELECT MAX(seq) FROM messages WHERE session_id = ?` |
|
||||||
|
| `load_messages_after_timestamp(session_id, after_ts)` | `WHERE created_at > ?` |
|
||||||
|
| `load_session_timelines(session_id, limit)` | `WHERE session_id = ? AND category = 'timeline' ORDER BY created_at DESC LIMIT ?` |
|
||||||
|
|
||||||
|
### 压缩跟踪
|
||||||
|
|
||||||
|
`compress_if_needed()` 返回值改为 `CompressionResult { history, created_timelines: bool }`。
|
||||||
|
`compress_once()` 中 LLM 摘要路径才置 `true`(Tier 2),Tier 1/3 不产生 Timeline。
|
||||||
|
|
||||||
|
**记录时机**(`handle_message` 正常流、溢出重试流、`/compact` 统一):
|
||||||
|
```rust
|
||||||
|
if result.created_timelines {
|
||||||
|
session.last_compressed_message_at = Some(now());
|
||||||
|
session.persist_session_meta().await;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Session::from_storage() 恢复逻辑
|
||||||
|
|
||||||
|
有压缩标记时:
|
||||||
|
1. `load_session_timelines(limit=4)` → 取 3 条给 LLM,第 4 条判"有更多"
|
||||||
|
2. 有更多 → 插入提示 user 消息
|
||||||
|
3. 逐条插入 Timeline 为 `[Previous Context]` user 消息
|
||||||
|
4. `load_messages_after_timestamp(after_ts)` → 原始尾消息
|
||||||
|
5. `repair_tool_call_chains`
|
||||||
|
|
||||||
|
无压缩标记 → 全量加载(现有行为)。
|
||||||
|
|
||||||
|
统一:`seq_counter = MAX(seq) + 1`
|
||||||
|
|
||||||
|
### 系统提示词
|
||||||
|
|
||||||
|
`Session.last_compressed_message_at` 非空时追加:
|
||||||
|
```
|
||||||
|
## 历史会话
|
||||||
|
之前的对话摘要已归档。如需回顾历史上下文,使用 `timeline_recall` 工具搜索。
|
||||||
|
```
|
||||||
|
|
||||||
|
## 改动清单
|
||||||
|
|
||||||
|
| # | 文件 | 改动 |
|
||||||
|
|---|------|------|
|
||||||
|
| 1 | `storage/session.rs` | `SessionMeta` 加 `last_compressed_message_at` |
|
||||||
|
| 2 | `storage/mod.rs` | DDL migration + upsert/get_session 加列 |
|
||||||
|
| 3 | `storage/mod.rs` | 新增 `get_max_message_seq`, `load_messages_after_timestamp` |
|
||||||
|
| 4 | `storage/memory.rs` | 新增 `load_session_timelines` |
|
||||||
|
| 5 | `agent/context_compressor.rs` | 返回值改为 `CompressionResult` 含 `created_timelines` |
|
||||||
|
| 6 | `session/session.rs` | `Session` 加字段,`persist_session_meta` 加字段 |
|
||||||
|
| 7 | `session/session.rs` | `from_storage()` 重写恢复逻辑 |
|
||||||
|
| 8 | `session/session.rs` | `handle_message()` 压缩后记录标记 |
|
||||||
|
| 9 | `session/session.rs` | `/compact` 命令压缩后记录标记 |
|
||||||
|
| 10 | `session/session.rs` | `build_system_prompt()` 注入 `last_compressed_message_at` |
|
||||||
674
docs/plans/2026-05-10-incremental-session-recovery.md
Normal file
674
docs/plans/2026-05-10-incremental-session-recovery.md
Normal file
@ -0,0 +1,674 @@
|
|||||||
|
# 启动增量恢复 Implementation Plan
|
||||||
|
|
||||||
|
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||||
|
|
||||||
|
**Goal:** PicoBot 重启后不再全量加载 messages 表,而是基于 `last_compressed_message_at` 标记增量恢复,用 Timeline 替代已压缩部分。
|
||||||
|
|
||||||
|
**Architecture:** 在 `sessions` 表加 `last_compressed_message_at` 列,`compress_if_needed` 返回值增加 `created_timelines` 标志,恢复时按时间戳增量加载消息并用近 3 条 Timeline 前置,`seq_counter` 统一从 SQLite 查 MAX(seq)。
|
||||||
|
|
||||||
|
**Tech Stack:** Rust, sqlx (SQLite), tokio
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 1: SessionMeta 和数据库 DDL 加列
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `src/storage/session.rs:15`
|
||||||
|
- Modify: `src/storage/mod.rs:44-45` (DDL), `:172-180` (migration)
|
||||||
|
- Modify: `src/storage/mod.rs:317-326` (upsert_session SQL + ON CONFLICT)
|
||||||
|
- Modify: `src/storage/mod.rs:345-369` (get_session SELECT + struct)
|
||||||
|
- Modify: `src/storage/mod.rs:380-406`, `:454-479`, `:564-588`, `:728`, `:754` (list_sessions 及测试 mock)
|
||||||
|
|
||||||
|
**Step 1: 在 `src/storage/session.rs` SessionMeta 加字段**
|
||||||
|
|
||||||
|
在 `last_consolidated_at: Option<i64>` 后加一行:
|
||||||
|
```rust
|
||||||
|
pub last_compressed_message_at: Option<i64>,
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: DDL schema 加列**
|
||||||
|
|
||||||
|
在 `src/storage/mod.rs` 的 CREATE TABLE sessions 中 (line 44),`last_consolidated_at INTEGER` 后加逗号和:
|
||||||
|
```sql
|
||||||
|
last_compressed_message_at INTEGER
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: migration 加列**
|
||||||
|
|
||||||
|
在 `src/storage/mod.rs` line 182 之后(现有 migration 的 `); .ok();` 之后),添加新 migration:
|
||||||
|
```rust
|
||||||
|
// Migration: add last_compressed_message_at column if not exists
|
||||||
|
sqlx::query(
|
||||||
|
r#"ALTER TABLE sessions ADD COLUMN last_compressed_message_at INTEGER"#,
|
||||||
|
)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await
|
||||||
|
.ok();
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: upsert_session SQL 加列**
|
||||||
|
|
||||||
|
`src/storage/mod.rs` line 317: INSERT 列列表加 `last_compressed_message_at`,VALUES 加 `?`,ON CONFLICT DO UPDATE SET 加 `last_compressed_message_at = excluded.last_compressed_message_at`。line 338 后加 `.bind(meta.last_compressed_message_at)`。
|
||||||
|
|
||||||
|
**Step 5: get_session SELECT 加列**
|
||||||
|
|
||||||
|
`src/storage/mod.rs` line 348: SELECT 列加 `last_compressed_message_at`。line 368 后加:
|
||||||
|
```rust
|
||||||
|
last_compressed_message_at: row.get("last_compressed_message_at"),
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 6: 其他 SELECT sessions 的地方(list_sessions 多个变体)**
|
||||||
|
|
||||||
|
同样补 `last_compressed_message_at` 到 SELECT 列和 struct 构造。以及测试中的 mock SessionMeta 构造(line 728, 754 等)。
|
||||||
|
|
||||||
|
**Step 7: 编译检查**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo check 2>&1
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 8: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add src/storage/session.rs src/storage/mod.rs
|
||||||
|
git commit -m "feat(storage): add last_compressed_message_at column to sessions"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 2: Storage 新增加载方法
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `src/storage/mod.rs` (在 load_messages 之后)
|
||||||
|
- Modify: `src/storage/memory.rs` (在 cleanup_old_timelines 之后)
|
||||||
|
|
||||||
|
**Step 1: `get_max_message_seq`**
|
||||||
|
|
||||||
|
在 `src/storage/mod.rs` 中 `load_messages` 函数后面添加:
|
||||||
|
```rust
|
||||||
|
pub async fn get_max_message_seq(&self, session_id: &str) -> Result<i64, StorageError> {
|
||||||
|
let row = sqlx::query(
|
||||||
|
"SELECT COALESCE(MAX(seq), 0) as max_seq FROM messages WHERE session_id = ?",
|
||||||
|
)
|
||||||
|
.bind(session_id)
|
||||||
|
.fetch_one(self.pool())
|
||||||
|
.await?;
|
||||||
|
Ok(row.get::<i64, _>("max_seq"))
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: `load_messages_after_timestamp`**
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub async fn load_messages_after_timestamp(
|
||||||
|
&self,
|
||||||
|
session_id: &str,
|
||||||
|
after_ts: i64,
|
||||||
|
) -> Result<Vec<crate::storage::message::MessageMeta>, StorageError> {
|
||||||
|
let rows = sqlx::query(
|
||||||
|
r#"
|
||||||
|
SELECT id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at
|
||||||
|
FROM messages
|
||||||
|
WHERE session_id = ? AND created_at > ?
|
||||||
|
ORDER BY seq ASC
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(session_id)
|
||||||
|
.bind(after_ts)
|
||||||
|
.fetch_all(self.pool())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(rows.into_iter().map(|row| crate::storage::message::MessageMeta {
|
||||||
|
id: row.get("id"),
|
||||||
|
session_id: row.get("session_id"),
|
||||||
|
seq: row.get("seq"),
|
||||||
|
role: row.get("role"),
|
||||||
|
content: row.get("content"),
|
||||||
|
media_refs: row.get("media_refs"),
|
||||||
|
tool_call_id: row.get("tool_call_id"),
|
||||||
|
tool_name: row.get("tool_name"),
|
||||||
|
tool_calls: row.get("tool_calls"),
|
||||||
|
source: row.get("source"),
|
||||||
|
created_at: row.get("created_at"),
|
||||||
|
}).collect())
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: `load_session_timelines`**
|
||||||
|
|
||||||
|
在 `src/storage/memory.rs` 的 `cleanup_old_timelines` 之后(line 252 的 `}` 之前)添加:
|
||||||
|
```rust
|
||||||
|
pub async fn load_session_timelines(
|
||||||
|
&self,
|
||||||
|
session_id: &str,
|
||||||
|
limit: usize,
|
||||||
|
) -> Result<Vec<MemoryEntry>, StorageError> {
|
||||||
|
let rows = sqlx::query(
|
||||||
|
r#"
|
||||||
|
SELECT id, key, content, category, importance,
|
||||||
|
session_id, created_at, updated_at
|
||||||
|
FROM memories
|
||||||
|
WHERE session_id = ? AND category = 'timeline'
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT ?
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(session_id)
|
||||||
|
.bind(limit as i64)
|
||||||
|
.fetch_all(self.pool())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
parse_memory_rows(&rows)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: 编译检查**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo check 2>&1
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 5: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add src/storage/mod.rs src/storage/memory.rs
|
||||||
|
git commit -m "feat(storage): add load_messages_after_timestamp, load_session_timelines, get_max_message_seq"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 3: ContextCompressor 引入 CompressionResult
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `src/agent/context_compressor.rs:172-274` (compress_if_needed)
|
||||||
|
- Modify: `src/agent/context_compressor.rs:320-402` (compress_once)
|
||||||
|
|
||||||
|
**Step 1: 定义 CompressionResult**
|
||||||
|
|
||||||
|
在 context_compressor.rs 中 `ContextCompressor` struct 定义之后添加:
|
||||||
|
```rust
|
||||||
|
pub struct CompressionResult {
|
||||||
|
pub history: Vec<ChatMessage>,
|
||||||
|
pub created_timelines: bool,
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: 修改 compress_if_needed 签名和返回**
|
||||||
|
|
||||||
|
将 `pub async fn compress_if_needed(&self, mut history: Vec<ChatMessage>) -> Result<Vec<ChatMessage>, AgentError>` 改为:
|
||||||
|
```rust
|
||||||
|
pub async fn compress_if_needed(
|
||||||
|
&self,
|
||||||
|
mut history: Vec<ChatMessage>,
|
||||||
|
) -> Result<CompressionResult, AgentError> {
|
||||||
|
```
|
||||||
|
|
||||||
|
内部的 `return Ok(history)` 改为 `return Ok(CompressionResult { history, created_timelines: false })`(Tier 1 fast trim 和不需要压缩时)。
|
||||||
|
|
||||||
|
**Step 3: 修改 LLM summarization pass 部分**
|
||||||
|
|
||||||
|
在压缩循环中维护一个 `created_timelines` 标志:
|
||||||
|
```rust
|
||||||
|
let mut created_timelines = false;
|
||||||
|
for pass in 0..self.config.max_passes {
|
||||||
|
// ...
|
||||||
|
match self.compress_once(...).await {
|
||||||
|
Ok(Some(compressed)) => {
|
||||||
|
current_history = compressed;
|
||||||
|
created_timelines = true;
|
||||||
|
}
|
||||||
|
// ...
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
最后返回:
|
||||||
|
```rust
|
||||||
|
Ok(CompressionResult { history: current_history, created_timelines })
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: 更新所有 compress_if_needed 调用方**
|
||||||
|
|
||||||
|
所有 `compress_if_needed(history)` 改为 `compress_if_needed(history).await?.history`。在 `handle_message` 和 `/compact` 中还需要用到 `created_timelines`。
|
||||||
|
|
||||||
|
**Step 5: 编译检查**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo check 2>&1
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 6: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add src/agent/context_compressor.rs src/session/session.rs
|
||||||
|
git commit -m "feat(compressor): return CompressionResult with created_timelines flag"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 4: Session 结构体和持久化
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `src/session/session.rs:52-74` (Session struct)
|
||||||
|
- Modify: `src/session/session.rs:76-121` (Session::new)
|
||||||
|
- Modify: `src/session/session.rs:298-320` (persist_session_meta)
|
||||||
|
|
||||||
|
**Step 1: Session struct 加字段**
|
||||||
|
|
||||||
|
在 `pub last_consolidated_at: Option<i64>` 后加:
|
||||||
|
```rust
|
||||||
|
pub last_compressed_message_at: Option<i64>,
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Session::new 初始化**
|
||||||
|
|
||||||
|
在 `last_consolidated_at: None` 后加:
|
||||||
|
```rust
|
||||||
|
last_compressed_message_at: None,
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: persist_session_meta 加字段**
|
||||||
|
|
||||||
|
在 `last_consolidated_at: self.last_consolidated_at` 后加:
|
||||||
|
```rust
|
||||||
|
last_compressed_message_at: self.last_compressed_message_at,
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: 编译检查**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo check 2>&1
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 5: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add src/session/session.rs
|
||||||
|
git commit -m "feat(session): add last_compressed_message_at field to Session and persist"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 5: Session::from_storage() 增量恢复
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `src/session/session.rs:124-189` (from_storage)
|
||||||
|
|
||||||
|
**Step 1: 重写 from_storage**
|
||||||
|
|
||||||
|
替换现有实现为:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub async fn from_storage(
|
||||||
|
id: UnifiedSessionId,
|
||||||
|
provider_config: LLMProviderConfig,
|
||||||
|
tools: Arc<ToolRegistry>,
|
||||||
|
storage: StdArc<Storage>,
|
||||||
|
memory_manager: Arc<crate::memory::MemoryManager>,
|
||||||
|
) -> Result<Self, AgentError> {
|
||||||
|
let session_meta = storage.get_session(&id.to_string()).await
|
||||||
|
.map_err(|e| AgentError::Other(format!("failed to load session from storage: {}", e)))?;
|
||||||
|
|
||||||
|
let mut provider_box = create_provider(provider_config.clone())
|
||||||
|
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
|
||||||
|
provider_box.set_storage(storage.clone());
|
||||||
|
let provider: Arc<dyn LLMProvider> = Arc::from(provider_box);
|
||||||
|
|
||||||
|
let compressor_config = ContextCompressionConfig {
|
||||||
|
protect_first_n: 2,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut compressor = ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config, memory_manager.clone());
|
||||||
|
compressor.set_session_id(Some(id.to_string()));
|
||||||
|
|
||||||
|
// Determine recovery strategy
|
||||||
|
let mut chat_messages: Vec<ChatMessage> = Vec::new();
|
||||||
|
|
||||||
|
if let Some(after_ts) = session_meta.last_compressed_message_at {
|
||||||
|
// Load last 4 timelines to determine if there are > 3
|
||||||
|
let timelines = storage
|
||||||
|
.load_session_timelines(&id.to_string(), 4)
|
||||||
|
.await
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
let has_more_timelines = timelines.len() > 3;
|
||||||
|
|
||||||
|
// Insert hint if more timelines exist
|
||||||
|
if has_more_timelines {
|
||||||
|
chat_messages.push(ChatMessage::user(
|
||||||
|
"[Earlier conversation summaries exist. \
|
||||||
|
Use `timeline_recall` to search if needed.]"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert latest 3 timelines as context (reversed: oldest first)
|
||||||
|
for tl in timelines.iter().take(3).rev() {
|
||||||
|
chat_messages.push(ChatMessage::user(format!(
|
||||||
|
"[Previous Context]\n{}", tl.content
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load raw messages after compressed timestamp
|
||||||
|
let tail = storage
|
||||||
|
.load_messages_after_timestamp(&id.to_string(), after_ts)
|
||||||
|
.await
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
let mut tail_msgs: Vec<ChatMessage> = tail.into_iter().map(|m| {
|
||||||
|
ChatMessage {
|
||||||
|
id: m.id,
|
||||||
|
role: m.role,
|
||||||
|
content: m.content,
|
||||||
|
media_refs: m.media_refs.map(|refs| serde_json::from_str(&refs).unwrap_or_default()).unwrap_or_default(),
|
||||||
|
timestamp: m.created_at,
|
||||||
|
tool_call_id: m.tool_call_id,
|
||||||
|
tool_name: m.tool_name,
|
||||||
|
tool_calls: m.tool_calls
|
||||||
|
.and_then(|tc| serde_json::from_str::<Vec<crate::providers::ToolCall>>(&tc).ok())
|
||||||
|
.filter(|v| !v.is_empty()),
|
||||||
|
source: m.source.and_then(|s| serde_json::from_str(&s).ok()),
|
||||||
|
}
|
||||||
|
}).collect();
|
||||||
|
|
||||||
|
repair_tool_call_chains(&mut tail_msgs);
|
||||||
|
chat_messages.extend(tail_msgs);
|
||||||
|
} else {
|
||||||
|
// No prior compression — load all messages (existing behavior)
|
||||||
|
let messages = storage.load_messages(&id.to_string(), 0).await
|
||||||
|
.map_err(|e| AgentError::Other(format!("failed to load messages from storage: {}", e)))?;
|
||||||
|
|
||||||
|
chat_messages = messages.into_iter().map(|m| {
|
||||||
|
ChatMessage {
|
||||||
|
id: m.id,
|
||||||
|
role: m.role,
|
||||||
|
content: m.content,
|
||||||
|
media_refs: m.media_refs.map(|refs| serde_json::from_str(&refs).unwrap_or_default()).unwrap_or_default(),
|
||||||
|
timestamp: m.created_at,
|
||||||
|
tool_call_id: m.tool_call_id,
|
||||||
|
tool_name: m.tool_name,
|
||||||
|
tool_calls: m.tool_calls
|
||||||
|
.and_then(|tc| serde_json::from_str::<Vec<crate::providers::ToolCall>>(&tc).ok())
|
||||||
|
.filter(|v| !v.is_empty()),
|
||||||
|
source: m.source.and_then(|s| serde_json::from_str(&s).ok()),
|
||||||
|
}
|
||||||
|
}).collect();
|
||||||
|
|
||||||
|
repair_tool_call_chains(&mut chat_messages);
|
||||||
|
}
|
||||||
|
|
||||||
|
// seq_counter from actual DB max
|
||||||
|
let max_seq = storage
|
||||||
|
.get_max_message_seq(&id.to_string())
|
||||||
|
.await
|
||||||
|
.unwrap_or(0);
|
||||||
|
let seq_counter = max_seq + 1;
|
||||||
|
let total_message_count = session_meta.message_count;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
id: id.clone(),
|
||||||
|
title: session_meta.title,
|
||||||
|
created_at: session_meta.created_at,
|
||||||
|
last_active_at: session_meta.last_active_at,
|
||||||
|
message_count: session_meta.message_count,
|
||||||
|
total_message_count,
|
||||||
|
messages: chat_messages,
|
||||||
|
seq_counter,
|
||||||
|
provider_config: provider_config.clone(),
|
||||||
|
provider: provider.clone(),
|
||||||
|
tools,
|
||||||
|
compressor,
|
||||||
|
storage: Some(storage),
|
||||||
|
routing_info: session_meta.routing_info.unwrap_or_default(),
|
||||||
|
last_consolidated_at: session_meta.last_consolidated_at,
|
||||||
|
last_compressed_message_at: session_meta.last_compressed_message_at,
|
||||||
|
memory_manager,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: 编译检查**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo check 2>&1
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add src/session/session.rs
|
||||||
|
git commit -m "feat(session): incremental recovery from storage using compressed timeline"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 6: 系统提示词加历史会话提示
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `src/agent/system_prompt.rs:289-304` (MemorySection)
|
||||||
|
- Modify: `src/agent/system_prompt.rs:16-23` (PromptContext)
|
||||||
|
- Modify: `src/agent/system_prompt.rs:343-358` (build_system_prompt free function)
|
||||||
|
- Modify: `src/session/session.rs:411-426` (build_system_prompt)
|
||||||
|
|
||||||
|
**Step 1: PromptContext 加 has_compressed_history 字段**
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub struct PromptContext<'a> {
|
||||||
|
pub workspace_dir: &'a Path,
|
||||||
|
pub model_name: &'a str,
|
||||||
|
pub tools: &'a ToolRegistry,
|
||||||
|
pub session_id: Option<&'a str>,
|
||||||
|
pub memory_context: Option<&'a str>,
|
||||||
|
pub has_compressed_history: bool,
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: 加 HistorySection**
|
||||||
|
|
||||||
|
在 MemorySection 后面添加:
|
||||||
|
```rust
|
||||||
|
pub struct HistorySection;
|
||||||
|
|
||||||
|
impl PromptSection for HistorySection {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"history"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build(&self, ctx: &PromptContext<'_>) -> String {
|
||||||
|
if ctx.has_compressed_history {
|
||||||
|
"## 历史会话\n之前的对话摘要已归档。如需回顾历史上下文,使用 `timeline_recall` 工具搜索。".to_string()
|
||||||
|
} else {
|
||||||
|
String::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: 注册到 SystemPromptBuilder::with_defaults**
|
||||||
|
|
||||||
|
在 `with_defaults()` 的 sections vec 中 `Box::new(MemorySection)` 后加 `Box::new(HistorySection)`。
|
||||||
|
|
||||||
|
**Step 4: 更新 build_system_prompt 签名和调用**
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub fn build_system_prompt(
|
||||||
|
workspace_dir: &Path,
|
||||||
|
model_name: &str,
|
||||||
|
tools: &ToolRegistry,
|
||||||
|
session_id: Option<&str>,
|
||||||
|
memory_context: Option<&str>,
|
||||||
|
has_compressed_history: bool,
|
||||||
|
) -> String {
|
||||||
|
let ctx = PromptContext {
|
||||||
|
workspace_dir,
|
||||||
|
model_name,
|
||||||
|
tools,
|
||||||
|
session_id,
|
||||||
|
memory_context,
|
||||||
|
has_compressed_history,
|
||||||
|
};
|
||||||
|
SystemPromptBuilder::with_defaults().build(&ctx)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 5: 更新 Session::build_system_prompt**
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub fn build_system_prompt(&self, skills_prompt: &str, memory_context: Option<&str>) -> String {
|
||||||
|
let base_prompt = build_system_prompt(
|
||||||
|
&self.provider_config.workspace_dir,
|
||||||
|
&self.provider_config.model_id,
|
||||||
|
&self.tools,
|
||||||
|
Some(&self.id.to_string()),
|
||||||
|
memory_context,
|
||||||
|
self.last_compressed_message_at.is_some(),
|
||||||
|
);
|
||||||
|
|
||||||
|
if skills_prompt.trim().is_empty() {
|
||||||
|
base_prompt
|
||||||
|
} else {
|
||||||
|
format!("{}\n\n## Skills\n\n{}\n\nUse the `get_skill` tool to load a skill's full content when needed.", base_prompt, skills_prompt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 6: 更新所有其他 build_system_prompt 调用方**
|
||||||
|
|
||||||
|
搜索 `build_system_prompt(` 的所有调用位置,每个都要加 `false` 参数。主要有 `agent/agent_loop.rs` 中的两个调用。
|
||||||
|
|
||||||
|
**Step 7: 编译检查**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo check 2>&1
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 8: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add src/agent/system_prompt.rs src/session/session.rs src/agent/agent_loop.rs
|
||||||
|
git commit -m "feat(system-prompt): add history section for archived conversation context"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 7: handle_message 和 /compact 记录压缩标记
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `src/session/session.rs:1339-1355` (handle_message 压缩后)
|
||||||
|
- Modify: `src/session/session.rs:1372-1376` (handle_message 溢出重试)
|
||||||
|
- Modify: `src/session/session.rs:851-872` (/compact 命令)
|
||||||
|
|
||||||
|
**Step 1: handle_message 正常流**
|
||||||
|
|
||||||
|
在 `compress_if_needed(history).await?` 之后(line 1346),改为:
|
||||||
|
```rust
|
||||||
|
let result = session_guard.compressor
|
||||||
|
.compress_if_needed(history)
|
||||||
|
.await?;
|
||||||
|
if result.created_timelines {
|
||||||
|
session_guard.last_compressed_message_at = Some(chrono::Utc::now().timestamp_millis());
|
||||||
|
if let Err(e) = session_guard.persist_session_meta().await {
|
||||||
|
tracing::warn!(error = %e, "Failed to persist compressed message marker");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mut history = result.history;
|
||||||
|
```
|
||||||
|
|
||||||
|
同时删除后面(line 1350-1355)单独的 `persist_session_meta` 调用(现在已合入上面的逻辑)。
|
||||||
|
|
||||||
|
**Step 2: handle_message 溢出重试流**
|
||||||
|
|
||||||
|
```rust
|
||||||
|
let raw = session_guard.get_history().to_vec();
|
||||||
|
let result = session_guard.compressor.compress_if_needed(raw).await?;
|
||||||
|
if result.created_timelines {
|
||||||
|
session_guard.last_compressed_message_at = Some(chrono::Utc::now().timestamp_millis());
|
||||||
|
let _ = session_guard.persist_session_meta().await;
|
||||||
|
}
|
||||||
|
let mut retry = result.history;
|
||||||
|
retry.insert(0, ChatMessage::system(system_prompt));
|
||||||
|
agent.process(retry).await?
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: /compact 命令**
|
||||||
|
|
||||||
|
```rust
|
||||||
|
let result = session_guard.compressor
|
||||||
|
.compress_if_needed(history)
|
||||||
|
.await?;
|
||||||
|
let compressed_count = result.history.len();
|
||||||
|
if result.created_timelines {
|
||||||
|
session_guard.last_compressed_message_at = Some(chrono::Utc::now().timestamp_millis());
|
||||||
|
let _ = session_guard.persist_session_meta().await;
|
||||||
|
}
|
||||||
|
session_guard.clear_history();
|
||||||
|
for msg in result.history {
|
||||||
|
session_guard.add_message(msg, false).await
|
||||||
|
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
同时确认 `compress_if_needed` 的 import 正常(已在 scope 中)。
|
||||||
|
|
||||||
|
**Step 4: 编译检查**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo check 2>&1
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 5: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add src/session/session.rs
|
||||||
|
git commit -m "feat(session): record last_compressed_message_at after compression"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 8: 全局编译和测试
|
||||||
|
|
||||||
|
**Step 1: 全局编译**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo check 2>&1
|
||||||
|
```
|
||||||
|
|
||||||
|
修复所有编译错误,确保全部文件一致。
|
||||||
|
|
||||||
|
**Step 2: 运行单元测试**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo test --lib 2>&1
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: 测试通过后 commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add -A
|
||||||
|
git commit -m "chore: fix remaining compilation and test issues for incremental recovery"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: 运行 lint**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo clippy --lib 2>&1 | head -50
|
||||||
|
```
|
||||||
|
|
||||||
|
修复任何 warning。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 9: 验证 & 提交设计文档
|
||||||
|
|
||||||
|
**Step 1: 最终验证**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo test --lib 2>&1
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Commit 设计文档**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add docs/plans/2026-05-10-incremental-session-recovery-design.md
|
||||||
|
git commit -m "docs: add incremental session recovery design doc"
|
||||||
|
```
|
||||||
2356
docs/superpowers/plans/2026-05-04-scheduled-tasks.md
Normal file
2356
docs/superpowers/plans/2026-05-04-scheduled-tasks.md
Normal file
File diff suppressed because it is too large
Load Diff
@ -5,7 +5,7 @@ always: true
|
|||||||
---
|
---
|
||||||
# About PicoBot
|
# About PicoBot
|
||||||
|
|
||||||
PicoBot 是一个基于 Rust 的个人 AI 助手运行时,包含本地 Gateway、CLI TUI 客户端、飞书渠道、SQLite 会话持久化、长期记忆、定时任务、Skill 系统、MCP 工具接入和子 Agent 委托能力。
|
PicoBot 是一个基于 Rust 的个人 AI 助手,支持多渠道(飞书、CLI)、长记忆、定时任务、Skill 系统等。
|
||||||
|
|
||||||
## 目录索引
|
## 目录索引
|
||||||
|
|
||||||
@ -13,10 +13,10 @@ PicoBot 是一个基于 Rust 的个人 AI 助手运行时,包含本地 Gateway
|
|||||||
|
|
||||||
| 文件 | 内容 |
|
| 文件 | 内容 |
|
||||||
|------|------|
|
|------|------|
|
||||||
| `references/config.md` | 配置字段详解:providers、models、agents、gateway、client、channels、memory、mcp、browser |
|
| `references/config.md` | 配置字段详解:providers、models、agents、gateway、memory、channels、mcp |
|
||||||
| `references/db-schema.md` | 数据库表结构:sessions、messages、memories、scheduled_jobs、llm_calls、background_tasks |
|
| `references/db-schema.md` | 数据库表结构:sessions、messages、memories、scheduled_jobs、llm_calls |
|
||||||
| `references/architecture.md` | 核心架构:数据流、会话系统、上下文压缩、记忆系统、Skill 优先级、MCP、子 Agent |
|
| `references/architecture.md` | 核心架构:数据流、会话系统、上下文压缩、记忆系统、Skill 优先级机制 |
|
||||||
| `references/faq.md` | 常见问题:模型切换、渠道添加、Skill 安装、历史查询、定时任务、MCP 等 |
|
| `references/faq.md` | 常见问题:模型切换、渠道添加、Skill 安装、历史查询、定时任务等 |
|
||||||
| `references/commands.md` | 常用命令:编译、启动网关、启动客户端、运行测试 |
|
| `references/commands.md` | 常用命令:编译、启动网关、启动客户端、运行测试 |
|
||||||
| `assets/config.example.json` | config.json 完整示例 |
|
| `assets/config.example.json` | config.json 完整示例 |
|
||||||
|
|
||||||
|
|||||||
@ -72,15 +72,5 @@
|
|||||||
"timeline_retention_days": 90,
|
"timeline_retention_days": 90,
|
||||||
"max_failures_before_degrade": 3
|
"max_failures_before_degrade": 3
|
||||||
},
|
},
|
||||||
"mcp": {
|
|
||||||
"servers": [],
|
|
||||||
"tool_timeout_secs": 180
|
|
||||||
},
|
|
||||||
"browser": {
|
|
||||||
"enabled": false,
|
|
||||||
"webdriver_url": "http://127.0.0.1:9515",
|
|
||||||
"headless": true,
|
|
||||||
"chrome_path": null
|
|
||||||
},
|
|
||||||
"workspace_dir": "~/.picobot/workspace"
|
"workspace_dir": "~/.picobot/workspace"
|
||||||
}
|
}
|
||||||
|
|||||||
@ -17,9 +17,9 @@ Channel → MessageBus → SessionManager → AgentLoop → (tools) → SessionM
|
|||||||
| `channels` | 外部集成(飞书、CLI),仅收发消息 |
|
| `channels` | 外部集成(飞书、CLI),仅收发消息 |
|
||||||
| `bus` | 异步消息队列,纯队列不路由 |
|
| `bus` | 异步消息队列,纯队列不路由 |
|
||||||
| `session` | 会话生命周期管理、dialog 操作 |
|
| `session` | 会话生命周期管理、dialog 操作 |
|
||||||
| `agent` | LLM 调用循环、工具执行、上下文压缩、媒体处理、子 Agent |
|
| `agent` | LLM 调用循环、工具执行、上下文压缩 |
|
||||||
| `providers` | LLM API 客户端(OpenAI 兼容、Anthropic) |
|
| `providers` | LLM API 客户端(OpenAI 兼容、Anthropic) |
|
||||||
| `tools` | Agent 工具(bash、文件操作、搜索、HTTP、web、browser、memory、delegate 等) |
|
| `tools` | Agent 工具(bash、文件操作、HTTP、web、get_skill 等) |
|
||||||
| `skills` | Skill 加载、管理和 prompt 构建 |
|
| `skills` | Skill 加载、管理和 prompt 构建 |
|
||||||
| `storage` | SQLite 持久化 |
|
| `storage` | SQLite 持久化 |
|
||||||
| `scheduler` | Cron 作业调度 |
|
| `scheduler` | Cron 作业调度 |
|
||||||
@ -37,8 +37,6 @@ Channel → MessageBus → SessionManager → AgentLoop → (tools) → SessionM
|
|||||||
- AgentLoop 无状态,接收 dialog 事件调用 LLM、执行工具
|
- AgentLoop 无状态,接收 dialog 事件调用 LLM、执行工具
|
||||||
- Providers 是纯 HTTP 客户端,无 bus/session/channel 感知
|
- Providers 是纯 HTTP 客户端,无 bus/session/channel 感知
|
||||||
- Tools 接收原始参数,返回字符串结果
|
- Tools 接收原始参数,返回字符串结果
|
||||||
- MCP 工具在 Gateway 初始化时连接服务器、发现工具,并包装成普通 Tool 注册到 ToolRegistry
|
|
||||||
- 子 Agent 由 `delegate` 工具创建,复用 provider 配置和按需过滤后的工具集;后台任务结果通过 MessageBus 发回原会话
|
|
||||||
|
|
||||||
## 关键约束
|
## 关键约束
|
||||||
|
|
||||||
@ -47,7 +45,6 @@ Channel → MessageBus → SessionManager → AgentLoop → (tools) → SessionM
|
|||||||
- ChannelManager 持有 MessageBus 和所有 channel
|
- ChannelManager 持有 MessageBus 和所有 channel
|
||||||
- OutboundDispatcher 通过 ChannelManager 路由出站消息
|
- OutboundDispatcher 通过 ChannelManager 路由出站消息
|
||||||
- Config `.env` 加载使用 `unsafe { env::set_var(...) }`
|
- Config `.env` 加载使用 `unsafe { env::set_var(...) }`
|
||||||
- `browser` 工具只有在 `browser.enabled=true` 时注册,依赖 Chrome/Chromium 与 WebDriver
|
|
||||||
|
|
||||||
## 上下文压缩
|
## 上下文压缩
|
||||||
|
|
||||||
@ -195,48 +192,3 @@ LLM 对话上下文接近 token 限制 (默认 128K × 70%) 时自动触发压
|
|||||||
| 有压缩历史时 | `HistorySection` 提示 LLM 使用 `timeline_recall` |
|
| 有压缩历史时 | `HistorySection` 提示 LLM 使用 `timeline_recall` |
|
||||||
| 压缩完成后 | 摘要自动存储为 Timeline 记忆 |
|
| 压缩完成后 | 摘要自动存储为 Timeline 记忆 |
|
||||||
| 空闲时 | 可配置自动 consolidation(`idle_consolidation_minutes`) |
|
| 空闲时 | 可配置自动 consolidation(`idle_consolidation_minutes`) |
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## MCP 工具集成
|
|
||||||
|
|
||||||
Gateway 初始化时读取 `config.mcp.servers`:
|
|
||||||
|
|
||||||
1. 按服务器配置连接 `stdio`、`sse` 或 `streamable-http` 传输
|
|
||||||
2. 调用 MCP `list_tools`
|
|
||||||
3. 将每个 MCP tool 包装为 `McpToolWrapper`
|
|
||||||
4. 注册到当前 session 的 `ToolRegistry`
|
|
||||||
|
|
||||||
`/mcp` 斜杠命令会显示 MCP 服务器连接状态和工具列表。
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 子 Agent / delegate
|
|
||||||
|
|
||||||
`delegate` 工具用于把独立任务交给子 Agent:
|
|
||||||
|
|
||||||
| 模式 | 行为 |
|
|
||||||
|------|------|
|
|
||||||
| `inline` | 当前轮阻塞等待子 Agent 返回 |
|
|
||||||
| `background` | 后台运行,完成后通过原 channel/chat 通知 |
|
|
||||||
| `parallel` | 多个子 Agent 并发执行并聚合结果 |
|
|
||||||
|
|
||||||
默认工具集是只读工具:`file_read`、`file_search`、`content_search`、`web_fetch`、`http_request`、`calculator`。调用时可通过 `allowed_tools` 显式放开其他工具。后台任务会写入 `background_tasks` 表,默认 24 小时后清理。
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 当前斜杠命令
|
|
||||||
|
|
||||||
| 命令 | 说明 |
|
|
||||||
|------|------|
|
|
||||||
| `/new` | 创建新对话 |
|
|
||||||
| `/sessions` | 列出最近对话 |
|
|
||||||
| `/switch <dialog_id>` | 切换到指定对话 |
|
|
||||||
| `/rename <title>` | 重命名当前对话 |
|
|
||||||
| `/delete` | 删除当前对话 |
|
|
||||||
| `/compact` | 手动触发上下文压缩 |
|
|
||||||
| `/info` | 显示当前对话信息 |
|
|
||||||
| `/dump` | 保存当前对话为 markdown |
|
|
||||||
| `/?`, `/help` | 显示帮助 |
|
|
||||||
| `/mcp` | 显示 MCP 状态 |
|
|
||||||
| `/stop` | 停止当前任务并清空消息队列 |
|
|
||||||
|
|||||||
@ -14,9 +14,8 @@
|
|||||||
"client": {}, // 客户端配置
|
"client": {}, // 客户端配置
|
||||||
"channels": {}, // 渠道配置
|
"channels": {}, // 渠道配置
|
||||||
"memory": {}, // 记忆系统配置
|
"memory": {}, // 记忆系统配置
|
||||||
"workspace_dir": "", // 工作目录,默认 ~/.picobot/workspace
|
"workspace_dir": // 工作目录,默认 ~/.picobot/workspace
|
||||||
"mcp": {}, // MCP 服务器配置
|
"mcp": {} // MCP 服务器配置
|
||||||
"browser": {} // 可选浏览器自动化配置
|
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -58,17 +57,8 @@
|
|||||||
| `session_ttl_hours` | int | - | 会话过期小时数 |
|
| `session_ttl_hours` | int | - | 会话过期小时数 |
|
||||||
| `session_db_path` | string | - | SQLite 数据库路径,默认在 workspace 下 |
|
| `session_db_path` | string | - | SQLite 数据库路径,默认在 workspace 下 |
|
||||||
| `cleanup_interval_minutes` | int | - | 清理间隔 |
|
| `cleanup_interval_minutes` | int | - | 清理间隔 |
|
||||||
| `max_concurrent_background_tasks` | int | 10 | delegate 后台子任务最大并发数 |
|
|
||||||
| `scheduler` | object | - | 调度器配置 |
|
| `scheduler` | object | - | 调度器配置 |
|
||||||
|
|
||||||
### gateway.scheduler 字段
|
|
||||||
|
|
||||||
| 字段 | 类型 | 默认 | 说明 |
|
|
||||||
|------|------|------|------|
|
|
||||||
| `enabled` | bool | true | 是否启动调度器并注册 cron 工具 |
|
|
||||||
| `poll_interval_secs` | int | 60 | 检查到期任务的轮询间隔 |
|
|
||||||
| `max_concurrent` | int | 1 | 最大并发任务数,当前实现预留 |
|
|
||||||
|
|
||||||
## memory 字段
|
## memory 字段
|
||||||
|
|
||||||
| 字段 | 类型 | 默认 | 说明 |
|
| 字段 | 类型 | 默认 | 说明 |
|
||||||
@ -104,21 +94,8 @@ MCP 服务器单条配置:
|
|||||||
| 字段 | 说明 |
|
| 字段 | 说明 |
|
||||||
|------|------|
|
|------|------|
|
||||||
| `name` | 服务器名称 |
|
| `name` | 服务器名称 |
|
||||||
| `transport` | 传输方式: `stdio`、`sse`、`streamable-http` |
|
| `transport` | 传输方式: `Stdio`、`Sse`、`streamable-http` |
|
||||||
| `command` | 启动命令(stdio 模式) |
|
| `command` | 启动命令(Stdio 模式) |
|
||||||
| `args` | 命令参数 |
|
| `args` | 命令参数 |
|
||||||
| `env` | 子进程环境变量 |
|
| `url` | URL(Sse / streamable-http 模式) |
|
||||||
| `url` | URL(sse / streamable-http 模式) |
|
|
||||||
| `headers` | HTTP 传输额外请求头 |
|
|
||||||
| `tool_timeout_secs` | 单独的超时设置 |
|
| `tool_timeout_secs` | 单独的超时设置 |
|
||||||
|
|
||||||
## browser 字段
|
|
||||||
|
|
||||||
浏览器工具默认关闭,开启后注册 `browser` 工具。依赖 Chrome/Chromium 与 chromedriver/WebDriver。
|
|
||||||
|
|
||||||
| 字段 | 类型 | 默认 | 说明 |
|
|
||||||
|------|------|------|------|
|
|
||||||
| `enabled` | bool | false | 是否启用浏览器工具 |
|
|
||||||
| `webdriver_url` | string | http://127.0.0.1:9515 | WebDriver 服务地址 |
|
|
||||||
| `headless` | bool | true | 是否无头运行 |
|
|
||||||
| `chrome_path` | string | - | 自定义 Chrome/Chromium 路径 |
|
|
||||||
|
|||||||
@ -36,28 +36,6 @@
|
|||||||
| `tool_calls` | TEXT | 工具调用参数 JSON |
|
| `tool_calls` | TEXT | 工具调用参数 JSON |
|
||||||
| `source` | TEXT | 消息来源(跨会话消息时标记来源 session_id) |
|
| `source` | TEXT | 消息来源(跨会话消息时标记来源 session_id) |
|
||||||
| `created_at` | INTEGER | 创建时间(unix 秒) |
|
| `created_at` | INTEGER | 创建时间(unix 秒) |
|
||||||
| `reasoning_content` | TEXT | provider 返回的推理内容(如有) |
|
|
||||||
|
|
||||||
## background_tasks 表
|
|
||||||
|
|
||||||
delegate 后台子任务表。`session_id` 不使用数据库外键,因为 session 使用软删除,关联关系由应用层维护。
|
|
||||||
|
|
||||||
| 字段 | 类型 | 说明 |
|
|
||||||
|------|------|------|
|
|
||||||
| `id` | TEXT PK | 后台任务 ID |
|
|
||||||
| `session_id` | TEXT | 所属会话 |
|
|
||||||
| `channel` | TEXT | 回传渠道 |
|
|
||||||
| `chat_id` | TEXT | 回传目标对话 |
|
|
||||||
| `prompt` | TEXT | 子任务提示 |
|
|
||||||
| `allowed_tools` | TEXT | 允许工具 JSON |
|
|
||||||
| `status` | TEXT | pending / running / completed / failed / cancelled |
|
|
||||||
| `result` | TEXT | 执行结果 |
|
|
||||||
| `error` | TEXT | 错误信息 |
|
|
||||||
| `tool_calls_count` | INTEGER | 工具调用次数 |
|
|
||||||
| `iterations` | INTEGER | Agent 迭代次数 |
|
|
||||||
| `started_at` | INTEGER | 开始时间 |
|
|
||||||
| `finished_at` | INTEGER | 结束时间 |
|
|
||||||
| `created_at` | INTEGER | 创建时间 |
|
|
||||||
|
|
||||||
## memories 表
|
## memories 表
|
||||||
|
|
||||||
|
|||||||
@ -124,51 +124,9 @@
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## delegate — 子 Agent 委托
|
## file_read / file_write / file_edit / file_search — 文件操作
|
||||||
|
|
||||||
创建子 Agent 处理独立任务。
|
工作目录内的文件读写编辑和搜索。详细的参数定义见各工具的 parameters_schema。
|
||||||
|
|
||||||
| 参数 | 必填 | 说明 |
|
|
||||||
|------|------|------|
|
|
||||||
| `action` | 是 | `run`, `check_task`, `cancel_task`, `list_tasks` |
|
|
||||||
| `prompt` | run 必填 | 子任务描述 |
|
|
||||||
| `mode` | 否 | `inline`, `background`, `parallel`,默认 `inline` |
|
|
||||||
| `allowed_tools` | 否 | 子 Agent 可用工具列表;默认只读工具集 |
|
|
||||||
| `max_iterations` | 否 | 最大迭代次数,默认 99 |
|
|
||||||
| `timeout_secs` | 否 | 超时秒数,默认 3600 |
|
|
||||||
| `tasks` | parallel 必填 | 并行子任务数组 |
|
|
||||||
| `task_id` | 查询/取消必填 | 后台任务 ID |
|
|
||||||
|
|
||||||
默认只读工具集:`file_read`、`file_search`、`content_search`、`web_fetch`、`http_request`、`calculator`。
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## browser — 浏览器自动化
|
|
||||||
|
|
||||||
仅在 `browser.enabled=true` 时注册。底层使用 WebDriver/Chrome。
|
|
||||||
|
|
||||||
| action | 说明 |
|
|
||||||
|--------|------|
|
|
||||||
| `open` | 打开 URL |
|
|
||||||
| `snapshot` | 获取页面结构快照 |
|
|
||||||
| `click`, `click_at` | 点击元素或坐标 |
|
|
||||||
| `fill`, `type`, `press` | 输入文本或按键 |
|
|
||||||
| `get_text`, `get_title`, `get_url` | 读取页面信息 |
|
|
||||||
| `screenshot` | 截图,可写入文件或返回 base64 |
|
|
||||||
| `focus`, `hover`, `scroll`, `wait` | 常见交互和等待 |
|
|
||||||
| `close` | 关闭浏览器会话 |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## MCP 工具
|
|
||||||
|
|
||||||
如果 `config.mcp.servers` 配置了 MCP 服务器,Gateway 启动时会连接服务器、发现工具,并把 MCP 工具包装后注册到 ToolRegistry。使用 `/mcp` 查看当前连接状态和工具列表。
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## file_read / file_write / file_edit / file_search / content_search — 文件操作和搜索
|
|
||||||
|
|
||||||
工作目录内的文件读写编辑、文件名搜索和内容搜索。详细的参数定义见各工具的 parameters_schema。
|
|
||||||
|
|
||||||
## bash — 执行命令
|
## bash — 执行命令
|
||||||
|
|
||||||
|
|||||||
@ -72,15 +72,5 @@
|
|||||||
"timeline_retention_days": 90,
|
"timeline_retention_days": 90,
|
||||||
"max_failures_before_degrade": 3
|
"max_failures_before_degrade": 3
|
||||||
},
|
},
|
||||||
"mcp": {
|
|
||||||
"servers": [],
|
|
||||||
"tool_timeout_secs": 180
|
|
||||||
},
|
|
||||||
"browser": {
|
|
||||||
"enabled": false,
|
|
||||||
"webdriver_url": "http://127.0.0.1:9515",
|
|
||||||
"headless": true,
|
|
||||||
"chrome_path": null
|
|
||||||
},
|
|
||||||
"workspace_dir": "~/.picobot/workspace"
|
"workspace_dir": "~/.picobot/workspace"
|
||||||
}
|
}
|
||||||
|
|||||||
@ -4,8 +4,10 @@ use crate::agent::system_prompt::build_system_prompt;
|
|||||||
use crate::bus::message::ContentBlock;
|
use crate::bus::message::ContentBlock;
|
||||||
use crate::bus::{ChatMessage, MediaRef};
|
use crate::bus::{ChatMessage, MediaRef};
|
||||||
use crate::config::LLMProviderConfig;
|
use crate::config::LLMProviderConfig;
|
||||||
use crate::observability::{Observer, ObserverEvent, ToolExecutionOutcome, truncate_args};
|
use crate::observability::{
|
||||||
use crate::providers::{ChatCompletionRequest, LLMProvider, Message, ToolCall, create_provider};
|
truncate_args, Observer, ObserverEvent, ToolExecutionOutcome,
|
||||||
|
};
|
||||||
|
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall};
|
||||||
use crate::tools::ToolRegistry;
|
use crate::tools::ToolRegistry;
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
use std::hash::{Hash, Hasher};
|
use std::hash::{Hash, Hasher};
|
||||||
@ -226,7 +228,6 @@ pub struct AgentLoop {
|
|||||||
pub struct AgentProcessResult {
|
pub struct AgentProcessResult {
|
||||||
pub final_response: ChatMessage,
|
pub final_response: ChatMessage,
|
||||||
pub emitted_messages: Vec<ChatMessage>,
|
pub emitted_messages: Vec<ChatMessage>,
|
||||||
pub total_tokens: Option<u32>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AgentLoop {
|
impl AgentLoop {
|
||||||
@ -254,10 +255,7 @@ impl AgentLoop {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Create a new AgentLoop with provider created from config and given tools.
|
/// Create a new AgentLoop with provider created from config and given tools.
|
||||||
pub fn with_tools(
|
pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>) -> Result<Self, AgentError> {
|
||||||
provider_config: LLMProviderConfig,
|
|
||||||
tools: Arc<ToolRegistry>,
|
|
||||||
) -> Result<Self, AgentError> {
|
|
||||||
let max_iterations = provider_config.max_tool_iterations;
|
let max_iterations = provider_config.max_tool_iterations;
|
||||||
let model_name = provider_config.model_id.clone();
|
let model_name = provider_config.model_id.clone();
|
||||||
let workspace_dir = provider_config.workspace_dir.clone();
|
let workspace_dir = provider_config.workspace_dir.clone();
|
||||||
@ -280,13 +278,7 @@ impl AgentLoop {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Create a new AgentLoop with an existing shared provider.
|
/// Create a new AgentLoop with an existing shared provider.
|
||||||
pub fn with_provider(
|
pub fn with_provider(provider: Arc<dyn LLMProvider>, max_iterations: usize, model_name: String, workspace_dir: PathBuf, input_types: Vec<String>) -> Self {
|
||||||
provider: Arc<dyn LLMProvider>,
|
|
||||||
max_iterations: usize,
|
|
||||||
model_name: String,
|
|
||||||
workspace_dir: PathBuf,
|
|
||||||
input_types: Vec<String>,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
Self {
|
||||||
provider,
|
provider,
|
||||||
tools: Arc::new(ToolRegistry::new()),
|
tools: Arc::new(ToolRegistry::new()),
|
||||||
@ -348,9 +340,8 @@ impl AgentLoop {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Preemptive trim: truncate old tool results in-place when history is
|
/// Preemptive trim: truncate old tool results in-place when history is
|
||||||
/// approaching the context window limit. Old results (outside of `keep_recent`
|
/// approaching the context window limit. Only trims tool messages with
|
||||||
/// zone) are replaced with a short placeholder; recent results are truncated
|
/// content > TRIM_CHARS, preserving the most recent KEEP messages.
|
||||||
/// to `max_chars`.
|
|
||||||
fn preemptive_trim_old_tool_results(
|
fn preemptive_trim_old_tool_results(
|
||||||
&self,
|
&self,
|
||||||
messages: &mut [ChatMessage],
|
messages: &mut [ChatMessage],
|
||||||
@ -367,11 +358,11 @@ impl AgentLoop {
|
|||||||
if messages[i].content.len() <= max_chars {
|
if messages[i].content.len() <= max_chars {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
let tool_name = messages[i].tool_name.as_deref().unwrap_or("unknown");
|
let removed = messages[i].content.len() - max_chars;
|
||||||
let chars = messages[i].content.len();
|
|
||||||
messages[i].content = format!(
|
messages[i].content = format!(
|
||||||
"[Tool output ({}) — {} chars, omitted from context]",
|
"{}...\n\n[Output truncated - {} characters removed]",
|
||||||
tool_name, chars
|
&messages[i].content[..messages[i].content.ceil_char_boundary(max_chars)],
|
||||||
|
removed
|
||||||
);
|
);
|
||||||
modified += 1;
|
modified += 1;
|
||||||
}
|
}
|
||||||
@ -386,12 +377,7 @@ impl AgentLoop {
|
|||||||
let content = if m.media_refs.is_empty() {
|
let content = if m.media_refs.is_empty() {
|
||||||
vec![ContentBlock::text(&m.content)]
|
vec![ContentBlock::text(&m.content)]
|
||||||
} else {
|
} else {
|
||||||
build_content_blocks(
|
build_content_blocks(&m.content, &m.media_refs, &self.input_types, &self.media_registry)
|
||||||
&m.content,
|
|
||||||
&m.media_refs,
|
|
||||||
&self.input_types,
|
|
||||||
&self.media_registry,
|
|
||||||
)
|
|
||||||
};
|
};
|
||||||
|
|
||||||
Message {
|
Message {
|
||||||
@ -411,28 +397,14 @@ impl AgentLoop {
|
|||||||
/// it loops back to the LLM with the tool results until either:
|
/// it loops back to the LLM with the tool results until either:
|
||||||
/// - The LLM returns no more tool calls (final response)
|
/// - The LLM returns no more tool calls (final response)
|
||||||
/// - Maximum iterations are reached
|
/// - Maximum iterations are reached
|
||||||
pub async fn process(
|
pub async fn process(&self, mut messages: Vec<ChatMessage>) -> Result<AgentProcessResult, AgentError> {
|
||||||
&self,
|
|
||||||
mut messages: Vec<ChatMessage>,
|
|
||||||
) -> Result<AgentProcessResult, AgentError> {
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(
|
tracing::debug!(history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process");
|
||||||
history_len = messages.len(),
|
|
||||||
max_iterations = self.max_iterations,
|
|
||||||
"Starting agent process"
|
|
||||||
);
|
|
||||||
|
|
||||||
// Build and inject system prompt if not present
|
// Build and inject system prompt if not present
|
||||||
let has_system = messages.first().is_some_and(|m| m.role == "system");
|
let has_system = messages.first().is_some_and(|m| m.role == "system");
|
||||||
if !has_system {
|
if !has_system {
|
||||||
let system_prompt = build_system_prompt(
|
let system_prompt = build_system_prompt(&self.workspace_dir, &self.model_name, &self.tools, None, None, false);
|
||||||
&self.workspace_dir,
|
|
||||||
&self.model_name,
|
|
||||||
&self.tools,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
false,
|
|
||||||
);
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!("System prompt injected:\n{}", system_prompt);
|
tracing::debug!("System prompt injected:\n{}", system_prompt);
|
||||||
messages.insert(0, ChatMessage::system(system_prompt));
|
messages.insert(0, ChatMessage::system(system_prompt));
|
||||||
@ -441,7 +413,6 @@ impl AgentLoop {
|
|||||||
// Track tool calls for loop detection
|
// Track tool calls for loop detection
|
||||||
let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default());
|
let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default());
|
||||||
let mut emitted_messages = Vec::new();
|
let mut emitted_messages = Vec::new();
|
||||||
let mut accumulated_tokens: u32 = 0;
|
|
||||||
|
|
||||||
for iteration in 0..self.max_iterations {
|
for iteration in 0..self.max_iterations {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
@ -453,7 +424,9 @@ impl AgentLoop {
|
|||||||
let estimated = estimate_tokens(&messages);
|
let estimated = estimate_tokens(&messages);
|
||||||
let danger = (self.context_window as f64 * 0.8) as usize;
|
let danger = (self.context_window as f64 * 0.8) as usize;
|
||||||
if estimated > danger {
|
if estimated > danger {
|
||||||
let trimmed = self.preemptive_trim_old_tool_results(&mut messages, 2000, 4);
|
let trimmed = self.preemptive_trim_old_tool_results(
|
||||||
|
&mut messages, 2000, 4,
|
||||||
|
);
|
||||||
if trimmed > 0 {
|
if trimmed > 0 {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
@ -487,13 +460,12 @@ impl AgentLoop {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Call LLM
|
// Call LLM
|
||||||
let response = (*self.provider).chat(request).await.map_err(|e| {
|
let response = (*self.provider).chat(request).await
|
||||||
|
.map_err(|e| {
|
||||||
tracing::error!(error = %e, "LLM request failed");
|
tracing::error!(error = %e, "LLM request failed");
|
||||||
AgentError::LlmError(e.to_string())
|
AgentError::LlmError(e.to_string())
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
accumulated_tokens += response.usage.total_tokens;
|
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
iteration,
|
iteration,
|
||||||
@ -510,15 +482,12 @@ impl AgentLoop {
|
|||||||
return Ok(AgentProcessResult {
|
return Ok(AgentProcessResult {
|
||||||
final_response: assistant_message,
|
final_response: assistant_message,
|
||||||
emitted_messages,
|
emitted_messages,
|
||||||
total_tokens: Some(accumulated_tokens),
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute tool calls — log and notify immediately
|
// Execute tool calls — log and notify immediately
|
||||||
{
|
{
|
||||||
let tools_info: Vec<String> = response
|
let tools_info: Vec<String> = response.tool_calls.iter()
|
||||||
.tool_calls
|
|
||||||
.iter()
|
|
||||||
.map(|tc| {
|
.map(|tc| {
|
||||||
let args = serde_json::to_string(&tc.arguments).unwrap_or_default();
|
let args = serde_json::to_string(&tc.arguments).unwrap_or_default();
|
||||||
let s = format!("{}:{}", tc.name, args);
|
let s = format!("{}:{}", tc.name, args);
|
||||||
@ -547,9 +516,7 @@ impl AgentLoop {
|
|||||||
// Log function call with name and arguments
|
// Log function call with name and arguments
|
||||||
let args_str = match &tool_call.arguments {
|
let args_str = match &tool_call.arguments {
|
||||||
serde_json::Value::Object(obj) if obj.is_empty() => "{}".to_string(),
|
serde_json::Value::Object(obj) if obj.is_empty() => "{}".to_string(),
|
||||||
other => {
|
other => serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()),
|
||||||
serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string())
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
tracing::info!(tool = %tool_call.name, args = %args_str, "Calling tool");
|
tracing::info!(tool = %tool_call.name, args = %args_str, "Calling tool");
|
||||||
|
|
||||||
@ -589,11 +556,7 @@ impl AgentLoop {
|
|||||||
|
|
||||||
// Loop continues to next iteration with updated messages
|
// Loop continues to next iteration with updated messages
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(
|
tracing::debug!(iteration, message_count = messages.len(), "Tool execution complete, continuing to next iteration");
|
||||||
iteration,
|
|
||||||
message_count = messages.len(),
|
|
||||||
"Tool execution complete, continuing to next iteration"
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Max iterations reached - ask LLM for a summary based on completed work
|
// Max iterations reached - ask LLM for a summary based on completed work
|
||||||
@ -602,7 +565,7 @@ impl AgentLoop {
|
|||||||
// Add a message asking for summary
|
// Add a message asking for summary
|
||||||
let summary_request = ChatMessage::user(
|
let summary_request = ChatMessage::user(
|
||||||
"You have reached the maximum number of tool call iterations. \
|
"You have reached the maximum number of tool call iterations. \
|
||||||
Please provide your best answer based on the work completed so far.",
|
Please provide your best answer based on the work completed so far."
|
||||||
);
|
);
|
||||||
messages.push(summary_request);
|
messages.push(summary_request);
|
||||||
|
|
||||||
@ -621,32 +584,24 @@ impl AgentLoop {
|
|||||||
|
|
||||||
match (*self.provider).chat(request).await {
|
match (*self.provider).chat(request).await {
|
||||||
Ok(response) => {
|
Ok(response) => {
|
||||||
accumulated_tokens += response.usage.total_tokens;
|
|
||||||
let mut assistant_message = ChatMessage::assistant(response.content);
|
let mut assistant_message = ChatMessage::assistant(response.content);
|
||||||
assistant_message.reasoning_content = response.reasoning_content;
|
assistant_message.reasoning_content = response.reasoning_content;
|
||||||
emitted_messages.push(assistant_message.clone());
|
emitted_messages.push(assistant_message.clone());
|
||||||
Ok(AgentProcessResult {
|
Ok(AgentProcessResult {
|
||||||
final_response: assistant_message,
|
final_response: assistant_message,
|
||||||
emitted_messages,
|
emitted_messages,
|
||||||
total_tokens: Some(accumulated_tokens),
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
// Fallback if summary call fails
|
// Fallback if summary call fails
|
||||||
tracing::error!(error = %e, "Failed to get summary from LLM");
|
tracing::error!(error = %e, "Failed to get summary from LLM");
|
||||||
let final_message = ChatMessage::assistant(format!(
|
let final_message = ChatMessage::assistant(
|
||||||
"I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.",
|
format!("I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.", self.max_iterations)
|
||||||
self.max_iterations
|
);
|
||||||
));
|
|
||||||
emitted_messages.push(final_message.clone());
|
emitted_messages.push(final_message.clone());
|
||||||
Ok(AgentProcessResult {
|
Ok(AgentProcessResult {
|
||||||
final_response: final_message,
|
final_response: final_message,
|
||||||
emitted_messages,
|
emitted_messages,
|
||||||
total_tokens: if accumulated_tokens > 0 {
|
|
||||||
Some(accumulated_tokens)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -734,7 +689,10 @@ impl AgentLoop {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Apply duration
|
// Apply duration
|
||||||
ToolExecutionOutcome { duration, ..result }
|
ToolExecutionOutcome {
|
||||||
|
duration,
|
||||||
|
..result
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Internal tool execution without event tracking.
|
/// Internal tool execution without event tracking.
|
||||||
@ -756,12 +714,18 @@ impl AgentLoop {
|
|||||||
ToolExecutionOutcome::success(result.output)
|
ToolExecutionOutcome::success(result.output)
|
||||||
} else {
|
} else {
|
||||||
let error = result.error.unwrap_or_default();
|
let error = result.error.unwrap_or_default();
|
||||||
ToolExecutionOutcome::failure(format!("Error: {}", error), Some(error))
|
ToolExecutionOutcome::failure(
|
||||||
|
format!("Error: {}", error),
|
||||||
|
Some(error),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!(tool = %tool_call.name, error = %e, "Tool execution failed");
|
tracing::error!(tool = %tool_call.name, error = %e, "Tool execution failed");
|
||||||
ToolExecutionOutcome::failure(format!("Error: {}", e), Some(e.to_string()))
|
ToolExecutionOutcome::failure(
|
||||||
|
format!("Error: {}", e),
|
||||||
|
Some(e.to_string()),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -849,14 +813,8 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(provider_message.role, "assistant");
|
assert_eq!(provider_message.role, "assistant");
|
||||||
assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1);
|
assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1);
|
||||||
assert_eq!(
|
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].id, "call_1");
|
||||||
provider_message.tool_calls.as_ref().unwrap()[0].id,
|
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].name, "calculator");
|
||||||
"call_1"
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
provider_message.tool_calls.as_ref().unwrap()[0].name,
|
|
||||||
"calculator"
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -68,10 +68,6 @@ pub struct ContextCompressor {
|
|||||||
memory: Arc<MemoryManager>,
|
memory: Arc<MemoryManager>,
|
||||||
/// Current session ID for timeline memory writes.
|
/// Current session ID for timeline memory writes.
|
||||||
session_id: Option<String>,
|
session_id: Option<String>,
|
||||||
/// Message count sent in the last LLM call (used to split known/new history).
|
|
||||||
last_sent_message_count: Option<usize>,
|
|
||||||
/// Real total_tokens from the last API response.
|
|
||||||
last_api_total_tokens: Option<u32>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Result of context compression.
|
/// Result of context compression.
|
||||||
@ -80,15 +76,6 @@ pub struct CompressionResult {
|
|||||||
pub created_timelines: bool,
|
pub created_timelines: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Token budget state snapshot for diagnostics.
|
|
||||||
pub struct TokenInfo {
|
|
||||||
pub context_window: usize,
|
|
||||||
pub threshold: usize,
|
|
||||||
pub estimated_tokens: usize,
|
|
||||||
pub last_api_tokens: Option<u32>,
|
|
||||||
pub cache_active: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ContextCompressor {
|
impl ContextCompressor {
|
||||||
/// Create a new compressor with the given provider, context window size, and memory manager.
|
/// Create a new compressor with the given provider, context window size, and memory manager.
|
||||||
pub fn new(
|
pub fn new(
|
||||||
@ -103,8 +90,6 @@ impl ContextCompressor {
|
|||||||
provider,
|
provider,
|
||||||
memory,
|
memory,
|
||||||
session_id: None,
|
session_id: None,
|
||||||
last_sent_message_count: None,
|
|
||||||
last_api_total_tokens: None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -122,8 +107,6 @@ impl ContextCompressor {
|
|||||||
provider,
|
provider,
|
||||||
memory,
|
memory,
|
||||||
session_id: None,
|
session_id: None,
|
||||||
last_sent_message_count: None,
|
|
||||||
last_api_total_tokens: None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -137,91 +120,39 @@ impl ContextCompressor {
|
|||||||
self.context_window = window;
|
self.context_window = window;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Record the API's reported token usage from the last completed turn.
|
|
||||||
/// `msg_count`: number of messages sent to LLM in that call.
|
|
||||||
/// `tokens`: `total_tokens` from the API response.
|
|
||||||
pub fn set_last_api_info(&mut self, msg_count: usize, tokens: Option<u32>) {
|
|
||||||
self.last_sent_message_count = Some(msg_count);
|
|
||||||
self.last_api_total_tokens = tokens;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Invalidate the cached API token info — called after compression modifies messages.
|
|
||||||
fn invalidate_token_cache(&mut self) {
|
|
||||||
self.last_sent_message_count = None;
|
|
||||||
self.last_api_total_tokens = None;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Hybrid token estimation: API-reported tokens for known history +
|
|
||||||
/// char/4 estimate for new messages since last API call.
|
|
||||||
fn token_estimate_with_history(&self, messages: &[ChatMessage]) -> usize {
|
|
||||||
match (self.last_api_total_tokens, self.last_sent_message_count) {
|
|
||||||
(Some(known), Some(known_count)) if messages.len() > known_count => {
|
|
||||||
let delta = &messages[known_count..];
|
|
||||||
known as usize + estimate_tokens(delta)
|
|
||||||
}
|
|
||||||
(Some(known), _) => known as usize,
|
|
||||||
_ => estimate_tokens(messages),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Always true — memory is always available (memory system is always on).
|
/// Always true — memory is always available (memory system is always on).
|
||||||
pub fn has_memory(&self) -> bool {
|
pub fn has_memory(&self) -> bool {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get a snapshot of the current token budget state for diagnostics.
|
|
||||||
pub fn token_info(&self, messages: &[ChatMessage]) -> TokenInfo {
|
|
||||||
TokenInfo {
|
|
||||||
context_window: self.context_window,
|
|
||||||
threshold: self.threshold(),
|
|
||||||
estimated_tokens: self.token_estimate_with_history(messages),
|
|
||||||
last_api_tokens: self.last_api_total_tokens,
|
|
||||||
cache_active: self.last_api_total_tokens.is_some(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the compression threshold in tokens.
|
/// Get the compression threshold in tokens.
|
||||||
pub fn threshold(&self) -> usize {
|
pub fn threshold(&self) -> usize {
|
||||||
(self.context_window as f64 * self.threshold_ratio) as usize
|
(self.context_window as f64 * self.threshold_ratio) as usize
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Fast-path: trim oversized tool results without LLM call.
|
/// Fast-path: trim oversized tool results without LLM call.
|
||||||
/// Old tool results (outside of `protect_tail` zone) are replaced with a
|
|
||||||
/// concise placeholder; recent results are truncated to `tool_result_trim_chars`.
|
|
||||||
/// Returns the number of messages modified.
|
/// Returns the number of messages modified.
|
||||||
fn fast_trim_tool_results(&self, messages: &mut [ChatMessage], protect_tail: usize) -> usize {
|
fn fast_trim_tool_results(&self, messages: &mut [ChatMessage]) -> usize {
|
||||||
let limit = self.config.tool_result_trim_chars;
|
let limit = self.config.tool_result_trim_chars;
|
||||||
let tail_start = messages.len().saturating_sub(protect_tail);
|
|
||||||
let mut modified = 0;
|
let mut modified = 0;
|
||||||
|
|
||||||
for (i, msg) in messages.iter_mut().enumerate() {
|
for msg in messages.iter_mut() {
|
||||||
if msg.role != "tool" || msg.content.len() <= limit {
|
if msg.role == "tool" && msg.content.len() > limit {
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if i < tail_start {
|
|
||||||
let tool_name = msg.tool_name.as_deref().unwrap_or("unknown");
|
|
||||||
let chars = msg.content.len();
|
|
||||||
msg.content = format!(
|
|
||||||
"[Tool output ({}) — {} chars, omitted from context]",
|
|
||||||
tool_name, chars
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
let removed = msg.content.len() - limit;
|
let removed = msg.content.len() - limit;
|
||||||
msg.content = format!(
|
msg.content = format!(
|
||||||
"{}...\n\n[Output truncated - {} characters removed]",
|
"{}...\n\n[Output truncated - {} characters removed]",
|
||||||
&msg.content[..msg.content.ceil_char_boundary(limit)],
|
&msg.content[..msg.content.ceil_char_boundary(limit)],
|
||||||
removed
|
removed
|
||||||
);
|
);
|
||||||
}
|
|
||||||
modified += 1;
|
modified += 1;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
modified
|
modified
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Repair tool call chains after compression.
|
/// Remove orphan tool results whose declaring tool_calls have been compressed away.
|
||||||
/// Phase 1: remove orphan tool results whose declaring tool_calls are missing.
|
/// Scans for tool messages with no preceding assistant tool_call, and removes them.
|
||||||
/// Phase 2: strip tool_calls from assistants whose results are missing.
|
|
||||||
pub fn repair_tool_pairs(messages: &mut Vec<ChatMessage>) {
|
pub fn repair_tool_pairs(messages: &mut Vec<ChatMessage>) {
|
||||||
let mut declared: std::collections::HashSet<String> = std::collections::HashSet::new();
|
let mut declared: std::collections::HashSet<String> = std::collections::HashSet::new();
|
||||||
let mut i = 0;
|
let mut i = 0;
|
||||||
@ -234,58 +165,23 @@ impl ContextCompressor {
|
|||||||
}
|
}
|
||||||
} else if messages[i].role == "tool"
|
} else if messages[i].role == "tool"
|
||||||
&& let Some(ref tid) = messages[i].tool_call_id
|
&& let Some(ref tid) = messages[i].tool_call_id
|
||||||
&& !declared.contains(tid.as_str())
|
&& !declared.contains(tid.as_str()) {
|
||||||
{
|
|
||||||
messages.remove(i);
|
messages.remove(i);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
i += 1;
|
i += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
let broken: Vec<usize> = messages
|
|
||||||
.iter()
|
|
||||||
.enumerate()
|
|
||||||
.filter_map(|(idx, msg)| {
|
|
||||||
if msg.role == "assistant"
|
|
||||||
&& let Some(ref tcs) = msg.tool_calls
|
|
||||||
&& !tcs.is_empty()
|
|
||||||
{
|
|
||||||
let all_present = tcs.iter().all(|tc| {
|
|
||||||
messages.iter().any(|m| {
|
|
||||||
m.role == "tool" && m.tool_call_id.as_deref() == Some(tc.id.as_str())
|
|
||||||
})
|
|
||||||
});
|
|
||||||
if !all_present { Some(idx) } else { None }
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
for idx in broken {
|
|
||||||
let msg = &mut messages[idx];
|
|
||||||
let tcs = msg.tool_calls.take().unwrap_or_default();
|
|
||||||
let names: Vec<&str> = tcs.iter().map(|tc| tc.name.as_str()).collect();
|
|
||||||
msg.content = format!(
|
|
||||||
"{}\n\n[Tool calls ({}) — results are no longer available]",
|
|
||||||
msg.content,
|
|
||||||
names.join(", ")
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Main entry point - compresses history if over threshold.
|
/// Main entry point - compresses history if over threshold.
|
||||||
pub async fn compress_if_needed(
|
pub async fn compress_if_needed(
|
||||||
&mut self,
|
&self,
|
||||||
mut history: Vec<ChatMessage>,
|
mut history: Vec<ChatMessage>,
|
||||||
) -> Result<CompressionResult, AgentError> {
|
) -> Result<CompressionResult, AgentError> {
|
||||||
// Check if compression is needed
|
// Check if compression is needed
|
||||||
let tokens = self.token_estimate_with_history(&history);
|
let tokens = estimate_tokens(&history);
|
||||||
if tokens <= self.threshold() {
|
if tokens <= self.threshold() {
|
||||||
return Ok(CompressionResult {
|
return Ok(CompressionResult { history, created_timelines: false });
|
||||||
history,
|
|
||||||
created_timelines: false,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
@ -297,8 +193,8 @@ impl ContextCompressor {
|
|||||||
);
|
);
|
||||||
|
|
||||||
// Fast trim pass first — modify history in place
|
// Fast trim pass first — modify history in place
|
||||||
let trimmed = self.fast_trim_tool_results(&mut history, self.config.protect_last_n);
|
let trimmed = self.fast_trim_tool_results(&mut history);
|
||||||
let tokens_after = self.token_estimate_with_history(&history);
|
let tokens_after = estimate_tokens(&history);
|
||||||
if trimmed > 0 {
|
if trimmed > 0 {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
@ -308,24 +204,24 @@ impl ContextCompressor {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
if tokens_after <= self.threshold() {
|
if tokens_after <= self.threshold() {
|
||||||
self.invalidate_token_cache();
|
return Ok(CompressionResult { history, created_timelines: false });
|
||||||
return Ok(CompressionResult {
|
|
||||||
history,
|
|
||||||
created_timelines: false,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// LLM summarization pass
|
// LLM summarization pass
|
||||||
let mut current_history = history;
|
let mut current_history = history;
|
||||||
let mut created_timelines = false;
|
let mut created_timelines = false;
|
||||||
for pass in 0..self.config.max_passes {
|
for pass in 0..self.config.max_passes {
|
||||||
let tokens = self.token_estimate_with_history(¤t_history);
|
let tokens = estimate_tokens(¤t_history);
|
||||||
if tokens <= self.threshold() {
|
if tokens <= self.threshold() {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(pass = pass + 1, tokens = tokens, "Compression pass");
|
tracing::debug!(
|
||||||
|
pass = pass + 1,
|
||||||
|
tokens = tokens,
|
||||||
|
"Compression pass"
|
||||||
|
);
|
||||||
|
|
||||||
match self.compress_once(¤t_history).await {
|
match self.compress_once(¤t_history).await {
|
||||||
Ok(Some(compressed)) => {
|
Ok(Some(compressed)) => {
|
||||||
@ -345,52 +241,15 @@ impl ContextCompressor {
|
|||||||
|
|
||||||
// Hard safety net: if still dangerously high after all passes,
|
// Hard safety net: if still dangerously high after all passes,
|
||||||
// fall back to head+tail truncation so the LLM call doesn't overflow.
|
// fall back to head+tail truncation so the LLM call doesn't overflow.
|
||||||
let final_tokens = self.token_estimate_with_history(¤t_history);
|
let final_tokens = estimate_tokens(¤t_history);
|
||||||
let danger_threshold = (self.context_window as f64 * 0.9) as usize;
|
let danger_threshold = (self.context_window as f64 * 0.9) as usize;
|
||||||
if final_tokens > danger_threshold
|
if final_tokens > danger_threshold
|
||||||
&& current_history.len() > self.config.protect_first_n + self.config.protect_last_n
|
&& current_history.len() > self.config.protect_first_n + self.config.protect_last_n
|
||||||
{
|
{
|
||||||
let mut tail_start = current_history.len() - self.config.protect_last_n;
|
|
||||||
|
|
||||||
// Align tail_start backwards to preserve tool chain boundaries:
|
|
||||||
// if an assistant with tool_calls has results spanning the cut,
|
|
||||||
// include the assistant in the tail.
|
|
||||||
if tail_start > 0 && tail_start < current_history.len() {
|
|
||||||
let mut scan = tail_start.saturating_sub(1);
|
|
||||||
loop {
|
|
||||||
let m = ¤t_history[scan];
|
|
||||||
if m.role == "assistant" {
|
|
||||||
if let Some(tcs) = &m.tool_calls
|
|
||||||
&& !tcs.is_empty()
|
|
||||||
{
|
|
||||||
let has_post = current_history[scan + 1..]
|
|
||||||
.iter()
|
|
||||||
.filter(|r| r.role == "tool")
|
|
||||||
.any(|r| {
|
|
||||||
tcs.iter()
|
|
||||||
.any(|tc| r.tool_call_id.as_deref() == Some(tc.id.as_str()))
|
|
||||||
});
|
|
||||||
if has_post {
|
|
||||||
tail_start = scan;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if scan == 0 {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
scan -= 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Skip orphan tool messages at the new head-tail boundary
|
|
||||||
while tail_start < current_history.len() && current_history[tail_start].role == "tool" {
|
|
||||||
tail_start += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
let head: Vec<_> = current_history[..self.config.protect_first_n].to_vec();
|
let head: Vec<_> = current_history[..self.config.protect_first_n].to_vec();
|
||||||
|
let tail_start = current_history.len() - self.config.protect_last_n;
|
||||||
let tail: Vec<_> = current_history[tail_start..].to_vec();
|
let tail: Vec<_> = current_history[tail_start..].to_vec();
|
||||||
let dropped = current_history.len() - self.config.protect_first_n - tail.len();
|
let dropped = current_history.len() - self.config.protect_first_n - self.config.protect_last_n;
|
||||||
|
|
||||||
let mut truncated = head;
|
let mut truncated = head;
|
||||||
truncated.push(ChatMessage::user(format!(
|
truncated.push(ChatMessage::user(format!(
|
||||||
@ -400,26 +259,6 @@ impl ContextCompressor {
|
|||||||
)));
|
)));
|
||||||
truncated.extend(tail);
|
truncated.extend(tail);
|
||||||
|
|
||||||
// Strip tool_calls from any assistant in the head whose results
|
|
||||||
// were dropped (previously in the middle section).
|
|
||||||
for msg in &mut truncated[..self.config.protect_first_n] {
|
|
||||||
if msg.role == "assistant" {
|
|
||||||
if let Some(ref tcs) = msg.tool_calls
|
|
||||||
&& !tcs.is_empty()
|
|
||||||
{
|
|
||||||
let names: Vec<&str> = tcs.iter().map(|tc| tc.name.as_str()).collect();
|
|
||||||
msg.content = format!(
|
|
||||||
"{}\n\n[Tool calls ({}) — results dropped during truncation]",
|
|
||||||
msg.content,
|
|
||||||
names.join(", ")
|
|
||||||
);
|
|
||||||
msg.tool_calls = None;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Self::repair_tool_pairs(&mut truncated);
|
|
||||||
|
|
||||||
tracing::warn!(
|
tracing::warn!(
|
||||||
final_tokens = final_tokens,
|
final_tokens = final_tokens,
|
||||||
danger = danger_threshold,
|
danger = danger_threshold,
|
||||||
@ -430,21 +269,14 @@ impl ContextCompressor {
|
|||||||
current_history = truncated;
|
current_history = truncated;
|
||||||
}
|
}
|
||||||
|
|
||||||
if created_timelines {
|
|
||||||
self.invalidate_token_cache();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
final_tokens = self.token_estimate_with_history(¤t_history),
|
final_tokens = estimate_tokens(¤t_history),
|
||||||
final_msg_count = current_history.len(),
|
final_msg_count = current_history.len(),
|
||||||
"Context compression completed"
|
"Context compression completed"
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok(CompressionResult {
|
Ok(CompressionResult { history: current_history, created_timelines })
|
||||||
history: current_history,
|
|
||||||
created_timelines,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Try to extract the actual context token limit from an LLM error message.
|
/// Try to extract the actual context token limit from an LLM error message.
|
||||||
@ -467,8 +299,7 @@ impl ContextCompressor {
|
|||||||
// Look for a number in the vicinity (up to 10 chars after marker)
|
// Look for a number in the vicinity (up to 10 chars after marker)
|
||||||
if let Some(num_str) = find_number_nearby(after, 50)
|
if let Some(num_str) = find_number_nearby(after, 50)
|
||||||
&& let Ok(n) = num_str.parse::<usize>()
|
&& let Ok(n) = num_str.parse::<usize>()
|
||||||
&& (1024..=10_000_000).contains(&n)
|
&& (1024..=10_000_000).contains(&n) {
|
||||||
{
|
|
||||||
return Some(n);
|
return Some(n);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -530,26 +361,19 @@ impl ContextCompressor {
|
|||||||
|
|
||||||
// Persist compressed summary as timeline memory entry
|
// Persist compressed summary as timeline memory entry
|
||||||
let ts = chrono::Utc::now().format("%Y-%m-%d %H:%M").to_string();
|
let ts = chrono::Utc::now().format("%Y-%m-%d %H:%M").to_string();
|
||||||
let timeline_content = format!(
|
let timeline_content = format!("[{}] Compressed {} conversation segments:\n{}",
|
||||||
"[{}] Compressed {} conversation segments:\n{}",
|
ts, between.len(), summary);
|
||||||
ts,
|
|
||||||
between.len(),
|
|
||||||
summary
|
|
||||||
);
|
|
||||||
let key = format!("ctx_compressed_{}", uuid::Uuid::new_v4());
|
let key = format!("ctx_compressed_{}", uuid::Uuid::new_v4());
|
||||||
let mm = self.memory.clone();
|
let mm = self.memory.clone();
|
||||||
let sid = self.session_id.clone();
|
let sid = self.session_id.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(e) = mm
|
if let Err(e) = mm.store(
|
||||||
.store(
|
|
||||||
&key,
|
&key,
|
||||||
&timeline_content,
|
&timeline_content,
|
||||||
crate::memory::MemoryCategory::Timeline,
|
crate::memory::MemoryCategory::Timeline,
|
||||||
sid.as_deref(),
|
sid.as_deref(),
|
||||||
Some(0.3),
|
Some(0.3),
|
||||||
)
|
).await {
|
||||||
.await
|
|
||||||
{
|
|
||||||
tracing::warn!(error = %e, "Failed to store compressed context as timeline");
|
tracing::warn!(error = %e, "Failed to store compressed context as timeline");
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@ -580,7 +404,10 @@ impl ContextCompressor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Summarize a segment of messages using LLM.
|
/// Summarize a segment of messages using LLM.
|
||||||
async fn summarize_segment(&self, messages: &[ChatMessage]) -> Result<String, AgentError> {
|
async fn summarize_segment(
|
||||||
|
&self,
|
||||||
|
messages: &[ChatMessage],
|
||||||
|
) -> Result<String, AgentError> {
|
||||||
if messages.is_empty() {
|
if messages.is_empty() {
|
||||||
return Ok(String::new());
|
return Ok(String::new());
|
||||||
}
|
}
|
||||||
@ -594,8 +421,7 @@ impl ContextCompressor {
|
|||||||
"tool" => "Tool",
|
"tool" => "Tool",
|
||||||
_ => m.role.as_str(),
|
_ => m.role.as_str(),
|
||||||
};
|
};
|
||||||
let name = m
|
let name = m.tool_name
|
||||||
.tool_name
|
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map(|n| format!(" ({})", n))
|
.map(|n| format!(" ({})", n))
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
@ -640,10 +466,7 @@ Be concise, aim for {} characters or less.
|
|||||||
);
|
);
|
||||||
|
|
||||||
let request = ChatCompletionRequest {
|
let request = ChatCompletionRequest {
|
||||||
messages: vec![
|
messages: vec![Message::system("You are a helpful assistant."), Message::user(&prompt)],
|
||||||
Message::system("You are a helpful assistant."),
|
|
||||||
Message::user(&prompt),
|
|
||||||
],
|
|
||||||
temperature: Some(0.3),
|
temperature: Some(0.3),
|
||||||
max_tokens: Some(1000),
|
max_tokens: Some(1000),
|
||||||
tools: None,
|
tools: None,
|
||||||
@ -715,23 +538,13 @@ mod tests {
|
|||||||
content: "[summarized]".into(),
|
content: "[summarized]".into(),
|
||||||
reasoning_content: None,
|
reasoning_content: None,
|
||||||
tool_calls: vec![],
|
tool_calls: vec![],
|
||||||
usage: Usage {
|
usage: Usage { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 },
|
||||||
prompt_tokens: 0,
|
|
||||||
completion_tokens: 0,
|
|
||||||
total_tokens: 0,
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ptype(&self) -> &str {
|
fn ptype(&self) -> &str { "mock" }
|
||||||
"mock"
|
fn name(&self) -> &str { "mock" }
|
||||||
}
|
fn model_id(&self) -> &str { "mock" }
|
||||||
fn name(&self) -> &str {
|
|
||||||
"mock"
|
|
||||||
}
|
|
||||||
fn model_id(&self) -> &str {
|
|
||||||
"mock"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mock_summarizer() -> Arc<dyn LLMProvider> {
|
fn mock_summarizer() -> Arc<dyn LLMProvider> {
|
||||||
@ -743,13 +556,11 @@ mod tests {
|
|||||||
MM.get_or_init(|| {
|
MM.get_or_init(|| {
|
||||||
let rt = tokio::runtime::Runtime::new().unwrap();
|
let rt = tokio::runtime::Runtime::new().unwrap();
|
||||||
rt.block_on(async {
|
rt.block_on(async {
|
||||||
let tmp = std::env::temp_dir()
|
let tmp = std::env::temp_dir().join(format!("picobot_ctx_test_{}.db", std::process::id()));
|
||||||
.join(format!("picobot_ctx_test_{}.db", std::process::id()));
|
|
||||||
let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap());
|
let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap());
|
||||||
Arc::new(MemoryManager::new(storage, "test".into(), "test".into()))
|
Arc::new(MemoryManager::new(storage, "test".into(), "test".into()))
|
||||||
})
|
})
|
||||||
})
|
}).clone()
|
||||||
.clone()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -765,11 +576,7 @@ mod tests {
|
|||||||
// "Hi there!" (8) -> ceil(8/4)+4 = 2+4 = 6
|
// "Hi there!" (8) -> ceil(8/4)+4 = 2+4 = 6
|
||||||
// "How are you?" (11) -> ceil(11/4)+4 = 3+4 = 7
|
// "How are you?" (11) -> ceil(11/4)+4 = 3+4 = 7
|
||||||
// raw = 19, with 1.2x = ~23
|
// raw = 19, with 1.2x = ~23
|
||||||
assert!(
|
assert!(tokens > 18 && tokens < 30, "Expected ~23 tokens, got {}", tokens);
|
||||||
tokens > 18 && tokens < 30,
|
|
||||||
"Expected ~23 tokens, got {}",
|
|
||||||
tokens
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -778,15 +585,14 @@ mod tests {
|
|||||||
tool_result_trim_chars: 50,
|
tool_result_trim_chars: 50,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let compressor =
|
let compressor = ContextCompressor::with_config(mock_provider(), 100_000, config, test_memory_manager());
|
||||||
ContextCompressor::with_config(mock_provider(), 100_000, config, test_memory_manager());
|
|
||||||
|
|
||||||
let mut messages = vec![
|
let mut messages = vec![
|
||||||
ChatMessage::user("Hello"),
|
ChatMessage::user("Hello"),
|
||||||
ChatMessage::tool("call1", "bash", &"x".repeat(200)),
|
ChatMessage::tool("call1", "bash", &"x".repeat(200)),
|
||||||
];
|
];
|
||||||
|
|
||||||
let modified = compressor.fast_trim_tool_results(&mut messages, 2);
|
let modified = compressor.fast_trim_tool_results(&mut messages);
|
||||||
assert_eq!(modified, 1);
|
assert_eq!(modified, 1);
|
||||||
assert!(messages[1].content.len() < 100);
|
assert!(messages[1].content.len() < 100);
|
||||||
}
|
}
|
||||||
@ -813,18 +619,14 @@ mod tests {
|
|||||||
max_passes: 0,
|
max_passes: 0,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let mut compressor = ContextCompressor::with_config(mock_provider(), 200, config, mm);
|
let compressor = ContextCompressor::with_config(mock_provider(), 200, config, mm);
|
||||||
|
|
||||||
let messages = vec![
|
let messages = vec![
|
||||||
ChatMessage::user("Hi"),
|
ChatMessage::user("Hi"),
|
||||||
ChatMessage::tool("call1", "bash", &"x".repeat(3000)),
|
ChatMessage::tool("call1", "bash", &"x".repeat(3000)),
|
||||||
];
|
];
|
||||||
|
|
||||||
let result = compressor
|
let result = compressor.compress_if_needed(messages).await.unwrap().history;
|
||||||
.compress_if_needed(messages)
|
|
||||||
.await
|
|
||||||
.unwrap()
|
|
||||||
.history;
|
|
||||||
|
|
||||||
let tool_msg = result.iter().find(|m| m.role == "tool").unwrap();
|
let tool_msg = result.iter().find(|m| m.role == "tool").unwrap();
|
||||||
assert!(
|
assert!(
|
||||||
@ -848,8 +650,7 @@ mod tests {
|
|||||||
// - B2B (L275): last user message lost when it is the final history message
|
// - B2B (L275): last user message lost when it is the final history message
|
||||||
//
|
//
|
||||||
// context_window=200 → threshold=100. Large tool outputs force LLM summarization.
|
// context_window=200 → threshold=100. Large tool outputs force LLM summarization.
|
||||||
let tmp =
|
let tmp = std::env::temp_dir().join(format!("picobot_ctx_boundary_{}.db", std::process::id()));
|
||||||
std::env::temp_dir().join(format!("picobot_ctx_boundary_{}.db", std::process::id()));
|
|
||||||
let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap());
|
let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap());
|
||||||
let mm = Arc::new(MemoryManager::new(storage, "test".into(), "test".into()));
|
let mm = Arc::new(MemoryManager::new(storage, "test".into(), "test".into()));
|
||||||
|
|
||||||
@ -860,7 +661,7 @@ mod tests {
|
|||||||
max_passes: 1,
|
max_passes: 1,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let mut compressor = ContextCompressor::with_config(mock_summarizer(), 200, config, mm);
|
let compressor = ContextCompressor::with_config(mock_summarizer(), 200, config, mm);
|
||||||
|
|
||||||
// History: 9 messages, last message is user Q4.
|
// History: 9 messages, last message is user Q4.
|
||||||
// user_indices (skip 1) = [1, 3, 6, 8]
|
// user_indices (skip 1) = [1, 3, 6, 8]
|
||||||
@ -879,33 +680,15 @@ mod tests {
|
|||||||
ChatMessage::user("Q4"), // 8: LAST, is user → B2B triggers
|
ChatMessage::user("Q4"), // 8: LAST, is user → B2B triggers
|
||||||
];
|
];
|
||||||
|
|
||||||
let result = compressor
|
let result = compressor.compress_if_needed(messages).await.unwrap().history;
|
||||||
.compress_if_needed(messages)
|
|
||||||
.await
|
|
||||||
.unwrap()
|
|
||||||
.history;
|
|
||||||
|
|
||||||
// B2A: "Q1" must appear exactly once
|
// B2A: "Q1" must appear exactly once
|
||||||
let q1_count = result
|
let q1_count = result.iter().filter(|m| m.role == "user" && m.content == "Q1").count();
|
||||||
.iter()
|
assert_eq!(q1_count, 1, "Q1 should appear exactly once, got {}", q1_count);
|
||||||
.filter(|m| m.role == "user" && m.content == "Q1")
|
|
||||||
.count();
|
|
||||||
assert_eq!(
|
|
||||||
q1_count, 1,
|
|
||||||
"Q1 should appear exactly once, got {}",
|
|
||||||
q1_count
|
|
||||||
);
|
|
||||||
|
|
||||||
// B2B: "Q4" must NOT be lost
|
// B2B: "Q4" must NOT be lost
|
||||||
let q4_count = result
|
let q4_count = result.iter().filter(|m| m.role == "user" && m.content == "Q4").count();
|
||||||
.iter()
|
assert_eq!(q4_count, 1, "Q4 should appear exactly once (not lost), got {}", q4_count);
|
||||||
.filter(|m| m.role == "user" && m.content == "Q4")
|
|
||||||
.count();
|
|
||||||
assert_eq!(
|
|
||||||
q4_count, 1,
|
|
||||||
"Q4 should appear exactly once (not lost), got {}",
|
|
||||||
q4_count
|
|
||||||
);
|
|
||||||
|
|
||||||
let _ = std::fs::remove_file(&tmp);
|
let _ = std::fs::remove_file(&tmp);
|
||||||
}
|
}
|
||||||
@ -928,7 +711,7 @@ mod tests {
|
|||||||
// context_window=100, danger_threshold=90.
|
// context_window=100, danger_threshold=90.
|
||||||
// Each trimmed tool (~500 chars): ceil(500/4)+4 = 129 raw. 3 tools = 387.
|
// Each trimmed tool (~500 chars): ceil(500/4)+4 = 129 raw. 3 tools = 387.
|
||||||
// Plus users (~5 each) + system (~15) = ~417 raw * 1.2 = 500 > 90.
|
// Plus users (~5 each) + system (~15) = ~417 raw * 1.2 = 500 > 90.
|
||||||
let mut compressor = ContextCompressor::with_config(mock_provider(), 100, config, mm);
|
let compressor = ContextCompressor::with_config(mock_provider(), 100, config, mm);
|
||||||
|
|
||||||
let big = "x".repeat(3000);
|
let big = "x".repeat(3000);
|
||||||
let messages = vec![
|
let messages = vec![
|
||||||
@ -941,23 +724,13 @@ mod tests {
|
|||||||
ChatMessage::tool("t3", "bash", &big),
|
ChatMessage::tool("t3", "bash", &big),
|
||||||
];
|
];
|
||||||
|
|
||||||
let result = compressor
|
let result = compressor.compress_if_needed(messages).await.unwrap().history;
|
||||||
.compress_if_needed(messages)
|
|
||||||
.await
|
|
||||||
.unwrap()
|
|
||||||
.history;
|
|
||||||
|
|
||||||
// After hard truncation: head (1) + trunc_note (1) + tail (2) = 4 messages
|
// After hard truncation: head (1) + trunc_note (1) + tail (2) = 4 messages
|
||||||
assert!(
|
assert!(result.len() < 7, "expected truncation reduction, got {} messages", result.len());
|
||||||
result.len() < 7,
|
|
||||||
"expected truncation reduction, got {} messages",
|
|
||||||
result.len()
|
|
||||||
);
|
|
||||||
|
|
||||||
// Truncation notice should be present
|
// Truncation notice should be present
|
||||||
let has_notice = result
|
let has_notice = result.iter().any(|m| m.content.contains("Context truncation"));
|
||||||
.iter()
|
|
||||||
.any(|m| m.content.contains("Context truncation"));
|
|
||||||
assert!(has_notice, "hard truncation notice missing");
|
assert!(has_notice, "hard truncation notice missing");
|
||||||
|
|
||||||
let _ = std::fs::remove_file(&tmp);
|
let _ = std::fs::remove_file(&tmp);
|
||||||
@ -989,16 +762,8 @@ mod tests {
|
|||||||
|
|
||||||
// orphan should be removed; legitimate should stay
|
// orphan should be removed; legitimate should stay
|
||||||
assert_eq!(messages.len(), 4);
|
assert_eq!(messages.len(), 4);
|
||||||
assert!(
|
assert!(messages.iter().all(|m| m.tool_call_id != Some("tc1".into())));
|
||||||
messages
|
assert!(messages.iter().any(|m| m.tool_call_id == Some("tc2".into())));
|
||||||
.iter()
|
|
||||||
.all(|m| m.tool_call_id != Some("tc1".into()))
|
|
||||||
);
|
|
||||||
assert!(
|
|
||||||
messages
|
|
||||||
.iter()
|
|
||||||
.any(|m| m.tool_call_id == Some("tc2".into()))
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@ -49,7 +49,7 @@ impl MediaHandler for ImageHandler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> {
|
fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> {
|
||||||
use base64::{Engine as _, engine::general_purpose::STANDARD};
|
use base64::{engine::general_purpose::STANDARD, Engine as _};
|
||||||
|
|
||||||
let mut file = std::fs::File::open(path)?;
|
let mut file = std::fs::File::open(path)?;
|
||||||
let mut buffer = Vec::new();
|
let mut buffer = Vec::new();
|
||||||
|
|||||||
@ -1,16 +1,8 @@
|
|||||||
pub mod agent_loop;
|
pub mod agent_loop;
|
||||||
pub mod context_compressor;
|
pub mod context_compressor;
|
||||||
pub mod media_handler;
|
pub mod media_handler;
|
||||||
pub mod sub_agent;
|
|
||||||
pub mod system_prompt;
|
pub mod system_prompt;
|
||||||
|
|
||||||
pub use agent_loop::{AgentError, AgentLoop, AgentProcessResult};
|
pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult};
|
||||||
pub use context_compressor::{ContextCompressor, estimate_tokens};
|
pub use context_compressor::{ContextCompressor, estimate_tokens};
|
||||||
pub use sub_agent::{
|
pub use system_prompt::{build_system_prompt, PromptContext, PromptSection, SystemPromptBuilder};
|
||||||
DelegateContext, ExecutionMode, SubAgentConfig, SubAgentError, SubAgentManager, SubAgentResult,
|
|
||||||
TaskNotification, TaskStatus,
|
|
||||||
};
|
|
||||||
pub use system_prompt::{
|
|
||||||
PromptContext, PromptSection, SystemPromptBuilder, build_sub_agent_system_prompt,
|
|
||||||
build_system_prompt,
|
|
||||||
};
|
|
||||||
|
|||||||
@ -1,623 +0,0 @@
|
|||||||
use std::collections::HashSet;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use std::time::Instant;
|
|
||||||
|
|
||||||
use dashmap::DashMap;
|
|
||||||
use tokio_util::sync::CancellationToken;
|
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
use crate::agent::AgentError;
|
|
||||||
use crate::agent::AgentLoop;
|
|
||||||
use crate::agent::system_prompt::build_sub_agent_system_prompt;
|
|
||||||
use crate::bus::ChatMessage;
|
|
||||||
use crate::config::LLMProviderConfig;
|
|
||||||
use crate::providers::{LLMProvider, create_provider};
|
|
||||||
use crate::skills::SkillsLoader;
|
|
||||||
use crate::tools::ToolRegistry;
|
|
||||||
|
|
||||||
tokio::task_local! {
|
|
||||||
pub(crate) static DELEGATE_CONTEXT: DelegateContext;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Read the delegate context from the current task. Returns an error if not set.
|
|
||||||
pub fn get_delegate_context() -> Result<DelegateContext, String> {
|
|
||||||
DELEGATE_CONTEXT
|
|
||||||
.try_with(|ctx| ctx.clone())
|
|
||||||
.map_err(|_| "DELEGATE_CONTEXT not set".to_string())
|
|
||||||
}
|
|
||||||
|
|
||||||
const DEFAULT_MAX_ITERATIONS: usize = 99;
|
|
||||||
const DEFAULT_TIMEOUT_SECS: u64 = 3600;
|
|
||||||
const MAX_INLINE_RESULT_CHARS: usize = 8000;
|
|
||||||
|
|
||||||
const DEFAULT_READONLY_TOOLS: &[&str] = &[
|
|
||||||
"file_read",
|
|
||||||
"file_search",
|
|
||||||
"content_search",
|
|
||||||
"web_fetch",
|
|
||||||
"http_request",
|
|
||||||
"calculator",
|
|
||||||
];
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct SubAgentConfig {
|
|
||||||
pub prompt: String,
|
|
||||||
pub mode: ExecutionMode,
|
|
||||||
pub allowed_tools: Option<Vec<String>>,
|
|
||||||
pub max_iterations: Option<usize>,
|
|
||||||
pub timeout_secs: Option<u64>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
||||||
pub enum ExecutionMode {
|
|
||||||
Inline,
|
|
||||||
Background,
|
|
||||||
Parallel,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct SubAgentResult {
|
|
||||||
pub task_id: String,
|
|
||||||
pub content: String,
|
|
||||||
pub content_truncated: bool,
|
|
||||||
pub status: TaskStatus,
|
|
||||||
pub tool_calls_count: usize,
|
|
||||||
pub iterations: usize,
|
|
||||||
pub duration_ms: u64,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub enum TaskStatus {
|
|
||||||
Completed,
|
|
||||||
Failed(String),
|
|
||||||
Cancelled,
|
|
||||||
TimedOut,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct TaskNotification {
|
|
||||||
pub task_id: String,
|
|
||||||
pub session_id: String,
|
|
||||||
pub channel: String,
|
|
||||||
pub chat_id: String,
|
|
||||||
pub status: TaskStatus,
|
|
||||||
pub result_summary: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct DelegateContext {
|
|
||||||
pub session_id: String,
|
|
||||||
pub channel: String,
|
|
||||||
pub chat_id: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum SubAgentError {
|
|
||||||
TooManyTasks(usize),
|
|
||||||
ProviderCreation(String),
|
|
||||||
Storage(String),
|
|
||||||
Other(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Display for SubAgentError {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
Self::TooManyTasks(max) => write!(f, "后台任务已达上限({}),请稍后重试", max),
|
|
||||||
Self::ProviderCreation(e) => write!(f, "provider creation failed: {}", e),
|
|
||||||
Self::Storage(e) => write!(f, "storage error: {}", e),
|
|
||||||
Self::Other(e) => write!(f, "{}", e),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::error::Error for SubAgentError {}
|
|
||||||
|
|
||||||
pub struct SubAgentManager {
|
|
||||||
provider_config: LLMProviderConfig,
|
|
||||||
full_tools: Arc<ToolRegistry>,
|
|
||||||
storage: Option<Arc<crate::storage::Storage>>,
|
|
||||||
active_tasks: Arc<DashMap<String, CancellationToken>>,
|
|
||||||
notify_tx: tokio::sync::mpsc::UnboundedSender<TaskNotification>,
|
|
||||||
max_concurrent_background_tasks: usize,
|
|
||||||
skills_loader: Option<Arc<SkillsLoader>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SubAgentManager {
|
|
||||||
pub fn new(
|
|
||||||
provider_config: LLMProviderConfig,
|
|
||||||
full_tools: Arc<ToolRegistry>,
|
|
||||||
storage: Option<Arc<crate::storage::Storage>>,
|
|
||||||
notify_tx: tokio::sync::mpsc::UnboundedSender<TaskNotification>,
|
|
||||||
max_concurrent_background_tasks: usize,
|
|
||||||
skills_loader: Option<Arc<SkillsLoader>>,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
|
||||||
provider_config,
|
|
||||||
full_tools,
|
|
||||||
storage,
|
|
||||||
active_tasks: Arc::new(DashMap::new()),
|
|
||||||
notify_tx,
|
|
||||||
max_concurrent_background_tasks,
|
|
||||||
skills_loader,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn filter_tools(&self, allowed: &Option<Vec<String>>) -> Arc<ToolRegistry> {
|
|
||||||
let allowed_set: HashSet<&str> = match allowed {
|
|
||||||
Some(list) => list.iter().map(|s| s.as_str()).collect(),
|
|
||||||
None => DEFAULT_READONLY_TOOLS.iter().copied().collect(),
|
|
||||||
};
|
|
||||||
let filtered = ToolRegistry::new();
|
|
||||||
for (name, tool) in self.full_tools.iter() {
|
|
||||||
if allowed_set.contains(name.as_str()) && name != "delegate" {
|
|
||||||
filtered.register_raw(name, tool);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Arc::new(filtered)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_skills_prompt(&self, tools: &ToolRegistry) -> Option<String> {
|
|
||||||
let has_get_skill = tools.iter().iter().any(|(name, _)| name == "get_skill");
|
|
||||||
if has_get_skill {
|
|
||||||
if let Some(ref loader) = self.skills_loader {
|
|
||||||
let prompt = loader.build_skills_prompt();
|
|
||||||
if !prompt.is_empty() {
|
|
||||||
return Some(prompt);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn build_sub_agent(
|
|
||||||
&self,
|
|
||||||
config: &SubAgentConfig,
|
|
||||||
tools: Arc<ToolRegistry>,
|
|
||||||
) -> Result<AgentLoop, AgentError> {
|
|
||||||
let mut provider = create_provider(self.provider_config.clone())
|
|
||||||
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
|
||||||
if let Some(ref s) = self.storage {
|
|
||||||
provider.set_storage(s.clone());
|
|
||||||
}
|
|
||||||
let provider: Arc<dyn LLMProvider> = Arc::from(provider);
|
|
||||||
|
|
||||||
let max_iterations = config.max_iterations.unwrap_or(DEFAULT_MAX_ITERATIONS);
|
|
||||||
let workspace_dir = self.provider_config.workspace_dir.clone();
|
|
||||||
let model_name = self.provider_config.model_id.clone();
|
|
||||||
let input_types = self.provider_config.input_types.clone();
|
|
||||||
|
|
||||||
let agent = AgentLoop::with_provider_and_tools(
|
|
||||||
provider,
|
|
||||||
tools,
|
|
||||||
max_iterations,
|
|
||||||
model_name,
|
|
||||||
workspace_dir,
|
|
||||||
input_types,
|
|
||||||
)
|
|
||||||
.with_context_window(self.provider_config.token_limit);
|
|
||||||
|
|
||||||
Ok(agent)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn run_inline(
|
|
||||||
&self,
|
|
||||||
config: SubAgentConfig,
|
|
||||||
) -> Result<SubAgentResult, SubAgentError> {
|
|
||||||
let task_id = generate_task_id();
|
|
||||||
let tools = self.filter_tools(&config.allowed_tools);
|
|
||||||
let timeout_secs = config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
|
|
||||||
let timeout_human = format_duration(timeout_secs);
|
|
||||||
let http_get_only = config.allowed_tools.is_none()
|
|
||||||
|| config
|
|
||||||
.allowed_tools
|
|
||||||
.as_ref()
|
|
||||||
.is_some_and(|v| v.iter().any(|t| t == "http_request"));
|
|
||||||
let skills_prompt = self.get_skills_prompt(&tools);
|
|
||||||
let system_prompt = build_sub_agent_system_prompt(
|
|
||||||
&config.prompt,
|
|
||||||
&timeout_human,
|
|
||||||
&tools,
|
|
||||||
&self.provider_config.workspace_dir,
|
|
||||||
&self.provider_config.model_id,
|
|
||||||
skills_prompt,
|
|
||||||
http_get_only,
|
|
||||||
);
|
|
||||||
|
|
||||||
let agent = self
|
|
||||||
.build_sub_agent(&config, tools)
|
|
||||||
.map_err(|e| SubAgentError::ProviderCreation(e.to_string()))?;
|
|
||||||
|
|
||||||
let history = vec![
|
|
||||||
ChatMessage::system(system_prompt),
|
|
||||||
ChatMessage::user(&config.prompt),
|
|
||||||
];
|
|
||||||
|
|
||||||
let start = Instant::now();
|
|
||||||
|
|
||||||
let result = tokio::time::timeout(
|
|
||||||
std::time::Duration::from_secs(timeout_secs),
|
|
||||||
agent.process(history),
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let duration_ms = start.elapsed().as_millis() as u64;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(Ok(agent_result)) => {
|
|
||||||
let (content, truncated) =
|
|
||||||
truncate_sub_agent_result(&agent_result.final_response.content);
|
|
||||||
let tool_calls_count = agent_result
|
|
||||||
.emitted_messages
|
|
||||||
.iter()
|
|
||||||
.filter(|m| m.tool_calls.is_some())
|
|
||||||
.count();
|
|
||||||
let iterations = agent_result
|
|
||||||
.emitted_messages
|
|
||||||
.iter()
|
|
||||||
.filter(|m| m.role == "assistant" && m.tool_calls.is_some())
|
|
||||||
.count();
|
|
||||||
Ok(SubAgentResult {
|
|
||||||
task_id,
|
|
||||||
content,
|
|
||||||
content_truncated: truncated,
|
|
||||||
status: TaskStatus::Completed,
|
|
||||||
tool_calls_count,
|
|
||||||
iterations,
|
|
||||||
duration_ms,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
Ok(Err(e)) => Ok(SubAgentResult {
|
|
||||||
task_id,
|
|
||||||
content: String::new(),
|
|
||||||
content_truncated: false,
|
|
||||||
status: TaskStatus::Failed(e.to_string()),
|
|
||||||
tool_calls_count: 0,
|
|
||||||
iterations: 0,
|
|
||||||
duration_ms,
|
|
||||||
}),
|
|
||||||
Err(_elapsed) => Ok(SubAgentResult {
|
|
||||||
task_id,
|
|
||||||
content: String::new(),
|
|
||||||
content_truncated: false,
|
|
||||||
status: TaskStatus::TimedOut,
|
|
||||||
tool_calls_count: 0,
|
|
||||||
iterations: 0,
|
|
||||||
duration_ms,
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn run_parallel(
|
|
||||||
&self,
|
|
||||||
configs: Vec<SubAgentConfig>,
|
|
||||||
) -> Result<Vec<SubAgentResult>, SubAgentError> {
|
|
||||||
let futures: Vec<_> = configs
|
|
||||||
.into_iter()
|
|
||||||
.map(|config| {
|
|
||||||
let mgr = self; // &self borrow, all tasks share the same manager
|
|
||||||
async move { mgr.run_inline(config).await }
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let results = futures_util::future::join_all(futures).await;
|
|
||||||
Ok(results.into_iter().collect::<Result<Vec<_>, _>>()?)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn run_background(
|
|
||||||
&self,
|
|
||||||
config: SubAgentConfig,
|
|
||||||
ctx: DelegateContext,
|
|
||||||
) -> Result<String, SubAgentError> {
|
|
||||||
if self.active_tasks.len() >= self.max_concurrent_background_tasks {
|
|
||||||
return Err(SubAgentError::TooManyTasks(
|
|
||||||
self.max_concurrent_background_tasks,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
let task_id = generate_task_id();
|
|
||||||
let cancel_token = CancellationToken::new();
|
|
||||||
|
|
||||||
self.active_tasks
|
|
||||||
.insert(task_id.clone(), cancel_token.clone());
|
|
||||||
|
|
||||||
// Write DB: pending
|
|
||||||
if let Some(ref storage) = self.storage {
|
|
||||||
let allowed_tools_json = config
|
|
||||||
.allowed_tools
|
|
||||||
.as_ref()
|
|
||||||
.and_then(|v| serde_json::to_string(v).ok());
|
|
||||||
let record = crate::storage::BackgroundTask {
|
|
||||||
id: task_id.clone(),
|
|
||||||
session_id: ctx.session_id.clone(),
|
|
||||||
channel: ctx.channel.clone(),
|
|
||||||
chat_id: ctx.chat_id.clone(),
|
|
||||||
prompt: config.prompt.clone(),
|
|
||||||
allowed_tools: allowed_tools_json,
|
|
||||||
status: "pending".to_string(),
|
|
||||||
result: None,
|
|
||||||
error: None,
|
|
||||||
tool_calls_count: 0,
|
|
||||||
iterations: 0,
|
|
||||||
started_at: None,
|
|
||||||
finished_at: None,
|
|
||||||
created_at: chrono::Utc::now().timestamp_millis(),
|
|
||||||
};
|
|
||||||
storage
|
|
||||||
.create_background_task(&record)
|
|
||||||
.await
|
|
||||||
.map_err(|e| SubAgentError::Storage(e.to_string()))?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let tools = self.filter_tools(&config.allowed_tools);
|
|
||||||
let timeout_secs = config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
|
|
||||||
let timeout_human = format_duration(timeout_secs);
|
|
||||||
let http_get_only = config.allowed_tools.is_none()
|
|
||||||
|| config
|
|
||||||
.allowed_tools
|
|
||||||
.as_ref()
|
|
||||||
.is_some_and(|v| v.iter().any(|t| t == "http_request"));
|
|
||||||
let skills_prompt = self.get_skills_prompt(&tools);
|
|
||||||
let system_prompt = build_sub_agent_system_prompt(
|
|
||||||
&config.prompt,
|
|
||||||
&timeout_human,
|
|
||||||
&tools,
|
|
||||||
&self.provider_config.workspace_dir,
|
|
||||||
&self.provider_config.model_id,
|
|
||||||
skills_prompt,
|
|
||||||
http_get_only,
|
|
||||||
);
|
|
||||||
let provider_config = self.provider_config.clone();
|
|
||||||
let storage = self.storage.clone();
|
|
||||||
let notify_tx = self.notify_tx.clone();
|
|
||||||
let active_tasks = Arc::clone(&self.active_tasks);
|
|
||||||
|
|
||||||
let tid = task_id.clone();
|
|
||||||
let sess_id = ctx.session_id.clone();
|
|
||||||
let ch = ctx.channel.clone();
|
|
||||||
let cid = ctx.chat_id.clone();
|
|
||||||
let prompt = config.prompt.clone();
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
let started_at = chrono::Utc::now().timestamp_millis();
|
|
||||||
|
|
||||||
// Update DB: running
|
|
||||||
if let Some(ref s) = storage {
|
|
||||||
let _ = s
|
|
||||||
.update_background_task_status(
|
|
||||||
&tid,
|
|
||||||
"running",
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
Some(started_at),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut provider = create_provider(provider_config.clone()).ok();
|
|
||||||
if let Some(ref mut p) = provider {
|
|
||||||
if let Some(ref s) = storage {
|
|
||||||
p.set_storage(s.clone());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let provider_result: Option<Arc<dyn LLMProvider>> = provider.map(|p| Arc::from(p));
|
|
||||||
|
|
||||||
let result = match provider_result {
|
|
||||||
Some(provider) => {
|
|
||||||
let agent = AgentLoop::with_provider_and_tools(
|
|
||||||
provider,
|
|
||||||
tools,
|
|
||||||
DEFAULT_MAX_ITERATIONS,
|
|
||||||
provider_config.model_id.clone(),
|
|
||||||
provider_config.workspace_dir.clone(),
|
|
||||||
provider_config.input_types.clone(),
|
|
||||||
)
|
|
||||||
.with_context_window(provider_config.token_limit);
|
|
||||||
|
|
||||||
let history = vec![
|
|
||||||
ChatMessage::system(system_prompt),
|
|
||||||
ChatMessage::user(&prompt),
|
|
||||||
];
|
|
||||||
|
|
||||||
tokio::select! {
|
|
||||||
r = tokio::time::timeout(
|
|
||||||
std::time::Duration::from_secs(timeout_secs),
|
|
||||||
agent.process(history),
|
|
||||||
) => {
|
|
||||||
match r {
|
|
||||||
Ok(Ok(agent_result)) => SubAgentResult {
|
|
||||||
task_id: tid.clone(),
|
|
||||||
content: agent_result.final_response.content,
|
|
||||||
content_truncated: false,
|
|
||||||
status: TaskStatus::Completed,
|
|
||||||
tool_calls_count: 0,
|
|
||||||
iterations: 0,
|
|
||||||
duration_ms: 0,
|
|
||||||
},
|
|
||||||
Ok(Err(e)) => SubAgentResult {
|
|
||||||
task_id: tid.clone(),
|
|
||||||
content: String::new(),
|
|
||||||
content_truncated: false,
|
|
||||||
status: TaskStatus::Failed(e.to_string()),
|
|
||||||
tool_calls_count: 0,
|
|
||||||
iterations: 0,
|
|
||||||
duration_ms: 0,
|
|
||||||
},
|
|
||||||
Err(_) => SubAgentResult {
|
|
||||||
task_id: tid.clone(),
|
|
||||||
content: String::new(),
|
|
||||||
content_truncated: false,
|
|
||||||
status: TaskStatus::TimedOut,
|
|
||||||
tool_calls_count: 0,
|
|
||||||
iterations: 0,
|
|
||||||
duration_ms: 0,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ = cancel_token.cancelled() => SubAgentResult {
|
|
||||||
task_id: tid.clone(),
|
|
||||||
content: String::new(),
|
|
||||||
content_truncated: false,
|
|
||||||
status: TaskStatus::Cancelled,
|
|
||||||
tool_calls_count: 0,
|
|
||||||
iterations: 0,
|
|
||||||
duration_ms: 0,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None => SubAgentResult {
|
|
||||||
task_id: tid.clone(),
|
|
||||||
content: String::new(),
|
|
||||||
content_truncated: false,
|
|
||||||
status: TaskStatus::Failed("provider creation failed".into()),
|
|
||||||
tool_calls_count: 0,
|
|
||||||
iterations: 0,
|
|
||||||
duration_ms: 0,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
let finished_at = chrono::Utc::now().timestamp_millis();
|
|
||||||
let duration_ms = (finished_at - started_at) as u64;
|
|
||||||
|
|
||||||
let (status_str, error_val) = match &result.status {
|
|
||||||
TaskStatus::Completed => ("completed".to_string(), None),
|
|
||||||
TaskStatus::Failed(e) => ("failed".to_string(), Some(e.clone())),
|
|
||||||
TaskStatus::Cancelled => ("cancelled".to_string(), None),
|
|
||||||
TaskStatus::TimedOut => ("failed".to_string(), Some("timeout".to_string())),
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(ref s) = storage {
|
|
||||||
let _ = s
|
|
||||||
.update_background_task_status(
|
|
||||||
&tid,
|
|
||||||
&status_str,
|
|
||||||
Some(&result.content),
|
|
||||||
error_val.as_deref(),
|
|
||||||
Some(started_at),
|
|
||||||
Some(finished_at),
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
|
|
||||||
let _ = notify_tx.send(TaskNotification {
|
|
||||||
task_id: tid.clone(),
|
|
||||||
session_id: sess_id,
|
|
||||||
channel: ch,
|
|
||||||
chat_id: cid,
|
|
||||||
status: result.status,
|
|
||||||
result_summary: summarize_for_notification(&result.content, duration_ms),
|
|
||||||
});
|
|
||||||
|
|
||||||
active_tasks.remove(&tid);
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(task_id)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn cancel_task(&self, task_id: &str) -> Result<bool, SubAgentError> {
|
|
||||||
if let Some((_, token)) = self.active_tasks.remove(task_id) {
|
|
||||||
token.cancel();
|
|
||||||
if let Some(ref s) = self.storage {
|
|
||||||
s.update_background_task_status(
|
|
||||||
task_id,
|
|
||||||
"cancelled",
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
Some(chrono::Utc::now().timestamp_millis()),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.map_err(|e| SubAgentError::Storage(e.to_string()))?;
|
|
||||||
}
|
|
||||||
Ok(true)
|
|
||||||
} else if let Some(ref s) = self.storage {
|
|
||||||
match s.get_background_task(task_id).await {
|
|
||||||
Ok(task) => match task.status.as_str() {
|
|
||||||
"pending" | "running" => {
|
|
||||||
tracing::warn!(task_id, "task in DB but not in active_tasks");
|
|
||||||
Ok(false)
|
|
||||||
}
|
|
||||||
_ => Ok(false),
|
|
||||||
},
|
|
||||||
Err(_) => Ok(false),
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Ok(false)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn check_task(&self, task_id: &str) -> Option<crate::storage::BackgroundTask> {
|
|
||||||
if let Some(ref s) = self.storage {
|
|
||||||
s.get_background_task(task_id).await.ok()
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn list_tasks(&self, session_id: &str) -> Vec<crate::storage::BackgroundTask> {
|
|
||||||
if let Some(ref s) = self.storage {
|
|
||||||
s.list_background_tasks(session_id)
|
|
||||||
.await
|
|
||||||
.unwrap_or_default()
|
|
||||||
} else {
|
|
||||||
vec![]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn cancel_by_session(&self, session_id: &str) {
|
|
||||||
// Cancel all running tasks for a session by checking DB
|
|
||||||
if let Some(ref s) = self.storage {
|
|
||||||
if let Ok(tasks) = s.list_background_tasks(session_id).await {
|
|
||||||
for task in &tasks {
|
|
||||||
if task.status == "pending" || task.status == "running" {
|
|
||||||
let _ = self.cancel_task(&task.id).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn active_task_count(&self) -> usize {
|
|
||||||
self.active_tasks.len()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn generate_task_id() -> String {
|
|
||||||
Uuid::new_v4().to_string()[..8].to_string()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn format_duration(seconds: u64) -> String {
|
|
||||||
if seconds < 60 {
|
|
||||||
format!("{}s", seconds)
|
|
||||||
} else if seconds < 3600 {
|
|
||||||
format!("{}m", seconds / 60)
|
|
||||||
} else {
|
|
||||||
format!("{}h", seconds / 3600)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn truncate_sub_agent_result(content: &str) -> (String, bool) {
|
|
||||||
if content.len() <= MAX_INLINE_RESULT_CHARS {
|
|
||||||
(content.to_string(), false)
|
|
||||||
} else {
|
|
||||||
let truncate_at = content.floor_char_boundary(MAX_INLINE_RESULT_CHARS);
|
|
||||||
(
|
|
||||||
format!(
|
|
||||||
"{}\n\n[... 结果已截断,共 {} 字符,完整结果请使用 check_task 查看 ...]",
|
|
||||||
&content[..truncate_at],
|
|
||||||
content.len()
|
|
||||||
),
|
|
||||||
true,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn summarize_for_notification(content: &str, _duration_ms: u64) -> String {
|
|
||||||
const MAX_SUMMARY_BYTES: usize = 500;
|
|
||||||
if content.len() <= MAX_SUMMARY_BYTES {
|
|
||||||
content.to_string()
|
|
||||||
} else {
|
|
||||||
let truncate_at = content.floor_char_boundary(MAX_SUMMARY_BYTES);
|
|
||||||
format!("{}...", &content[..truncate_at])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -3,7 +3,11 @@
|
|||||||
//! This module provides a modular framework for building system prompts
|
//! This module provides a modular framework for building system prompts
|
||||||
//! using the SystemPromptBuilder pattern.
|
//! using the SystemPromptBuilder pattern.
|
||||||
//!
|
//!
|
||||||
//! Prompt section ordering: Identity → Environment → Tasks → Rules → Capabilities → Dynamic → Delegation
|
//! Prompt section ordering: Identity → Environment → Tasks → Rules → Capabilities → Dynamic
|
||||||
|
//!
|
||||||
|
//! Configuration files loaded from ~/.picobot/:
|
||||||
|
//! - AGENTS.md — agent identity and behavior
|
||||||
|
//! - USER.md — user preferences and profile
|
||||||
|
|
||||||
use crate::tools::ToolRegistry;
|
use crate::tools::ToolRegistry;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
@ -51,35 +55,10 @@ impl SystemPromptBuilder {
|
|||||||
Box::new(CrossChannelSection),
|
Box::new(CrossChannelSection),
|
||||||
Box::new(MemorySection),
|
Box::new(MemorySection),
|
||||||
Box::new(HistorySection),
|
Box::new(HistorySection),
|
||||||
Box::new(DelegationSection),
|
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a builder with sub-agent specific sections.
|
|
||||||
pub fn with_sub_agent_defaults(
|
|
||||||
task: &str,
|
|
||||||
timeout: &str,
|
|
||||||
skills_prompt: Option<String>,
|
|
||||||
http_get_only: bool,
|
|
||||||
) -> Self {
|
|
||||||
let mut sections: Vec<Box<dyn PromptSection>> = vec![
|
|
||||||
Box::new(SubAgentIdentitySection {
|
|
||||||
task: task.to_string(),
|
|
||||||
timeout: timeout.to_string(),
|
|
||||||
}),
|
|
||||||
Box::new(ToolHonestySection),
|
|
||||||
Box::new(SafetySection),
|
|
||||||
Box::new(SubAgentToolsSection { http_get_only }),
|
|
||||||
Box::new(WorkspaceSection),
|
|
||||||
Box::new(DateTimeSection),
|
|
||||||
];
|
|
||||||
if let Some(sp) = skills_prompt {
|
|
||||||
sections.push(Box::new(SubAgentSkillsSection { skills_prompt: sp }));
|
|
||||||
}
|
|
||||||
Self { sections }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Add a custom section to the builder.
|
/// Add a custom section to the builder.
|
||||||
pub fn add_section(mut self, section: Box<dyn PromptSection>) -> Self {
|
pub fn add_section(mut self, section: Box<dyn PromptSection>) -> Self {
|
||||||
self.sections.push(section);
|
self.sections.push(section);
|
||||||
@ -374,120 +353,6 @@ impl PromptSection for HistorySection {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sub-agent delegation principles.
|
|
||||||
pub struct DelegationSection;
|
|
||||||
|
|
||||||
impl PromptSection for DelegationSection {
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
"delegation"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build(&self, _ctx: &PromptContext<'_>) -> String {
|
|
||||||
"## 子 Agent 委托原则\n\n\
|
|
||||||
当任务复杂需要拆解时,使用 delegate 工具创建子 Agent:\n\
|
|
||||||
\n\
|
|
||||||
### 何时委托\n\
|
|
||||||
- 多个独立子任务可以并行处理时(使用 mode=\"parallel\")\n\
|
|
||||||
- 长时间运行的任务需要后台执行时(使用 mode=\"background\")\n\
|
|
||||||
- 需要以不同权限(受限工具集)执行时\n\
|
|
||||||
\n\
|
|
||||||
### 工具分配原则\n\
|
|
||||||
- **最小权限**:只给子 Agent 完成其任务所需的最少工具\n\
|
|
||||||
- **只读优先**:如果可以只用 file_read、file_search、web_fetch 完成,不要给写权限(bash、file_write、file_edit)\n\
|
|
||||||
- **禁止递归**:永远不要把 delegate 工具分配给子 Agent\n\
|
|
||||||
- **明确边界**:每个子 Agent 只负责一个清晰、独立的子任务\n\
|
|
||||||
\n\
|
|
||||||
### Skill 分配原则\n\
|
|
||||||
- 如果子任务的领域有对应的 skill,在 allowed_tools 中加入 get_skill\n\
|
|
||||||
- 在任务 prompt 中明确告诉子 Agent 使用 get_skill 加载哪个技能\n\
|
|
||||||
- 例如:\"使用 get_skill action='get' skill_name='pdf' 加载 PDF 处理技能后完成任务\"\n\
|
|
||||||
\n\
|
|
||||||
### 任务描述\n\
|
|
||||||
- 任务 prompt 要清晰、具体、有明确输出要求\n\
|
|
||||||
- 如需额外约束,直接写在 prompt 中(例如:\"跳过 .tmp 文件\")\n\
|
|
||||||
- 明确说明期望的输出格式\n\
|
|
||||||
\n\
|
|
||||||
### 并行模式\n\
|
|
||||||
- 多个无依赖的子任务使用 mode=\"parallel\",任务定义在 tasks 数组中\n\
|
|
||||||
- 并行任务之间不应有数据依赖\n\
|
|
||||||
- 并行任务数建议不超过 5 个\n\
|
|
||||||
\n\
|
|
||||||
### 后台模式\n\
|
|
||||||
- 预计执行时间超过 30s 的任务使用 mode=\"background\"\n\
|
|
||||||
- 后台任务有全局并发上限,如果失败提示用户稍后重试".to_string()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// === Sub-Agent Prompt Sections ===
|
|
||||||
|
|
||||||
/// Sub-agent identity and task instructions.
|
|
||||||
pub struct SubAgentIdentitySection {
|
|
||||||
pub task: String,
|
|
||||||
pub timeout: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PromptSection for SubAgentIdentitySection {
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
"sub_agent_identity"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build(&self, _ctx: &PromptContext<'_>) -> String {
|
|
||||||
format!(
|
|
||||||
"## 子 Agent\n\n\
|
|
||||||
你是主 Agent 派出的子 Agent,负责完成一个具体任务。你的最终回复将汇报给主 Agent。\n\
|
|
||||||
\n\
|
|
||||||
## 任务\n\n\
|
|
||||||
{}\n\
|
|
||||||
\n\
|
|
||||||
## 规则\n\
|
|
||||||
- 只专注于上述任务,不要探索无关话题\n\
|
|
||||||
- 只在必要时使用工具\n\
|
|
||||||
- 不要使用 delegate 工具(禁止递归委托)\n\
|
|
||||||
- 如果任务无法完成,清楚说明原因\n\
|
|
||||||
- 只返回最终结果,不要描述过程\n\
|
|
||||||
- 超时:{},接近时限时返回部分结果",
|
|
||||||
self.task, self.timeout,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Sub-agent available tools description.
|
|
||||||
pub struct SubAgentToolsSection {
|
|
||||||
pub http_get_only: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PromptSection for SubAgentToolsSection {
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
"sub_agent_tools"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build(&self, ctx: &PromptContext<'_>) -> String {
|
|
||||||
let mut s = String::from("## 可用工具\n\n");
|
|
||||||
s.push_str(&ctx.tools.describe_for_prompt());
|
|
||||||
if self.http_get_only {
|
|
||||||
s.push_str(
|
|
||||||
"\n\n**注意**:使用 http_request 时只允许 GET 方法,禁止 POST、PUT、DELETE 等。",
|
|
||||||
);
|
|
||||||
}
|
|
||||||
s
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Sub-agent skills information, injected when get_skill tool is available.
|
|
||||||
pub struct SubAgentSkillsSection {
|
|
||||||
pub skills_prompt: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PromptSection for SubAgentSkillsSection {
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
"sub_agent_skills"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build(&self, _ctx: &PromptContext<'_>) -> String {
|
|
||||||
self.skills_prompt.clone()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// === Helper Functions ===
|
// === Helper Functions ===
|
||||||
|
|
||||||
/// Get user config directory (~/.picobot/).
|
/// Get user config directory (~/.picobot/).
|
||||||
@ -544,28 +409,6 @@ pub fn build_system_prompt(
|
|||||||
SystemPromptBuilder::with_defaults().build(&ctx)
|
SystemPromptBuilder::with_defaults().build(&ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build a system prompt for a sub-agent with all relevant operational sections.
|
|
||||||
pub fn build_sub_agent_system_prompt(
|
|
||||||
task: &str,
|
|
||||||
timeout_human: &str,
|
|
||||||
tools: &ToolRegistry,
|
|
||||||
workspace_dir: &Path,
|
|
||||||
model_name: &str,
|
|
||||||
skills_prompt: Option<String>,
|
|
||||||
http_get_only: bool,
|
|
||||||
) -> String {
|
|
||||||
let ctx = PromptContext {
|
|
||||||
workspace_dir,
|
|
||||||
model_name,
|
|
||||||
tools,
|
|
||||||
session_id: None,
|
|
||||||
memory_context: None,
|
|
||||||
has_compressed_history: false,
|
|
||||||
};
|
|
||||||
SystemPromptBuilder::with_sub_agent_defaults(task, timeout_human, skills_prompt, http_get_only)
|
|
||||||
.build(&ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::bus::{MessageBus, OutboundMessage};
|
use crate::bus::{MessageBus, OutboundMessage};
|
||||||
use crate::channels::ChannelManager;
|
|
||||||
use crate::channels::base::{Channel, ChannelError};
|
use crate::channels::base::{Channel, ChannelError};
|
||||||
|
use crate::channels::ChannelManager;
|
||||||
|
|
||||||
/// OutboundDispatcher consumes outbound messages from the MessageBus
|
/// OutboundDispatcher consumes outbound messages from the MessageBus
|
||||||
/// and dispatches them to the appropriate Channel
|
/// and dispatches them to the appropriate Channel
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::providers::ToolCall;
|
use crate::providers::ToolCall;
|
||||||
|
|
||||||
@ -23,9 +23,7 @@ pub struct ImageUrlBlock {
|
|||||||
|
|
||||||
impl ContentBlock {
|
impl ContentBlock {
|
||||||
pub fn text(content: impl Into<String>) -> Self {
|
pub fn text(content: impl Into<String>) -> Self {
|
||||||
Self::Text {
|
Self::Text { text: content.into() }
|
||||||
text: content.into(),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn image_url(url: impl Into<String>) -> Self {
|
pub fn image_url(url: impl Into<String>) -> Self {
|
||||||
@ -163,10 +161,7 @@ impl ChatMessage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn assistant_with_tool_calls(
|
pub fn assistant_with_tool_calls(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
|
||||||
content: impl Into<String>,
|
|
||||||
tool_calls: Vec<ToolCall>,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
Self {
|
||||||
id: uuid::Uuid::new_v4().to_string(),
|
id: uuid::Uuid::new_v4().to_string(),
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
@ -211,11 +206,7 @@ impl ChatMessage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tool(
|
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self {
|
||||||
tool_call_id: impl Into<String>,
|
|
||||||
tool_name: impl Into<String>,
|
|
||||||
content: impl Into<String>,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
Self {
|
||||||
id: uuid::Uuid::new_v4().to_string(),
|
id: uuid::Uuid::new_v4().to_string(),
|
||||||
role: "tool".to_string(),
|
role: "tool".to_string(),
|
||||||
|
|||||||
@ -2,13 +2,10 @@ pub mod dispatcher;
|
|||||||
pub mod message;
|
pub mod message;
|
||||||
|
|
||||||
pub use dispatcher::OutboundDispatcher;
|
pub use dispatcher::OutboundDispatcher;
|
||||||
pub use message::{
|
pub use message::{ChatMessage, ContentBlock, ControlMessage, InboundMessage, MediaItem, MediaRef, MessageSource, OutboundMessage, SourceKind};
|
||||||
ChatMessage, ContentBlock, ControlMessage, InboundMessage, MediaItem, MediaRef, MessageSource,
|
|
||||||
OutboundMessage, SourceKind,
|
|
||||||
};
|
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::{Mutex, mpsc};
|
use tokio::sync::{mpsc, Mutex};
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// MessageBus - Async message queue for Channel <-> Agent communication
|
// MessageBus - Async message queue for Channel <-> Agent communication
|
||||||
@ -52,8 +49,7 @@ impl MessageBus {
|
|||||||
|
|
||||||
/// Consume an inbound message (Agent -> Bus)
|
/// Consume an inbound message (Agent -> Bus)
|
||||||
pub async fn consume_inbound(&self) -> InboundMessage {
|
pub async fn consume_inbound(&self) -> InboundMessage {
|
||||||
let msg = self
|
let msg = self.inbound_rx
|
||||||
.inbound_rx
|
|
||||||
.lock()
|
.lock()
|
||||||
.await
|
.await
|
||||||
.recv()
|
.recv()
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
use async_trait::async_trait;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::{Mutex, mpsc};
|
use async_trait::async_trait;
|
||||||
|
use tokio::sync::{mpsc, Mutex};
|
||||||
|
|
||||||
use crate::bus::{ControlMessage, InboundMessage, MessageBus, OutboundMessage};
|
use crate::bus::{ControlMessage, InboundMessage, MessageBus, OutboundMessage};
|
||||||
use crate::protocol::{SlashCommandInfo, WsInbound, WsOutbound, parse_inbound};
|
|
||||||
use crate::session::{SessionCommand, SessionEvent, UnifiedSessionId};
|
use crate::session::{SessionCommand, SessionEvent, UnifiedSessionId};
|
||||||
|
use crate::protocol::{parse_inbound, WsInbound, WsOutbound, SlashCommandInfo};
|
||||||
|
|
||||||
use super::base::{Channel, ChannelError};
|
use super::base::{Channel, ChannelError};
|
||||||
|
|
||||||
@ -14,7 +14,6 @@ use super::base::{Channel, ChannelError};
|
|||||||
|
|
||||||
pub(crate) struct Client {
|
pub(crate) struct Client {
|
||||||
sender: mpsc::Sender<WsOutbound>,
|
sender: mpsc::Sender<WsOutbound>,
|
||||||
chat_id: String,
|
|
||||||
current_session_id: Mutex<Option<String>>,
|
current_session_id: Mutex<Option<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -42,28 +41,23 @@ impl CliChatChannel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Register a new client connection, returns (session_id, client)
|
/// Register a new client connection, returns (session_id, client)
|
||||||
pub(crate) async fn register_client(
|
pub(crate) async fn register_client(&self, sender: mpsc::Sender<WsOutbound>) -> (String, Arc<Client>) {
|
||||||
&self,
|
// Generate connection ID (used as chat_id) - use short ID
|
||||||
sender: mpsc::Sender<WsOutbound>,
|
let connection_id = crate::util::short_id();
|
||||||
) -> (String, Arc<Client>) {
|
|
||||||
// Each WebSocket connection gets a stable chat scope. All user input and
|
|
||||||
// dialog controls for this client stay inside that scope unless the
|
|
||||||
// protocol explicitly carries a full session id.
|
|
||||||
let chat_id = crate::util::short_id();
|
|
||||||
|
|
||||||
let client = Arc::new(Client {
|
let client = Arc::new(Client {
|
||||||
sender,
|
sender,
|
||||||
chat_id: chat_id.clone(),
|
|
||||||
current_session_id: Mutex::new(None),
|
current_session_id: Mutex::new(None),
|
||||||
});
|
});
|
||||||
self.clients.lock().await.push(client.clone());
|
self.clients.lock().await.push(client.clone());
|
||||||
|
|
||||||
// Create initial session via control message
|
// Create initial session via control message
|
||||||
let session_id = match self.create_session_via_control(&chat_id, None).await {
|
let session_id = match self.create_session_via_control(&connection_id, None).await {
|
||||||
Ok((id, _title)) => id,
|
Ok(id) => id,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!(error = %e, "Failed to create initial session");
|
tracing::error!(error = %e, "Failed to create initial session");
|
||||||
UnifiedSessionId::new("cli_chat", &chat_id, &crate::util::short_id()).to_string()
|
// Fall back to old format for backward compatibility
|
||||||
|
connection_id.clone()
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -79,7 +73,8 @@ impl CliChatChannel {
|
|||||||
/// Handle an inbound message from a client
|
/// Handle an inbound message from a client
|
||||||
pub(crate) async fn handle_inbound(&self, client: Arc<Client>, raw_msg: &str) {
|
pub(crate) async fn handle_inbound(&self, client: Arc<Client>, raw_msg: &str) {
|
||||||
match parse_inbound(raw_msg) {
|
match parse_inbound(raw_msg) {
|
||||||
Ok(inbound) => match self.handle_ws_inbound(client.clone(), inbound).await {
|
Ok(inbound) => {
|
||||||
|
match self.handle_ws_inbound(client.clone(), inbound).await {
|
||||||
Ok(()) => {}
|
Ok(()) => {}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::warn!(error = %e, "Failed to handle inbound message");
|
tracing::warn!(error = %e, "Failed to handle inbound message");
|
||||||
@ -91,7 +86,8 @@ impl CliChatChannel {
|
|||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::warn!(error = %e, "Failed to parse inbound message");
|
tracing::warn!(error = %e, "Failed to parse inbound message");
|
||||||
let _ = client
|
let _ = client
|
||||||
@ -105,30 +101,22 @@ impl CliChatChannel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_ws_inbound(
|
async fn handle_ws_inbound(&self, client: Arc<Client>, inbound: WsInbound) -> Result<(), ChannelError> {
|
||||||
&self,
|
|
||||||
client: Arc<Client>,
|
|
||||||
inbound: WsInbound,
|
|
||||||
) -> Result<(), ChannelError> {
|
|
||||||
let bus = {
|
let bus = {
|
||||||
let guard = self.bus.lock().unwrap();
|
let guard = self.bus.lock().unwrap();
|
||||||
guard
|
guard.clone().ok_or_else(|| ChannelError::Other("Channel not started".to_string()))?
|
||||||
.clone()
|
|
||||||
.ok_or_else(|| ChannelError::Other("Channel not started".to_string()))?
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut current_session_guard = client.current_session_id.lock().await;
|
let mut current_session_guard = client.current_session_id.lock().await;
|
||||||
|
|
||||||
match inbound {
|
match inbound {
|
||||||
WsInbound::UserInput {
|
WsInbound::UserInput { content, chat_id, .. } => {
|
||||||
content, chat_id, ..
|
|
||||||
} => {
|
|
||||||
// All messages (including slash commands) go through the normal inbound flow
|
// All messages (including slash commands) go through the normal inbound flow
|
||||||
// SessionManager handles session creation/reuse internally
|
// SessionManager handles session creation/reuse internally
|
||||||
let msg = InboundMessage {
|
let msg = InboundMessage {
|
||||||
channel: self.name().to_string(),
|
channel: self.name().to_string(),
|
||||||
sender_id: "cli".to_string(),
|
sender_id: "cli".to_string(),
|
||||||
chat_id: chat_id.unwrap_or_else(|| client.chat_id.clone()),
|
chat_id: chat_id.unwrap_or_else(crate::util::short_id),
|
||||||
content,
|
content,
|
||||||
timestamp: crate::bus::message::current_timestamp(),
|
timestamp: crate::bus::message::current_timestamp(),
|
||||||
media: Vec::new(),
|
media: Vec::new(),
|
||||||
@ -137,56 +125,19 @@ impl CliChatChannel {
|
|||||||
};
|
};
|
||||||
bus.publish_inbound(msg).await?;
|
bus.publish_inbound(msg).await?;
|
||||||
}
|
}
|
||||||
WsInbound::ClearHistory {
|
WsInbound::ClearHistory { chat_id, session_id } => {
|
||||||
chat_id,
|
let target = session_id
|
||||||
session_id,
|
.or(chat_id)
|
||||||
} => {
|
.or(current_session_guard.clone())
|
||||||
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
|
||||||
let session_id = if let Some(session_id) = session_id {
|
|
||||||
UnifiedSessionId::parse(&session_id).ok_or_else(|| {
|
|
||||||
ChannelError::Other("Invalid session ID format".to_string())
|
|
||||||
})?
|
|
||||||
} else if let Some(chat_id) = chat_id {
|
|
||||||
let (current_tx, mut current_rx) = mpsc::channel(1);
|
|
||||||
bus.publish_control(ControlMessage {
|
|
||||||
op: SessionCommand::GetCurrentDialog {
|
|
||||||
channel: "cli_chat".to_string(),
|
|
||||||
chat_id,
|
|
||||||
},
|
|
||||||
reply_tx: current_tx,
|
|
||||||
})
|
|
||||||
.await?;
|
|
||||||
match current_rx.recv().await {
|
|
||||||
Some(Ok(SessionEvent::CurrentDialog {
|
|
||||||
session_id: Some(session_id),
|
|
||||||
})) => session_id,
|
|
||||||
Some(Ok(SessionEvent::CurrentDialog { session_id: None })) => {
|
|
||||||
return Err(ChannelError::Other("No active session".to_string()));
|
|
||||||
}
|
|
||||||
Some(Ok(_)) => {
|
|
||||||
return Err(ChannelError::Other(
|
|
||||||
"Unexpected response type".to_string(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
Some(Err(e)) => return Err(e),
|
|
||||||
None => {
|
|
||||||
return Err(ChannelError::Other("Control channel closed".to_string()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
let target = current_session_guard
|
|
||||||
.clone()
|
|
||||||
.ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
|
.ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
|
||||||
UnifiedSessionId::parse(&target).ok_or_else(|| {
|
|
||||||
ChannelError::Other("Invalid session ID format".to_string())
|
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
||||||
})?
|
let session_id = UnifiedSessionId::parse(&target)
|
||||||
};
|
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
|
||||||
let target = session_id.to_string();
|
|
||||||
bus.publish_control(ControlMessage {
|
bus.publish_control(ControlMessage {
|
||||||
op: SessionCommand::ClearHistory { session_id },
|
op: SessionCommand::ClearHistory { session_id },
|
||||||
reply_tx,
|
reply_tx,
|
||||||
})
|
}).await?;
|
||||||
.await?;
|
|
||||||
|
|
||||||
match reply_rx.recv().await {
|
match reply_rx.recv().await {
|
||||||
Some(Ok(SessionEvent::HistoryCleared { .. })) => {
|
Some(Ok(SessionEvent::HistoryCleared { .. })) => {
|
||||||
@ -207,21 +158,24 @@ impl CliChatChannel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
WsInbound::CreateSession { title } => {
|
WsInbound::CreateSession { title } => {
|
||||||
let (new_id, created_title) = self
|
// Use current session's chat_id if available, otherwise generate new one
|
||||||
.create_session_via_control(&client.chat_id, title.as_deref())
|
let chat_id = current_session_guard.clone()
|
||||||
.await?;
|
.unwrap_or_else(crate::util::short_id);
|
||||||
|
let new_id = self.create_session_via_control(&chat_id, title.as_deref()).await?;
|
||||||
*current_session_guard = Some(new_id.clone());
|
*current_session_guard = Some(new_id.clone());
|
||||||
let _ = client
|
let _ = client
|
||||||
.sender
|
.sender
|
||||||
.send(WsOutbound::SessionCreated {
|
.send(WsOutbound::SessionCreated {
|
||||||
session_id: new_id,
|
session_id: new_id,
|
||||||
title: created_title,
|
title: title.unwrap_or_default(),
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
WsInbound::ListSessions { include_archived } => {
|
WsInbound::ListSessions { include_archived } => {
|
||||||
// List dialogs for the current chat
|
// List dialogs for the current chat
|
||||||
let chat_id = client.chat_id.clone();
|
let chat_id = current_session_guard.clone()
|
||||||
|
.unwrap_or_else(|| "".to_string());
|
||||||
|
let chat_id_for_response = chat_id.clone();
|
||||||
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
||||||
bus.publish_control(ControlMessage {
|
bus.publish_control(ControlMessage {
|
||||||
op: SessionCommand::ListDialogs {
|
op: SessionCommand::ListDialogs {
|
||||||
@ -230,18 +184,13 @@ impl CliChatChannel {
|
|||||||
include_archived,
|
include_archived,
|
||||||
},
|
},
|
||||||
reply_tx,
|
reply_tx,
|
||||||
})
|
}).await?;
|
||||||
.await?;
|
|
||||||
|
|
||||||
match reply_rx.recv().await {
|
match reply_rx.recv().await {
|
||||||
Some(Ok(SessionEvent::DialogList {
|
Some(Ok(SessionEvent::DialogList { dialogs, current_dialog_id })) => {
|
||||||
dialogs,
|
|
||||||
current_dialog_id,
|
|
||||||
})) => {
|
|
||||||
// Convert DialogInfo to SessionSummary for backward compatibility
|
// Convert DialogInfo to SessionSummary for backward compatibility
|
||||||
let sessions: Vec<crate::protocol::SessionSummary> = dialogs
|
let sessions: Vec<crate::protocol::SessionSummary> = dialogs.into_iter().map(|d| {
|
||||||
.into_iter()
|
crate::protocol::SessionSummary {
|
||||||
.map(|d| crate::protocol::SessionSummary {
|
|
||||||
session_id: d.session_id.to_string(),
|
session_id: d.session_id.to_string(),
|
||||||
title: d.title,
|
title: d.title,
|
||||||
channel_name: d.session_id.channel.clone(),
|
channel_name: d.session_id.channel.clone(),
|
||||||
@ -249,14 +198,11 @@ impl CliChatChannel {
|
|||||||
message_count: d.message_count,
|
message_count: d.message_count,
|
||||||
last_active_at: d.last_active_at,
|
last_active_at: d.last_active_at,
|
||||||
archived_at: d.archived_at,
|
archived_at: d.archived_at,
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
let current_session_id = current_dialog_id.map(|did| {
|
|
||||||
UnifiedSessionId::new("cli_chat", &client.chat_id, &did).to_string()
|
|
||||||
});
|
|
||||||
if let Some(ref session_id) = current_session_id {
|
|
||||||
*current_session_guard = Some(session_id.clone());
|
|
||||||
}
|
}
|
||||||
|
}).collect();
|
||||||
|
let current_session_id = current_dialog_id.map(|did| {
|
||||||
|
UnifiedSessionId::new("cli_chat", chat_id_for_response.clone(), did).to_string()
|
||||||
|
});
|
||||||
let _ = client
|
let _ = client
|
||||||
.sender
|
.sender
|
||||||
.send(WsOutbound::SessionList {
|
.send(WsOutbound::SessionList {
|
||||||
@ -277,35 +223,39 @@ impl CliChatChannel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
WsInbound::LoadSession { session_id } => {
|
WsInbound::LoadSession { session_id } => {
|
||||||
|
// LoadSession: parse the session_id and get current dialog info
|
||||||
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
||||||
let unified_id = UnifiedSessionId::parse(&session_id)
|
let unified_id = UnifiedSessionId::parse(&session_id)
|
||||||
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
|
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
|
||||||
if unified_id.channel != "cli_chat" || unified_id.chat_id != client.chat_id {
|
|
||||||
return Err(ChannelError::Other(
|
|
||||||
"Session does not belong to this client".to_string(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
bus.publish_control(ControlMessage {
|
bus.publish_control(ControlMessage {
|
||||||
op: SessionCommand::SwitchDialog {
|
op: SessionCommand::GetCurrentDialog {
|
||||||
channel: unified_id.channel.clone(),
|
channel: unified_id.channel.clone(),
|
||||||
chat_id: unified_id.chat_id.clone(),
|
chat_id: unified_id.chat_id.clone(),
|
||||||
dialog_id: unified_id.dialog_id.clone(),
|
|
||||||
},
|
},
|
||||||
reply_tx,
|
reply_tx,
|
||||||
})
|
}).await?;
|
||||||
.await?;
|
|
||||||
|
|
||||||
match reply_rx.recv().await {
|
match reply_rx.recv().await {
|
||||||
Some(Ok(SessionEvent::DialogSwitched { session_id })) => {
|
Some(Ok(SessionEvent::CurrentDialog { session_id: current_session_id_opt })) => {
|
||||||
*current_session_guard = Some(session_id.to_string());
|
if let Some(current_session_id) = current_session_id_opt {
|
||||||
|
*current_session_guard = Some(current_session_id.to_string());
|
||||||
let _ = client
|
let _ = client
|
||||||
.sender
|
.sender
|
||||||
.send(WsOutbound::SessionLoaded {
|
.send(WsOutbound::SessionLoaded {
|
||||||
session_id: session_id.to_string(),
|
session_id: current_session_id.to_string(),
|
||||||
title: "Session".to_string(),
|
title: "Session".to_string(), // TODO: get actual title
|
||||||
message_count: 0,
|
message_count: 0, // TODO: get actual count
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
|
} else {
|
||||||
|
let _ = client
|
||||||
|
.sender
|
||||||
|
.send(WsOutbound::Error {
|
||||||
|
code: "NO_CURRENT_DIALOG".to_string(),
|
||||||
|
message: "No current dialog".to_string(),
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Some(Ok(_)) => {
|
Some(Ok(_)) => {
|
||||||
// Unexpected response type
|
// Unexpected response type
|
||||||
@ -325,30 +275,23 @@ impl CliChatChannel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
WsInbound::RenameSession { session_id, title } => {
|
WsInbound::RenameSession { session_id, title } => {
|
||||||
let target = session_id
|
let target = session_id.or(current_session_guard.clone()).ok_or_else(|| {
|
||||||
.or(current_session_guard.clone())
|
ChannelError::Other("No active session".to_string())
|
||||||
.ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
|
})?;
|
||||||
|
|
||||||
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
||||||
let unified_id = UnifiedSessionId::parse(&target)
|
let unified_id = UnifiedSessionId::parse(&target)
|
||||||
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
|
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
|
||||||
bus.publish_control(ControlMessage {
|
bus.publish_control(ControlMessage {
|
||||||
op: SessionCommand::RenameDialog {
|
op: SessionCommand::RenameDialog { session_id: unified_id, title: title.clone() },
|
||||||
session_id: unified_id,
|
|
||||||
title: title.clone(),
|
|
||||||
},
|
|
||||||
reply_tx,
|
reply_tx,
|
||||||
})
|
}).await?;
|
||||||
.await?;
|
|
||||||
|
|
||||||
match reply_rx.recv().await {
|
match reply_rx.recv().await {
|
||||||
Some(Ok(SessionEvent::DialogRenamed { session_id, title })) => {
|
Some(Ok(SessionEvent::DialogRenamed { session_id, title })) => {
|
||||||
let _ = client
|
let _ = client
|
||||||
.sender
|
.sender
|
||||||
.send(WsOutbound::SessionRenamed {
|
.send(WsOutbound::SessionRenamed { session_id: session_id.to_string(), title })
|
||||||
session_id: session_id.to_string(),
|
|
||||||
title,
|
|
||||||
})
|
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
Some(Ok(_)) => {
|
Some(Ok(_)) => {
|
||||||
@ -363,43 +306,24 @@ impl CliChatChannel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
WsInbound::ArchiveSession { session_id } => {
|
WsInbound::ArchiveSession { session_id } => {
|
||||||
let target = session_id
|
let target = session_id.or(current_session_guard.clone()).ok_or_else(|| {
|
||||||
.or(current_session_guard.clone())
|
ChannelError::Other("No active session".to_string())
|
||||||
.ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
|
})?;
|
||||||
let was_current = current_session_guard.as_deref() == Some(&target);
|
|
||||||
|
|
||||||
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
||||||
let unified_id = UnifiedSessionId::parse(&target)
|
let unified_id = UnifiedSessionId::parse(&target)
|
||||||
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
|
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
|
||||||
bus.publish_control(ControlMessage {
|
bus.publish_control(ControlMessage {
|
||||||
op: SessionCommand::ArchiveDialog {
|
op: SessionCommand::ArchiveDialog { session_id: unified_id },
|
||||||
session_id: unified_id,
|
|
||||||
},
|
|
||||||
reply_tx,
|
reply_tx,
|
||||||
})
|
}).await?;
|
||||||
.await?;
|
|
||||||
|
|
||||||
match reply_rx.recv().await {
|
match reply_rx.recv().await {
|
||||||
Some(Ok(SessionEvent::DialogArchived { session_id })) => {
|
Some(Ok(SessionEvent::DialogArchived { session_id })) => {
|
||||||
let _ = client
|
let _ = client
|
||||||
.sender
|
.sender
|
||||||
.send(WsOutbound::SessionArchived {
|
.send(WsOutbound::SessionArchived { session_id: session_id.to_string() })
|
||||||
session_id: session_id.to_string(),
|
|
||||||
})
|
|
||||||
.await;
|
.await;
|
||||||
if was_current {
|
|
||||||
let (new_id, title) = self
|
|
||||||
.create_session_via_control(&client.chat_id, None)
|
|
||||||
.await?;
|
|
||||||
*current_session_guard = Some(new_id.clone());
|
|
||||||
let _ = client
|
|
||||||
.sender
|
|
||||||
.send(WsOutbound::SessionCreated {
|
|
||||||
session_id: new_id,
|
|
||||||
title,
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
Some(Ok(_)) => {
|
Some(Ok(_)) => {
|
||||||
// Unexpected response type
|
// Unexpected response type
|
||||||
@ -413,42 +337,35 @@ impl CliChatChannel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
WsInbound::DeleteSession { session_id } => {
|
WsInbound::DeleteSession { session_id } => {
|
||||||
let target = session_id
|
let target = session_id.or(current_session_guard.clone()).ok_or_else(|| {
|
||||||
.or(current_session_guard.clone())
|
ChannelError::Other("No active session".to_string())
|
||||||
.ok_or_else(|| ChannelError::Other("No active session".to_string()))?;
|
})?;
|
||||||
|
|
||||||
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
||||||
let unified_id = UnifiedSessionId::parse(&target)
|
let unified_id = UnifiedSessionId::parse(&target)
|
||||||
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
|
.ok_or_else(|| ChannelError::Other("Invalid session ID format".to_string()))?;
|
||||||
bus.publish_control(ControlMessage {
|
bus.publish_control(ControlMessage {
|
||||||
op: SessionCommand::DeleteDialog {
|
op: SessionCommand::DeleteDialog { session_id: unified_id },
|
||||||
session_id: unified_id,
|
|
||||||
},
|
|
||||||
reply_tx,
|
reply_tx,
|
||||||
})
|
}).await?;
|
||||||
.await?;
|
|
||||||
|
|
||||||
match reply_rx.recv().await {
|
match reply_rx.recv().await {
|
||||||
Some(Ok(SessionEvent::DialogDeleted { session_id })) => {
|
Some(Ok(SessionEvent::DialogDeleted { session_id })) => {
|
||||||
let _ = client
|
let _ = client
|
||||||
.sender
|
.sender
|
||||||
.send(WsOutbound::SessionDeleted {
|
.send(WsOutbound::SessionDeleted { session_id: session_id.to_string() })
|
||||||
session_id: session_id.to_string(),
|
|
||||||
})
|
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
// If deleting current session, create a new one
|
// If deleting current session, create a new one
|
||||||
if current_session_guard.as_deref() == Some(&target) {
|
if current_session_guard.as_deref() == Some(&target) {
|
||||||
drop(reply_rx);
|
drop(reply_rx);
|
||||||
if let Ok((new_id, title)) =
|
if let Ok(new_id) = self.create_session_via_control(&target, None).await {
|
||||||
self.create_session_via_control(&client.chat_id, None).await
|
|
||||||
{
|
|
||||||
*current_session_guard = Some(new_id.clone());
|
*current_session_guard = Some(new_id.clone());
|
||||||
let _ = client
|
let _ = client
|
||||||
.sender
|
.sender
|
||||||
.send(WsOutbound::SessionCreated {
|
.send(WsOutbound::SessionCreated {
|
||||||
session_id: new_id,
|
session_id: new_id,
|
||||||
title,
|
title: String::new(),
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
@ -471,45 +388,32 @@ impl CliChatChannel {
|
|||||||
bus.publish_control(ControlMessage {
|
bus.publish_control(ControlMessage {
|
||||||
op: SessionCommand::GetSlashCommands {
|
op: SessionCommand::GetSlashCommands {
|
||||||
channel: "cli_chat".to_string(),
|
channel: "cli_chat".to_string(),
|
||||||
chat_id: client.chat_id.clone(),
|
chat_id: "".to_string(),
|
||||||
},
|
},
|
||||||
reply_tx,
|
reply_tx,
|
||||||
})
|
}).await?;
|
||||||
.await?;
|
|
||||||
|
|
||||||
if let Some(result) = reply_rx.recv().await {
|
if let Some(result) = reply_rx.recv().await {
|
||||||
match result {
|
match result {
|
||||||
Ok(SessionEvent::SlashCommandsList { commands }) => {
|
Ok(SessionEvent::SlashCommandsList { commands }) => {
|
||||||
// Convert to SlashCommand to SlashCommandInfo
|
// Convert to SlashCommand to SlashCommandInfo
|
||||||
let command_infos: Vec<SlashCommandInfo> = commands
|
let command_infos: Vec<SlashCommandInfo> = commands.into_iter().map(|cmd| {
|
||||||
.into_iter()
|
SlashCommandInfo {
|
||||||
.map(|cmd| SlashCommandInfo {
|
|
||||||
name: cmd.name.to_string(),
|
name: cmd.name.to_string(),
|
||||||
description: cmd.description.to_string(),
|
description: cmd.description.to_string(),
|
||||||
aliases: cmd.aliases.iter().map(|&a| a.to_string()).collect(),
|
aliases: cmd.aliases.iter().map(|&a| a.to_string()).collect(),
|
||||||
})
|
}
|
||||||
.collect();
|
}).collect();
|
||||||
let _ = client
|
let _ = client.sender.send(WsOutbound::SlashCommandsList { commands: command_infos }).await;
|
||||||
.sender
|
|
||||||
.send(WsOutbound::SlashCommandsList {
|
|
||||||
commands: command_infos,
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
}
|
}
|
||||||
Ok(SessionEvent::Error { code, message }) => {
|
Ok(SessionEvent::Error { code, message }) => {
|
||||||
let _ = client
|
let _ = client.sender.send(WsOutbound::Error { code, message }).await;
|
||||||
.sender
|
|
||||||
.send(WsOutbound::Error { code, message })
|
|
||||||
.await;
|
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
let _ = client
|
let _ = client.sender.send(WsOutbound::Error {
|
||||||
.sender
|
|
||||||
.send(WsOutbound::Error {
|
|
||||||
code: "GET_COMMANDS_ERROR".to_string(),
|
code: "GET_COMMANDS_ERROR".to_string(),
|
||||||
message: e.to_string(),
|
message: e.to_string()
|
||||||
})
|
}).await;
|
||||||
.await;
|
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
@ -523,34 +427,29 @@ impl CliChatChannel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Create a session via control message and return the session_id
|
/// Create a session via control message and return the session_id
|
||||||
async fn create_session_via_control(
|
async fn create_session_via_control(&self, connection_id: &str, title: Option<&str>) -> Result<String, ChannelError> {
|
||||||
&self,
|
|
||||||
chat_id: &str,
|
|
||||||
title: Option<&str>,
|
|
||||||
) -> Result<(String, String), ChannelError> {
|
|
||||||
let bus = {
|
let bus = {
|
||||||
let guard = self.bus.lock().unwrap();
|
let guard = self.bus.lock().unwrap();
|
||||||
guard
|
guard.clone().ok_or_else(|| ChannelError::Other("Channel not started".to_string()))?
|
||||||
.clone()
|
|
||||||
.ok_or_else(|| ChannelError::Other("Channel not started".to_string()))?
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
let (reply_tx, mut reply_rx) = mpsc::channel(1);
|
||||||
bus.publish_control(ControlMessage {
|
bus.publish_control(ControlMessage {
|
||||||
op: SessionCommand::CreateDialog {
|
op: SessionCommand::CreateDialog {
|
||||||
channel: "cli_chat".to_string(),
|
channel: "cli_chat".to_string(),
|
||||||
chat_id: chat_id.to_string(),
|
chat_id: connection_id.to_string(),
|
||||||
title: title.map(String::from),
|
title: title.map(String::from),
|
||||||
},
|
},
|
||||||
reply_tx,
|
reply_tx,
|
||||||
})
|
}).await?;
|
||||||
.await?;
|
|
||||||
|
|
||||||
match reply_rx.recv().await {
|
match reply_rx.recv().await {
|
||||||
Some(Ok(SessionEvent::DialogCreated { session_id, title })) => {
|
Some(Ok(SessionEvent::DialogCreated { session_id, .. })) => {
|
||||||
Ok((session_id.to_string(), title))
|
Ok(session_id.to_string())
|
||||||
|
}
|
||||||
|
Some(Ok(_)) => {
|
||||||
|
Err(ChannelError::Other("Unexpected response type".to_string()))
|
||||||
}
|
}
|
||||||
Some(Ok(_)) => Err(ChannelError::Other("Unexpected response type".to_string())),
|
|
||||||
Some(Err(e)) => Err(e),
|
Some(Err(e)) => Err(e),
|
||||||
None => Err(ChannelError::Other("Control channel closed".to_string())),
|
None => Err(ChannelError::Other("Control channel closed".to_string())),
|
||||||
}
|
}
|
||||||
@ -580,11 +479,7 @@ impl Channel for CliChatChannel {
|
|||||||
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
|
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
|
||||||
let clients = self.clients.lock().await.clone();
|
let clients = self.clients.lock().await.clone();
|
||||||
for client in clients {
|
for client in clients {
|
||||||
if client.chat_id != msg.chat_id {
|
let outbound = if msg.metadata.get("_type").map(|v| v.as_str()) == Some("notification") {
|
||||||
continue;
|
|
||||||
}
|
|
||||||
let outbound = if msg.metadata.get("_type").map(|v| v.as_str()) == Some("notification")
|
|
||||||
{
|
|
||||||
WsOutbound::SystemNotification {
|
WsOutbound::SystemNotification {
|
||||||
content: msg.content.clone(),
|
content: msg.content.clone(),
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -24,10 +24,7 @@ impl ChannelManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_bus(
|
pub fn with_bus(cli_chat_channel: Arc<crate::channels::CliChatChannel>, bus: Arc<MessageBus>) -> Self {
|
||||||
cli_chat_channel: Arc<crate::channels::CliChatChannel>,
|
|
||||||
bus: Arc<MessageBus>,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
Self {
|
||||||
channels: Arc::new(RwLock::new(HashMap::new())),
|
channels: Arc::new(RwLock::new(HashMap::new())),
|
||||||
cli_chat_channel,
|
cli_chat_channel,
|
||||||
@ -42,10 +39,7 @@ impl ChannelManager {
|
|||||||
|
|
||||||
/// Register a channel with the manager
|
/// Register a channel with the manager
|
||||||
pub async fn register_channel(&self, name: &str, channel: Arc<dyn Channel + Send + Sync>) {
|
pub async fn register_channel(&self, name: &str, channel: Arc<dyn Channel + Send + Sync>) {
|
||||||
self.channels
|
self.channels.write().await.insert(name.to_string(), channel);
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.insert(name.to_string(), channel);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get CLI chat channel
|
/// Get CLI chat channel
|
||||||
@ -62,19 +56,14 @@ impl ChannelManager {
|
|||||||
// Initialize Feishu channel if enabled
|
// Initialize Feishu channel if enabled
|
||||||
if let Some(feishu_config) = config.channels.get("feishu") {
|
if let Some(feishu_config) = config.channels.get("feishu") {
|
||||||
if feishu_config.enabled {
|
if feishu_config.enabled {
|
||||||
let channel =
|
let channel = FeishuChannel::new(feishu_config.clone(), &workspace_dir)
|
||||||
FeishuChannel::new(feishu_config.clone(), &workspace_dir).map_err(|e| {
|
.map_err(|e| ChannelError::Other(format!("Failed to create Feishu channel: {}", e)))?;
|
||||||
ChannelError::Other(format!("Failed to create Feishu channel: {}", e))
|
|
||||||
})?;
|
|
||||||
|
|
||||||
self.channels
|
self.channels
|
||||||
.write()
|
.write()
|
||||||
.await
|
.await
|
||||||
.insert("feishu".to_string(), Arc::new(channel));
|
.insert("feishu".to_string(), Arc::new(channel));
|
||||||
tracing::info!(
|
tracing::info!("Feishu channel registered (media_dir: {}/media/feishu)", workspace_dir.display());
|
||||||
"Feishu channel registered (media_dir: {}/media/feishu)",
|
|
||||||
workspace_dir.display()
|
|
||||||
);
|
|
||||||
} else {
|
} else {
|
||||||
tracing::info!("Feishu channel disabled in config");
|
tracing::info!("Feishu channel disabled in config");
|
||||||
}
|
}
|
||||||
@ -129,10 +118,7 @@ impl ChannelManager {
|
|||||||
if let Some(channel) = self.get_channel(channel_name).await {
|
if let Some(channel) = self.get_channel(channel_name).await {
|
||||||
channel.send(msg).await
|
channel.send(msg).await
|
||||||
} else {
|
} else {
|
||||||
Err(ChannelError::Other(format!(
|
Err(ChannelError::Other(format!("Channel not found: {}", channel_name)))
|
||||||
"Channel not found: {}",
|
|
||||||
channel_name
|
|
||||||
)))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
pub mod base;
|
pub mod base;
|
||||||
pub mod cli_chat;
|
|
||||||
pub mod feishu;
|
pub mod feishu;
|
||||||
|
pub mod cli_chat;
|
||||||
pub mod manager;
|
pub mod manager;
|
||||||
pub mod slash_command;
|
pub mod slash_command;
|
||||||
|
|
||||||
pub use base::{Channel, ChannelError};
|
pub use base::{Channel, ChannelError};
|
||||||
pub use cli_chat::CliChatChannel;
|
|
||||||
pub use feishu::FeishuChannel;
|
|
||||||
pub use manager::ChannelManager;
|
pub use manager::ChannelManager;
|
||||||
pub use slash_command::{command_matches, parse_slash_command};
|
pub use feishu::FeishuChannel;
|
||||||
|
pub use cli_chat::CliChatChannel;
|
||||||
|
pub use slash_command::{parse_slash_command, command_matches};
|
||||||
|
|||||||
@ -16,9 +16,7 @@ pub fn parse_slash_command(content: &str) -> Option<(&str, &str)> {
|
|||||||
/// 检查内容是否匹配指定命令
|
/// 检查内容是否匹配指定命令
|
||||||
pub fn command_matches(content: &str, aliases: &[&str]) -> bool {
|
pub fn command_matches(content: &str, aliases: &[&str]) -> bool {
|
||||||
let trimmed = content.trim();
|
let trimmed = content.trim();
|
||||||
aliases
|
aliases.iter().any(|&alias| trimmed == alias || trimmed.starts_with(&format!("{} ", alias)))
|
||||||
.iter()
|
|
||||||
.any(|&alias| trimmed == alias || trimmed.starts_with(&format!("{} ", alias)))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@ -29,10 +27,7 @@ mod tests {
|
|||||||
fn test_parse_slash_command() {
|
fn test_parse_slash_command() {
|
||||||
assert_eq!(parse_slash_command("/reset"), Some(("reset", "")));
|
assert_eq!(parse_slash_command("/reset"), Some(("reset", "")));
|
||||||
assert_eq!(parse_slash_command("/reset arg"), Some(("reset", "arg")));
|
assert_eq!(parse_slash_command("/reset arg"), Some(("reset", "arg")));
|
||||||
assert_eq!(
|
assert_eq!(parse_slash_command("/new hello world"), Some(("new", "hello world")));
|
||||||
parse_slash_command("/new hello world"),
|
|
||||||
Some(("new", "hello world"))
|
|
||||||
);
|
|
||||||
assert_eq!(parse_slash_command("/??"), Some(("??", "")));
|
assert_eq!(parse_slash_command("/??"), Some(("??", "")));
|
||||||
assert_eq!(parse_slash_command("/? arg"), Some(("?", "arg")));
|
assert_eq!(parse_slash_command("/? arg"), Some(("?", "arg")));
|
||||||
assert_eq!(parse_slash_command("/?"), Some(("?", "")));
|
assert_eq!(parse_slash_command("/?"), Some(("?", "")));
|
||||||
|
|||||||
@ -8,10 +8,10 @@ use crate::client::tui::ui::render_ui;
|
|||||||
use crossterm::{
|
use crossterm::{
|
||||||
event::{self, Event},
|
event::{self, Event},
|
||||||
execute,
|
execute,
|
||||||
terminal::{EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, enable_raw_mode},
|
terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen},
|
||||||
};
|
};
|
||||||
use futures_util::{SinkExt, StreamExt};
|
use futures_util::{SinkExt, StreamExt};
|
||||||
use ratatui::{Terminal, prelude::CrosstermBackend};
|
use ratatui::{prelude::CrosstermBackend, Terminal};
|
||||||
use std::io;
|
use std::io;
|
||||||
use tokio_tungstenite::{connect_async, tungstenite::Message};
|
use tokio_tungstenite::{connect_async, tungstenite::Message};
|
||||||
|
|
||||||
@ -104,10 +104,7 @@ async fn handle_ws_message(app: &mut App, outbound: WsOutbound) {
|
|||||||
WsOutbound::SessionCreated { session_id, .. } => {
|
WsOutbound::SessionCreated { session_id, .. } => {
|
||||||
app.set_current_session(Some(session_id));
|
app.set_current_session(Some(session_id));
|
||||||
}
|
}
|
||||||
WsOutbound::SessionList {
|
WsOutbound::SessionList { sessions, current_session_id } => {
|
||||||
sessions,
|
|
||||||
current_session_id,
|
|
||||||
} => {
|
|
||||||
app.set_sessions(sessions);
|
app.set_sessions(sessions);
|
||||||
if let Some(id) = current_session_id {
|
if let Some(id) = current_session_id {
|
||||||
app.set_current_session(Some(id));
|
app.set_current_session(Some(id));
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
use crate::client::tui::app::{App, MessageRole};
|
use crate::client::tui::app::{App, MessageRole};
|
||||||
use ratatui::{
|
use ratatui::{
|
||||||
Frame,
|
|
||||||
layout::Rect,
|
layout::Rect,
|
||||||
style::{Color, Modifier, Style},
|
style::{Color, Modifier, Style},
|
||||||
text::Line,
|
text::Line,
|
||||||
widgets::{Block, Borders, List, ListItem},
|
widgets::{Block, Borders, List, ListItem},
|
||||||
|
Frame,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
use crate::client::tui::app::App;
|
use crate::client::tui::app::App;
|
||||||
use ratatui::{
|
use ratatui::{
|
||||||
Frame,
|
|
||||||
layout::Rect,
|
layout::Rect,
|
||||||
style::{Color, Modifier, Style},
|
style::{Color, Modifier, Style},
|
||||||
text::{Line, Span},
|
text::{Line, Span},
|
||||||
widgets::{Block, Borders, List, ListItem},
|
widgets::{Block, Borders, List, ListItem},
|
||||||
|
Frame,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
use ratatui::{
|
use ratatui::{
|
||||||
Frame,
|
|
||||||
layout::Rect,
|
layout::Rect,
|
||||||
style::{Color, Modifier, Style},
|
style::{Color, Modifier, Style},
|
||||||
widgets::{Block, Borders, Clear, List, ListItem},
|
widgets::{Block, Borders, Clear, List, ListItem},
|
||||||
|
Frame,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn render(f: &mut Frame, area: Rect) {
|
pub fn render(f: &mut Frame, area: Rect) {
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
use crate::client::tui::app::App;
|
use crate::client::tui::app::App;
|
||||||
use ratatui::{
|
use ratatui::{
|
||||||
Frame,
|
|
||||||
layout::Rect,
|
layout::Rect,
|
||||||
style::{Color, Style},
|
style::{Color, Style},
|
||||||
widgets::{Block, Borders, Paragraph},
|
widgets::{Block, Borders, Paragraph},
|
||||||
|
Frame,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
use crate::client::tui::app::App;
|
use crate::client::tui::app::App;
|
||||||
use ratatui::{
|
use ratatui::{
|
||||||
Frame,
|
|
||||||
layout::Rect,
|
layout::Rect,
|
||||||
style::{Color, Modifier, Style},
|
style::{Color, Modifier, Style},
|
||||||
widgets::{Block, Borders, List, ListItem},
|
widgets::{Block, Borders, List, ListItem},
|
||||||
|
Frame,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
||||||
@ -11,7 +11,9 @@ pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
|||||||
.sessions
|
.sessions
|
||||||
.iter()
|
.iter()
|
||||||
.map(|session| {
|
.map(|session| {
|
||||||
let is_current = app.current_session_id.as_ref() == Some(&session.session_id);
|
let is_current = app
|
||||||
|
.current_session_id
|
||||||
|
.as_ref() == Some(&session.session_id);
|
||||||
let archived = session.archived_at.is_some();
|
let archived = session.archived_at.is_some();
|
||||||
|
|
||||||
let mut content = if is_current {
|
let mut content = if is_current {
|
||||||
|
|||||||
@ -1,18 +1,15 @@
|
|||||||
use crate::client::tui::app::App;
|
use crate::client::tui::app::App;
|
||||||
use ratatui::{
|
use ratatui::{
|
||||||
Frame,
|
|
||||||
layout::Rect,
|
layout::Rect,
|
||||||
style::{Color, Modifier, Style},
|
style::{Color, Modifier, Style},
|
||||||
widgets::{Block, Borders, Paragraph},
|
widgets::{Block, Borders, Paragraph},
|
||||||
|
Frame,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
||||||
let (title, style) = if app.pending_quit {
|
let (title, style) = if app.pending_quit {
|
||||||
let msg = if let Some(session_id) = &app.current_session_id {
|
let msg = if let Some(session_id) = &app.current_session_id {
|
||||||
format!(
|
format!("PicoBot | Session: {} | Press Ctrl+C again to quit", session_id)
|
||||||
"PicoBot | Session: {} | Press Ctrl+C again to quit",
|
|
||||||
session_id
|
|
||||||
)
|
|
||||||
} else {
|
} else {
|
||||||
"PicoBot | Press Ctrl+C again to quit".to_string()
|
"PicoBot | Press Ctrl+C again to quit".to_string()
|
||||||
};
|
};
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
use crate::client::tui::app::{App, MessageRole};
|
use crate::client::tui::app::{App, MessageRole};
|
||||||
use crate::protocol::WsInbound;
|
|
||||||
use crate::protocol::serialize_inbound;
|
use crate::protocol::serialize_inbound;
|
||||||
|
use crate::protocol::WsInbound;
|
||||||
use crossterm::event::{KeyCode, KeyEvent};
|
use crossterm::event::{KeyCode, KeyEvent};
|
||||||
use futures_util::SinkExt;
|
use futures_util::SinkExt;
|
||||||
|
|
||||||
@ -48,10 +48,7 @@ pub async fn handle_key_event(app: &mut App, key: KeyEvent) {
|
|||||||
|
|
||||||
async fn handle_normal_input(app: &mut App, key: KeyEvent) {
|
async fn handle_normal_input(app: &mut App, key: KeyEvent) {
|
||||||
// Handle Ctrl+C for quit (double press to exit)
|
// Handle Ctrl+C for quit (double press to exit)
|
||||||
let is_ctrl_c = key.code == KeyCode::Char('c')
|
let is_ctrl_c = key.code == KeyCode::Char('c') && key.modifiers.contains(crossterm::event::KeyModifiers::CONTROL);
|
||||||
&& key
|
|
||||||
.modifiers
|
|
||||||
.contains(crossterm::event::KeyModifiers::CONTROL);
|
|
||||||
if is_ctrl_c {
|
if is_ctrl_c {
|
||||||
if app.handle_ctrl_c_for_quit() {
|
if app.handle_ctrl_c_for_quit() {
|
||||||
return;
|
return;
|
||||||
@ -68,9 +65,7 @@ async fn handle_normal_input(app: &mut App, key: KeyEvent) {
|
|||||||
app.input_insert_char(c);
|
app.input_insert_char(c);
|
||||||
|
|
||||||
// Show command menu when input starts with /
|
// Show command menu when input starts with /
|
||||||
if !app.show_command_menu
|
if !app.show_command_menu && (app.input == "/" || (app.input.len() > 1 && app.input.starts_with('/'))) {
|
||||||
&& (app.input == "/" || (app.input.len() > 1 && app.input.starts_with('/')))
|
|
||||||
{
|
|
||||||
app.show_command_menu = true;
|
app.show_command_menu = true;
|
||||||
app.selected_command_idx = 0;
|
app.selected_command_idx = 0;
|
||||||
} else if app.show_command_menu && !app.input.starts_with('/') {
|
} else if app.show_command_menu && !app.input.starts_with('/') {
|
||||||
@ -126,9 +121,7 @@ async fn process_input(app: &mut App, input: String) {
|
|||||||
sender_id: None,
|
sender_id: None,
|
||||||
};
|
};
|
||||||
if let Ok(text) = serialize_inbound(&inbound) {
|
if let Ok(text) = serialize_inbound(&inbound) {
|
||||||
let _ = sender
|
let _ = sender.send(tokio_tungstenite::tungstenite::Message::Text(text.into())).await;
|
||||||
.send(tokio_tungstenite::tungstenite::Message::Text(text.into()))
|
|
||||||
.await;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
use crate::client::tui::app::App;
|
use crate::client::tui::app::App;
|
||||||
use crate::client::tui::components::*;
|
use crate::client::tui::components::*;
|
||||||
use ratatui::{
|
use ratatui::{
|
||||||
Frame,
|
|
||||||
layout::{Constraint, Direction, Layout, Rect},
|
layout::{Constraint, Direction, Layout, Rect},
|
||||||
|
Frame,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn render_ui(f: &mut Frame, app: &App) {
|
pub fn render_ui(f: &mut Frame, app: &App) {
|
||||||
|
|||||||
@ -152,26 +152,10 @@ pub struct GatewayConfig {
|
|||||||
pub cleanup_interval_minutes: Option<u64>,
|
pub cleanup_interval_minutes: Option<u64>,
|
||||||
#[serde(default, rename = "session_db_path")]
|
#[serde(default, rename = "session_db_path")]
|
||||||
pub session_db_path: Option<String>,
|
pub session_db_path: Option<String>,
|
||||||
#[serde(default, rename = "max_concurrent_background_tasks")]
|
|
||||||
pub max_concurrent_background_tasks: usize,
|
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub scheduler: Option<SchedulerConfig>,
|
pub scheduler: Option<SchedulerConfig>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for GatewayConfig {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
host: default_gateway_host(),
|
|
||||||
port: default_gateway_port(),
|
|
||||||
session_ttl_hours: None,
|
|
||||||
cleanup_interval_minutes: None,
|
|
||||||
session_db_path: None,
|
|
||||||
max_concurrent_background_tasks: 10,
|
|
||||||
scheduler: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct SchedulerConfig {
|
pub struct SchedulerConfig {
|
||||||
/// Whether the scheduler is enabled
|
/// Whether the scheduler is enabled
|
||||||
@ -225,6 +209,19 @@ fn default_gateway_url() -> String {
|
|||||||
"ws://127.0.0.1:19876/ws".to_string()
|
"ws://127.0.0.1:19876/ws".to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Default for GatewayConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
host: default_gateway_host(),
|
||||||
|
port: default_gateway_port(),
|
||||||
|
session_ttl_hours: None,
|
||||||
|
cleanup_interval_minutes: None,
|
||||||
|
session_db_path: None,
|
||||||
|
scheduler: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Default for ClientConfig {
|
impl Default for ClientConfig {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
@ -273,16 +270,12 @@ impl Default for MemoryConfig {
|
|||||||
impl MemoryConfig {
|
impl MemoryConfig {
|
||||||
/// Resolve consolidation provider name, falling back to the main agent's provider.
|
/// Resolve consolidation provider name, falling back to the main agent's provider.
|
||||||
pub fn resolve_consolidation_provider(&self, default: &str) -> String {
|
pub fn resolve_consolidation_provider(&self, default: &str) -> String {
|
||||||
self.consolidation_provider
|
self.consolidation_provider.clone().unwrap_or_else(|| default.to_string())
|
||||||
.clone()
|
|
||||||
.unwrap_or_else(|| default.to_string())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Resolve consolidation model name, falling back to the main agent's model.
|
/// Resolve consolidation model name, falling back to the main agent's model.
|
||||||
pub fn resolve_consolidation_model(&self, default: &str) -> String {
|
pub fn resolve_consolidation_model(&self, default: &str) -> String {
|
||||||
self.consolidation_model
|
self.consolidation_model.clone().unwrap_or_else(|| default.to_string())
|
||||||
.clone()
|
|
||||||
.unwrap_or_else(|| default.to_string())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -370,18 +363,10 @@ impl Default for BrowserConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_recall_limit() -> usize {
|
fn default_recall_limit() -> usize { 5 }
|
||||||
5
|
fn default_idle_consolidation_minutes() -> u64 { 10 }
|
||||||
}
|
fn default_timeline_retention_days() -> u64 { 90 }
|
||||||
fn default_idle_consolidation_minutes() -> u64 {
|
fn default_max_failures_before_degrade() -> usize { 3 }
|
||||||
10
|
|
||||||
}
|
|
||||||
fn default_timeline_retention_days() -> u64 {
|
|
||||||
90
|
|
||||||
}
|
|
||||||
fn default_max_failures_before_degrade() -> usize {
|
|
||||||
3
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct LLMProviderConfig {
|
pub struct LLMProviderConfig {
|
||||||
@ -481,11 +466,7 @@ pub enum ConfigError {
|
|||||||
impl std::fmt::Display for ConfigError {
|
impl std::fmt::Display for ConfigError {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
ConfigError::ConfigNotFound(path) => write!(
|
ConfigError::ConfigNotFound(path) => write!(f, "Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json", path),
|
||||||
f,
|
|
||||||
"Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json",
|
|
||||||
path
|
|
||||||
),
|
|
||||||
ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name),
|
ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name),
|
||||||
ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name),
|
ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name),
|
||||||
ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name),
|
ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name),
|
||||||
|
|||||||
@ -1,19 +1,19 @@
|
|||||||
pub mod http;
|
pub mod http;
|
||||||
pub mod ws;
|
pub mod ws;
|
||||||
|
|
||||||
use axum::{Router, routing};
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use axum::{routing, Router};
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
|
|
||||||
use crate::bus::{ControlMessage, MessageBus, OutboundDispatcher};
|
use crate::bus::{ControlMessage, MessageBus, OutboundDispatcher};
|
||||||
use crate::channels::base::{Channel, ChannelError};
|
|
||||||
use crate::channels::{ChannelManager, CliChatChannel};
|
use crate::channels::{ChannelManager, CliChatChannel};
|
||||||
use crate::config::{Config, ensure_workspace_dir, expand_path};
|
use crate::channels::base::{Channel, ChannelError};
|
||||||
|
use crate::config::{Config, expand_path, ensure_workspace_dir};
|
||||||
use crate::logging;
|
use crate::logging;
|
||||||
use crate::mcp;
|
use crate::mcp;
|
||||||
use crate::memory::MemoryManager;
|
use crate::memory::MemoryManager;
|
||||||
use crate::scheduler::Scheduler;
|
|
||||||
use crate::session::SessionManager;
|
use crate::session::SessionManager;
|
||||||
|
use crate::scheduler::Scheduler;
|
||||||
|
|
||||||
pub struct GatewayState {
|
pub struct GatewayState {
|
||||||
pub config: Config,
|
pub config: Config,
|
||||||
@ -32,13 +32,8 @@ impl GatewayState {
|
|||||||
let workspace_path = ensure_workspace_dir(&workspace_path)?;
|
let workspace_path = ensure_workspace_dir(&workspace_path)?;
|
||||||
|
|
||||||
// Switch current working directory to workspace
|
// Switch current working directory to workspace
|
||||||
std::env::set_current_dir(&workspace_path).map_err(|e| {
|
std::env::set_current_dir(&workspace_path)
|
||||||
format!(
|
.map_err(|e| format!("Failed to switch to workspace directory {}: {}", workspace_path.display(), e))?;
|
||||||
"Failed to switch to workspace directory {}: {}",
|
|
||||||
workspace_path.display(),
|
|
||||||
e
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
tracing::info!("Using workspace directory: {}", workspace_path.display());
|
tracing::info!("Using workspace directory: {}", workspace_path.display());
|
||||||
|
|
||||||
@ -57,9 +52,8 @@ impl GatewayState {
|
|||||||
workspace_path.join("picobot.db")
|
workspace_path.join("picobot.db")
|
||||||
};
|
};
|
||||||
let storage = Arc::new(
|
let storage = Arc::new(
|
||||||
crate::storage::Storage::new(&db_path)
|
crate::storage::Storage::new(&db_path).await
|
||||||
.await
|
.map_err(|e| format!("failed to initialize session storage: {}", e))?
|
||||||
.map_err(|e| format!("failed to initialize session storage: {}", e))?,
|
|
||||||
);
|
);
|
||||||
tracing::info!("Session storage: {}", db_path.display());
|
tracing::info!("Session storage: {}", db_path.display());
|
||||||
|
|
||||||
@ -97,16 +91,13 @@ impl GatewayState {
|
|||||||
bus.clone(),
|
bus.clone(),
|
||||||
memory_manager,
|
memory_manager,
|
||||||
browser_config,
|
browser_config,
|
||||||
config.gateway.max_concurrent_background_tasks,
|
|
||||||
)?;
|
)?;
|
||||||
let session_manager = Arc::new(session_manager);
|
let session_manager = Arc::new(session_manager);
|
||||||
|
|
||||||
// Create ChannelManager and init channels
|
// Create ChannelManager and init channels
|
||||||
let cli_chat_channel = Arc::new(CliChatChannel::new());
|
let cli_chat_channel = Arc::new(CliChatChannel::new());
|
||||||
let channel_manager = ChannelManager::with_bus(cli_chat_channel, bus);
|
let channel_manager = ChannelManager::with_bus(cli_chat_channel, bus);
|
||||||
channel_manager
|
channel_manager.init(&config, workspace_path.clone()).await
|
||||||
.init(&config, workspace_path.clone())
|
|
||||||
.await
|
|
||||||
.map_err(|e| format!("Failed to init channels: {}", e))?;
|
.map_err(|e| format!("Failed to init channels: {}", e))?;
|
||||||
|
|
||||||
// Register send_message tool with available channel names
|
// Register send_message tool with available channel names
|
||||||
@ -115,12 +106,9 @@ impl GatewayState {
|
|||||||
session_manager.register_outbound_tool(available_channels);
|
session_manager.register_outbound_tool(available_channels);
|
||||||
|
|
||||||
// Register chat_manager tool
|
// Register chat_manager tool
|
||||||
session_manager
|
session_manager.tools().register(
|
||||||
.tools()
|
crate::tools::ChatManagerTool::new(storage.clone(), valid_channels.clone()),
|
||||||
.register(crate::tools::ChatManagerTool::new(
|
);
|
||||||
storage.clone(),
|
|
||||||
valid_channels.clone(),
|
|
||||||
));
|
|
||||||
|
|
||||||
// Initialize MCP servers — connect and register discovered tools
|
// Initialize MCP servers — connect and register discovered tools
|
||||||
if !config.mcp.servers.is_empty() {
|
if !config.mcp.servers.is_empty() {
|
||||||
@ -141,27 +129,24 @@ impl GatewayState {
|
|||||||
let scheduler_config = config.gateway.scheduler.clone().unwrap_or_default();
|
let scheduler_config = config.gateway.scheduler.clone().unwrap_or_default();
|
||||||
if scheduler_config.enabled {
|
if scheduler_config.enabled {
|
||||||
// Register cron tools
|
// Register cron tools
|
||||||
session_manager
|
session_manager.tools().register(
|
||||||
.tools()
|
crate::tools::cron::CronAddTool::new(storage.clone(), valid_channels),
|
||||||
.register(crate::tools::cron::CronAddTool::new(
|
);
|
||||||
storage.clone(),
|
session_manager.tools().register(
|
||||||
valid_channels,
|
crate::tools::cron::CronListTool::new(storage.clone()),
|
||||||
));
|
);
|
||||||
session_manager
|
session_manager.tools().register(
|
||||||
.tools()
|
crate::tools::cron::CronRemoveTool::new(storage.clone()),
|
||||||
.register(crate::tools::cron::CronListTool::new(storage.clone()));
|
);
|
||||||
session_manager
|
session_manager.tools().register(
|
||||||
.tools()
|
crate::tools::cron::CronEnableTool::new(storage.clone()),
|
||||||
.register(crate::tools::cron::CronRemoveTool::new(storage.clone()));
|
);
|
||||||
session_manager
|
session_manager.tools().register(
|
||||||
.tools()
|
crate::tools::cron::CronDisableTool::new(storage.clone()),
|
||||||
.register(crate::tools::cron::CronEnableTool::new(storage.clone()));
|
);
|
||||||
session_manager
|
session_manager.tools().register(
|
||||||
.tools()
|
crate::tools::cron::CronUpdateTool::new(storage.clone()),
|
||||||
.register(crate::tools::cron::CronDisableTool::new(storage.clone()));
|
);
|
||||||
session_manager
|
|
||||||
.tools()
|
|
||||||
.register(crate::tools::cron::CronUpdateTool::new(storage.clone()));
|
|
||||||
tracing::info!("Cron tools registered");
|
tracing::info!("Cron tools registered");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -282,103 +267,71 @@ impl GatewayState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Handle control messages (session management operations)
|
/// Handle control messages (session management operations)
|
||||||
async fn handle_control_message(session_manager: &SessionManager, msg: ControlMessage) {
|
async fn handle_control_message(
|
||||||
|
session_manager: &SessionManager,
|
||||||
|
msg: ControlMessage,
|
||||||
|
) {
|
||||||
use crate::session::{SessionCommand::*, SessionEvent};
|
use crate::session::{SessionCommand::*, SessionEvent};
|
||||||
|
|
||||||
let reply_tx = msg.reply_tx;
|
let reply_tx = msg.reply_tx;
|
||||||
let result: Result<SessionEvent, ChannelError> = match msg.op {
|
let result: Result<SessionEvent, ChannelError> = match msg.op {
|
||||||
CreateDialog {
|
CreateDialog { channel, chat_id, title } => {
|
||||||
channel,
|
session_manager.create_dialog(&channel, &chat_id, title.as_deref()).await
|
||||||
chat_id,
|
|
||||||
title,
|
|
||||||
} => session_manager
|
|
||||||
.create_dialog(&channel, &chat_id, title.as_deref())
|
|
||||||
.await
|
|
||||||
.map(|(session_id, title)| SessionEvent::DialogCreated { session_id, title })
|
.map(|(session_id, title)| SessionEvent::DialogCreated { session_id, title })
|
||||||
.map_err(|e| ChannelError::Other(e.to_string())),
|
.map_err(|e| ChannelError::Other(e.to_string()))
|
||||||
ListDialogs {
|
}
|
||||||
channel,
|
ListDialogs { channel, chat_id, include_archived } => {
|
||||||
chat_id,
|
session_manager.list_dialogs(&channel, &chat_id, include_archived).await
|
||||||
include_archived,
|
.map(|(dialogs, current_dialog_id)| SessionEvent::DialogList { dialogs, current_dialog_id })
|
||||||
} => session_manager
|
.map_err(|e| ChannelError::Other(e.to_string()))
|
||||||
.list_dialogs(&channel, &chat_id, include_archived)
|
}
|
||||||
.await
|
GetCurrentDialog { channel, chat_id } => {
|
||||||
.map(|(dialogs, current_dialog_id)| SessionEvent::DialogList {
|
session_manager.get_current_dialog(&channel, &chat_id).await
|
||||||
dialogs,
|
|
||||||
current_dialog_id,
|
|
||||||
})
|
|
||||||
.map_err(|e| ChannelError::Other(e.to_string())),
|
|
||||||
GetCurrentDialog { channel, chat_id } => session_manager
|
|
||||||
.get_current_dialog(&channel, &chat_id)
|
|
||||||
.await
|
|
||||||
.map(|session_id| SessionEvent::CurrentDialog { session_id })
|
.map(|session_id| SessionEvent::CurrentDialog { session_id })
|
||||||
.map_err(|e| ChannelError::Other(e.to_string())),
|
.map_err(|e| ChannelError::Other(e.to_string()))
|
||||||
SwitchDialog {
|
}
|
||||||
channel,
|
SwitchDialog { channel, chat_id, dialog_id } => {
|
||||||
chat_id,
|
session_manager.switch_dialog(&channel, &chat_id, &dialog_id).await
|
||||||
dialog_id,
|
|
||||||
} => session_manager
|
|
||||||
.switch_dialog(&channel, &chat_id, &dialog_id)
|
|
||||||
.await
|
|
||||||
.map(|session_id| SessionEvent::DialogSwitched { session_id })
|
.map(|session_id| SessionEvent::DialogSwitched { session_id })
|
||||||
.map_err(|e| ChannelError::Other(e.to_string())),
|
.map_err(|e| ChannelError::Other(e.to_string()))
|
||||||
RenameDialog { session_id, title } => session_manager
|
}
|
||||||
.rename_dialog(&session_id, &title)
|
RenameDialog { session_id, title } => {
|
||||||
.await
|
session_manager.rename_dialog(&session_id, &title).await
|
||||||
.map(|()| SessionEvent::DialogRenamed { session_id, title })
|
.map(|()| SessionEvent::DialogRenamed { session_id, title })
|
||||||
.map_err(|e| ChannelError::Other(e.to_string())),
|
.map_err(|e| ChannelError::Other(e.to_string()))
|
||||||
ArchiveDialog { session_id } => session_manager
|
}
|
||||||
.archive_dialog(&session_id)
|
ArchiveDialog { session_id } => {
|
||||||
.await
|
session_manager.archive_dialog(&session_id)
|
||||||
.map(|()| SessionEvent::DialogArchived { session_id })
|
.map(|()| SessionEvent::DialogArchived { session_id })
|
||||||
.map_err(|e| ChannelError::Other(e.to_string())),
|
.map_err(|e| ChannelError::Other(e.to_string()))
|
||||||
DeleteDialog { session_id } => session_manager
|
}
|
||||||
.delete_dialog(&session_id)
|
DeleteDialog { session_id } => {
|
||||||
.await
|
session_manager.delete_dialog(&session_id).await
|
||||||
.map(|()| SessionEvent::DialogDeleted { session_id })
|
.map(|()| SessionEvent::DialogDeleted { session_id })
|
||||||
.map_err(|e| ChannelError::Other(e.to_string())),
|
.map_err(|e| ChannelError::Other(e.to_string()))
|
||||||
ClearHistory { session_id } => session_manager
|
}
|
||||||
.clear_dialog_history(&session_id)
|
ClearHistory { session_id } => {
|
||||||
.await
|
session_manager.clear_dialog_history(&session_id)
|
||||||
.map(|()| SessionEvent::HistoryCleared { session_id })
|
.map(|()| SessionEvent::HistoryCleared { session_id })
|
||||||
.map_err(|e| ChannelError::Other(e.to_string())),
|
.map_err(|e| ChannelError::Other(e.to_string()))
|
||||||
GetSlashCommands {
|
}
|
||||||
channel: _,
|
GetSlashCommands { channel: _, chat_id: _ } => {
|
||||||
chat_id: _,
|
|
||||||
} => {
|
|
||||||
let commands = session_manager.get_slash_commands().to_vec();
|
let commands = session_manager.get_slash_commands().to_vec();
|
||||||
Ok(SessionEvent::SlashCommandsList { commands })
|
Ok(SessionEvent::SlashCommandsList { commands })
|
||||||
}
|
}
|
||||||
ExecuteSlashCommand {
|
ExecuteSlashCommand { command, args, channel, chat_id, current_session_id } => {
|
||||||
command,
|
session_manager.execute_slash_command(&command, args.as_deref(), &channel, &chat_id, current_session_id.as_ref())
|
||||||
args,
|
|
||||||
channel,
|
|
||||||
chat_id,
|
|
||||||
current_session_id,
|
|
||||||
} => session_manager
|
|
||||||
.execute_slash_command(
|
|
||||||
&command,
|
|
||||||
args.as_deref(),
|
|
||||||
&channel,
|
|
||||||
&chat_id,
|
|
||||||
current_session_id.as_ref(),
|
|
||||||
)
|
|
||||||
.await
|
.await
|
||||||
.map(|(new_id, msg)| SessionEvent::SlashCommandExecuted {
|
.map(|(new_id, msg)| SessionEvent::SlashCommandExecuted { new_session_id: new_id, message: msg })
|
||||||
new_session_id: new_id,
|
.map_err(|e| ChannelError::Other(e.to_string()))
|
||||||
message: msg,
|
}
|
||||||
})
|
|
||||||
.map_err(|e| ChannelError::Other(e.to_string())),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let _ = reply_tx.send(result).await;
|
let _ = reply_tx.send(result).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn run(
|
pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
host: Option<String>,
|
|
||||||
port: Option<u16>,
|
|
||||||
) -> Result<(), Box<dyn std::error::Error>> {
|
|
||||||
// Initialize logging
|
// Initialize logging
|
||||||
logging::init_logging();
|
logging::init_logging();
|
||||||
tracing::info!("Starting PicoBot Gateway");
|
tracing::info!("Starting PicoBot Gateway");
|
||||||
|
|||||||
@ -1,12 +1,12 @@
|
|||||||
use super::GatewayState;
|
use std::sync::Arc;
|
||||||
use crate::protocol::WsOutbound;
|
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage};
|
||||||
use crate::protocol::serialize_outbound;
|
|
||||||
use axum::extract::State;
|
use axum::extract::State;
|
||||||
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
|
|
||||||
use axum::response::Response;
|
use axum::response::Response;
|
||||||
use futures_util::{SinkExt, StreamExt};
|
use futures_util::{SinkExt, StreamExt};
|
||||||
use std::sync::Arc;
|
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
|
use crate::protocol::serialize_outbound;
|
||||||
|
use crate::protocol::WsOutbound;
|
||||||
|
use super::GatewayState;
|
||||||
|
|
||||||
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
|
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
|
||||||
ws.on_upgrade(|socket| async move {
|
ws.on_upgrade(|socket| async move {
|
||||||
@ -25,11 +25,9 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
|||||||
let (session_id, client) = cli_chat_channel.register_client(sender.clone()).await;
|
let (session_id, client) = cli_chat_channel.register_client(sender.clone()).await;
|
||||||
|
|
||||||
// Send session established message
|
// Send session established message
|
||||||
let _ = sender
|
let _ = sender.send(WsOutbound::SessionEstablished {
|
||||||
.send(WsOutbound::SessionEstablished {
|
|
||||||
session_id: session_id.clone(),
|
session_id: session_id.clone(),
|
||||||
})
|
}).await;
|
||||||
.await;
|
|
||||||
|
|
||||||
tracing::info!(session_id = %session_id, "CLI session established");
|
tracing::info!(session_id = %session_id, "CLI session established");
|
||||||
|
|
||||||
@ -39,8 +37,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
|||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
while let Some(msg) = receiver.recv().await {
|
while let Some(msg) = receiver.recv().await {
|
||||||
if let Ok(text) = serialize_outbound(&msg)
|
if let Ok(text) = serialize_outbound(&msg)
|
||||||
&& ws_sender.send(WsMessage::Text(text.into())).await.is_err()
|
&& ws_sender.send(WsMessage::Text(text.into())).await.is_err() {
|
||||||
{
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
14
src/lib.rs
14
src/lib.rs
@ -1,17 +1,17 @@
|
|||||||
pub mod agent;
|
|
||||||
pub mod bus;
|
|
||||||
pub mod channels;
|
|
||||||
pub mod client;
|
|
||||||
pub mod config;
|
pub mod config;
|
||||||
|
pub mod providers;
|
||||||
|
pub mod bus;
|
||||||
|
pub mod agent;
|
||||||
pub mod gateway;
|
pub mod gateway;
|
||||||
|
pub mod session;
|
||||||
|
pub mod client;
|
||||||
|
pub mod protocol;
|
||||||
|
pub mod channels;
|
||||||
pub mod logging;
|
pub mod logging;
|
||||||
pub mod mcp;
|
pub mod mcp;
|
||||||
pub mod memory;
|
pub mod memory;
|
||||||
pub mod observability;
|
pub mod observability;
|
||||||
pub mod protocol;
|
|
||||||
pub mod providers;
|
|
||||||
pub mod scheduler;
|
pub mod scheduler;
|
||||||
pub mod session;
|
|
||||||
pub mod skills;
|
pub mod skills;
|
||||||
pub mod storage;
|
pub mod storage;
|
||||||
pub mod tools;
|
pub mod tools;
|
||||||
|
|||||||
@ -1,7 +1,11 @@
|
|||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use tracing_appender::rolling::{RollingFileAppender, Rotation};
|
use tracing_appender::rolling::{RollingFileAppender, Rotation};
|
||||||
use tracing_subscriber::{
|
use tracing_subscriber::{
|
||||||
EnvFilter, fmt, fmt::time::LocalTime, layer::SubscriberExt, util::SubscriberInitExt,
|
fmt,
|
||||||
|
layer::SubscriberExt,
|
||||||
|
util::SubscriberInitExt,
|
||||||
|
fmt::time::LocalTime,
|
||||||
|
EnvFilter,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Get the default log directory path: ~/.picobot/logs
|
/// Get the default log directory path: ~/.picobot/logs
|
||||||
@ -23,20 +27,20 @@ pub fn init_logging() {
|
|||||||
|
|
||||||
// Create log directory if it doesn't exist
|
// Create log directory if it doesn't exist
|
||||||
if !log_dir.exists()
|
if !log_dir.exists()
|
||||||
&& let Err(e) = std::fs::create_dir_all(&log_dir)
|
&& let Err(e) = std::fs::create_dir_all(&log_dir) {
|
||||||
{
|
eprintln!("Warning: Failed to create log directory {}: {}", log_dir.display(), e);
|
||||||
eprintln!(
|
|
||||||
"Warning: Failed to create log directory {}: {}",
|
|
||||||
log_dir.display(),
|
|
||||||
e
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create file appender with daily rotation
|
// Create file appender with daily rotation
|
||||||
let file_appender = RollingFileAppender::new(Rotation::DAILY, &log_dir, "picobot.log");
|
let file_appender = RollingFileAppender::new(
|
||||||
|
Rotation::DAILY,
|
||||||
|
&log_dir,
|
||||||
|
"picobot.log",
|
||||||
|
);
|
||||||
|
|
||||||
// Build subscriber with both console and file output
|
// Build subscriber with both console and file output
|
||||||
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
|
let env_filter = EnvFilter::try_from_default_env()
|
||||||
|
.unwrap_or_else(|_| EnvFilter::new("info"));
|
||||||
|
|
||||||
let file_layer = fmt::layer()
|
let file_layer = fmt::layer()
|
||||||
.with_writer(file_appender)
|
.with_writer(file_appender)
|
||||||
@ -62,7 +66,8 @@ pub fn init_logging() {
|
|||||||
|
|
||||||
/// Initialize logging without file output (console only)
|
/// Initialize logging without file output (console only)
|
||||||
pub fn init_logging_console_only() {
|
pub fn init_logging_console_only() {
|
||||||
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
|
let env_filter = EnvFilter::try_from_default_env()
|
||||||
|
.unwrap_or_else(|_| EnvFilter::new("info"));
|
||||||
|
|
||||||
let console_layer = fmt::layer()
|
let console_layer = fmt::layer()
|
||||||
.with_timer(LocalTime::rfc_3339())
|
.with_timer(LocalTime::rfc_3339())
|
||||||
|
|||||||
@ -1,9 +1,8 @@
|
|||||||
use clap::{CommandFactory, Parser};
|
use clap::{Parser, CommandFactory};
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
#[command(name = "picobot")]
|
#[command(name = "picobot")]
|
||||||
#[command(about = "A CLI chatbot", long_about = None)]
|
#[command(about = "A CLI chatbot", long_about = None)]
|
||||||
#[command(version = "1.1.0")]
|
|
||||||
enum Command {
|
enum Command {
|
||||||
/// Connect to gateway
|
/// Connect to gateway
|
||||||
Chat {
|
Chat {
|
||||||
|
|||||||
@ -92,9 +92,13 @@ fn extract_text(result: &rmcp::model::CallToolResult) -> String {
|
|||||||
parts.push(text.text.clone());
|
parts.push(text.text.clone());
|
||||||
}
|
}
|
||||||
RawContent::Image(image) => {
|
RawContent::Image(image) => {
|
||||||
parts.push(format!("[image: {}]", image.mime_type,));
|
parts.push(format!(
|
||||||
|
"[image: {}]",
|
||||||
|
image.mime_type,
|
||||||
|
));
|
||||||
}
|
}
|
||||||
RawContent::Resource(resource) => match &resource.resource {
|
RawContent::Resource(resource) => {
|
||||||
|
match &resource.resource {
|
||||||
rmcp::model::ResourceContents::TextResourceContents { text, .. } => {
|
rmcp::model::ResourceContents::TextResourceContents { text, .. } => {
|
||||||
parts.push(format!(
|
parts.push(format!(
|
||||||
"[resource text: {}]",
|
"[resource text: {}]",
|
||||||
@ -104,7 +108,8 @@ fn extract_text(result: &rmcp::model::CallToolResult) -> String {
|
|||||||
rmcp::model::ResourceContents::BlobResourceContents { uri, .. } => {
|
rmcp::model::ResourceContents::BlobResourceContents { uri, .. } => {
|
||||||
parts.push(format!("[resource blob: {}]", uri));
|
parts.push(format!("[resource blob: {}]", uri));
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
|
}
|
||||||
_ => {
|
_ => {
|
||||||
parts.push("[unsupported content]".to_string());
|
parts.push("[unsupported content]".to_string());
|
||||||
}
|
}
|
||||||
@ -220,8 +225,8 @@ async fn connect_server(config: &McpServerConfig) -> anyhow::Result<McpConnectio
|
|||||||
cmd.env(k, v);
|
cmd.env(k, v);
|
||||||
}
|
}
|
||||||
|
|
||||||
let service =
|
let service = ()
|
||||||
().serve(
|
.serve(
|
||||||
TokioChildProcess::new(cmd).context("failed to create stdio MCP transport")?,
|
TokioChildProcess::new(cmd).context("failed to create stdio MCP transport")?,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
@ -256,12 +261,12 @@ async fn connect_server(config: &McpServerConfig) -> anyhow::Result<McpConnectio
|
|||||||
} else {
|
} else {
|
||||||
StreamableHttpClientTransport::from_config(
|
StreamableHttpClientTransport::from_config(
|
||||||
StreamableHttpClientTransportConfig::with_uri(url.to_string())
|
StreamableHttpClientTransportConfig::with_uri(url.to_string())
|
||||||
.custom_headers(headers_map),
|
.custom_headers(headers_map)
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
let service =
|
let service = ()
|
||||||
().serve(transport)
|
.serve(transport)
|
||||||
.await
|
.await
|
||||||
.context("failed to connect to HTTP/SSE MCP server")?;
|
.context("failed to connect to HTTP/SSE MCP server")?;
|
||||||
|
|
||||||
|
|||||||
@ -102,11 +102,7 @@ mod tests {
|
|||||||
let dir = tempdir().unwrap();
|
let dir = tempdir().unwrap();
|
||||||
let db_path = dir.path().join("test.db");
|
let db_path = dir.path().join("test.db");
|
||||||
let storage = Arc::new(Storage::new(&db_path).await.unwrap());
|
let storage = Arc::new(Storage::new(&db_path).await.unwrap());
|
||||||
let mm = Arc::new(MemoryManager::new(
|
let mm = Arc::new(MemoryManager::new(storage, "default".into(), "default".into()));
|
||||||
storage,
|
|
||||||
"default".into(),
|
|
||||||
"default".into(),
|
|
||||||
));
|
|
||||||
(mm, dir)
|
(mm, dir)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -135,7 +131,13 @@ mod tests {
|
|||||||
async fn test_upsert_overwrites() {
|
async fn test_upsert_overwrites() {
|
||||||
let (mm, _dir) = setup_memory_manager().await;
|
let (mm, _dir) = setup_memory_manager().await;
|
||||||
|
|
||||||
mm.store("dup_key", "original", MemoryCategory::Knowledge, None, None)
|
mm.store(
|
||||||
|
"dup_key",
|
||||||
|
"original",
|
||||||
|
MemoryCategory::Knowledge,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mm.store(
|
mm.store(
|
||||||
@ -245,12 +247,7 @@ mod tests {
|
|||||||
|
|
||||||
// Recall scoped to session A — should get only tl_a
|
// Recall scoped to session A — should get only tl_a
|
||||||
let scoped = mm
|
let scoped = mm
|
||||||
.recall(
|
.recall("summary", 10, Some(MemoryCategory::Timeline), Some("chan:chat:dialog_a"))
|
||||||
"summary",
|
|
||||||
10,
|
|
||||||
Some(MemoryCategory::Timeline),
|
|
||||||
Some("chan:chat:dialog_a"),
|
|
||||||
)
|
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(scoped.len(), 1);
|
assert_eq!(scoped.len(), 1);
|
||||||
|
|||||||
@ -20,7 +20,10 @@ pub enum ObserverEvent {
|
|||||||
success: bool,
|
success: bool,
|
||||||
},
|
},
|
||||||
/// Emitted when the agent starts processing.
|
/// Emitted when the agent starts processing.
|
||||||
AgentStart { provider: String, model: String },
|
AgentStart {
|
||||||
|
provider: String,
|
||||||
|
model: String,
|
||||||
|
},
|
||||||
/// Emitted when the agent finishes processing.
|
/// Emitted when the agent finishes processing.
|
||||||
AgentEnd {
|
AgentEnd {
|
||||||
provider: String,
|
provider: String,
|
||||||
@ -91,11 +94,7 @@ impl ToolExecutionOutcome {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Create a failed outcome with duration.
|
/// Create a failed outcome with duration.
|
||||||
pub fn failure_with_duration(
|
pub fn failure_with_duration(output: String, error_reason: Option<String>, duration: Duration) -> Self {
|
||||||
output: String,
|
|
||||||
error_reason: Option<String>,
|
|
||||||
duration: Duration,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
Self {
|
||||||
output,
|
output,
|
||||||
success: false,
|
success: false,
|
||||||
|
|||||||
@ -4,24 +4,23 @@ use serde::{Deserialize, Serialize};
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use super::traits::Usage;
|
|
||||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
|
||||||
use crate::bus::message::ContentBlock;
|
use crate::bus::message::ContentBlock;
|
||||||
use crate::storage::Storage;
|
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
||||||
|
use super::traits::Usage;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use crate::storage::Storage;
|
||||||
|
|
||||||
const LLM_REQUEST_TIMEOUT_SECS: u64 = 300;
|
const LLM_REQUEST_TIMEOUT_SECS: u64 = 300;
|
||||||
|
|
||||||
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> {
|
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> {
|
||||||
blocks
|
blocks.iter().map(|b| match b {
|
||||||
.iter()
|
|
||||||
.map(|b| match b {
|
|
||||||
ContentBlock::Text { text } => {
|
ContentBlock::Text { text } => {
|
||||||
serde_json::json!({ "type": "text", "text": text })
|
serde_json::json!({ "type": "text", "text": text })
|
||||||
}
|
}
|
||||||
ContentBlock::ImageUrl { image_url } => convert_image_url_to_anthropic(&image_url.url),
|
ContentBlock::ImageUrl { image_url } => {
|
||||||
})
|
convert_image_url_to_anthropic(&image_url.url)
|
||||||
.collect()
|
}
|
||||||
|
}).collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn convert_image_url_to_anthropic(url: &str) -> serde_json::Value {
|
fn convert_image_url_to_anthropic(url: &str) -> serde_json::Value {
|
||||||
@ -198,13 +197,8 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
};
|
};
|
||||||
let content = if let Some(ref tc_id) = m.tool_call_id {
|
let content = if let Some(ref tc_id) = m.tool_call_id {
|
||||||
// Tool result: wrap as tool_result content block
|
// Tool result: wrap as tool_result content block
|
||||||
let output = m
|
let output = m.content.iter()
|
||||||
.content
|
.filter_map(|b| match b { ContentBlock::Text { text } => Some(text.as_str()), _ => None })
|
||||||
.iter()
|
|
||||||
.filter_map(|b| match b {
|
|
||||||
ContentBlock::Text { text } => Some(text.as_str()),
|
|
||||||
_ => None,
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join("");
|
.join("");
|
||||||
vec![serde_json::json!({
|
vec![serde_json::json!({
|
||||||
@ -250,7 +244,8 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
|
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
|
||||||
tracing::debug!(req_body = %req_body_str, "LLM request");
|
tracing::debug!(req_body = %req_body_str, "LLM request");
|
||||||
|
|
||||||
let resp = req_builder.json(&body).send().await.inspect_err(|e| {
|
let resp = req_builder.json(&body).send().await
|
||||||
|
.inspect_err(|e| {
|
||||||
let is_timeout = e.is_timeout();
|
let is_timeout = e.is_timeout();
|
||||||
tracing::error!(
|
tracing::error!(
|
||||||
provider = %self.name,
|
provider = %self.name,
|
||||||
@ -286,21 +281,17 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
"LLM API returned error"
|
"LLM API returned error"
|
||||||
);
|
);
|
||||||
if let Some(ref storage) = self.storage {
|
if let Some(ref storage) = self.storage {
|
||||||
let _ = storage
|
let _ = storage.append_llm_call(
|
||||||
.append_llm_call(
|
&self.name, &self.model_id, &req_body_str,
|
||||||
&self.name,
|
Some(&body_text), Some(&error_msg),
|
||||||
&self.model_id,
|
|
||||||
&req_body_str,
|
|
||||||
Some(&body_text),
|
|
||||||
Some(&error_msg),
|
|
||||||
start.elapsed().as_millis() as u64,
|
start.elapsed().as_millis() as u64,
|
||||||
)
|
).await;
|
||||||
.await;
|
|
||||||
}
|
}
|
||||||
return Err(format!("API error ({}): {}", status.as_u16(), error_msg).into());
|
return Err(format!("API error ({}): {}", status.as_u16(), error_msg).into());
|
||||||
}
|
}
|
||||||
|
|
||||||
let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text).map_err(|e| {
|
let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text)
|
||||||
|
.map_err(|e| {
|
||||||
let err_msg = format!("decode error: {} | body: {}", e, &body_text);
|
let err_msg = format!("decode error: {} | body: {}", e, &body_text);
|
||||||
if let Some(ref storage) = self.storage {
|
if let Some(ref storage) = self.storage {
|
||||||
let name = self.name.clone();
|
let name = self.name.clone();
|
||||||
@ -311,9 +302,7 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
let err = err_msg.clone();
|
let err = err_msg.clone();
|
||||||
let s = storage.clone();
|
let s = storage.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let _ = s
|
let _ = s.append_llm_call(&name, &model, &req, Some(&resp_body), Some(&err), dur).await;
|
||||||
.append_llm_call(&name, &model, &req, Some(&resp_body), Some(&err), dur)
|
|
||||||
.await;
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
err_msg
|
err_msg
|
||||||
@ -354,35 +343,21 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
reasoning_content: reasoning,
|
reasoning_content: reasoning,
|
||||||
tool_calls,
|
tool_calls,
|
||||||
usage: Usage {
|
usage: Usage {
|
||||||
prompt_tokens: anthropic_resp
|
prompt_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens).unwrap_or(0),
|
||||||
.usage
|
completion_tokens: anthropic_resp.usage.as_ref().map(|u| u.output_tokens).unwrap_or(0),
|
||||||
.as_ref()
|
total_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens + u.output_tokens).unwrap_or(0),
|
||||||
.map(|u| u.input_tokens)
|
|
||||||
.unwrap_or(0),
|
|
||||||
completion_tokens: anthropic_resp
|
|
||||||
.usage
|
|
||||||
.as_ref()
|
|
||||||
.map(|u| u.output_tokens)
|
|
||||||
.unwrap_or(0),
|
|
||||||
total_tokens: anthropic_resp
|
|
||||||
.usage
|
|
||||||
.as_ref()
|
|
||||||
.map(|u| u.input_tokens + u.output_tokens)
|
|
||||||
.unwrap_or(0),
|
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(ref storage) = self.storage {
|
if let Some(ref storage) = self.storage {
|
||||||
let _ = storage
|
let _ = storage.append_llm_call(
|
||||||
.append_llm_call(
|
|
||||||
&self.name,
|
&self.name,
|
||||||
&self.model_id,
|
&self.model_id,
|
||||||
&req_body_str,
|
&req_body_str,
|
||||||
Some(&body_text),
|
Some(&body_text),
|
||||||
None,
|
None,
|
||||||
start.elapsed().as_millis() as u64,
|
start.elapsed().as_millis() as u64,
|
||||||
)
|
).await;
|
||||||
.await;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(response)
|
Ok(response)
|
||||||
|
|||||||
@ -1,15 +1,12 @@
|
|||||||
pub mod anthropic;
|
|
||||||
pub mod openai;
|
|
||||||
pub mod traits;
|
pub mod traits;
|
||||||
|
pub mod openai;
|
||||||
|
pub mod anthropic;
|
||||||
|
|
||||||
pub use self::anthropic::AnthropicProvider;
|
|
||||||
pub use self::openai::OpenAIProvider;
|
pub use self::openai::OpenAIProvider;
|
||||||
|
pub use self::anthropic::AnthropicProvider;
|
||||||
|
|
||||||
use crate::config::LLMProviderConfig;
|
use crate::config::LLMProviderConfig;
|
||||||
pub use traits::{
|
pub use traits::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall, ToolFunction, Usage};
|
||||||
ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall,
|
|
||||||
ToolFunction, Usage,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub fn create_provider(config: LLMProviderConfig) -> Result<Box<dyn LLMProvider>, ProviderError> {
|
pub fn create_provider(config: LLMProviderConfig) -> Result<Box<dyn LLMProvider>, ProviderError> {
|
||||||
match config.provider_type.as_str() {
|
match config.provider_type.as_str() {
|
||||||
|
|||||||
@ -1,35 +1,29 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde_json::{Value, json};
|
use serde_json::{json, Value};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use super::traits::Usage;
|
|
||||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
|
||||||
use crate::bus::message::ContentBlock;
|
use crate::bus::message::ContentBlock;
|
||||||
use crate::storage::Storage;
|
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
||||||
|
use super::traits::Usage;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use crate::storage::Storage;
|
||||||
|
|
||||||
const LLM_REQUEST_TIMEOUT_SECS: u64 = 300;
|
const LLM_REQUEST_TIMEOUT_SECS: u64 = 300;
|
||||||
|
|
||||||
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
|
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
|
||||||
if blocks.len() == 1
|
if blocks.len() == 1
|
||||||
&& let ContentBlock::Text { text } = &blocks[0]
|
&& let ContentBlock::Text { text } = &blocks[0] {
|
||||||
{
|
|
||||||
return Value::String(text.clone());
|
return Value::String(text.clone());
|
||||||
}
|
}
|
||||||
Value::Array(
|
Value::Array(blocks.iter().map(|b| match b {
|
||||||
blocks
|
|
||||||
.iter()
|
|
||||||
.map(|b| match b {
|
|
||||||
ContentBlock::Text { text } => json!({ "type": "text", "text": text }),
|
ContentBlock::Text { text } => json!({ "type": "text", "text": text }),
|
||||||
ContentBlock::ImageUrl { image_url } => {
|
ContentBlock::ImageUrl { image_url } => {
|
||||||
json!({ "type": "image_url", "image_url": { "url": image_url.url } })
|
json!({ "type": "image_url", "image_url": { "url": image_url.url } })
|
||||||
}
|
}
|
||||||
})
|
}).collect())
|
||||||
.collect(),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct OpenAIProvider {
|
pub struct OpenAIProvider {
|
||||||
@ -207,11 +201,7 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
if let Some(content) = msg.get("content").and_then(|c| c.as_array()) {
|
if let Some(content) = msg.get("content").and_then(|c| c.as_array()) {
|
||||||
for (j, item) in content.iter().enumerate() {
|
for (j, item) in content.iter().enumerate() {
|
||||||
if item.get("type").and_then(|t| t.as_str()) == Some("image_url")
|
if item.get("type").and_then(|t| t.as_str()) == Some("image_url")
|
||||||
&& let Some(url_str) = item
|
&& let Some(url_str) = item.get("image_url").and_then(|u| u.get("url")).and_then(|v| v.as_str()) {
|
||||||
.get("image_url")
|
|
||||||
.and_then(|u| u.get("url"))
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
{
|
|
||||||
let prefix: String = url_str.chars().take(20).collect();
|
let prefix: String = url_str.chars().take(20).collect();
|
||||||
tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)");
|
tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)");
|
||||||
}
|
}
|
||||||
@ -234,7 +224,8 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
|
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
|
||||||
tracing::debug!(req_body = %req_body_str, "LLM request");
|
tracing::debug!(req_body = %req_body_str, "LLM request");
|
||||||
|
|
||||||
let resp = req_builder.json(&body).send().await.inspect_err(|e| {
|
let resp = req_builder.json(&body).send().await
|
||||||
|
.inspect_err(|e| {
|
||||||
let is_timeout = e.is_timeout();
|
let is_timeout = e.is_timeout();
|
||||||
tracing::error!(
|
tracing::error!(
|
||||||
provider = %self.name,
|
provider = %self.name,
|
||||||
@ -262,23 +253,18 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
"LLM API returned error"
|
"LLM API returned error"
|
||||||
);
|
);
|
||||||
if let Some(ref storage) = self.storage
|
if let Some(ref storage) = self.storage
|
||||||
&& let Err(e) = storage
|
&& let Err(e) = storage.append_llm_call(
|
||||||
.append_llm_call(
|
&self.name, &self.model_id, &req_body_str,
|
||||||
&self.name,
|
Some(&text), Some(&error),
|
||||||
&self.model_id,
|
|
||||||
&req_body_str,
|
|
||||||
Some(&text),
|
|
||||||
Some(&error),
|
|
||||||
start.elapsed().as_millis() as u64,
|
start.elapsed().as_millis() as u64,
|
||||||
)
|
).await {
|
||||||
.await
|
|
||||||
{
|
|
||||||
tracing::warn!("failed to persist LLM call: {}", e);
|
tracing::warn!("failed to persist LLM call: {}", e);
|
||||||
}
|
}
|
||||||
return Err(error.into());
|
return Err(error.into());
|
||||||
}
|
}
|
||||||
|
|
||||||
let openai_resp: OpenAIResponse = serde_json::from_str(&text).map_err(|e| {
|
let openai_resp: OpenAIResponse = serde_json::from_str(&text)
|
||||||
|
.map_err(|e| {
|
||||||
let err_msg = format!("decode error: {} | body: {}", e, &text);
|
let err_msg = format!("decode error: {} | body: {}", e, &text);
|
||||||
if let Some(ref storage) = self.storage {
|
if let Some(ref storage) = self.storage {
|
||||||
let name = self.name.clone();
|
let name = self.name.clone();
|
||||||
@ -289,10 +275,7 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
let err = err_msg.clone();
|
let err = err_msg.clone();
|
||||||
let s = storage.clone();
|
let s = storage.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(e) = s
|
if let Err(e) = s.append_llm_call(&name, &model, &req, Some(&resp), Some(&err), dur).await {
|
||||||
.append_llm_call(&name, &model, &req, Some(&resp), Some(&err), dur)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
tracing::warn!("failed to persist LLM call (decode error): {}", e);
|
tracing::warn!("failed to persist LLM call (decode error): {}", e);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@ -300,10 +283,7 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
err_msg
|
err_msg
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let first_choice = openai_resp
|
let first_choice = openai_resp.choices.into_iter().next()
|
||||||
.choices
|
|
||||||
.into_iter()
|
|
||||||
.next()
|
|
||||||
.ok_or("no choices in response")?;
|
.ok_or("no choices in response")?;
|
||||||
|
|
||||||
let content = first_choice
|
let content = first_choice
|
||||||
@ -320,8 +300,7 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
.map(|tc| ToolCall {
|
.map(|tc| ToolCall {
|
||||||
id: tc.id.clone(),
|
id: tc.id.clone(),
|
||||||
name: tc.function.name.clone(),
|
name: tc.function.name.clone(),
|
||||||
arguments: serde_json::from_str(&tc.function.arguments)
|
arguments: serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null),
|
||||||
.unwrap_or(serde_json::Value::Null),
|
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
@ -339,17 +318,11 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
};
|
};
|
||||||
|
|
||||||
if let Some(ref storage) = self.storage
|
if let Some(ref storage) = self.storage
|
||||||
&& let Err(e) = storage
|
&& let Err(e) = storage.append_llm_call(
|
||||||
.append_llm_call(
|
&self.name, &self.model_id, &req_body_str,
|
||||||
&self.name,
|
Some(&text), None,
|
||||||
&self.model_id,
|
|
||||||
&req_body_str,
|
|
||||||
Some(&text),
|
|
||||||
None,
|
|
||||||
start.elapsed().as_millis() as u64,
|
start.elapsed().as_millis() as u64,
|
||||||
)
|
).await {
|
||||||
.await
|
|
||||||
{
|
|
||||||
tracing::warn!("failed to persist LLM call: {}", e);
|
tracing::warn!("failed to persist LLM call: {}", e);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -413,9 +386,6 @@ mod tests {
|
|||||||
assert_eq!(tool_calls[0]["id"], "call_1");
|
assert_eq!(tool_calls[0]["id"], "call_1");
|
||||||
assert_eq!(tool_calls[0]["type"], "function");
|
assert_eq!(tool_calls[0]["type"], "function");
|
||||||
assert_eq!(tool_calls[0]["function"]["name"], "calculator");
|
assert_eq!(tool_calls[0]["function"]["name"], "calculator");
|
||||||
assert_eq!(
|
assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}");
|
||||||
tool_calls[0]["function"]["arguments"],
|
|
||||||
"{\"expression\":\"1+1\"}"
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
use crate::bus::message::ContentBlock;
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use crate::bus::message::ContentBlock;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct Message {
|
pub struct Message {
|
||||||
@ -61,11 +61,7 @@ impl Message {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tool(
|
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self {
|
||||||
tool_call_id: impl Into<String>,
|
|
||||||
tool_name: impl Into<String>,
|
|
||||||
content: impl Into<String>,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
Self {
|
||||||
role: "tool".to_string(),
|
role: "tool".to_string(),
|
||||||
content: vec![ContentBlock::text(content)],
|
content: vec![ContentBlock::text(content)],
|
||||||
|
|||||||
@ -5,11 +5,11 @@ use std::time::Instant;
|
|||||||
use tokio::time;
|
use tokio::time;
|
||||||
|
|
||||||
use crate::config::SchedulerConfig;
|
use crate::config::SchedulerConfig;
|
||||||
use crate::session::SessionManager;
|
|
||||||
use crate::session::session::HandleResult;
|
use crate::session::session::HandleResult;
|
||||||
use crate::storage::JobRun;
|
use crate::session::SessionManager;
|
||||||
use crate::storage::ScheduledJob;
|
use crate::storage::ScheduledJob;
|
||||||
use crate::storage::Storage;
|
use crate::storage::Storage;
|
||||||
|
use crate::storage::JobRun;
|
||||||
|
|
||||||
pub use types::Schedule;
|
pub use types::Schedule;
|
||||||
|
|
||||||
@ -89,11 +89,7 @@ impl Scheduler {
|
|||||||
|
|
||||||
let now = now_ms();
|
let now = now_ms();
|
||||||
|
|
||||||
let due = match self
|
let due = match self.storage.due_scheduled_jobs(now, self.config.max_concurrent).await {
|
||||||
.storage
|
|
||||||
.due_scheduled_jobs(now, self.config.max_concurrent)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(jobs) => jobs,
|
Ok(jobs) => jobs,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!("scheduler: failed to query due jobs: {}", e);
|
tracing::error!("scheduler: failed to query due jobs: {}", e);
|
||||||
@ -111,11 +107,7 @@ impl Scheduler {
|
|||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
let started_at = now_ms();
|
let started_at = now_ms();
|
||||||
|
|
||||||
if let Err(e) = self
|
if let Err(e) = self.storage.touch_scheduled_job_last_run(&job.id, started_at).await {
|
||||||
.storage
|
|
||||||
.touch_scheduled_job_last_run(&job.id, started_at)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
tracing::error!(job_id = %job.id, "scheduler: failed to touch last_run_at: {}", e);
|
tracing::error!(job_id = %job.id, "scheduler: failed to touch last_run_at: {}", e);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -143,10 +135,7 @@ impl Scheduler {
|
|||||||
match result {
|
match result {
|
||||||
Ok(HandleResult::AgentResponse(output)) => {
|
Ok(HandleResult::AgentResponse(output)) => {
|
||||||
let output_truncated = if output.len() > 8000 {
|
let output_truncated = if output.len() > 8000 {
|
||||||
format!(
|
format!("{}...[truncated]", &output[..output.ceil_char_boundary(8000)])
|
||||||
"{}...[truncated]",
|
|
||||||
&output[..output.ceil_char_boundary(8000)]
|
|
||||||
)
|
|
||||||
} else {
|
} else {
|
||||||
output.clone()
|
output.clone()
|
||||||
};
|
};
|
||||||
@ -166,11 +155,7 @@ impl Scheduler {
|
|||||||
tracing::error!(job_id = %job.id, "scheduler: failed to record run: {}", e);
|
tracing::error!(job_id = %job.id, "scheduler: failed to record run: {}", e);
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Err(e) = self
|
if let Err(e) = self.storage.set_scheduled_job_last_status(&job.id, "ok", None).await {
|
||||||
.storage
|
|
||||||
.set_scheduled_job_last_status(&job.id, "ok", None)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
tracing::error!(job_id = %job.id, "scheduler: failed to set last_status: {}", e);
|
tracing::error!(job_id = %job.id, "scheduler: failed to set last_status: {}", e);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -214,11 +199,9 @@ impl Scheduler {
|
|||||||
tracing::error!(job_id = %job.id, "scheduler: failed to record error run: {}", e2);
|
tracing::error!(job_id = %job.id, "scheduler: failed to record error run: {}", e2);
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Err(e2) = self
|
if let Err(e2) = self.storage.set_scheduled_job_last_status(
|
||||||
.storage
|
&job.id, "error", Some(&error_str),
|
||||||
.set_scheduled_job_last_status(&job.id, "error", Some(&error_str))
|
).await {
|
||||||
.await
|
|
||||||
{
|
|
||||||
tracing::error!(job_id = %job.id, "scheduler: failed to set error status: {}", e2);
|
tracing::error!(job_id = %job.id, "scheduler: failed to set error status: {}", e2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -248,23 +231,17 @@ impl Scheduler {
|
|||||||
self.storage.remove_scheduled_job(&job.id).await?;
|
self.storage.remove_scheduled_job(&job.id).await?;
|
||||||
tracing::info!(job_id = %job.id, "scheduler: one-shot job deleted after run");
|
tracing::info!(job_id = %job.id, "scheduler: one-shot job deleted after run");
|
||||||
} else {
|
} else {
|
||||||
self.storage
|
self.storage.set_scheduled_job_enabled(&job.id, false).await?;
|
||||||
.set_scheduled_job_enabled(&job.id, false)
|
|
||||||
.await?;
|
|
||||||
tracing::info!(job_id = %job.id, "scheduler: one-shot job disabled after run");
|
tracing::info!(job_id = %job.id, "scheduler: one-shot job disabled after run");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Schedule::Every { .. } | Schedule::Cron { .. } => {
|
Schedule::Every { .. } | Schedule::Cron { .. } => {
|
||||||
if let Some(next) = next_run_for_schedule(&job.schedule, now) {
|
if let Some(next) = next_run_for_schedule(&job.schedule, now) {
|
||||||
self.storage
|
self.storage.set_scheduled_job_next_run(&job.id, next).await?;
|
||||||
.set_scheduled_job_next_run(&job.id, next)
|
|
||||||
.await?;
|
|
||||||
tracing::info!(job_id = %job.id, next_run_at = %next, "scheduler: job rescheduled");
|
tracing::info!(job_id = %job.id, next_run_at = %next, "scheduler: job rescheduled");
|
||||||
} else {
|
} else {
|
||||||
tracing::error!(job_id = %job.id, "scheduler: could not compute next run -- disabling job");
|
tracing::error!(job_id = %job.id, "scheduler: could not compute next run -- disabling job");
|
||||||
self.storage
|
self.storage.set_scheduled_job_enabled(&job.id, false).await?;
|
||||||
.set_scheduled_job_enabled(&job.id, false)
|
|
||||||
.await?;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -22,20 +22,32 @@ pub enum SessionCommand {
|
|||||||
dialog_id: String,
|
dialog_id: String,
|
||||||
},
|
},
|
||||||
/// Get the current dialog for a chat
|
/// Get the current dialog for a chat
|
||||||
GetCurrentDialog { channel: String, chat_id: String },
|
GetCurrentDialog {
|
||||||
|
channel: String,
|
||||||
|
chat_id: String,
|
||||||
|
},
|
||||||
/// Rename a dialog
|
/// Rename a dialog
|
||||||
RenameDialog {
|
RenameDialog {
|
||||||
session_id: UnifiedSessionId,
|
session_id: UnifiedSessionId,
|
||||||
title: String,
|
title: String,
|
||||||
},
|
},
|
||||||
/// Archive a dialog
|
/// Archive a dialog
|
||||||
ArchiveDialog { session_id: UnifiedSessionId },
|
ArchiveDialog {
|
||||||
|
session_id: UnifiedSessionId,
|
||||||
|
},
|
||||||
/// Delete a dialog
|
/// Delete a dialog
|
||||||
DeleteDialog { session_id: UnifiedSessionId },
|
DeleteDialog {
|
||||||
|
session_id: UnifiedSessionId,
|
||||||
|
},
|
||||||
/// Clear dialog history
|
/// Clear dialog history
|
||||||
ClearHistory { session_id: UnifiedSessionId },
|
ClearHistory {
|
||||||
|
session_id: UnifiedSessionId,
|
||||||
|
},
|
||||||
/// Get list of available slash commands
|
/// Get list of available slash commands
|
||||||
GetSlashCommands { channel: String, chat_id: String },
|
GetSlashCommands {
|
||||||
|
channel: String,
|
||||||
|
chat_id: String,
|
||||||
|
},
|
||||||
/// Execute a slash command
|
/// Execute a slash command
|
||||||
ExecuteSlashCommand {
|
ExecuteSlashCommand {
|
||||||
command: String,
|
command: String,
|
||||||
@ -48,11 +60,7 @@ pub enum SessionCommand {
|
|||||||
|
|
||||||
impl SessionCommand {
|
impl SessionCommand {
|
||||||
/// Create a CreateDialog command
|
/// Create a CreateDialog command
|
||||||
pub fn create_dialog(
|
pub fn create_dialog(channel: impl Into<String>, chat_id: impl Into<String>, title: Option<String>) -> Self {
|
||||||
channel: impl Into<String>,
|
|
||||||
chat_id: impl Into<String>,
|
|
||||||
title: Option<String>,
|
|
||||||
) -> Self {
|
|
||||||
Self::CreateDialog {
|
Self::CreateDialog {
|
||||||
channel: channel.into(),
|
channel: channel.into(),
|
||||||
chat_id: chat_id.into(),
|
chat_id: chat_id.into(),
|
||||||
@ -61,11 +69,7 @@ impl SessionCommand {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Create a ListDialogs command
|
/// Create a ListDialogs command
|
||||||
pub fn list_dialogs(
|
pub fn list_dialogs(channel: impl Into<String>, chat_id: impl Into<String>, include_archived: bool) -> Self {
|
||||||
channel: impl Into<String>,
|
|
||||||
chat_id: impl Into<String>,
|
|
||||||
include_archived: bool,
|
|
||||||
) -> Self {
|
|
||||||
Self::ListDialogs {
|
Self::ListDialogs {
|
||||||
channel: channel.into(),
|
channel: channel.into(),
|
||||||
chat_id: chat_id.into(),
|
chat_id: chat_id.into(),
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
use super::session::SlashCommand;
|
|
||||||
use super::session_id::UnifiedSessionId;
|
use super::session_id::UnifiedSessionId;
|
||||||
|
use super::session::SlashCommand;
|
||||||
|
|
||||||
/// Dialog information returned by SessionManager
|
/// Dialog information returned by SessionManager
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -30,20 +30,30 @@ pub enum SessionEvent {
|
|||||||
session_id: Option<UnifiedSessionId>,
|
session_id: Option<UnifiedSessionId>,
|
||||||
},
|
},
|
||||||
/// Dialog switched successfully
|
/// Dialog switched successfully
|
||||||
DialogSwitched { session_id: UnifiedSessionId },
|
DialogSwitched {
|
||||||
|
session_id: UnifiedSessionId,
|
||||||
|
},
|
||||||
/// Dialog renamed
|
/// Dialog renamed
|
||||||
DialogRenamed {
|
DialogRenamed {
|
||||||
session_id: UnifiedSessionId,
|
session_id: UnifiedSessionId,
|
||||||
title: String,
|
title: String,
|
||||||
},
|
},
|
||||||
/// Dialog archived
|
/// Dialog archived
|
||||||
DialogArchived { session_id: UnifiedSessionId },
|
DialogArchived {
|
||||||
|
session_id: UnifiedSessionId,
|
||||||
|
},
|
||||||
/// Dialog deleted
|
/// Dialog deleted
|
||||||
DialogDeleted { session_id: UnifiedSessionId },
|
DialogDeleted {
|
||||||
|
session_id: UnifiedSessionId,
|
||||||
|
},
|
||||||
/// Dialog history cleared
|
/// Dialog history cleared
|
||||||
HistoryCleared { session_id: UnifiedSessionId },
|
HistoryCleared {
|
||||||
|
session_id: UnifiedSessionId,
|
||||||
|
},
|
||||||
/// List of available slash commands
|
/// List of available slash commands
|
||||||
SlashCommandsList { commands: Vec<SlashCommand> },
|
SlashCommandsList {
|
||||||
|
commands: Vec<SlashCommand>,
|
||||||
|
},
|
||||||
/// Slash command executed successfully
|
/// Slash command executed successfully
|
||||||
SlashCommandExecuted {
|
SlashCommandExecuted {
|
||||||
new_session_id: Option<UnifiedSessionId>,
|
new_session_id: Option<UnifiedSessionId>,
|
||||||
@ -60,5 +70,8 @@ pub enum SessionEvent {
|
|||||||
message_count: usize,
|
message_count: usize,
|
||||||
},
|
},
|
||||||
/// Error occurred
|
/// Error occurred
|
||||||
Error { code: String, message: String },
|
Error {
|
||||||
|
code: String,
|
||||||
|
message: String,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
pub mod commands;
|
|
||||||
pub mod error;
|
pub mod error;
|
||||||
|
pub mod commands;
|
||||||
pub mod events;
|
pub mod events;
|
||||||
pub mod session;
|
pub mod session;
|
||||||
pub mod session_id;
|
pub mod session_id;
|
||||||
|
|
||||||
pub use commands::SessionCommand;
|
|
||||||
pub use error::SessionError;
|
pub use error::SessionError;
|
||||||
pub use events::{DialogInfo, SessionEvent};
|
pub use commands::SessionCommand;
|
||||||
pub use session::{SLASH_COMMANDS, Session, SessionManager, SlashCommand};
|
pub use events::{SessionEvent, DialogInfo};
|
||||||
|
pub use session::{Session, SessionManager, SlashCommand, SLASH_COMMANDS};
|
||||||
pub use session_id::UnifiedSessionId;
|
pub use session_id::UnifiedSessionId;
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -8,6 +8,7 @@
|
|||||||
///
|
///
|
||||||
/// For simple cases where only one dialog exists per chat:
|
/// For simple cases where only one dialog exists per chat:
|
||||||
/// - `dialog_id` defaults to `"default"`
|
/// - `dialog_id` defaults to `"default"`
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
pub const DEFAULT_DIALOG_ID: &str = "default";
|
pub const DEFAULT_DIALOG_ID: &str = "default";
|
||||||
@ -21,11 +22,7 @@ pub struct UnifiedSessionId {
|
|||||||
|
|
||||||
impl UnifiedSessionId {
|
impl UnifiedSessionId {
|
||||||
/// Create a new UnifiedSessionId
|
/// Create a new UnifiedSessionId
|
||||||
pub fn new(
|
pub fn new(channel: impl Into<String>, chat_id: impl Into<String>, dialog_id: impl Into<String>) -> Self {
|
||||||
channel: impl Into<String>,
|
|
||||||
chat_id: impl Into<String>,
|
|
||||||
dialog_id: impl Into<String>,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
Self {
|
||||||
channel: channel.into(),
|
channel: channel.into(),
|
||||||
chat_id: chat_id.into(),
|
chat_id: chat_id.into(),
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
use super::embedded::{EMBEDDED_SKILLS, EmbeddedSkill};
|
use super::embedded::{EmbeddedSkill, EMBEDDED_SKILLS};
|
||||||
|
|
||||||
pub fn install_builtin_skills(target_dir: &Path) {
|
pub fn install_builtin_skills(target_dir: &Path) {
|
||||||
for skill in EMBEDDED_SKILLS {
|
for skill in EMBEDDED_SKILLS {
|
||||||
@ -22,7 +22,8 @@ pub fn install_builtin_skills(target_dir: &Path) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn install_one(skill: &EmbeddedSkill, target_dir: &Path) -> Result<(), String> {
|
fn install_one(skill: &EmbeddedSkill, target_dir: &Path) -> Result<(), String> {
|
||||||
let decompressed = zstd::decode_all(skill.data).map_err(|e| format!("zstd decode: {}", e))?;
|
let decompressed = zstd::decode_all(skill.data)
|
||||||
|
.map_err(|e| format!("zstd decode: {}", e))?;
|
||||||
|
|
||||||
let mut archive = tar::Archive::new(decompressed.as_slice());
|
let mut archive = tar::Archive::new(decompressed.as_slice());
|
||||||
archive
|
archive
|
||||||
|
|||||||
@ -120,11 +120,7 @@ impl SkillsLoader {
|
|||||||
let count = loaded.len();
|
let count = loaded.len();
|
||||||
let mut replaced = 0usize;
|
let mut replaced = 0usize;
|
||||||
for skill in loaded {
|
for skill in loaded {
|
||||||
if let Some(existing) = state
|
if let Some(existing) = state.loaded_skills.iter_mut().find(|s| s.name == skill.name) {
|
||||||
.loaded_skills
|
|
||||||
.iter_mut()
|
|
||||||
.find(|s| s.name == skill.name)
|
|
||||||
{
|
|
||||||
*existing = skill;
|
*existing = skill;
|
||||||
replaced += 1;
|
replaced += 1;
|
||||||
} else {
|
} else {
|
||||||
@ -142,17 +138,12 @@ impl SkillsLoader {
|
|||||||
|
|
||||||
// Load from workspace skills dir (highest priority) — replace same-name skills
|
// Load from workspace skills dir (highest priority) — replace same-name skills
|
||||||
if let Some(ref ws_dir) = self.workspace_skills_dir
|
if let Some(ref ws_dir) = self.workspace_skills_dir
|
||||||
&& ws_dir.exists()
|
&& ws_dir.exists() {
|
||||||
{
|
|
||||||
let loaded = self.load_skills_from_dir(ws_dir);
|
let loaded = self.load_skills_from_dir(ws_dir);
|
||||||
let count = loaded.len();
|
let count = loaded.len();
|
||||||
let mut replaced = 0usize;
|
let mut replaced = 0usize;
|
||||||
for skill in loaded {
|
for skill in loaded {
|
||||||
if let Some(existing) = state
|
if let Some(existing) = state.loaded_skills.iter_mut().find(|s| s.name == skill.name) {
|
||||||
.loaded_skills
|
|
||||||
.iter_mut()
|
|
||||||
.find(|s| s.name == skill.name)
|
|
||||||
{
|
|
||||||
*existing = skill;
|
*existing = skill;
|
||||||
replaced += 1;
|
replaced += 1;
|
||||||
} else {
|
} else {
|
||||||
@ -173,11 +164,7 @@ impl SkillsLoader {
|
|||||||
if state.loaded_skills.is_empty() {
|
if state.loaded_skills.is_empty() {
|
||||||
tracing::debug!("No skills found in any skills directory");
|
tracing::debug!("No skills found in any skills directory");
|
||||||
} else {
|
} else {
|
||||||
tracing::info!(
|
tracing::info!(count = state.loaded_skills.len(), "Loaded {} skills total", state.loaded_skills.len());
|
||||||
count = state.loaded_skills.len(),
|
|
||||||
"Loaded {} skills total",
|
|
||||||
state.loaded_skills.len()
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -228,8 +215,7 @@ impl SkillsLoader {
|
|||||||
let mut max_mtime = None;
|
let mut max_mtime = None;
|
||||||
|
|
||||||
if let Ok(metadata) = std::fs::metadata(dir)
|
if let Ok(metadata) = std::fs::metadata(dir)
|
||||||
&& let Ok(mtime) = metadata.modified()
|
&& let Ok(mtime) = metadata.modified() {
|
||||||
{
|
|
||||||
max_mtime = Some(mtime);
|
max_mtime = Some(mtime);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -238,8 +224,7 @@ impl SkillsLoader {
|
|||||||
let path = entry.path();
|
let path = entry.path();
|
||||||
if let Ok(metadata) = std::fs::metadata(&path)
|
if let Ok(metadata) = std::fs::metadata(&path)
|
||||||
&& let Ok(mtime) = metadata.modified()
|
&& let Ok(mtime) = metadata.modified()
|
||||||
&& max_mtime.is_none_or(|current| mtime > current)
|
&& max_mtime.is_none_or(|current| mtime > current) {
|
||||||
{
|
|
||||||
max_mtime = Some(mtime);
|
max_mtime = Some(mtime);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -259,12 +244,7 @@ impl SkillsLoader {
|
|||||||
pub fn get_always_skills(&self) -> Vec<Skill> {
|
pub fn get_always_skills(&self) -> Vec<Skill> {
|
||||||
self.reload_if_changed();
|
self.reload_if_changed();
|
||||||
let state = self.state.lock().unwrap();
|
let state = self.state.lock().unwrap();
|
||||||
state
|
state.loaded_skills.iter().filter(|s| s.always).cloned().collect()
|
||||||
.loaded_skills
|
|
||||||
.iter()
|
|
||||||
.filter(|s| s.always)
|
|
||||||
.cloned()
|
|
||||||
.collect()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get a specific skill by name (checks for changes first)
|
/// Get a specific skill by name (checks for changes first)
|
||||||
@ -278,8 +258,7 @@ impl SkillsLoader {
|
|||||||
pub fn list_skills(&self) -> Vec<(String, String)> {
|
pub fn list_skills(&self) -> Vec<(String, String)> {
|
||||||
self.reload_if_changed();
|
self.reload_if_changed();
|
||||||
let state = self.state.lock().unwrap();
|
let state = self.state.lock().unwrap();
|
||||||
state
|
state.loaded_skills
|
||||||
.loaded_skills
|
|
||||||
.iter()
|
.iter()
|
||||||
.map(|s| (s.name.clone(), s.description.clone()))
|
.map(|s| (s.name.clone(), s.description.clone()))
|
||||||
.collect()
|
.collect()
|
||||||
@ -300,21 +279,15 @@ impl SkillsLoader {
|
|||||||
prompt.push_str("### 目录说明\n\n");
|
prompt.push_str("### 目录说明\n\n");
|
||||||
prompt.push_str("- `~/.agents/skills/` — 外部共享 skill 目录(第三方、系统级 skill)\n");
|
prompt.push_str("- `~/.agents/skills/` — 外部共享 skill 目录(第三方、系统级 skill)\n");
|
||||||
prompt.push_str("- `~/.picobot/skills/` — 安装 skill 的默认目录\n");
|
prompt.push_str("- `~/.picobot/skills/` — 安装 skill 的默认目录\n");
|
||||||
prompt.push_str(
|
prompt.push_str("- `{workspace}/skills/` — 工作目录下的 skill,picobot 自行创建的 skill 存放于此\n\n");
|
||||||
"- `{workspace}/skills/` — 工作目录下的 skill,picobot 自行创建的 skill 存放于此\n\n",
|
prompt.push_str("安装或创建 skill 时请按上述目录规范存放,创建skill时不要和已有skill同名。\n\n");
|
||||||
);
|
|
||||||
prompt.push_str(
|
|
||||||
"安装或创建 skill 时请按上述目录规范存放,创建skill时不要和已有skill同名。\n\n",
|
|
||||||
);
|
|
||||||
|
|
||||||
// Always skills summary
|
// Always skills summary
|
||||||
let always_skills: Vec<_> = state.loaded_skills.iter().filter(|s| s.always).collect();
|
let always_skills: Vec<_> = state.loaded_skills.iter().filter(|s| s.always).collect();
|
||||||
if !always_skills.is_empty() {
|
if !always_skills.is_empty() {
|
||||||
prompt.push_str("### 常用技能\n\n");
|
prompt.push_str("### 常用技能\n\n");
|
||||||
for skill in &always_skills {
|
for skill in &always_skills {
|
||||||
let path_str = skill
|
let path_str = skill.path.as_ref()
|
||||||
.path
|
|
||||||
.as_ref()
|
|
||||||
.map(|p| p.to_string_lossy().to_string())
|
.map(|p| p.to_string_lossy().to_string())
|
||||||
.unwrap_or_else(|| "—".to_string());
|
.unwrap_or_else(|| "—".to_string());
|
||||||
prompt.push_str(&format!(
|
prompt.push_str(&format!(
|
||||||
@ -327,12 +300,8 @@ impl SkillsLoader {
|
|||||||
|
|
||||||
// Usage instructions
|
// Usage instructions
|
||||||
prompt.push_str("### 使用方法\n\n");
|
prompt.push_str("### 使用方法\n\n");
|
||||||
prompt.push_str(
|
prompt.push_str("- 使用 `get_skill` 工具 action=\"list\" 列出所有可用 skill 及其名称、简介、路径\n");
|
||||||
"- 使用 `get_skill` 工具 action=\"list\" 列出所有可用 skill 及其名称、简介、路径\n",
|
prompt.push_str("- 使用 `get_skill` 工具 action=\"get\" 并提供 `skill_name` 获取指定 skill 完整内容\n");
|
||||||
);
|
|
||||||
prompt.push_str(
|
|
||||||
"- 使用 `get_skill` 工具 action=\"get\" 并提供 `skill_name` 获取指定 skill 完整内容\n",
|
|
||||||
);
|
|
||||||
prompt.push_str("- 当用户询问已安装的 skill 相关信息时,需重新调用 get_skill 工具查询最新内容,避免 skill 已变更导致信息过时\n");
|
prompt.push_str("- 当用户询问已安装的 skill 相关信息时,需重新调用 get_skill 工具查询最新内容,避免 skill 已变更导致信息过时\n");
|
||||||
|
|
||||||
// Always skills full content
|
// Always skills full content
|
||||||
@ -369,7 +338,8 @@ impl SkillsLoader {
|
|||||||
}
|
}
|
||||||
|
|
||||||
match std::fs::read_to_string(&skill_file) {
|
match std::fs::read_to_string(&skill_file) {
|
||||||
Ok(content) => match self.parse_skill(&path, &content) {
|
Ok(content) => {
|
||||||
|
match self.parse_skill(&path, &content) {
|
||||||
Some(skill) => {
|
Some(skill) => {
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
skill = %skill.name,
|
skill = %skill.name,
|
||||||
@ -385,7 +355,8 @@ impl SkillsLoader {
|
|||||||
"Failed to parse skill"
|
"Failed to parse skill"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::warn!(
|
tracing::warn!(
|
||||||
path = %skill_file.display(),
|
path = %skill_file.display(),
|
||||||
@ -476,6 +447,7 @@ impl Default for SkillsLoader {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// Extract first non-empty, non-heading line as description
|
/// Extract first non-empty, non-heading line as description
|
||||||
fn extract_description(content: &str) -> String {
|
fn extract_description(content: &str) -> String {
|
||||||
content
|
content
|
||||||
|
|||||||
@ -1,19 +0,0 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct BackgroundTask {
|
|
||||||
pub id: String,
|
|
||||||
pub session_id: String,
|
|
||||||
pub channel: String,
|
|
||||||
pub chat_id: String,
|
|
||||||
pub prompt: String,
|
|
||||||
pub allowed_tools: Option<String>,
|
|
||||||
pub status: String,
|
|
||||||
pub result: Option<String>,
|
|
||||||
pub error: Option<String>,
|
|
||||||
pub tool_calls_count: i64,
|
|
||||||
pub iterations: i64,
|
|
||||||
pub started_at: Option<i64>,
|
|
||||||
pub finished_at: Option<i64>,
|
|
||||||
pub created_at: i64,
|
|
||||||
}
|
|
||||||
@ -241,8 +241,9 @@ impl super::Storage {
|
|||||||
let cutoff = chrono::Utc::now() - chrono::Duration::days(retention_days as i64);
|
let cutoff = chrono::Utc::now() - chrono::Duration::days(retention_days as i64);
|
||||||
let cutoff_str = cutoff.to_rfc3339();
|
let cutoff_str = cutoff.to_rfc3339();
|
||||||
|
|
||||||
let result =
|
let result = sqlx::query(
|
||||||
sqlx::query("DELETE FROM memories WHERE category = 'timeline' AND created_at < ?")
|
"DELETE FROM memories WHERE category = 'timeline' AND created_at < ?",
|
||||||
|
)
|
||||||
.bind(&cutoff_str)
|
.bind(&cutoff_str)
|
||||||
.execute(self.pool())
|
.execute(self.pool())
|
||||||
.await?;
|
.await?;
|
||||||
@ -275,7 +276,9 @@ impl super::Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_memory_rows(rows: &[sqlx::sqlite::SqliteRow]) -> Result<Vec<MemoryEntry>, StorageError> {
|
fn parse_memory_rows(
|
||||||
|
rows: &[sqlx::sqlite::SqliteRow],
|
||||||
|
) -> Result<Vec<MemoryEntry>, StorageError> {
|
||||||
rows.iter()
|
rows.iter()
|
||||||
.map(|row| {
|
.map(|row| {
|
||||||
Ok(MemoryEntry {
|
Ok(MemoryEntry {
|
||||||
|
|||||||
@ -1,17 +1,15 @@
|
|||||||
pub mod background_task;
|
|
||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod memory;
|
pub mod memory;
|
||||||
pub mod message;
|
pub mod message;
|
||||||
pub mod scheduler;
|
pub mod scheduler;
|
||||||
pub mod session;
|
pub mod session;
|
||||||
|
|
||||||
pub use background_task::BackgroundTask;
|
|
||||||
pub use error::StorageError;
|
pub use error::StorageError;
|
||||||
pub use scheduler::{JobRun, ScheduledJob};
|
pub use scheduler::{JobRun, ScheduledJob};
|
||||||
|
|
||||||
use sqlx::{Pool, Row, Sqlite, SqlitePool};
|
use sqlx::{Pool, Row, Sqlite, SqlitePool};
|
||||||
|
use tokio::time::{sleep, Duration};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use tokio::time::{Duration, sleep};
|
|
||||||
|
|
||||||
pub struct Storage {
|
pub struct Storage {
|
||||||
pub(crate) pool: Pool<Sqlite>,
|
pub(crate) pool: Pool<Sqlite>,
|
||||||
@ -42,7 +40,6 @@ impl Storage {
|
|||||||
last_active_at INTEGER NOT NULL,
|
last_active_at INTEGER NOT NULL,
|
||||||
message_count INTEGER DEFAULT 0,
|
message_count INTEGER DEFAULT 0,
|
||||||
routing_info TEXT,
|
routing_info TEXT,
|
||||||
archived_at INTEGER,
|
|
||||||
deleted_at INTEGER,
|
deleted_at INTEGER,
|
||||||
last_consolidated_at INTEGER,
|
last_consolidated_at INTEGER,
|
||||||
last_compressed_message_at INTEGER,
|
last_compressed_message_at INTEGER,
|
||||||
@ -93,59 +90,21 @@ impl Storage {
|
|||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// Migration: add source column if upgrading from older schema
|
// Migration: add source column if upgrading from older schema
|
||||||
sqlx::query(r#"ALTER TABLE messages ADD COLUMN source TEXT"#)
|
sqlx::query(
|
||||||
|
r#"ALTER TABLE messages ADD COLUMN source TEXT"#,
|
||||||
|
)
|
||||||
.execute(&self.pool)
|
.execute(&self.pool)
|
||||||
.await
|
.await
|
||||||
.ok();
|
.ok();
|
||||||
|
|
||||||
// Migration: add reasoning_content column if upgrading from older schema
|
// Migration: add reasoning_content column if upgrading from older schema
|
||||||
sqlx::query(r#"ALTER TABLE messages ADD COLUMN reasoning_content TEXT"#)
|
sqlx::query(
|
||||||
|
r#"ALTER TABLE messages ADD COLUMN reasoning_content TEXT"#,
|
||||||
|
)
|
||||||
.execute(&self.pool)
|
.execute(&self.pool)
|
||||||
.await
|
.await
|
||||||
.ok();
|
.ok();
|
||||||
|
|
||||||
// Background tasks table — for async sub-agent tasks.
|
|
||||||
// Note: No FOREIGN KEY on session_id because sessions use soft delete (deleted_at IS NULL).
|
|
||||||
// Session and task association is maintained at the application level.
|
|
||||||
sqlx::query(
|
|
||||||
r#"
|
|
||||||
CREATE TABLE IF NOT EXISTS background_tasks (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
session_id TEXT NOT NULL,
|
|
||||||
channel TEXT NOT NULL,
|
|
||||||
chat_id TEXT NOT NULL,
|
|
||||||
prompt TEXT NOT NULL,
|
|
||||||
allowed_tools TEXT,
|
|
||||||
status TEXT NOT NULL DEFAULT 'pending',
|
|
||||||
result TEXT,
|
|
||||||
error TEXT,
|
|
||||||
tool_calls_count INTEGER DEFAULT 0,
|
|
||||||
iterations INTEGER DEFAULT 0,
|
|
||||||
started_at INTEGER,
|
|
||||||
finished_at INTEGER,
|
|
||||||
created_at INTEGER NOT NULL
|
|
||||||
)
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.execute(&self.pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
sqlx::query(
|
|
||||||
r#"
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_bg_tasks_session ON background_tasks(session_id)
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.execute(&self.pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
sqlx::query(
|
|
||||||
r#"
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_bg_tasks_status ON background_tasks(status)
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.execute(&self.pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
r#"
|
r#"
|
||||||
CREATE TABLE IF NOT EXISTS memories (
|
CREATE TABLE IF NOT EXISTS memories (
|
||||||
@ -213,19 +172,11 @@ impl Storage {
|
|||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// Rebuild FTS5 index for any existing records
|
// Rebuild FTS5 index for any existing records
|
||||||
sqlx::query("INSERT INTO memory_fts(memory_fts) VALUES ('rebuild')")
|
|
||||||
.execute(&self.pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Migration: add last_consolidated_at column if not exists
|
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
r#"
|
"INSERT INTO memory_fts(memory_fts) VALUES ('rebuild')",
|
||||||
ALTER TABLE sessions ADD COLUMN archived_at INTEGER
|
|
||||||
"#,
|
|
||||||
)
|
)
|
||||||
.execute(&self.pool)
|
.execute(&self.pool)
|
||||||
.await
|
.await?;
|
||||||
.ok();
|
|
||||||
|
|
||||||
// Migration: add last_consolidated_at column if not exists
|
// Migration: add last_consolidated_at column if not exists
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
@ -265,10 +216,7 @@ impl Storage {
|
|||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
if let Err(e) = Self::init_scheduler_schema(&self.pool).await {
|
if let Err(e) = Self::init_scheduler_schema(&self.pool).await {
|
||||||
tracing::warn!(
|
tracing::warn!("Failed to init scheduler schema (tables may already exist): {}", e);
|
||||||
"Failed to init scheduler schema (tables may already exist): {}",
|
|
||||||
e
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -382,20 +330,16 @@ impl Storage {
|
|||||||
&self.pool
|
&self.pool
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn upsert_session(
|
pub async fn upsert_session(&self, meta: &crate::storage::session::SessionMeta) -> Result<(), StorageError> {
|
||||||
&self,
|
|
||||||
meta: &crate::storage::session::SessionMeta,
|
|
||||||
) -> Result<(), StorageError> {
|
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
r#"
|
r#"
|
||||||
INSERT INTO sessions (id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at)
|
INSERT INTO sessions (id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
ON CONFLICT(id) DO UPDATE SET
|
ON CONFLICT(id) DO UPDATE SET
|
||||||
title = excluded.title,
|
title = excluded.title,
|
||||||
last_active_at = excluded.last_active_at,
|
last_active_at = excluded.last_active_at,
|
||||||
message_count = excluded.message_count,
|
message_count = excluded.message_count,
|
||||||
routing_info = excluded.routing_info,
|
routing_info = excluded.routing_info,
|
||||||
archived_at = excluded.archived_at,
|
|
||||||
deleted_at = excluded.deleted_at,
|
deleted_at = excluded.deleted_at,
|
||||||
last_consolidated_at = excluded.last_consolidated_at,
|
last_consolidated_at = excluded.last_consolidated_at,
|
||||||
last_compressed_message_at = excluded.last_compressed_message_at
|
last_compressed_message_at = excluded.last_compressed_message_at
|
||||||
@ -410,7 +354,6 @@ impl Storage {
|
|||||||
.bind(meta.last_active_at)
|
.bind(meta.last_active_at)
|
||||||
.bind(meta.message_count)
|
.bind(meta.message_count)
|
||||||
.bind(&meta.routing_info)
|
.bind(&meta.routing_info)
|
||||||
.bind(meta.archived_at)
|
|
||||||
.bind(meta.deleted_at)
|
.bind(meta.deleted_at)
|
||||||
.bind(meta.last_consolidated_at)
|
.bind(meta.last_consolidated_at)
|
||||||
.bind(meta.last_compressed_message_at)
|
.bind(meta.last_compressed_message_at)
|
||||||
@ -420,13 +363,10 @@ impl Storage {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_session(
|
pub async fn get_session(&self, id: &str) -> Result<crate::storage::session::SessionMeta, StorageError> {
|
||||||
&self,
|
|
||||||
id: &str,
|
|
||||||
) -> Result<crate::storage::session::SessionMeta, StorageError> {
|
|
||||||
let row = sqlx::query(
|
let row = sqlx::query(
|
||||||
r#"
|
r#"
|
||||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at
|
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at
|
||||||
FROM sessions WHERE id = ? AND deleted_at IS NULL
|
FROM sessions WHERE id = ? AND deleted_at IS NULL
|
||||||
"#,
|
"#,
|
||||||
)
|
)
|
||||||
@ -445,7 +385,6 @@ impl Storage {
|
|||||||
last_active_at: row.get("last_active_at"),
|
last_active_at: row.get("last_active_at"),
|
||||||
message_count: row.get("message_count"),
|
message_count: row.get("message_count"),
|
||||||
routing_info: row.get("routing_info"),
|
routing_info: row.get("routing_info"),
|
||||||
archived_at: row.get("archived_at"),
|
|
||||||
deleted_at: row.get("deleted_at"),
|
deleted_at: row.get("deleted_at"),
|
||||||
last_consolidated_at: row.get("last_consolidated_at"),
|
last_consolidated_at: row.get("last_consolidated_at"),
|
||||||
last_compressed_message_at: row.get("last_compressed_message_at"),
|
last_compressed_message_at: row.get("last_compressed_message_at"),
|
||||||
@ -457,21 +396,18 @@ impl Storage {
|
|||||||
channel: &str,
|
channel: &str,
|
||||||
chat_id: &str,
|
chat_id: &str,
|
||||||
limit: i64,
|
limit: i64,
|
||||||
include_archived: bool,
|
|
||||||
) -> Result<Vec<crate::storage::session::SessionMeta>, StorageError> {
|
) -> Result<Vec<crate::storage::session::SessionMeta>, StorageError> {
|
||||||
let rows = sqlx::query(
|
let rows = sqlx::query(
|
||||||
r#"
|
r#"
|
||||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at
|
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at
|
||||||
FROM sessions
|
FROM sessions
|
||||||
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL
|
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL
|
||||||
AND (? OR archived_at IS NULL)
|
|
||||||
ORDER BY last_active_at DESC
|
ORDER BY last_active_at DESC
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
"#,
|
"#,
|
||||||
)
|
)
|
||||||
.bind(channel)
|
.bind(channel)
|
||||||
.bind(chat_id)
|
.bind(chat_id)
|
||||||
.bind(include_archived)
|
|
||||||
.bind(limit)
|
.bind(limit)
|
||||||
.fetch_all(self.pool())
|
.fetch_all(self.pool())
|
||||||
.await?;
|
.await?;
|
||||||
@ -488,7 +424,6 @@ impl Storage {
|
|||||||
last_active_at: row.get("last_active_at"),
|
last_active_at: row.get("last_active_at"),
|
||||||
message_count: row.get("message_count"),
|
message_count: row.get("message_count"),
|
||||||
routing_info: row.get("routing_info"),
|
routing_info: row.get("routing_info"),
|
||||||
archived_at: row.get("archived_at"),
|
|
||||||
deleted_at: row.get("deleted_at"),
|
deleted_at: row.get("deleted_at"),
|
||||||
last_consolidated_at: row.get("last_consolidated_at"),
|
last_consolidated_at: row.get("last_consolidated_at"),
|
||||||
last_compressed_message_at: row.get("last_compressed_message_at"),
|
last_compressed_message_at: row.get("last_compressed_message_at"),
|
||||||
@ -519,18 +454,9 @@ impl Storage {
|
|||||||
|
|
||||||
pub async fn soft_delete_session(&self, id: &str) -> Result<(), StorageError> {
|
pub async fn soft_delete_session(&self, id: &str) -> Result<(), StorageError> {
|
||||||
let now = chrono::Utc::now().timestamp_millis();
|
let now = chrono::Utc::now().timestamp_millis();
|
||||||
sqlx::query(r#"UPDATE sessions SET deleted_at = ? WHERE id = ?"#)
|
sqlx::query(
|
||||||
.bind(now)
|
r#"UPDATE sessions SET deleted_at = ? WHERE id = ?"#,
|
||||||
.bind(id)
|
)
|
||||||
.execute(self.pool())
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn archive_session(&self, id: &str) -> Result<(), StorageError> {
|
|
||||||
let now = chrono::Utc::now().timestamp_millis();
|
|
||||||
sqlx::query(r#"UPDATE sessions SET archived_at = ? WHERE id = ? AND deleted_at IS NULL"#)
|
|
||||||
.bind(now)
|
.bind(now)
|
||||||
.bind(id)
|
.bind(id)
|
||||||
.execute(self.pool())
|
.execute(self.pool())
|
||||||
@ -546,9 +472,9 @@ impl Storage {
|
|||||||
) -> Result<Option<crate::storage::session::SessionMeta>, StorageError> {
|
) -> Result<Option<crate::storage::session::SessionMeta>, StorageError> {
|
||||||
let row = sqlx::query(
|
let row = sqlx::query(
|
||||||
r#"
|
r#"
|
||||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at
|
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at
|
||||||
FROM sessions
|
FROM sessions
|
||||||
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL AND archived_at IS NULL
|
WHERE channel = ? AND chat_id = ? AND deleted_at IS NULL
|
||||||
ORDER BY last_active_at DESC
|
ORDER BY last_active_at DESC
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
"#,
|
"#,
|
||||||
@ -569,7 +495,6 @@ impl Storage {
|
|||||||
last_active_at: row.get("last_active_at"),
|
last_active_at: row.get("last_active_at"),
|
||||||
message_count: row.get("message_count"),
|
message_count: row.get("message_count"),
|
||||||
routing_info: row.get("routing_info"),
|
routing_info: row.get("routing_info"),
|
||||||
archived_at: row.get("archived_at"),
|
|
||||||
deleted_at: row.get("deleted_at"),
|
deleted_at: row.get("deleted_at"),
|
||||||
last_consolidated_at: row.get("last_consolidated_at"),
|
last_consolidated_at: row.get("last_consolidated_at"),
|
||||||
last_compressed_message_at: row.get("last_compressed_message_at"),
|
last_compressed_message_at: row.get("last_compressed_message_at"),
|
||||||
@ -578,11 +503,7 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn append_message(
|
pub async fn append_message(&self, session_id: &str, msg: &crate::storage::message::MessageMeta) -> Result<i64, StorageError> {
|
||||||
&self,
|
|
||||||
session_id: &str,
|
|
||||||
msg: &crate::storage::message::MessageMeta,
|
|
||||||
) -> Result<i64, StorageError> {
|
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
r#"
|
r#"
|
||||||
INSERT INTO messages (id, session_id, seq, role, content, reasoning_content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at)
|
INSERT INTO messages (id, session_id, seq, role, content, reasoning_content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at)
|
||||||
@ -709,15 +630,16 @@ impl Storage {
|
|||||||
offset: i64,
|
offset: i64,
|
||||||
limit: i64,
|
limit: i64,
|
||||||
) -> Result<(Vec<crate::storage::session::SessionMeta>, i64), StorageError> {
|
) -> Result<(Vec<crate::storage::session::SessionMeta>, i64), StorageError> {
|
||||||
let count_row =
|
let count_row = sqlx::query(
|
||||||
sqlx::query("SELECT COUNT(*) as total FROM sessions WHERE deleted_at IS NULL")
|
"SELECT COUNT(*) as total FROM sessions WHERE deleted_at IS NULL",
|
||||||
|
)
|
||||||
.fetch_one(self.pool())
|
.fetch_one(self.pool())
|
||||||
.await?;
|
.await?;
|
||||||
let total: i64 = count_row.get("total");
|
let total: i64 = count_row.get("total");
|
||||||
|
|
||||||
let rows = sqlx::query(
|
let rows = sqlx::query(
|
||||||
r#"
|
r#"
|
||||||
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, archived_at, deleted_at, last_consolidated_at, last_compressed_message_at
|
SELECT id, channel, chat_id, dialog_id, title, created_at, last_active_at, message_count, routing_info, deleted_at, last_consolidated_at, last_compressed_message_at
|
||||||
FROM sessions
|
FROM sessions
|
||||||
WHERE deleted_at IS NULL
|
WHERE deleted_at IS NULL
|
||||||
ORDER BY last_active_at DESC
|
ORDER BY last_active_at DESC
|
||||||
@ -741,7 +663,6 @@ impl Storage {
|
|||||||
last_active_at: row.get("last_active_at"),
|
last_active_at: row.get("last_active_at"),
|
||||||
message_count: row.get("message_count"),
|
message_count: row.get("message_count"),
|
||||||
routing_info: row.get("routing_info"),
|
routing_info: row.get("routing_info"),
|
||||||
archived_at: row.get("archived_at"),
|
|
||||||
deleted_at: row.get("deleted_at"),
|
deleted_at: row.get("deleted_at"),
|
||||||
last_consolidated_at: row.get("last_consolidated_at"),
|
last_consolidated_at: row.get("last_consolidated_at"),
|
||||||
last_compressed_message_at: row.get("last_compressed_message_at"),
|
last_compressed_message_at: row.get("last_compressed_message_at"),
|
||||||
@ -807,10 +728,7 @@ impl Storage {
|
|||||||
where_extra.push_str(" AND created_at > ?");
|
where_extra.push_str(" AND created_at > ?");
|
||||||
}
|
}
|
||||||
|
|
||||||
let count_sql = format!(
|
let count_sql = format!("SELECT COUNT(*) as total FROM messages WHERE session_id = ?{}", where_extra);
|
||||||
"SELECT COUNT(*) as total FROM messages WHERE session_id = ?{}",
|
|
||||||
where_extra
|
|
||||||
);
|
|
||||||
let select_sql = format!(
|
let select_sql = format!(
|
||||||
r#"
|
r#"
|
||||||
SELECT id, session_id, seq, role, content, reasoning_content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at
|
SELECT id, session_id, seq, role, content, reasoning_content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at
|
||||||
@ -898,148 +816,6 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
unreachable!()
|
unreachable!()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Background Task CRUD ──
|
|
||||||
|
|
||||||
pub async fn create_background_task(
|
|
||||||
&self,
|
|
||||||
task: &crate::storage::background_task::BackgroundTask,
|
|
||||||
) -> Result<(), StorageError> {
|
|
||||||
sqlx::query(
|
|
||||||
r#"
|
|
||||||
INSERT INTO background_tasks (id, session_id, channel, chat_id, prompt, allowed_tools, status, result, error, tool_calls_count, iterations, started_at, finished_at, created_at)
|
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.bind(&task.id)
|
|
||||||
.bind(&task.session_id)
|
|
||||||
.bind(&task.channel)
|
|
||||||
.bind(&task.chat_id)
|
|
||||||
.bind(&task.prompt)
|
|
||||||
.bind(&task.allowed_tools)
|
|
||||||
.bind(&task.status)
|
|
||||||
.bind(&task.result)
|
|
||||||
.bind(&task.error)
|
|
||||||
.bind(task.tool_calls_count)
|
|
||||||
.bind(task.iterations)
|
|
||||||
.bind(task.started_at)
|
|
||||||
.bind(task.finished_at)
|
|
||||||
.bind(task.created_at)
|
|
||||||
.execute(self.pool())
|
|
||||||
.await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn update_background_task_status(
|
|
||||||
&self,
|
|
||||||
id: &str,
|
|
||||||
status: &str,
|
|
||||||
result: Option<&str>,
|
|
||||||
error: Option<&str>,
|
|
||||||
started_at: Option<i64>,
|
|
||||||
finished_at: Option<i64>,
|
|
||||||
) -> Result<(), StorageError> {
|
|
||||||
sqlx::query(
|
|
||||||
r#"
|
|
||||||
UPDATE background_tasks
|
|
||||||
SET status = ?, result = COALESCE(?, result), error = COALESCE(?, error),
|
|
||||||
started_at = COALESCE(?, started_at), finished_at = COALESCE(?, finished_at)
|
|
||||||
WHERE id = ?
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.bind(status)
|
|
||||||
.bind(result)
|
|
||||||
.bind(error)
|
|
||||||
.bind(started_at)
|
|
||||||
.bind(finished_at)
|
|
||||||
.bind(id)
|
|
||||||
.execute(self.pool())
|
|
||||||
.await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn get_background_task(
|
|
||||||
&self,
|
|
||||||
id: &str,
|
|
||||||
) -> Result<crate::storage::background_task::BackgroundTask, StorageError> {
|
|
||||||
let row = sqlx::query(
|
|
||||||
r#"
|
|
||||||
SELECT id, session_id, channel, chat_id, prompt, allowed_tools, status, result, error,
|
|
||||||
tool_calls_count, iterations, started_at, finished_at, created_at
|
|
||||||
FROM background_tasks WHERE id = ?
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.bind(id)
|
|
||||||
.fetch_optional(self.pool())
|
|
||||||
.await?
|
|
||||||
.ok_or_else(|| StorageError::NotFound(id.to_string()))?;
|
|
||||||
|
|
||||||
Ok(crate::storage::background_task::BackgroundTask {
|
|
||||||
id: row.get("id"),
|
|
||||||
session_id: row.get("session_id"),
|
|
||||||
channel: row.get("channel"),
|
|
||||||
chat_id: row.get("chat_id"),
|
|
||||||
prompt: row.get("prompt"),
|
|
||||||
allowed_tools: row.get("allowed_tools"),
|
|
||||||
status: row.get("status"),
|
|
||||||
result: row.get("result"),
|
|
||||||
error: row.get("error"),
|
|
||||||
tool_calls_count: row.get("tool_calls_count"),
|
|
||||||
iterations: row.get("iterations"),
|
|
||||||
started_at: row.get("started_at"),
|
|
||||||
finished_at: row.get("finished_at"),
|
|
||||||
created_at: row.get("created_at"),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn list_background_tasks(
|
|
||||||
&self,
|
|
||||||
session_id: &str,
|
|
||||||
) -> Result<Vec<crate::storage::background_task::BackgroundTask>, StorageError> {
|
|
||||||
let rows = sqlx::query(
|
|
||||||
r#"
|
|
||||||
SELECT id, session_id, channel, chat_id, prompt, allowed_tools, status, result, error,
|
|
||||||
tool_calls_count, iterations, started_at, finished_at, created_at
|
|
||||||
FROM background_tasks
|
|
||||||
WHERE session_id = ?
|
|
||||||
ORDER BY created_at DESC
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.bind(session_id)
|
|
||||||
.fetch_all(self.pool())
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(rows
|
|
||||||
.into_iter()
|
|
||||||
.map(|row| crate::storage::background_task::BackgroundTask {
|
|
||||||
id: row.get("id"),
|
|
||||||
session_id: row.get("session_id"),
|
|
||||||
channel: row.get("channel"),
|
|
||||||
chat_id: row.get("chat_id"),
|
|
||||||
prompt: row.get("prompt"),
|
|
||||||
allowed_tools: row.get("allowed_tools"),
|
|
||||||
status: row.get("status"),
|
|
||||||
result: row.get("result"),
|
|
||||||
error: row.get("error"),
|
|
||||||
tool_calls_count: row.get("tool_calls_count"),
|
|
||||||
iterations: row.get("iterations"),
|
|
||||||
started_at: row.get("started_at"),
|
|
||||||
finished_at: row.get("finished_at"),
|
|
||||||
created_at: row.get("created_at"),
|
|
||||||
})
|
|
||||||
.collect())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn cleanup_old_tasks(&self, ttl_ms: i64) -> Result<usize, StorageError> {
|
|
||||||
let cutoff = chrono::Utc::now().timestamp_millis() - ttl_ms;
|
|
||||||
let result = sqlx::query(
|
|
||||||
"DELETE FROM background_tasks WHERE status IN ('completed', 'failed', 'cancelled') AND finished_at IS NOT NULL AND finished_at < ?",
|
|
||||||
)
|
|
||||||
.bind(cutoff)
|
|
||||||
.execute(self.pool())
|
|
||||||
.await?;
|
|
||||||
Ok(result.rows_affected() as usize)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@ -1068,7 +844,6 @@ mod tests {
|
|||||||
last_active_at: 1000,
|
last_active_at: 1000,
|
||||||
message_count: 0,
|
message_count: 0,
|
||||||
routing_info: Some(r#"{"type":"cli"}"#.to_string()),
|
routing_info: Some(r#"{"type":"cli"}"#.to_string()),
|
||||||
archived_at: None,
|
|
||||||
deleted_at: None,
|
deleted_at: None,
|
||||||
last_consolidated_at: None,
|
last_consolidated_at: None,
|
||||||
last_compressed_message_at: None,
|
last_compressed_message_at: None,
|
||||||
@ -1105,7 +880,6 @@ mod tests {
|
|||||||
last_active_at: i as i64 * 1000,
|
last_active_at: i as i64 * 1000,
|
||||||
message_count: i,
|
message_count: i,
|
||||||
routing_info: None,
|
routing_info: None,
|
||||||
archived_at: None,
|
|
||||||
deleted_at: None,
|
deleted_at: None,
|
||||||
last_consolidated_at: None,
|
last_consolidated_at: None,
|
||||||
last_compressed_message_at: None,
|
last_compressed_message_at: None,
|
||||||
@ -1113,10 +887,7 @@ mod tests {
|
|||||||
storage.upsert_session(&meta).await.unwrap();
|
storage.upsert_session(&meta).await.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
let sessions = storage
|
let sessions = storage.list_sessions("cli_chat", "sid123", 10).await.unwrap();
|
||||||
.list_sessions("cli_chat", "sid123", 10, false)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert_eq!(sessions.len(), 5);
|
assert_eq!(sessions.len(), 5);
|
||||||
// 按 last_active_at DESC 排序
|
// 按 last_active_at DESC 排序
|
||||||
assert_eq!(sessions[0].dialog_id, "dialog4");
|
assert_eq!(sessions[0].dialog_id, "dialog4");
|
||||||
@ -1136,7 +907,6 @@ mod tests {
|
|||||||
last_active_at: 1000,
|
last_active_at: 1000,
|
||||||
message_count: 0,
|
message_count: 0,
|
||||||
routing_info: None,
|
routing_info: None,
|
||||||
archived_at: None,
|
|
||||||
deleted_at: None,
|
deleted_at: None,
|
||||||
last_consolidated_at: None,
|
last_consolidated_at: None,
|
||||||
last_compressed_message_at: None,
|
last_compressed_message_at: None,
|
||||||
@ -1164,7 +934,6 @@ mod tests {
|
|||||||
last_active_at: 1000,
|
last_active_at: 1000,
|
||||||
message_count: 0,
|
message_count: 0,
|
||||||
routing_info: None,
|
routing_info: None,
|
||||||
archived_at: None,
|
|
||||||
deleted_at: None,
|
deleted_at: None,
|
||||||
last_consolidated_at: None,
|
last_consolidated_at: None,
|
||||||
last_compressed_message_at: None,
|
last_compressed_message_at: None,
|
||||||
@ -1186,10 +955,7 @@ mod tests {
|
|||||||
created_at: 1000,
|
created_at: 1000,
|
||||||
};
|
};
|
||||||
|
|
||||||
let seq = storage
|
let seq = storage.append_message(&session_meta.id, &msg).await.unwrap();
|
||||||
.append_message(&session_meta.id, &msg)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert_eq!(seq, 1);
|
assert_eq!(seq, 1);
|
||||||
|
|
||||||
let loaded = storage.load_messages(&session_meta.id, 0).await.unwrap();
|
let loaded = storage.load_messages(&session_meta.id, 0).await.unwrap();
|
||||||
@ -1211,7 +977,6 @@ mod tests {
|
|||||||
last_active_at: 1000,
|
last_active_at: 1000,
|
||||||
message_count: 0,
|
message_count: 0,
|
||||||
routing_info: None,
|
routing_info: None,
|
||||||
archived_at: None,
|
|
||||||
deleted_at: None,
|
deleted_at: None,
|
||||||
last_consolidated_at: None,
|
last_consolidated_at: None,
|
||||||
last_compressed_message_at: None,
|
last_compressed_message_at: None,
|
||||||
|
|||||||
@ -165,11 +165,7 @@ impl crate::storage::Storage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Update next_run_at and last_run_at for a job.
|
/// Update next_run_at and last_run_at for a job.
|
||||||
pub async fn set_scheduled_job_next_run(
|
pub async fn set_scheduled_job_next_run(&self, id: &str, next_run_at: i64) -> anyhow::Result<()> {
|
||||||
&self,
|
|
||||||
id: &str,
|
|
||||||
next_run_at: i64,
|
|
||||||
) -> anyhow::Result<()> {
|
|
||||||
let now = now_ms();
|
let now = now_ms();
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
"UPDATE scheduled_jobs SET next_run_at = ?, last_run_at = ?, updated_at = ? WHERE id = ?",
|
"UPDATE scheduled_jobs SET next_run_at = ?, last_run_at = ?, updated_at = ? WHERE id = ?",
|
||||||
@ -335,9 +331,7 @@ mod tests {
|
|||||||
async fn setup_storage() -> Storage {
|
async fn setup_storage() -> Storage {
|
||||||
let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
|
let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
|
||||||
let storage = Storage { pool };
|
let storage = Storage { pool };
|
||||||
Storage::init_scheduler_schema(storage.pool())
|
Storage::init_scheduler_schema(storage.pool()).await.unwrap();
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
storage
|
storage
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -456,10 +450,7 @@ mod tests {
|
|||||||
updated_at: t,
|
updated_at: t,
|
||||||
};
|
};
|
||||||
storage.add_scheduled_job(&job).await.unwrap();
|
storage.add_scheduled_job(&job).await.unwrap();
|
||||||
storage
|
storage.set_scheduled_job_enabled("job-toggle", false).await.unwrap();
|
||||||
.set_scheduled_job_enabled("job-toggle", false)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
let got = storage.get_scheduled_job("job-toggle").await.unwrap();
|
let got = storage.get_scheduled_job("job-toggle").await.unwrap();
|
||||||
assert!(!got.enabled);
|
assert!(!got.enabled);
|
||||||
}
|
}
|
||||||
@ -470,55 +461,31 @@ mod tests {
|
|||||||
let t = now();
|
let t = now();
|
||||||
let jobs = vec![
|
let jobs = vec![
|
||||||
ScheduledJob {
|
ScheduledJob {
|
||||||
id: "due".into(),
|
id: "due".into(), name: "due".into(),
|
||||||
name: "due".into(),
|
schedule: Schedule::At { at: t }, prompt: "1".into(),
|
||||||
schedule: Schedule::At { at: t },
|
channel: "cli_chat".into(), chat_id: "c".into(),
|
||||||
prompt: "1".into(),
|
model: None, enabled: true, delete_after_run: false,
|
||||||
channel: "cli_chat".into(),
|
next_run_at: t - 1000, last_run_at: None,
|
||||||
chat_id: "c".into(),
|
last_status: None, last_error: None,
|
||||||
model: None,
|
created_at: t, updated_at: t,
|
||||||
enabled: true,
|
|
||||||
delete_after_run: false,
|
|
||||||
next_run_at: t - 1000,
|
|
||||||
last_run_at: None,
|
|
||||||
last_status: None,
|
|
||||||
last_error: None,
|
|
||||||
created_at: t,
|
|
||||||
updated_at: t,
|
|
||||||
},
|
},
|
||||||
ScheduledJob {
|
ScheduledJob {
|
||||||
id: "future".into(),
|
id: "future".into(), name: "future".into(),
|
||||||
name: "future".into(),
|
schedule: Schedule::At { at: t + 99999999 }, prompt: "2".into(),
|
||||||
schedule: Schedule::At { at: t + 99999999 },
|
channel: "cli_chat".into(), chat_id: "c".into(),
|
||||||
prompt: "2".into(),
|
model: None, enabled: true, delete_after_run: false,
|
||||||
channel: "cli_chat".into(),
|
next_run_at: t + 99999999, last_run_at: None,
|
||||||
chat_id: "c".into(),
|
last_status: None, last_error: None,
|
||||||
model: None,
|
created_at: t, updated_at: t,
|
||||||
enabled: true,
|
|
||||||
delete_after_run: false,
|
|
||||||
next_run_at: t + 99999999,
|
|
||||||
last_run_at: None,
|
|
||||||
last_status: None,
|
|
||||||
last_error: None,
|
|
||||||
created_at: t,
|
|
||||||
updated_at: t,
|
|
||||||
},
|
},
|
||||||
ScheduledJob {
|
ScheduledJob {
|
||||||
id: "disabled-due".into(),
|
id: "disabled-due".into(), name: "disabled due".into(),
|
||||||
name: "disabled due".into(),
|
schedule: Schedule::At { at: t }, prompt: "3".into(),
|
||||||
schedule: Schedule::At { at: t },
|
channel: "cli_chat".into(), chat_id: "c".into(),
|
||||||
prompt: "3".into(),
|
model: None, enabled: false, delete_after_run: false,
|
||||||
channel: "cli_chat".into(),
|
next_run_at: t - 1000, last_run_at: None,
|
||||||
chat_id: "c".into(),
|
last_status: None, last_error: None,
|
||||||
model: None,
|
created_at: t, updated_at: t,
|
||||||
enabled: false,
|
|
||||||
delete_after_run: false,
|
|
||||||
next_run_at: t - 1000,
|
|
||||||
last_run_at: None,
|
|
||||||
last_status: None,
|
|
||||||
last_error: None,
|
|
||||||
created_at: t,
|
|
||||||
updated_at: t,
|
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
for j in &jobs {
|
for j in &jobs {
|
||||||
@ -534,39 +501,24 @@ mod tests {
|
|||||||
let storage = setup_storage().await;
|
let storage = setup_storage().await;
|
||||||
let t = now();
|
let t = now();
|
||||||
let job = ScheduledJob {
|
let job = ScheduledJob {
|
||||||
id: "job-run".into(),
|
id: "job-run".into(), name: "run test".into(),
|
||||||
name: "run test".into(),
|
|
||||||
schedule: Schedule::Every { every_ms: 1000 },
|
schedule: Schedule::Every { every_ms: 1000 },
|
||||||
prompt: "hi".into(),
|
prompt: "hi".into(), channel: "cli_chat".into(), chat_id: "c".into(),
|
||||||
channel: "cli_chat".into(),
|
model: None, enabled: true, delete_after_run: false,
|
||||||
chat_id: "c".into(),
|
next_run_at: t, last_run_at: None,
|
||||||
model: None,
|
last_status: None, last_error: None,
|
||||||
enabled: true,
|
created_at: t, updated_at: t,
|
||||||
delete_after_run: false,
|
|
||||||
next_run_at: t,
|
|
||||||
last_run_at: None,
|
|
||||||
last_status: None,
|
|
||||||
last_error: None,
|
|
||||||
created_at: t,
|
|
||||||
updated_at: t,
|
|
||||||
};
|
};
|
||||||
storage.add_scheduled_job(&job).await.unwrap();
|
storage.add_scheduled_job(&job).await.unwrap();
|
||||||
|
|
||||||
let run = super::JobRun {
|
let run = super::JobRun {
|
||||||
id: 0,
|
id: 0, job_id: "job-run".into(),
|
||||||
job_id: "job-run".into(),
|
started_at: t, finished_at: t + 500,
|
||||||
started_at: t,
|
status: "ok".into(), output: Some("hello".into()),
|
||||||
finished_at: t + 500,
|
error: None, duration_ms: 500,
|
||||||
status: "ok".into(),
|
|
||||||
output: Some("hello".into()),
|
|
||||||
error: None,
|
|
||||||
duration_ms: 500,
|
|
||||||
};
|
};
|
||||||
storage.record_scheduled_job_run(&run).await.unwrap();
|
storage.record_scheduled_job_run(&run).await.unwrap();
|
||||||
let runs = storage
|
let runs = storage.list_scheduled_job_runs("job-run", 10).await.unwrap();
|
||||||
.list_scheduled_job_runs("job-run", 10)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert_eq!(runs.len(), 1);
|
assert_eq!(runs.len(), 1);
|
||||||
assert_eq!(runs[0].status, "ok");
|
assert_eq!(runs[0].status, "ok");
|
||||||
assert_eq!(runs[0].output.as_deref(), Some("hello"));
|
assert_eq!(runs[0].output.as_deref(), Some("hello"));
|
||||||
@ -577,34 +529,22 @@ mod tests {
|
|||||||
let storage = setup_storage().await;
|
let storage = setup_storage().await;
|
||||||
let t = now();
|
let t = now();
|
||||||
let job = ScheduledJob {
|
let job = ScheduledJob {
|
||||||
id: "job-update".into(),
|
id: "job-update".into(), name: "old name".into(),
|
||||||
name: "old name".into(),
|
|
||||||
schedule: Schedule::Every { every_ms: 1000 },
|
schedule: Schedule::Every { every_ms: 1000 },
|
||||||
prompt: "old prompt".into(),
|
prompt: "old prompt".into(), channel: "feishu".into(),
|
||||||
channel: "feishu".into(),
|
chat_id: "oc_1".into(), model: None,
|
||||||
chat_id: "oc_1".into(),
|
enabled: true, delete_after_run: false,
|
||||||
model: None,
|
next_run_at: t, last_run_at: None,
|
||||||
enabled: true,
|
last_status: None, last_error: None,
|
||||||
delete_after_run: false,
|
created_at: t, updated_at: t,
|
||||||
next_run_at: t,
|
|
||||||
last_run_at: None,
|
|
||||||
last_status: None,
|
|
||||||
last_error: None,
|
|
||||||
created_at: t,
|
|
||||||
updated_at: t,
|
|
||||||
};
|
};
|
||||||
storage.add_scheduled_job(&job).await.unwrap();
|
storage.add_scheduled_job(&job).await.unwrap();
|
||||||
storage
|
storage.update_scheduled_job(
|
||||||
.update_scheduled_job(
|
|
||||||
"job-update",
|
"job-update",
|
||||||
Some("new prompt".into()),
|
Some("new prompt".into()),
|
||||||
Some(Schedule::Every { every_ms: 60000 }),
|
Some(Schedule::Every { every_ms: 60000 }),
|
||||||
None,
|
None, None, None,
|
||||||
None,
|
).await.unwrap();
|
||||||
None,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
let got = storage.get_scheduled_job("job-update").await.unwrap();
|
let got = storage.get_scheduled_job("job-update").await.unwrap();
|
||||||
assert_eq!(got.prompt, "new prompt");
|
assert_eq!(got.prompt, "new prompt");
|
||||||
}
|
}
|
||||||
|
|||||||
@ -11,7 +11,6 @@ pub struct SessionMeta {
|
|||||||
pub last_active_at: i64,
|
pub last_active_at: i64,
|
||||||
pub message_count: i64,
|
pub message_count: i64,
|
||||||
pub routing_info: Option<String>,
|
pub routing_info: Option<String>,
|
||||||
pub archived_at: Option<i64>,
|
|
||||||
pub deleted_at: Option<i64>,
|
pub deleted_at: Option<i64>,
|
||||||
pub last_consolidated_at: Option<i64>,
|
pub last_consolidated_at: Option<i64>,
|
||||||
pub last_compressed_message_at: Option<i64>,
|
pub last_compressed_message_at: Option<i64>,
|
||||||
|
|||||||
@ -167,7 +167,10 @@ impl Tool for BashTool {
|
|||||||
Err(_) => Ok(ToolResult {
|
Err(_) => Ok(ToolResult {
|
||||||
success: false,
|
success: false,
|
||||||
output: String::new(),
|
output: String::new(),
|
||||||
error: Some(format!("Command timed out after {} seconds", timeout_secs)),
|
error: Some(format!(
|
||||||
|
"Command timed out after {} seconds",
|
||||||
|
timeout_secs
|
||||||
|
)),
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -246,7 +249,10 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_pwd_command() {
|
async fn test_pwd_command() {
|
||||||
let tool = BashTool::new();
|
let tool = BashTool::new();
|
||||||
let result = tool.execute(json!({ "command": "pwd" })).await.unwrap();
|
let result = tool
|
||||||
|
.execute(json!({ "command": "pwd" }))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert!(result.success);
|
assert!(result.success);
|
||||||
}
|
}
|
||||||
@ -254,10 +260,7 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_ls_command() {
|
async fn test_ls_command() {
|
||||||
let tool = BashTool::new();
|
let tool = BashTool::new();
|
||||||
let result = tool
|
let result = tool.execute(json!({ "command": "ls -la /tmp" })).await.unwrap();
|
||||||
.execute(json!({ "command": "ls -la /tmp" }))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert!(result.success);
|
assert!(result.success);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,7 +5,7 @@ use std::time::Duration;
|
|||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use base64::Engine;
|
use base64::Engine;
|
||||||
use fantoccini::actions::{InputSource, MOUSE_BUTTON_LEFT, MouseActions, PointerAction};
|
use fantoccini::actions::{InputSource, MouseActions, PointerAction, MOUSE_BUTTON_LEFT};
|
||||||
use fantoccini::key::Key;
|
use fantoccini::key::Key;
|
||||||
use fantoccini::{Client, ClientBuilder, Locator};
|
use fantoccini::{Client, ClientBuilder, Locator};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
@ -63,9 +63,7 @@ impl BrowserTool {
|
|||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub enum BrowserAction {
|
pub enum BrowserAction {
|
||||||
Open {
|
Open { url: String },
|
||||||
url: String,
|
|
||||||
},
|
|
||||||
Snapshot {
|
Snapshot {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
interactive_only: bool,
|
interactive_only: bool,
|
||||||
@ -74,20 +72,10 @@ pub enum BrowserAction {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
depth: Option<i64>,
|
depth: Option<i64>,
|
||||||
},
|
},
|
||||||
Click {
|
Click { selector: String },
|
||||||
selector: String,
|
Fill { selector: String, value: String },
|
||||||
},
|
Type { selector: Option<String>, text: String },
|
||||||
Fill {
|
GetText { selector: String },
|
||||||
selector: String,
|
|
||||||
value: String,
|
|
||||||
},
|
|
||||||
Type {
|
|
||||||
selector: Option<String>,
|
|
||||||
text: String,
|
|
||||||
},
|
|
||||||
GetText {
|
|
||||||
selector: String,
|
|
||||||
},
|
|
||||||
GetTitle,
|
GetTitle,
|
||||||
GetUrl,
|
GetUrl,
|
||||||
Screenshot {
|
Screenshot {
|
||||||
@ -96,9 +84,7 @@ pub enum BrowserAction {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
return_base64: bool,
|
return_base64: bool,
|
||||||
},
|
},
|
||||||
Focus {
|
Focus { selector: String },
|
||||||
selector: String,
|
|
||||||
},
|
|
||||||
Wait {
|
Wait {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
selector: Option<String>,
|
selector: Option<String>,
|
||||||
@ -107,16 +93,9 @@ pub enum BrowserAction {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
text: Option<String>,
|
text: Option<String>,
|
||||||
},
|
},
|
||||||
Press {
|
Press { key: String },
|
||||||
key: String,
|
Hover { selector: String },
|
||||||
},
|
ClickAt { x: u32, y: u32 },
|
||||||
Hover {
|
|
||||||
selector: String,
|
|
||||||
},
|
|
||||||
ClickAt {
|
|
||||||
x: u32,
|
|
||||||
y: u32,
|
|
||||||
},
|
|
||||||
Scroll {
|
Scroll {
|
||||||
direction: String,
|
direction: String,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
@ -141,8 +120,13 @@ fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result<Browse
|
|||||||
.get("interactive_only")
|
.get("interactive_only")
|
||||||
.and_then(Value::as_bool)
|
.and_then(Value::as_bool)
|
||||||
.unwrap_or(true),
|
.unwrap_or(true),
|
||||||
compact: args.get("compact").and_then(Value::as_bool).unwrap_or(true),
|
compact: args
|
||||||
depth: args.get("depth").and_then(|v| v.as_i64()),
|
.get("compact")
|
||||||
|
.and_then(Value::as_bool)
|
||||||
|
.unwrap_or(true),
|
||||||
|
depth: args
|
||||||
|
.get("depth")
|
||||||
|
.and_then(|v| v.as_i64()),
|
||||||
}),
|
}),
|
||||||
"click" => {
|
"click" => {
|
||||||
let selector = args
|
let selector = args
|
||||||
@ -214,7 +198,10 @@ fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result<Browse
|
|||||||
.and_then(|v| v.as_str())
|
.and_then(|v| v.as_str())
|
||||||
.map(String::from),
|
.map(String::from),
|
||||||
ms: args.get("ms").and_then(|v| v.as_u64()),
|
ms: args.get("ms").and_then(|v| v.as_u64()),
|
||||||
text: args.get("text").and_then(|v| v.as_str()).map(String::from),
|
text: args
|
||||||
|
.get("text")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.map(String::from),
|
||||||
}),
|
}),
|
||||||
"press" => {
|
"press" => {
|
||||||
let key = args
|
let key = args
|
||||||
@ -252,13 +239,11 @@ fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result<Browse
|
|||||||
let x = args
|
let x = args
|
||||||
.get("x")
|
.get("x")
|
||||||
.and_then(|v| v.as_u64())
|
.and_then(|v| v.as_u64())
|
||||||
.ok_or_else(|| anyhow::anyhow!("Missing 'x' for click_at"))?
|
.ok_or_else(|| anyhow::anyhow!("Missing 'x' for click_at"))? as u32;
|
||||||
as u32;
|
|
||||||
let y = args
|
let y = args
|
||||||
.get("y")
|
.get("y")
|
||||||
.and_then(|v| v.as_u64())
|
.and_then(|v| v.as_u64())
|
||||||
.ok_or_else(|| anyhow::anyhow!("Missing 'y' for click_at"))?
|
.ok_or_else(|| anyhow::anyhow!("Missing 'y' for click_at"))? as u32;
|
||||||
as u32;
|
|
||||||
Ok(BrowserAction::ClickAt { x, y })
|
Ok(BrowserAction::ClickAt { x, y })
|
||||||
}
|
}
|
||||||
other => anyhow::bail!("Unsupported browser action: {}", other),
|
other => anyhow::bail!("Unsupported browser action: {}", other),
|
||||||
@ -503,11 +488,7 @@ impl BrowserState {
|
|||||||
}
|
}
|
||||||
Err(e) => return Err(e.into()),
|
Err(e) => return Err(e.into()),
|
||||||
}
|
}
|
||||||
tracing::debug!(
|
tracing::debug!(action = "fill", output_len = value.len(), "Browser action completed");
|
||||||
action = "fill",
|
|
||||||
output_len = value.len(),
|
|
||||||
"Browser action completed"
|
|
||||||
);
|
|
||||||
Ok(ToolResult {
|
Ok(ToolResult {
|
||||||
success: true,
|
success: true,
|
||||||
output: format!("Filled {} with {}", selector, value),
|
output: format!("Filled {} with {}", selector, value),
|
||||||
@ -592,10 +573,7 @@ impl BrowserState {
|
|||||||
error: None,
|
error: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
BrowserAction::Screenshot {
|
BrowserAction::Screenshot { path, return_base64 } => {
|
||||||
path,
|
|
||||||
return_base64,
|
|
||||||
} => {
|
|
||||||
let client = self.active_client()?;
|
let client = self.active_client()?;
|
||||||
let png = client.screenshot().await?;
|
let png = client.screenshot().await?;
|
||||||
let save_path = path.unwrap_or_else(|| {
|
let save_path = path.unwrap_or_else(|| {
|
||||||
@ -610,25 +588,14 @@ impl BrowserState {
|
|||||||
tokio::fs::write(&save_path, &png).await?;
|
tokio::fs::write(&save_path, &png).await?;
|
||||||
if return_base64 {
|
if return_base64 {
|
||||||
let b64 = base64::engine::general_purpose::STANDARD.encode(&png);
|
let b64 = base64::engine::general_purpose::STANDARD.encode(&png);
|
||||||
tracing::debug!(
|
tracing::debug!(action = "screenshot", output_len = b64.len(), "Browser action completed");
|
||||||
action = "screenshot",
|
|
||||||
output_len = b64.len(),
|
|
||||||
"Browser action completed"
|
|
||||||
);
|
|
||||||
return Ok(ToolResult {
|
return Ok(ToolResult {
|
||||||
success: true,
|
success: true,
|
||||||
output: format!(
|
output: format!("Screenshot saved to {}. Base64: data:image/png;base64,{}", save_path, b64),
|
||||||
"Screenshot saved to {}. Base64: data:image/png;base64,{}",
|
|
||||||
save_path, b64
|
|
||||||
),
|
|
||||||
error: None,
|
error: None,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
tracing::debug!(
|
tracing::debug!(action = "screenshot", output_len = save_path.len(), "Browser action completed");
|
||||||
action = "screenshot",
|
|
||||||
output_len = save_path.len(),
|
|
||||||
"Browser action completed"
|
|
||||||
);
|
|
||||||
Ok(ToolResult {
|
Ok(ToolResult {
|
||||||
success: true,
|
success: true,
|
||||||
output: format!("Screenshot saved to {}", save_path),
|
output: format!("Screenshot saved to {}", save_path),
|
||||||
@ -644,18 +611,18 @@ impl BrowserState {
|
|||||||
vec![serde_json::to_value(el)?],
|
vec![serde_json::to_value(el)?],
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
tracing::debug!(
|
tracing::debug!(action = "focus", output_len = selector.len(), "Browser action completed");
|
||||||
action = "focus",
|
|
||||||
output_len = selector.len(),
|
|
||||||
"Browser action completed"
|
|
||||||
);
|
|
||||||
Ok(ToolResult {
|
Ok(ToolResult {
|
||||||
success: true,
|
success: true,
|
||||||
output: format!("Focused {}", selector),
|
output: format!("Focused {}", selector),
|
||||||
error: None,
|
error: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
BrowserAction::Wait { selector, ms, text } => {
|
BrowserAction::Wait {
|
||||||
|
selector,
|
||||||
|
ms,
|
||||||
|
text,
|
||||||
|
} => {
|
||||||
if let Some(sel) = selector {
|
if let Some(sel) = selector {
|
||||||
let client = self.active_client()?;
|
let client = self.active_client()?;
|
||||||
wait_for_selector(client, &sel).await?;
|
wait_for_selector(client, &sel).await?;
|
||||||
@ -752,21 +719,9 @@ impl BrowserState {
|
|||||||
let id = info.get("id").and_then(|v| v.as_str()).unwrap_or("");
|
let id = info.get("id").and_then(|v| v.as_str()).unwrap_or("");
|
||||||
let el_type = info.get("type").and_then(|v| v.as_str()).unwrap_or("");
|
let el_type = info.get("type").and_then(|v| v.as_str()).unwrap_or("");
|
||||||
let text = info.get("text").and_then(|v| v.as_str()).unwrap_or("");
|
let text = info.get("text").and_then(|v| v.as_str()).unwrap_or("");
|
||||||
let id_str = if id.is_empty() {
|
let id_str = if id.is_empty() { String::new() } else { format!("#{id}") };
|
||||||
String::new()
|
let type_str = if el_type.is_empty() { String::new() } else { format!("[type={el_type}]") };
|
||||||
} else {
|
let text_str = if text.is_empty() { String::new() } else { format!(" ({text})") };
|
||||||
format!("#{id}")
|
|
||||||
};
|
|
||||||
let type_str = if el_type.is_empty() {
|
|
||||||
String::new()
|
|
||||||
} else {
|
|
||||||
format!("[type={el_type}]")
|
|
||||||
};
|
|
||||||
let text_str = if text.is_empty() {
|
|
||||||
String::new()
|
|
||||||
} else {
|
|
||||||
format!(" ({text})")
|
|
||||||
};
|
|
||||||
format!("Clicked at ({x},{y}) on <{tag}{id_str}{type_str}>{text_str}")
|
format!("Clicked at ({x},{y}) on <{tag}{id_str}{type_str}>{text_str}")
|
||||||
}
|
}
|
||||||
None => format!("Clicked at ({}, {})", x, y),
|
None => format!("Clicked at ({}, {})", x, y),
|
||||||
@ -1135,7 +1090,10 @@ fn css_attr_escape(input: &str) -> String {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn xpath_contains_text(text: &str) -> String {
|
fn xpath_contains_text(text: &str) -> String {
|
||||||
format!("//*[contains(normalize-space(.), {})]", xpath_literal(text))
|
format!(
|
||||||
|
"//*[contains(normalize-space(.), {})]",
|
||||||
|
xpath_literal(text)
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn xpath_literal(input: &str) -> String {
|
fn xpath_literal(input: &str) -> String {
|
||||||
@ -1182,10 +1140,7 @@ fn webdriver_key(key: &str) -> String {
|
|||||||
"pagedown" => Key::PageDown.to_string(),
|
"pagedown" => Key::PageDown.to_string(),
|
||||||
"space" => " ".to_string(),
|
"space" => " ".to_string(),
|
||||||
other => {
|
other => {
|
||||||
tracing::warn!(
|
tracing::warn!("Unrecognized key '{}', this will have no effect (press only supports single named keys)", other);
|
||||||
"Unrecognized key '{}', this will have no effect (press only supports single named keys)",
|
|
||||||
other
|
|
||||||
);
|
|
||||||
other.to_string()
|
other.to_string()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -659,7 +659,10 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_evaluate_missing_expression() {
|
async fn test_evaluate_missing_expression() {
|
||||||
let tool = CalculatorTool::new();
|
let tool = CalculatorTool::new();
|
||||||
let result = tool.execute(json!({"function": "evaluate"})).await.unwrap();
|
let result = tool
|
||||||
|
.execute(json!({"function": "evaluate"}))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
assert!(!result.success);
|
assert!(!result.success);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -126,10 +126,7 @@ impl ChatManagerTool {
|
|||||||
let start_num = offset + 1;
|
let start_num = offset + 1;
|
||||||
let end_num = offset + sessions.len() as i64;
|
let end_num = offset + sessions.len() as i64;
|
||||||
|
|
||||||
let mut output = format!(
|
let mut output = format!("全部会话 (共 {} 个,第 {}-{} 个):\n", total, start_num, end_num);
|
||||||
"全部会话 (共 {} 个,第 {}-{} 个):\n",
|
|
||||||
total, start_num, end_num
|
|
||||||
);
|
|
||||||
|
|
||||||
for s in &sessions {
|
for s in &sessions {
|
||||||
let ago = format_duration_ago(now_ms - s.last_active_at);
|
let ago = format_duration_ago(now_ms - s.last_active_at);
|
||||||
@ -303,7 +300,6 @@ mod tests {
|
|||||||
last_active_at: now - i * 3600_000,
|
last_active_at: now - i * 3600_000,
|
||||||
message_count: i * 5,
|
message_count: i * 5,
|
||||||
routing_info: None,
|
routing_info: None,
|
||||||
archived_at: None,
|
|
||||||
deleted_at: None,
|
deleted_at: None,
|
||||||
last_consolidated_at: None,
|
last_consolidated_at: None,
|
||||||
last_compressed_message_at: None,
|
last_compressed_message_at: None,
|
||||||
@ -339,7 +335,6 @@ mod tests {
|
|||||||
last_active_at: now,
|
last_active_at: now,
|
||||||
message_count: 3,
|
message_count: 3,
|
||||||
routing_info: None,
|
routing_info: None,
|
||||||
archived_at: None,
|
|
||||||
deleted_at: None,
|
deleted_at: None,
|
||||||
last_consolidated_at: None,
|
last_consolidated_at: None,
|
||||||
last_compressed_message_at: None,
|
last_compressed_message_at: None,
|
||||||
@ -351,11 +346,7 @@ mod tests {
|
|||||||
id: format!("msg{}", i),
|
id: format!("msg{}", i),
|
||||||
session_id: session_id.to_string(),
|
session_id: session_id.to_string(),
|
||||||
seq: i as i64 + 1,
|
seq: i as i64 + 1,
|
||||||
role: if i == 0 {
|
role: if i == 0 { "user".to_string() } else { "assistant".to_string() },
|
||||||
"user".to_string()
|
|
||||||
} else {
|
|
||||||
"assistant".to_string()
|
|
||||||
},
|
|
||||||
content: format!("消息内容 {}", i),
|
content: format!("消息内容 {}", i),
|
||||||
reasoning_content: None,
|
reasoning_content: None,
|
||||||
media_refs: None,
|
media_refs: None,
|
||||||
@ -401,7 +392,6 @@ mod tests {
|
|||||||
last_active_at: now,
|
last_active_at: now,
|
||||||
message_count: 5,
|
message_count: 5,
|
||||||
routing_info: None,
|
routing_info: None,
|
||||||
archived_at: None,
|
|
||||||
deleted_at: None,
|
deleted_at: None,
|
||||||
last_consolidated_at: None,
|
last_consolidated_at: None,
|
||||||
last_compressed_message_at: None,
|
last_compressed_message_at: None,
|
||||||
@ -413,11 +403,7 @@ mod tests {
|
|||||||
id: format!("msg{}", i),
|
id: format!("msg{}", i),
|
||||||
session_id: session_id.to_string(),
|
session_id: session_id.to_string(),
|
||||||
seq: i as i64 + 1,
|
seq: i as i64 + 1,
|
||||||
role: if i % 2 == 0 {
|
role: if i % 2 == 0 { "user".to_string() } else { "assistant".to_string() },
|
||||||
"user".to_string()
|
|
||||||
} else {
|
|
||||||
"assistant".to_string()
|
|
||||||
},
|
|
||||||
content: format!("消息内容 {}", i),
|
content: format!("消息内容 {}", i),
|
||||||
reasoning_content: None,
|
reasoning_content: None,
|
||||||
media_refs: None,
|
media_refs: None,
|
||||||
@ -461,7 +447,6 @@ mod tests {
|
|||||||
last_active_at: now,
|
last_active_at: now,
|
||||||
message_count: 5,
|
message_count: 5,
|
||||||
routing_info: None,
|
routing_info: None,
|
||||||
archived_at: None,
|
|
||||||
deleted_at: None,
|
deleted_at: None,
|
||||||
last_consolidated_at: None,
|
last_consolidated_at: None,
|
||||||
last_compressed_message_at: None,
|
last_compressed_message_at: None,
|
||||||
@ -507,7 +492,10 @@ mod tests {
|
|||||||
let (storage, _dir) = create_test_storage().await;
|
let (storage, _dir) = create_test_storage().await;
|
||||||
let tool = ChatManagerTool::new(storage, vec![]);
|
let tool = ChatManagerTool::new(storage, vec![]);
|
||||||
|
|
||||||
let result = tool.execute(json!({ "action": "unknown" })).await.unwrap();
|
let result = tool
|
||||||
|
.execute(json!({ "action": "unknown" }))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
assert!(!result.success);
|
assert!(!result.success);
|
||||||
assert!(result.error.unwrap().contains("Unknown action"));
|
assert!(result.error.unwrap().contains("Unknown action"));
|
||||||
}
|
}
|
||||||
|
|||||||
@ -31,7 +31,10 @@ impl ContentSearchTool {
|
|||||||
for (i, line) in lines.iter().enumerate() {
|
for (i, line) in lines.iter().enumerate() {
|
||||||
if output.len() + line.len() + 1 > MAX_OUTPUT_CHARS {
|
if output.len() + line.len() + 1 > MAX_OUTPUT_CHARS {
|
||||||
let omitted = lines.len() - i;
|
let omitted = lines.len() - i;
|
||||||
output.push_str(&format!("\n... ({} matches omitted) ...", omitted));
|
output.push_str(&format!(
|
||||||
|
"\n... ({} matches omitted) ...",
|
||||||
|
omitted
|
||||||
|
));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if !output.is_empty() {
|
if !output.is_empty() {
|
||||||
@ -110,40 +113,18 @@ impl Tool for ContentSearchTool {
|
|||||||
|
|
||||||
let dir = self.resolve_dir(args.get("dir").and_then(|v| v.as_str()));
|
let dir = self.resolve_dir(args.get("dir").and_then(|v| v.as_str()));
|
||||||
let file_pattern = args.get("file_pattern").and_then(|v| v.as_str());
|
let file_pattern = args.get("file_pattern").and_then(|v| v.as_str());
|
||||||
let case_sensitive = args
|
let case_sensitive = args.get("case_sensitive").and_then(|v| v.as_bool()).unwrap_or(false);
|
||||||
.get("case_sensitive")
|
let context_lines = args.get("context_lines").and_then(|v| v.as_u64()).unwrap_or(0) as usize;
|
||||||
.and_then(|v| v.as_bool())
|
let max_results = args.get("max_results").and_then(|v| v.as_u64()).unwrap_or(MAX_RESULTS as u64) as usize;
|
||||||
.unwrap_or(false);
|
|
||||||
let context_lines = args
|
|
||||||
.get("context_lines")
|
|
||||||
.and_then(|v| v.as_u64())
|
|
||||||
.unwrap_or(0) as usize;
|
|
||||||
let max_results = args
|
|
||||||
.get("max_results")
|
|
||||||
.and_then(|v| v.as_u64())
|
|
||||||
.unwrap_or(MAX_RESULTS as u64) as usize;
|
|
||||||
|
|
||||||
let result = self
|
let result = self.run_search(pattern, &dir, file_pattern, case_sensitive, context_lines, max_results).await;
|
||||||
.run_search(
|
|
||||||
pattern,
|
|
||||||
&dir,
|
|
||||||
file_pattern,
|
|
||||||
case_sensitive,
|
|
||||||
context_lines,
|
|
||||||
max_results,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
Ok(lines) => {
|
Ok(lines) => {
|
||||||
let count = lines.len();
|
let count = lines.len();
|
||||||
let mut output = self.truncate_output(&lines);
|
let mut output = self.truncate_output(&lines);
|
||||||
output.push_str(&format!("\n\n---\n共 {} 条匹配", count));
|
output.push_str(&format!("\n\n---\n共 {} 条匹配", count));
|
||||||
Ok(ToolResult {
|
Ok(ToolResult { success: true, output, error: None })
|
||||||
success: true,
|
|
||||||
output,
|
|
||||||
error: None,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
Err(e) => Ok(ToolResult {
|
Err(e) => Ok(ToolResult {
|
||||||
success: false,
|
success: false,
|
||||||
@ -165,52 +146,22 @@ impl ContentSearchTool {
|
|||||||
max_results: usize,
|
max_results: usize,
|
||||||
) -> anyhow::Result<Vec<String>> {
|
) -> anyhow::Result<Vec<String>> {
|
||||||
if which::which("rg").is_ok() {
|
if which::which("rg").is_ok() {
|
||||||
match self
|
match self.search_with_rg(pattern, dir, file_pattern, case_sensitive, context_lines, max_results).await {
|
||||||
.search_with_rg(
|
|
||||||
pattern,
|
|
||||||
dir,
|
|
||||||
file_pattern,
|
|
||||||
case_sensitive,
|
|
||||||
context_lines,
|
|
||||||
max_results,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(lines) => return Ok(lines),
|
Ok(lines) => return Ok(lines),
|
||||||
Err(e) => tracing::warn!("rg failed: {}, falling back", e),
|
Err(e) => tracing::warn!("rg failed: {}, falling back", e),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if which::which("grep").is_ok() {
|
if which::which("grep").is_ok() {
|
||||||
match self
|
match self.search_with_grep(pattern, dir, file_pattern, case_sensitive, context_lines, max_results).await {
|
||||||
.search_with_grep(
|
|
||||||
pattern,
|
|
||||||
dir,
|
|
||||||
file_pattern,
|
|
||||||
case_sensitive,
|
|
||||||
context_lines,
|
|
||||||
max_results,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(lines) if !lines.is_empty() => return Ok(lines),
|
Ok(lines) if !lines.is_empty() => return Ok(lines),
|
||||||
Ok(_) => {}
|
Ok(_) => {},
|
||||||
Err(e) => tracing::warn!("grep failed: {}, falling back", e),
|
Err(e) => tracing::warn!("grep failed: {}, falling back", e),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tracing::warn!(
|
tracing::warn!("No rg/grep available, using built-in content search (much slower). Install ripgrep for better performance.");
|
||||||
"No rg/grep available, using built-in content search (much slower). Install ripgrep for better performance."
|
self.search_with_rust(pattern, dir, file_pattern, case_sensitive, context_lines, max_results).await
|
||||||
);
|
|
||||||
self.search_with_rust(
|
|
||||||
pattern,
|
|
||||||
dir,
|
|
||||||
file_pattern,
|
|
||||||
case_sensitive,
|
|
||||||
context_lines,
|
|
||||||
max_results,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn search_with_rg(
|
async fn search_with_rg(
|
||||||
@ -225,10 +176,8 @@ impl ContentSearchTool {
|
|||||||
let mut cmd = Command::new("rg");
|
let mut cmd = Command::new("rg");
|
||||||
cmd.arg("-n")
|
cmd.arg("-n")
|
||||||
.arg("--no-heading")
|
.arg("--no-heading")
|
||||||
.arg("--color")
|
.arg("--color").arg("never")
|
||||||
.arg("never")
|
.arg("--max-count").arg(max_results.to_string())
|
||||||
.arg("--max-count")
|
|
||||||
.arg(max_results.to_string())
|
|
||||||
.arg(pattern)
|
.arg(pattern)
|
||||||
.arg(dir)
|
.arg(dir)
|
||||||
.stdout(Stdio::piped())
|
.stdout(Stdio::piped())
|
||||||
@ -244,7 +193,10 @@ impl ContentSearchTool {
|
|||||||
cmd.arg("--glob").arg(fp);
|
cmd.arg("--glob").arg(fp);
|
||||||
}
|
}
|
||||||
|
|
||||||
let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
|
let output = timeout(
|
||||||
|
std::time::Duration::from_secs(TIMEOUT_SECS),
|
||||||
|
cmd.output(),
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.map_err(|_| anyhow::anyhow!("rg timed out after {}s", TIMEOUT_SECS))??;
|
.map_err(|_| anyhow::anyhow!("rg timed out after {}s", TIMEOUT_SECS))??;
|
||||||
|
|
||||||
@ -254,8 +206,7 @@ impl ContentSearchTool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let text = String::from_utf8_lossy(&output.stdout);
|
let text = String::from_utf8_lossy(&output.stdout);
|
||||||
let lines: Vec<String> = text
|
let lines: Vec<String> = text.lines()
|
||||||
.lines()
|
|
||||||
.take(max_results)
|
.take(max_results)
|
||||||
.map(|l| l.to_string())
|
.map(|l| l.to_string())
|
||||||
.collect();
|
.collect();
|
||||||
@ -291,13 +242,15 @@ impl ContentSearchTool {
|
|||||||
cmd.arg("--include").arg(fp);
|
cmd.arg("--include").arg(fp);
|
||||||
}
|
}
|
||||||
|
|
||||||
let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
|
let output = timeout(
|
||||||
|
std::time::Duration::from_secs(TIMEOUT_SECS),
|
||||||
|
cmd.output(),
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.map_err(|_| anyhow::anyhow!("grep timed out after {}s", TIMEOUT_SECS))??;
|
.map_err(|_| anyhow::anyhow!("grep timed out after {}s", TIMEOUT_SECS))??;
|
||||||
|
|
||||||
let text = String::from_utf8_lossy(&output.stdout);
|
let text = String::from_utf8_lossy(&output.stdout);
|
||||||
let lines: Vec<String> = text
|
let lines: Vec<String> = text.lines()
|
||||||
.lines()
|
|
||||||
.take(max_results)
|
.take(max_results)
|
||||||
.map(|l| l.to_string())
|
.map(|l| l.to_string())
|
||||||
.collect();
|
.collect();
|
||||||
@ -327,9 +280,7 @@ impl ContentSearchTool {
|
|||||||
if case_sensitive {
|
if case_sensitive {
|
||||||
regex::Regex::new(&re_str)
|
regex::Regex::new(&re_str)
|
||||||
} else {
|
} else {
|
||||||
regex::RegexBuilder::new(&re_str)
|
regex::RegexBuilder::new(&re_str).case_insensitive(true).build()
|
||||||
.case_insensitive(true)
|
|
||||||
.build()
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -340,14 +291,7 @@ impl ContentSearchTool {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let mut results = Vec::new();
|
let mut results = Vec::new();
|
||||||
grep_dir(
|
grep_dir(Path::new(dir), Path::new(dir), &re, file_re.as_ref(), &mut results, max_results)?;
|
||||||
Path::new(dir),
|
|
||||||
Path::new(dir),
|
|
||||||
&re,
|
|
||||||
file_re.as_ref(),
|
|
||||||
&mut results,
|
|
||||||
max_results,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(results)
|
Ok(results)
|
||||||
}
|
}
|
||||||
@ -406,17 +350,14 @@ fn grep_dir(
|
|||||||
|
|
||||||
if path.is_dir() {
|
if path.is_dir() {
|
||||||
if let Some(name) = rel.file_name().and_then(|n| n.to_str())
|
if let Some(name) = rel.file_name().and_then(|n| n.to_str())
|
||||||
&& name.starts_with('.')
|
&& name.starts_with('.') && name.len() > 1 {
|
||||||
&& name.len() > 1
|
|
||||||
{
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
grep_dir(base, &path, re, file_re, results, max)?;
|
grep_dir(base, &path, re, file_re, results, max)?;
|
||||||
} else if path.is_file() {
|
} else if path.is_file() {
|
||||||
if let Some(file_re) = file_re
|
if let Some(file_re) = file_re
|
||||||
&& let Some(name) = rel.file_name().and_then(|n| n.to_str())
|
&& let Some(name) = rel.file_name().and_then(|n| n.to_str())
|
||||||
&& !file_re.is_match(name)
|
&& !file_re.is_match(name) {
|
||||||
{
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -450,16 +391,8 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_content_search_rust_fallback() {
|
async fn test_content_search_rust_fallback() {
|
||||||
let dir = TempDir::new().unwrap();
|
let dir = TempDir::new().unwrap();
|
||||||
fs::write(
|
fs::write(dir.path().join("main.rs"), "fn main() {\n let x = 42;\n println!(\"hello\");\n}").unwrap();
|
||||||
dir.path().join("main.rs"),
|
fs::write(dir.path().join("lib.rs"), "pub fn foo() -> u32 {\n let y = 42;\n y\n}").unwrap();
|
||||||
"fn main() {\n let x = 42;\n println!(\"hello\");\n}",
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
fs::write(
|
|
||||||
dir.path().join("lib.rs"),
|
|
||||||
"pub fn foo() -> u32 {\n let y = 42;\n y\n}",
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
fs::write(dir.path().join("README.md"), "# Project\nHello world").unwrap();
|
fs::write(dir.path().join("README.md"), "# Project\nHello world").unwrap();
|
||||||
|
|
||||||
let tool = ContentSearchTool::new();
|
let tool = ContentSearchTool::new();
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde_json::{Value, json};
|
use serde_json::{json, Value};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::scheduler::{Schedule, next_run_for_schedule};
|
use crate::scheduler::{next_run_for_schedule, Schedule};
|
||||||
use crate::storage::{ScheduledJob, Storage};
|
use crate::storage::{ScheduledJob, Storage};
|
||||||
use crate::tools::traits::{Tool, ToolResult};
|
use crate::tools::traits::{Tool, ToolResult};
|
||||||
|
|
||||||
@ -229,7 +229,10 @@ impl Tool for CronListTool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn execute(&self, args: Value) -> anyhow::Result<ToolResult> {
|
async fn execute(&self, args: Value) -> anyhow::Result<ToolResult> {
|
||||||
let filter = args.get("status").and_then(|v| v.as_str()).unwrap_or("all");
|
let filter = args
|
||||||
|
.get("status")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("all");
|
||||||
let jobs = self.storage.list_scheduled_jobs().await?;
|
let jobs = self.storage.list_scheduled_jobs().await?;
|
||||||
|
|
||||||
let filtered: Vec<&ScheduledJob> = match filter {
|
let filtered: Vec<&ScheduledJob> = match filter {
|
||||||
@ -394,9 +397,7 @@ impl Tool for CronEnableTool {
|
|||||||
.map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?;
|
.map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?;
|
||||||
|
|
||||||
let next = next_run_for_schedule(&job.schedule, now_ms());
|
let next = next_run_for_schedule(&job.schedule, now_ms());
|
||||||
self.storage
|
self.storage.set_scheduled_job_enabled(&job_id, true).await?;
|
||||||
.set_scheduled_job_enabled(&job_id, true)
|
|
||||||
.await?;
|
|
||||||
if let Some(n) = next {
|
if let Some(n) = next {
|
||||||
self.storage.set_scheduled_job_next_run(&job_id, n).await?;
|
self.storage.set_scheduled_job_next_run(&job_id, n).await?;
|
||||||
}
|
}
|
||||||
@ -463,9 +464,7 @@ impl Tool for CronDisableTool {
|
|||||||
.get_scheduled_job(&job_id)
|
.get_scheduled_job(&job_id)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?;
|
.map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?;
|
||||||
self.storage
|
self.storage.set_scheduled_job_enabled(&job_id, false).await?;
|
||||||
.set_scheduled_job_enabled(&job_id, false)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(ToolResult {
|
Ok(ToolResult {
|
||||||
success: true,
|
success: true,
|
||||||
@ -581,9 +580,7 @@ impl Tool for CronUpdateTool {
|
|||||||
if args.get("schedule").is_some() {
|
if args.get("schedule").is_some() {
|
||||||
let job = self.storage.get_scheduled_job(&job_id).await?;
|
let job = self.storage.get_scheduled_job(&job_id).await?;
|
||||||
if let Some(next) = next_run_for_schedule(&job.schedule, now_ms()) {
|
if let Some(next) = next_run_for_schedule(&job.schedule, now_ms()) {
|
||||||
self.storage
|
self.storage.set_scheduled_job_next_run(&job_id, next).await?;
|
||||||
.set_scheduled_job_next_run(&job_id, next)
|
|
||||||
.await?;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -768,7 +765,9 @@ mod tests {
|
|||||||
let job = ScheduledJob {
|
let job = ScheduledJob {
|
||||||
id: "job-update-tool".into(),
|
id: "job-update-tool".into(),
|
||||||
name: "old".into(),
|
name: "old".into(),
|
||||||
schedule: Schedule::Every { every_ms: 3600000 },
|
schedule: Schedule::Every {
|
||||||
|
every_ms: 3600000,
|
||||||
|
},
|
||||||
prompt: "old prompt".into(),
|
prompt: "old prompt".into(),
|
||||||
channel: "feishu".into(),
|
channel: "feishu".into(),
|
||||||
chat_id: "oc_1".into(),
|
chat_id: "oc_1".into(),
|
||||||
|
|||||||
@ -1,390 +0,0 @@
|
|||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use serde_json::json;
|
|
||||||
|
|
||||||
use crate::agent::{ExecutionMode, SubAgentConfig, SubAgentManager, TaskStatus};
|
|
||||||
use crate::tools::traits::{Tool, ToolResult};
|
|
||||||
|
|
||||||
pub struct DelegateTool {
|
|
||||||
sub_agent_manager: Arc<SubAgentManager>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DelegateTool {
|
|
||||||
pub fn new(sub_agent_manager: Arc<SubAgentManager>) -> Self {
|
|
||||||
Self { sub_agent_manager }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl Tool for DelegateTool {
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
"delegate"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn description(&self) -> &str {
|
|
||||||
"子任务委托工具。创建子 Agent 处理独立任务,支持三种模式:\
|
|
||||||
inline (阻塞返回结果)、background (异步执行,完成后通知)、\
|
|
||||||
parallel (多个子 Agent 并发执行,聚合结果)。\
|
|
||||||
也可用于查询、取消和列出后台任务。"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parameters_schema(&self) -> serde_json::Value {
|
|
||||||
json!({
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"action": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": ["run", "check_task", "cancel_task", "list_tasks"],
|
|
||||||
"description": "操作类型: run 创建子Agent执行任务, check_task 查询后台任务, cancel_task 取消后台任务, list_tasks 列出后台任务"
|
|
||||||
},
|
|
||||||
"prompt": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "子任务描述(可内含额外约束,如:跳过 .tmp 文件)。action=run 时必填"
|
|
||||||
},
|
|
||||||
"mode": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": ["inline", "background", "parallel"],
|
|
||||||
"description": "执行模式: inline=阻塞返回结果, background=异步执行+通知, parallel=多子Agent并发。默认 inline"
|
|
||||||
},
|
|
||||||
"allowed_tools": {
|
|
||||||
"type": "array",
|
|
||||||
"items": { "type": "string" },
|
|
||||||
"description": "允许子Agent使用的工具列表。不填使用默认只读集: file_read,file_search,content_search,web_fetch,http_request,calculator"
|
|
||||||
},
|
|
||||||
"max_iterations": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "最大迭代次数,默认 99"
|
|
||||||
},
|
|
||||||
"timeout_secs": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "超时秒数,默认 3600(1小时)"
|
|
||||||
},
|
|
||||||
"tasks": {
|
|
||||||
"type": "array",
|
|
||||||
"description": "并行模式下的多个子任务(仅 mode=parallel 时使用)",
|
|
||||||
"items": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"prompt": { "type": "string", "description": "子任务描述" },
|
|
||||||
"allowed_tools": {
|
|
||||||
"type": "array",
|
|
||||||
"items": { "type": "string" },
|
|
||||||
"description": "该子任务的工具列表"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["prompt"]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"task_id": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "后台任务ID(action=check_task/cancel_task 时必填)"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["action"]
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn read_only(&self) -> bool {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
|
||||||
let action = args["action"]
|
|
||||||
.as_str()
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("missing required parameter: action"))?;
|
|
||||||
|
|
||||||
match action {
|
|
||||||
"run" => self.handle_run(&args).await,
|
|
||||||
"check_task" => self.handle_check_task(&args).await,
|
|
||||||
"cancel_task" => self.handle_cancel_task(&args).await,
|
|
||||||
"list_tasks" => self.handle_list_tasks(&args).await,
|
|
||||||
_ => Ok(ToolResult {
|
|
||||||
success: false,
|
|
||||||
output: String::new(),
|
|
||||||
error: Some(format!(
|
|
||||||
"Unknown action: {}. Supported: run, check_task, cancel_task, list_tasks",
|
|
||||||
action
|
|
||||||
)),
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DelegateTool {
|
|
||||||
fn parse_config_from_args(&self, args: &serde_json::Value) -> anyhow::Result<SubAgentConfig> {
|
|
||||||
let prompt = args["prompt"]
|
|
||||||
.as_str()
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("missing required parameter: prompt"))?
|
|
||||||
.to_string();
|
|
||||||
|
|
||||||
let allowed_tools: Option<Vec<String>> = args["allowed_tools"].as_array().map(|arr| {
|
|
||||||
arr.iter()
|
|
||||||
.filter_map(|v| v.as_str().map(|s| s.to_string()))
|
|
||||||
.collect()
|
|
||||||
});
|
|
||||||
|
|
||||||
let max_iterations = args["max_iterations"].as_u64().map(|v| v as usize);
|
|
||||||
let timeout_secs = args["timeout_secs"].as_u64();
|
|
||||||
|
|
||||||
Ok(SubAgentConfig {
|
|
||||||
prompt,
|
|
||||||
mode: ExecutionMode::Inline,
|
|
||||||
allowed_tools,
|
|
||||||
max_iterations,
|
|
||||||
timeout_secs,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn handle_run(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
|
||||||
let mode_str = args["mode"].as_str().unwrap_or("inline");
|
|
||||||
let mode = match mode_str {
|
|
||||||
"inline" => ExecutionMode::Inline,
|
|
||||||
"background" => ExecutionMode::Background,
|
|
||||||
"parallel" => ExecutionMode::Parallel,
|
|
||||||
_ => {
|
|
||||||
return Ok(ToolResult {
|
|
||||||
success: false,
|
|
||||||
output: String::new(),
|
|
||||||
error: Some(format!(
|
|
||||||
"unknown mode: {}. Supported: inline, background, parallel",
|
|
||||||
mode_str
|
|
||||||
)),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
match mode {
|
|
||||||
ExecutionMode::Inline => {
|
|
||||||
let config = self.parse_config_from_args(args)?;
|
|
||||||
let result = self
|
|
||||||
.sub_agent_manager
|
|
||||||
.run_inline(config)
|
|
||||||
.await
|
|
||||||
.map_err(|e| anyhow::anyhow!("{}", e))?;
|
|
||||||
|
|
||||||
match result.status {
|
|
||||||
TaskStatus::Completed => Ok(ToolResult {
|
|
||||||
success: true,
|
|
||||||
output: result.content,
|
|
||||||
error: None,
|
|
||||||
}),
|
|
||||||
TaskStatus::Failed(err) => Ok(ToolResult {
|
|
||||||
success: false,
|
|
||||||
output: result.content,
|
|
||||||
error: Some(err),
|
|
||||||
}),
|
|
||||||
TaskStatus::TimedOut => Ok(ToolResult {
|
|
||||||
success: false,
|
|
||||||
output: result.content,
|
|
||||||
error: Some("sub-agent timed out".into()),
|
|
||||||
}),
|
|
||||||
TaskStatus::Cancelled => Ok(ToolResult {
|
|
||||||
success: false,
|
|
||||||
output: result.content,
|
|
||||||
error: Some("sub-agent cancelled".into()),
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ExecutionMode::Background => {
|
|
||||||
let config = self.parse_config_from_args(args)?;
|
|
||||||
let ctx = crate::agent::sub_agent::get_delegate_context().map_err(|_| {
|
|
||||||
anyhow::anyhow!("delegate context not available: not in an agent worker")
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let task_id = self
|
|
||||||
.sub_agent_manager
|
|
||||||
.run_background(config, ctx)
|
|
||||||
.await
|
|
||||||
.map_err(|e| anyhow::anyhow!("{}", e))?;
|
|
||||||
|
|
||||||
Ok(ToolResult {
|
|
||||||
success: true,
|
|
||||||
output: format!("后台任务已启动。\ntask_id: {}", task_id),
|
|
||||||
error: None,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
ExecutionMode::Parallel => {
|
|
||||||
let tasks = args["tasks"]
|
|
||||||
.as_array()
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("parallel mode requires 'tasks' array"))?;
|
|
||||||
|
|
||||||
let mut configs = Vec::new();
|
|
||||||
for task in tasks {
|
|
||||||
let prompt = task["prompt"]
|
|
||||||
.as_str()
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("each parallel task requires 'prompt'"))?
|
|
||||||
.to_string();
|
|
||||||
let allowed_tools: Option<Vec<String>> =
|
|
||||||
task["allowed_tools"].as_array().map(|arr| {
|
|
||||||
arr.iter()
|
|
||||||
.filter_map(|v| v.as_str().map(|s| s.to_string()))
|
|
||||||
.collect()
|
|
||||||
});
|
|
||||||
|
|
||||||
configs.push(SubAgentConfig {
|
|
||||||
prompt,
|
|
||||||
mode: ExecutionMode::Inline,
|
|
||||||
allowed_tools,
|
|
||||||
max_iterations: args["max_iterations"].as_u64().map(|v| v as usize),
|
|
||||||
timeout_secs: args["timeout_secs"].as_u64(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
let has_args_allowed = args["allowed_tools"].as_array().is_some();
|
|
||||||
for c in &mut configs {
|
|
||||||
if c.allowed_tools.is_none() && has_args_allowed {
|
|
||||||
c.allowed_tools = args["allowed_tools"].as_array().map(|arr| {
|
|
||||||
arr.iter()
|
|
||||||
.filter_map(|v| v.as_str().map(|s| s.to_string()))
|
|
||||||
.collect()
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let results = self
|
|
||||||
.sub_agent_manager
|
|
||||||
.run_parallel(configs)
|
|
||||||
.await
|
|
||||||
.map_err(|e| anyhow::anyhow!("{}", e))?;
|
|
||||||
|
|
||||||
let mut output = String::new();
|
|
||||||
for (i, r) in results.iter().enumerate() {
|
|
||||||
let status_icon = match r.status {
|
|
||||||
TaskStatus::Completed => "✅",
|
|
||||||
TaskStatus::Failed(_) => "❌",
|
|
||||||
TaskStatus::TimedOut => "⏱️ 超时",
|
|
||||||
TaskStatus::Cancelled => "🚫 已取消",
|
|
||||||
};
|
|
||||||
output.push_str(&format!("[task_{}] {}\n", i + 1, status_icon));
|
|
||||||
if !r.content.is_empty() {
|
|
||||||
output.push_str(&r.content);
|
|
||||||
output.push_str("\n\n");
|
|
||||||
}
|
|
||||||
if let TaskStatus::Failed(ref err) = r.status {
|
|
||||||
output.push_str(&format!("错误: {}\n\n", err));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let all_success = results
|
|
||||||
.iter()
|
|
||||||
.all(|r| matches!(r.status, TaskStatus::Completed));
|
|
||||||
Ok(ToolResult {
|
|
||||||
success: all_success,
|
|
||||||
output: output.trim().to_string(),
|
|
||||||
error: None,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn handle_check_task(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
|
||||||
let task_id = args["task_id"]
|
|
||||||
.as_str()
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("missing required parameter: task_id"))?;
|
|
||||||
|
|
||||||
match self.sub_agent_manager.check_task(task_id).await {
|
|
||||||
Some(task) => {
|
|
||||||
let status_icon = match task.status.as_str() {
|
|
||||||
"completed" => "✅ 已完成",
|
|
||||||
"failed" => "❌ 失败",
|
|
||||||
"cancelled" => "🚫 已取消",
|
|
||||||
"running" => "🔄 运行中",
|
|
||||||
"pending" => "⏳ 等待中",
|
|
||||||
_ => task.status.as_str(),
|
|
||||||
};
|
|
||||||
let mut output = format!(
|
|
||||||
"任务 ID: {}\n状态: {}\n任务: {}",
|
|
||||||
task.id, status_icon, task.prompt
|
|
||||||
);
|
|
||||||
if let Some(ref result) = task.result {
|
|
||||||
output.push_str(&format!("\n\n结果:\n{}", result));
|
|
||||||
}
|
|
||||||
if let Some(ref error) = task.error {
|
|
||||||
output.push_str(&format!("\n错误: {}", error));
|
|
||||||
}
|
|
||||||
if let Some(started) = task.started_at {
|
|
||||||
if let Some(finished) = task.finished_at {
|
|
||||||
let duration = (finished - started) as f64 / 1000.0;
|
|
||||||
output.push_str(&format!("\n耗时: {:.1}s", duration));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(ToolResult {
|
|
||||||
success: true,
|
|
||||||
output,
|
|
||||||
error: None,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
None => Ok(ToolResult {
|
|
||||||
success: false,
|
|
||||||
output: String::new(),
|
|
||||||
error: Some(format!("task not found: {}", task_id)),
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn handle_cancel_task(&self, args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
|
||||||
let task_id = args["task_id"]
|
|
||||||
.as_str()
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("missing required parameter: task_id"))?;
|
|
||||||
|
|
||||||
match self.sub_agent_manager.cancel_task(task_id).await {
|
|
||||||
Ok(true) => Ok(ToolResult {
|
|
||||||
success: true,
|
|
||||||
output: format!("后台任务 {} 已取消", task_id),
|
|
||||||
error: None,
|
|
||||||
}),
|
|
||||||
Ok(false) => Ok(ToolResult {
|
|
||||||
success: false,
|
|
||||||
output: String::new(),
|
|
||||||
error: Some(format!("无法取消任务 {}(可能已完成或不存在)", task_id)),
|
|
||||||
}),
|
|
||||||
Err(e) => Ok(ToolResult {
|
|
||||||
success: false,
|
|
||||||
output: String::new(),
|
|
||||||
error: Some(format!("取消失败: {}", e)),
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn handle_list_tasks(&self, _args: &serde_json::Value) -> anyhow::Result<ToolResult> {
|
|
||||||
let ctx = crate::agent::sub_agent::get_delegate_context()
|
|
||||||
.map_err(|_| anyhow::anyhow!("delegate context not available"))?;
|
|
||||||
let tasks = self.sub_agent_manager.list_tasks(&ctx.session_id).await;
|
|
||||||
|
|
||||||
if tasks.is_empty() {
|
|
||||||
return Ok(ToolResult {
|
|
||||||
success: true,
|
|
||||||
output: "没有后台任务".to_string(),
|
|
||||||
error: None,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut output = String::from("后台任务列表:\n\n");
|
|
||||||
for task in &tasks {
|
|
||||||
let status_icon = match task.status.as_str() {
|
|
||||||
"completed" => "✅",
|
|
||||||
"failed" => "❌",
|
|
||||||
"cancelled" => "🚫",
|
|
||||||
"running" => "🔄",
|
|
||||||
"pending" => "⏳",
|
|
||||||
_ => "❓",
|
|
||||||
};
|
|
||||||
output.push_str(&format!(
|
|
||||||
"{} {} - {} - {} (created: {})\n",
|
|
||||||
status_icon,
|
|
||||||
&task.id[..std::cmp::min(8, task.id.len())],
|
|
||||||
task.prompt.chars().take(60).collect::<String>(),
|
|
||||||
task.status,
|
|
||||||
task.created_at,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
output.push_str(&format!("\n共 {} 个任务", tasks.len()));
|
|
||||||
|
|
||||||
Ok(ToolResult {
|
|
||||||
success: true,
|
|
||||||
output,
|
|
||||||
error: None,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -243,8 +243,8 @@ impl Tool for FileEditTool {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use std::io::Write;
|
|
||||||
use tempfile::NamedTempFile;
|
use tempfile::NamedTempFile;
|
||||||
|
use std::io::Write;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_edit_simple() {
|
async fn test_edit_simple() {
|
||||||
|
|||||||
@ -181,7 +181,10 @@ impl Tool for FileReadTool {
|
|||||||
}
|
}
|
||||||
result = lines[..end_idx].join("\n");
|
result = lines[..end_idx].join("\n");
|
||||||
let truncated = original_len - result.len();
|
let truncated = original_len - result.len();
|
||||||
result.push_str(&format!("\n\n... ({} chars truncated) ...", truncated));
|
result.push_str(&format!(
|
||||||
|
"\n\n... ({} chars truncated) ...",
|
||||||
|
truncated
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
if end < total {
|
if end < total {
|
||||||
@ -193,7 +196,10 @@ impl Tool for FileReadTool {
|
|||||||
end + 1
|
end + 1
|
||||||
));
|
));
|
||||||
} else {
|
} else {
|
||||||
result.push_str(&format!("\n\n(End of file — {} lines total)", total));
|
result.push_str(&format!(
|
||||||
|
"\n\n(End of file — {} lines total)",
|
||||||
|
total
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(label) = encoding_label {
|
if let Some(label) = encoding_label {
|
||||||
@ -208,7 +214,7 @@ impl Tool for FileReadTool {
|
|||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
// Truly binary file — base64 encode
|
// Truly binary file — base64 encode
|
||||||
use base64::{Engine, engine::general_purpose::STANDARD};
|
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||||
let encoded = STANDARD.encode(&bytes);
|
let encoded = STANDARD.encode(&bytes);
|
||||||
let mime = mime_guess::from_path(&resolved)
|
let mime = mime_guess::from_path(&resolved)
|
||||||
.first_or_octet_stream()
|
.first_or_octet_stream()
|
||||||
@ -272,8 +278,8 @@ fn decode_text(bytes: &[u8]) -> (Option<String>, Option<&'static str>) {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use std::io::Write;
|
|
||||||
use tempfile::NamedTempFile;
|
use tempfile::NamedTempFile;
|
||||||
|
use std::io::Write;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_read_simple_file() {
|
async fn test_read_simple_file() {
|
||||||
@ -332,7 +338,10 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_is_directory() {
|
async fn test_is_directory() {
|
||||||
let tool = FileReadTool::new();
|
let tool = FileReadTool::new();
|
||||||
let result = tool.execute(json!({ "path": "." })).await.unwrap();
|
let result = tool
|
||||||
|
.execute(json!({ "path": "." }))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert!(!result.success);
|
assert!(!result.success);
|
||||||
assert!(result.error.unwrap().contains("Not a file"));
|
assert!(result.error.unwrap().contains("Not a file"));
|
||||||
|
|||||||
@ -101,29 +101,17 @@ impl Tool for FileSearchTool {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let dir = self.resolve_dir(args.get("dir").and_then(|v| v.as_str()));
|
let dir = self.resolve_dir(args.get("dir").and_then(|v| v.as_str()));
|
||||||
let case_sensitive = args
|
let case_sensitive = args.get("case_sensitive").and_then(|v| v.as_bool()).unwrap_or(true);
|
||||||
.get("case_sensitive")
|
let max_results = args.get("max_results").and_then(|v| v.as_u64()).unwrap_or(MAX_RESULTS as u64) as usize;
|
||||||
.and_then(|v| v.as_bool())
|
|
||||||
.unwrap_or(true);
|
|
||||||
let max_results = args
|
|
||||||
.get("max_results")
|
|
||||||
.and_then(|v| v.as_u64())
|
|
||||||
.unwrap_or(MAX_RESULTS as u64) as usize;
|
|
||||||
|
|
||||||
let result = self
|
let result = self.run_search(pattern, &dir, case_sensitive, max_results).await;
|
||||||
.run_search(pattern, &dir, case_sensitive, max_results)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
Ok(lines) => {
|
Ok(lines) => {
|
||||||
let count = lines.len();
|
let count = lines.len();
|
||||||
let mut output = self.truncate_output(&lines);
|
let mut output = self.truncate_output(&lines);
|
||||||
output.push_str(&format!("\n\n---\n共 {} 个文件", count));
|
output.push_str(&format!("\n\n---\n共 {} 个文件", count));
|
||||||
Ok(ToolResult {
|
Ok(ToolResult { success: true, output, error: None })
|
||||||
success: true,
|
|
||||||
output,
|
|
||||||
error: None,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
Err(e) => Ok(ToolResult {
|
Err(e) => Ok(ToolResult {
|
||||||
success: false,
|
success: false,
|
||||||
@ -151,12 +139,9 @@ impl FileSearchTool {
|
|||||||
};
|
};
|
||||||
|
|
||||||
if !fd_cmd.is_empty() {
|
if !fd_cmd.is_empty() {
|
||||||
match self
|
match self.search_with_fd(pattern, dir, case_sensitive, max_results, fd_cmd).await {
|
||||||
.search_with_fd(pattern, dir, case_sensitive, max_results, fd_cmd)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(lines) if !lines.is_empty() => return Ok(lines),
|
Ok(lines) if !lines.is_empty() => return Ok(lines),
|
||||||
Ok(_) => {}
|
Ok(_) => {},
|
||||||
Err(e) => tracing::warn!("{} failed: {}, falling back", fd_cmd, e),
|
Err(e) => tracing::warn!("{} failed: {}, falling back", fd_cmd, e),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -164,14 +149,13 @@ impl FileSearchTool {
|
|||||||
if which::which("find").is_ok() {
|
if which::which("find").is_ok() {
|
||||||
match self.search_with_find(pattern, dir, max_results).await {
|
match self.search_with_find(pattern, dir, max_results).await {
|
||||||
Ok(lines) if !lines.is_empty() => return Ok(lines),
|
Ok(lines) if !lines.is_empty() => return Ok(lines),
|
||||||
Ok(_) => {}
|
Ok(_) => {},
|
||||||
Err(e) => tracing::warn!("find failed: {}, falling back", e),
|
Err(e) => tracing::warn!("find failed: {}, falling back", e),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tracing::warn!("No fd/find available, using built-in file search (slower)");
|
tracing::warn!("No fd/find available, using built-in file search (slower)");
|
||||||
self.search_with_rust(pattern, dir, case_sensitive, max_results)
|
self.search_with_rust(pattern, dir, case_sensitive, max_results).await
|
||||||
.await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn search_with_fd(
|
async fn search_with_fd(
|
||||||
@ -183,15 +167,11 @@ impl FileSearchTool {
|
|||||||
fd_cmd: &str,
|
fd_cmd: &str,
|
||||||
) -> anyhow::Result<Vec<String>> {
|
) -> anyhow::Result<Vec<String>> {
|
||||||
let mut cmd = Command::new(fd_cmd);
|
let mut cmd = Command::new(fd_cmd);
|
||||||
cmd.arg("--search-path")
|
cmd.arg("--search-path").arg(dir)
|
||||||
.arg(dir)
|
.arg("--glob").arg(pattern)
|
||||||
.arg("--glob")
|
.arg("--color").arg("never")
|
||||||
.arg(pattern)
|
|
||||||
.arg("--color")
|
|
||||||
.arg("never")
|
|
||||||
.arg("--strip-cwd-prefix")
|
.arg("--strip-cwd-prefix")
|
||||||
.arg("--max-results")
|
.arg("--max-results").arg(max_results.to_string())
|
||||||
.arg(max_results.to_string())
|
|
||||||
.stdout(Stdio::piped())
|
.stdout(Stdio::piped())
|
||||||
.stderr(Stdio::piped());
|
.stderr(Stdio::piped());
|
||||||
|
|
||||||
@ -199,7 +179,10 @@ impl FileSearchTool {
|
|||||||
cmd.arg("--ignore-case");
|
cmd.arg("--ignore-case");
|
||||||
}
|
}
|
||||||
|
|
||||||
let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
|
let output = timeout(
|
||||||
|
std::time::Duration::from_secs(TIMEOUT_SECS),
|
||||||
|
cmd.output(),
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.map_err(|_| anyhow::anyhow!("fd timed out after {}s", TIMEOUT_SECS))??;
|
.map_err(|_| anyhow::anyhow!("fd timed out after {}s", TIMEOUT_SECS))??;
|
||||||
|
|
||||||
@ -209,8 +192,7 @@ impl FileSearchTool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let text = String::from_utf8_lossy(&output.stdout);
|
let text = String::from_utf8_lossy(&output.stdout);
|
||||||
let lines: Vec<String> = text
|
let lines: Vec<String> = text.lines()
|
||||||
.lines()
|
|
||||||
.filter(|l| !l.is_empty())
|
.filter(|l| !l.is_empty())
|
||||||
.map(|l| l.to_string())
|
.map(|l| l.to_string())
|
||||||
.collect();
|
.collect();
|
||||||
@ -233,13 +215,15 @@ impl FileSearchTool {
|
|||||||
.stdout(Stdio::piped())
|
.stdout(Stdio::piped())
|
||||||
.stderr(Stdio::null());
|
.stderr(Stdio::null());
|
||||||
|
|
||||||
let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
|
let output = timeout(
|
||||||
|
std::time::Duration::from_secs(TIMEOUT_SECS),
|
||||||
|
cmd.output(),
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.map_err(|_| anyhow::anyhow!("find timed out after {}s", TIMEOUT_SECS))??;
|
.map_err(|_| anyhow::anyhow!("find timed out after {}s", TIMEOUT_SECS))??;
|
||||||
|
|
||||||
let text = String::from_utf8_lossy(&output.stdout);
|
let text = String::from_utf8_lossy(&output.stdout);
|
||||||
let mut lines: Vec<String> = text
|
let mut lines: Vec<String> = text.lines()
|
||||||
.lines()
|
|
||||||
.filter(|l| !l.is_empty())
|
.filter(|l| !l.is_empty())
|
||||||
.map(|l| {
|
.map(|l| {
|
||||||
let p = Path::new(l);
|
let p = Path::new(l);
|
||||||
@ -270,13 +254,7 @@ impl FileSearchTool {
|
|||||||
.map_err(|e| anyhow::anyhow!("Invalid glob pattern '{}': {}", pattern, e))?;
|
.map_err(|e| anyhow::anyhow!("Invalid glob pattern '{}': {}", pattern, e))?;
|
||||||
|
|
||||||
let mut results = Vec::new();
|
let mut results = Vec::new();
|
||||||
walk_dir(
|
walk_dir(Path::new(dir), Path::new(dir), &re, &mut results, max_results)?;
|
||||||
Path::new(dir),
|
|
||||||
Path::new(dir),
|
|
||||||
&re,
|
|
||||||
&mut results,
|
|
||||||
max_results,
|
|
||||||
)?;
|
|
||||||
Ok(results)
|
Ok(results)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -333,16 +311,13 @@ fn walk_dir(
|
|||||||
|
|
||||||
if path.is_dir() {
|
if path.is_dir() {
|
||||||
if let Some(name) = rel.file_name().and_then(|n| n.to_str())
|
if let Some(name) = rel.file_name().and_then(|n| n.to_str())
|
||||||
&& name.starts_with('.')
|
&& name.starts_with('.') && name.len() > 1 {
|
||||||
&& name.len() > 1
|
|
||||||
{
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
walk_dir(base, &path, re, results, max)?;
|
walk_dir(base, &path, re, results, max)?;
|
||||||
} else if path.is_file() {
|
} else if path.is_file() {
|
||||||
if let Some(name) = rel.file_name().and_then(|n| n.to_str())
|
if let Some(name) = rel.file_name().and_then(|n| n.to_str())
|
||||||
&& re.is_match(name)
|
&& re.is_match(name) {
|
||||||
{
|
|
||||||
results.push(rel.to_string_lossy().to_string());
|
results.push(rel.to_string_lossy().to_string());
|
||||||
}
|
}
|
||||||
if results.len() >= max {
|
if results.len() >= max {
|
||||||
|
|||||||
@ -90,8 +90,7 @@ impl Tool for FileWriteTool {
|
|||||||
// Create parent directories if needed
|
// Create parent directories if needed
|
||||||
if let Some(parent) = resolved.parent()
|
if let Some(parent) = resolved.parent()
|
||||||
&& !parent.exists()
|
&& !parent.exists()
|
||||||
&& let Err(e) = std::fs::create_dir_all(parent)
|
&& let Err(e) = std::fs::create_dir_all(parent) {
|
||||||
{
|
|
||||||
return Ok(ToolResult {
|
return Ok(ToolResult {
|
||||||
success: false,
|
success: false,
|
||||||
output: String::new(),
|
output: String::new(),
|
||||||
@ -169,7 +168,10 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_write_missing_path() {
|
async fn test_write_missing_path() {
|
||||||
let tool = FileWriteTool::new();
|
let tool = FileWriteTool::new();
|
||||||
let result = tool.execute(json!({ "content": "Hello" })).await.unwrap();
|
let result = tool
|
||||||
|
.execute(json!({ "content": "Hello" }))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert!(!result.success);
|
assert!(!result.success);
|
||||||
assert!(result.error.unwrap().contains("path"));
|
assert!(result.error.unwrap().contains("path"));
|
||||||
|
|||||||
@ -129,9 +129,7 @@ impl GetSkillTool {
|
|||||||
let mut output = format!("可用 skill (共 {} 个):\n", skills.len());
|
let mut output = format!("可用 skill (共 {} 个):\n", skills.len());
|
||||||
for s in &skills {
|
for s in &skills {
|
||||||
let always_mark = if s.always { " [常驻]" } else { "" };
|
let always_mark = if s.always { " [常驻]" } else { "" };
|
||||||
let path_str = s
|
let path_str = s.path.as_ref()
|
||||||
.path
|
|
||||||
.as_ref()
|
|
||||||
.map(|p| p.to_string_lossy().to_string())
|
.map(|p| p.to_string_lossy().to_string())
|
||||||
.unwrap_or_else(|| "—".to_string());
|
.unwrap_or_else(|| "—".to_string());
|
||||||
output.push_str(&format!(
|
output.push_str(&format!(
|
||||||
@ -150,10 +148,10 @@ impl GetSkillTool {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use tempfile::tempdir;
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use tempfile::tempdir;
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_existing_skill() {
|
async fn test_get_existing_skill() {
|
||||||
|
|||||||
@ -50,7 +50,10 @@ impl HttpRequestTool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !host_matches_allowlist(&host, &self.allowed_domains) {
|
if !host_matches_allowlist(&host, &self.allowed_domains) {
|
||||||
return Err(format!("Host '{}' is not in allowed_domains", host));
|
return Err(format!(
|
||||||
|
"Host '{}' is not in allowed_domains",
|
||||||
|
host
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(url.to_string())
|
Ok(url.to_string())
|
||||||
@ -77,7 +80,8 @@ impl HttpRequestTool {
|
|||||||
for (key, value) in obj {
|
for (key, value) in obj {
|
||||||
if let Some(str_val) = value.as_str()
|
if let Some(str_val) = value.as_str()
|
||||||
&& let Ok(name) = reqwest::header::HeaderName::from_bytes(key.as_bytes())
|
&& let Ok(name) = reqwest::header::HeaderName::from_bytes(key.as_bytes())
|
||||||
&& let Ok(val) = reqwest::header::HeaderValue::from_str(str_val)
|
&& let Ok(val) =
|
||||||
|
reqwest::header::HeaderValue::from_str(str_val)
|
||||||
{
|
{
|
||||||
header_map.insert(name, val);
|
header_map.insert(name, val);
|
||||||
}
|
}
|
||||||
@ -187,9 +191,7 @@ fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool {
|
|||||||
|
|
||||||
allowed_domains.iter().any(|domain| {
|
allowed_domains.iter().any(|domain| {
|
||||||
host == domain
|
host == domain
|
||||||
|| host
|
|| host.strip_suffix(domain).is_some_and(|prefix| prefix.ends_with('.'))
|
||||||
.strip_suffix(domain)
|
|
||||||
.is_some_and(|prefix| prefix.ends_with('.'))
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -200,11 +202,7 @@ fn is_private_host(host: &str) -> bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check .local TLD
|
// Check .local TLD
|
||||||
if host
|
if host.rsplit('.').next().is_some_and(|label| label == "local") {
|
||||||
.rsplit('.')
|
|
||||||
.next()
|
|
||||||
.is_some_and(|label| label == "local")
|
|
||||||
{
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -226,7 +224,9 @@ fn is_private_ip(ip: &std::net::IpAddr) -> bool {
|
|||||||
|| v4.is_broadcast()
|
|| v4.is_broadcast()
|
||||||
|| v4.is_multicast()
|
|| v4.is_multicast()
|
||||||
}
|
}
|
||||||
std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(),
|
std::net::IpAddr::V6(v6) => {
|
||||||
|
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -278,7 +278,10 @@ impl Tool for HttpRequestTool {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let method_str = args.get("method").and_then(|v| v.as_str()).unwrap_or("GET");
|
let method_str = args
|
||||||
|
.get("method")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("GET");
|
||||||
|
|
||||||
let headers_val = args.get("headers").cloned().unwrap_or(json!({}));
|
let headers_val = args.get("headers").cloned().unwrap_or(json!({}));
|
||||||
let body = args.get("body").and_then(|v| v.as_str());
|
let body = args.get("body").and_then(|v| v.as_str());
|
||||||
|
|||||||
@ -151,19 +151,10 @@ impl Tool for MemoryRecallTool {
|
|||||||
.and_then(|v| v.as_i64())
|
.and_then(|v| v.as_i64())
|
||||||
.unwrap_or(chrono::Utc::now().timestamp_millis());
|
.unwrap_or(chrono::Utc::now().timestamp_millis());
|
||||||
self.memory
|
self.memory
|
||||||
.recall_by_time(
|
.recall_by_time(since, until, Some(query), limit, Some(MemoryCategory::Knowledge), None)
|
||||||
since,
|
|
||||||
until,
|
|
||||||
Some(query),
|
|
||||||
limit,
|
|
||||||
Some(MemoryCategory::Knowledge),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.await?
|
.await?
|
||||||
} else {
|
} else {
|
||||||
self.memory
|
self.memory.recall(query, limit, Some(MemoryCategory::Knowledge), None).await?
|
||||||
.recall(query, limit, Some(MemoryCategory::Knowledge), None)
|
|
||||||
.await?
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if entries.is_empty() {
|
if entries.is_empty() {
|
||||||
@ -177,11 +168,7 @@ impl Tool for MemoryRecallTool {
|
|||||||
let formatted = entries
|
let formatted = entries
|
||||||
.iter()
|
.iter()
|
||||||
.map(|e| {
|
.map(|e| {
|
||||||
let session = e
|
let session = e.session_id.as_deref().map(|s| format!(" [session: {}]", s)).unwrap_or_default();
|
||||||
.session_id
|
|
||||||
.as_deref()
|
|
||||||
.map(|s| format!(" [session: {}]", s))
|
|
||||||
.unwrap_or_default();
|
|
||||||
format!(
|
format!(
|
||||||
"- {} [{}]{} [importance: {:.1}]: {}",
|
"- {} [{}]{} [importance: {:.1}]: {}",
|
||||||
e.key,
|
e.key,
|
||||||
@ -277,19 +264,10 @@ impl Tool for TimelineRecallTool {
|
|||||||
.and_then(|v| v.as_i64())
|
.and_then(|v| v.as_i64())
|
||||||
.unwrap_or(chrono::Utc::now().timestamp_millis());
|
.unwrap_or(chrono::Utc::now().timestamp_millis());
|
||||||
self.memory
|
self.memory
|
||||||
.recall_by_time(
|
.recall_by_time(since, until, Some(query), limit, Some(MemoryCategory::Timeline), session_id)
|
||||||
since,
|
|
||||||
until,
|
|
||||||
Some(query),
|
|
||||||
limit,
|
|
||||||
Some(MemoryCategory::Timeline),
|
|
||||||
session_id,
|
|
||||||
)
|
|
||||||
.await?
|
.await?
|
||||||
} else {
|
} else {
|
||||||
self.memory
|
self.memory.recall(query, limit, Some(MemoryCategory::Timeline), session_id).await?
|
||||||
.recall(query, limit, Some(MemoryCategory::Timeline), session_id)
|
|
||||||
.await?
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if entries.is_empty() {
|
if entries.is_empty() {
|
||||||
@ -303,11 +281,7 @@ impl Tool for TimelineRecallTool {
|
|||||||
let formatted = entries
|
let formatted = entries
|
||||||
.iter()
|
.iter()
|
||||||
.map(|e| {
|
.map(|e| {
|
||||||
let session = e
|
let session = e.session_id.as_deref().map(|s| format!(" [session: {}]", s)).unwrap_or_default();
|
||||||
.session_id
|
|
||||||
.as_deref()
|
|
||||||
.map(|s| format!(" [session: {}]", s))
|
|
||||||
.unwrap_or_default();
|
|
||||||
format!(
|
format!(
|
||||||
"- {} [{}]{} [importance: {:.1}]: {}",
|
"- {} [{}]{} [importance: {:.1}]: {}",
|
||||||
e.key,
|
e.key,
|
||||||
|
|||||||
@ -4,7 +4,6 @@ pub mod calculator;
|
|||||||
pub mod chat_manager;
|
pub mod chat_manager;
|
||||||
pub mod content_search;
|
pub mod content_search;
|
||||||
pub mod cron;
|
pub mod cron;
|
||||||
pub mod delegate;
|
|
||||||
pub mod file_edit;
|
pub mod file_edit;
|
||||||
pub mod file_read;
|
pub mod file_read;
|
||||||
pub mod file_search;
|
pub mod file_search;
|
||||||
@ -24,7 +23,6 @@ pub use browser::BrowserTool;
|
|||||||
pub use calculator::CalculatorTool;
|
pub use calculator::CalculatorTool;
|
||||||
pub use chat_manager::ChatManagerTool;
|
pub use chat_manager::ChatManagerTool;
|
||||||
pub use content_search::ContentSearchTool;
|
pub use content_search::ContentSearchTool;
|
||||||
pub use delegate::DelegateTool;
|
|
||||||
pub use file_edit::FileEditTool;
|
pub use file_edit::FileEditTool;
|
||||||
pub use file_read::FileReadTool;
|
pub use file_read::FileReadTool;
|
||||||
pub use file_search::FileSearchTool;
|
pub use file_search::FileSearchTool;
|
||||||
@ -37,11 +35,10 @@ pub use send_message::SendMessageTool;
|
|||||||
pub use traits::{OutboundMessenger, Tool, ToolResult};
|
pub use traits::{OutboundMessenger, Tool, ToolResult};
|
||||||
pub use web_fetch::WebFetchTool;
|
pub use web_fetch::WebFetchTool;
|
||||||
|
|
||||||
use crate::agent::SubAgentManager;
|
use std::sync::Arc;
|
||||||
use crate::config::BrowserConfig;
|
use crate::config::BrowserConfig;
|
||||||
use crate::memory::MemoryManager;
|
use crate::memory::MemoryManager;
|
||||||
use crate::skills::SkillsLoader;
|
use crate::skills::SkillsLoader;
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
/// Create the base tool registry (without send_message).
|
/// Create the base tool registry (without send_message).
|
||||||
/// `send_message` tool is registered later via `SessionManager::register_outbound_tool()`
|
/// `send_message` tool is registered later via `SessionManager::register_outbound_tool()`
|
||||||
@ -49,7 +46,6 @@ use std::sync::Arc;
|
|||||||
pub fn create_default_tools(
|
pub fn create_default_tools(
|
||||||
skills_loader: Arc<SkillsLoader>,
|
skills_loader: Arc<SkillsLoader>,
|
||||||
memory: Arc<MemoryManager>,
|
memory: Arc<MemoryManager>,
|
||||||
sub_agent_manager: Option<Arc<SubAgentManager>>,
|
|
||||||
browser_config: Option<&BrowserConfig>,
|
browser_config: Option<&BrowserConfig>,
|
||||||
) -> ToolRegistry {
|
) -> ToolRegistry {
|
||||||
let registry = ToolRegistry::new();
|
let registry = ToolRegistry::new();
|
||||||
@ -80,9 +76,5 @@ pub fn create_default_tools(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(mgr) = sub_agent_manager {
|
|
||||||
registry.register(DelegateTool::new(mgr));
|
|
||||||
}
|
|
||||||
|
|
||||||
registry
|
registry
|
||||||
}
|
}
|
||||||
|
|||||||
608
src/tools/pty.rs
608
src/tools/pty.rs
@ -1,608 +0,0 @@
|
|||||||
use std::collections::{HashMap, VecDeque};
|
|
||||||
use std::io::Write;
|
|
||||||
use std::sync::{Arc, Mutex};
|
|
||||||
use std::time::Instant;
|
|
||||||
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use serde_json::json;
|
|
||||||
|
|
||||||
use crate::tools::traits::{Tool, ToolResult};
|
|
||||||
|
|
||||||
const MAX_OUTPUT_LINES: usize = 50_000;
|
|
||||||
const MAX_CHARS_PER_LINE: usize = 2_000;
|
|
||||||
const MAX_SESSIONS: usize = 10;
|
|
||||||
|
|
||||||
fn guard_command(command: &str) -> Option<String> {
|
|
||||||
let deny_patterns: &[&str] = &[
|
|
||||||
r"\brm\s+-[rf]{1,2}\b",
|
|
||||||
r"\bdel\s+/[fq]\b",
|
|
||||||
r"\brmdir\s+/s\b",
|
|
||||||
r":\(\)\s*\{.*\};\s*:",
|
|
||||||
];
|
|
||||||
let lower = command.to_lowercase();
|
|
||||||
for pattern in deny_patterns {
|
|
||||||
if regex::Regex::new(pattern)
|
|
||||||
.ok()
|
|
||||||
.map(|re| re.is_match(&lower))
|
|
||||||
.unwrap_or(false)
|
|
||||||
{
|
|
||||||
return Some(format!(
|
|
||||||
"Command blocked by safety guard (dangerous pattern: {})",
|
|
||||||
pattern
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
fn unescape_control_chars(s: &str) -> String {
|
|
||||||
let mut result = String::with_capacity(s.len());
|
|
||||||
let mut chars = s.chars().peekable();
|
|
||||||
while let Some(c) = chars.next() {
|
|
||||||
if c == '\\' {
|
|
||||||
match chars.next() {
|
|
||||||
Some('n') => result.push('\n'),
|
|
||||||
Some('r') => result.push('\r'),
|
|
||||||
Some('t') => result.push('\t'),
|
|
||||||
Some('x') => {
|
|
||||||
let hex: String = chars.by_ref().take(2).collect();
|
|
||||||
if hex.len() == 2 {
|
|
||||||
if let Ok(byte) = u8::from_str_radix(&hex, 16) {
|
|
||||||
result.push(byte as char);
|
|
||||||
} else {
|
|
||||||
result.push_str(&format!("\\x{}", hex));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
result.push_str(&format!("\\x{}", hex));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Some('\\') => result.push('\\'),
|
|
||||||
Some(other) => {
|
|
||||||
result.push('\\');
|
|
||||||
result.push(other);
|
|
||||||
}
|
|
||||||
None => result.push('\\'),
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
result.push(c);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
result
|
|
||||||
}
|
|
||||||
|
|
||||||
fn truncate_command(cmd: &str, max_len: usize) -> String {
|
|
||||||
let cmd = cmd.trim();
|
|
||||||
let first_arg = cmd.split_whitespace().next().unwrap_or(cmd);
|
|
||||||
if first_arg.len() > max_len {
|
|
||||||
format!("{}...", &first_arg[..max_len])
|
|
||||||
} else {
|
|
||||||
first_arg.to_string()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn status_str(status: &SessionStatus) -> &str {
|
|
||||||
match status {
|
|
||||||
SessionStatus::Running => "running",
|
|
||||||
SessionStatus::Exited(_) => "exited",
|
|
||||||
SessionStatus::Killed => "killed",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
|
||||||
enum SessionStatus {
|
|
||||||
Running,
|
|
||||||
Exited(i32),
|
|
||||||
Killed,
|
|
||||||
}
|
|
||||||
|
|
||||||
struct PtySession {
|
|
||||||
#[allow(dead_code)]
|
|
||||||
id: String,
|
|
||||||
#[allow(dead_code)]
|
|
||||||
command: String,
|
|
||||||
#[allow(dead_code)]
|
|
||||||
started_at: Instant,
|
|
||||||
status: SessionStatus,
|
|
||||||
child: Arc<Mutex<Option<Box<dyn portable_pty::Child + Send + Sync>>>>,
|
|
||||||
writer: Arc<Mutex<Option<Box<dyn Write + Send>>>>,
|
|
||||||
output_buffer: VecDeque<String>,
|
|
||||||
output_total_lines: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PtySession {
|
|
||||||
fn new(
|
|
||||||
id: String,
|
|
||||||
command: String,
|
|
||||||
child: Box<dyn portable_pty::Child + Send + Sync>,
|
|
||||||
writer: Box<dyn Write + Send>,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
|
||||||
id,
|
|
||||||
command,
|
|
||||||
started_at: Instant::now(),
|
|
||||||
status: SessionStatus::Running,
|
|
||||||
child: Arc::new(Mutex::new(Some(child))),
|
|
||||||
writer: Arc::new(Mutex::new(Some(writer))),
|
|
||||||
output_buffer: VecDeque::new(),
|
|
||||||
output_total_lines: 0,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn push_line(&mut self, line: String) {
|
|
||||||
let line = if line.len() > MAX_CHARS_PER_LINE {
|
|
||||||
format!("{}...<truncated>", &line[..MAX_CHARS_PER_LINE])
|
|
||||||
} else {
|
|
||||||
line
|
|
||||||
};
|
|
||||||
self.output_total_lines += 1;
|
|
||||||
if self.output_buffer.len() >= MAX_OUTPUT_LINES {
|
|
||||||
self.output_buffer.pop_front();
|
|
||||||
}
|
|
||||||
self.output_buffer.push_back(line);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct PtyManager {
|
|
||||||
sessions: Mutex<HashMap<String, Arc<Mutex<PtySession>>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PtyManager {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
sessions: Mutex::new(HashMap::new()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn cleanup_all(&self) {
|
|
||||||
let sessions: Vec<Arc<Mutex<PtySession>>> = {
|
|
||||||
let mut guard = self.sessions.lock().unwrap();
|
|
||||||
guard.drain().map(|(_, s)| s).collect()
|
|
||||||
};
|
|
||||||
for session in sessions {
|
|
||||||
let mut guard = session.lock().unwrap();
|
|
||||||
let mut child_guard = guard.child.lock().unwrap();
|
|
||||||
if let Some(ref mut child) = *child_guard {
|
|
||||||
let _ = child.kill();
|
|
||||||
}
|
|
||||||
*child_guard = None;
|
|
||||||
guard.status = SessionStatus::Killed;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn spawn(&self, command: &str) -> Result<String, String> {
|
|
||||||
if let Some(reason) = guard_command(command) {
|
|
||||||
return Err(reason);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut sessions = self.sessions.lock().unwrap();
|
|
||||||
if sessions.len() >= MAX_SESSIONS {
|
|
||||||
return Err(format!(
|
|
||||||
"Max sessions ({}) reached, kill some sessions first",
|
|
||||||
MAX_SESSIONS
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
let session_id = format!("pty_{}", crate::util::short_id());
|
|
||||||
let cwd = std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from("."));
|
|
||||||
|
|
||||||
let pty_system = portable_pty::native_pty_system();
|
|
||||||
let pty_pair = pty_system
|
|
||||||
.openpty(portable_pty::PtySize {
|
|
||||||
rows: 24,
|
|
||||||
cols: 80,
|
|
||||||
pixel_width: 0,
|
|
||||||
pixel_height: 0,
|
|
||||||
})
|
|
||||||
.map_err(|e| format!("Failed to open PTY: {}", e))?;
|
|
||||||
|
|
||||||
let mut cmd = portable_pty::CommandBuilder::new("bash");
|
|
||||||
cmd.args(&["-c", command]);
|
|
||||||
cmd.cwd(cwd);
|
|
||||||
|
|
||||||
let child = pty_pair
|
|
||||||
.slave
|
|
||||||
.spawn_command(cmd)
|
|
||||||
.map_err(|e| format!("Failed to spawn: {}", e))?;
|
|
||||||
|
|
||||||
let pid = child.process_id().unwrap_or(0);
|
|
||||||
|
|
||||||
let writer = pty_pair
|
|
||||||
.master
|
|
||||||
.take_writer()
|
|
||||||
.map_err(|e| format!("Failed to take writer: {}", e))?;
|
|
||||||
|
|
||||||
let reader = pty_pair
|
|
||||||
.master
|
|
||||||
.try_clone_reader()
|
|
||||||
.map_err(|e| format!("Failed to clone reader: {}", e))?;
|
|
||||||
|
|
||||||
let session_id_clone = session_id.clone();
|
|
||||||
let session = PtySession::new(session_id_clone, command.to_string(), child, writer);
|
|
||||||
let session = Arc::new(Mutex::new(session));
|
|
||||||
sessions.insert(session_id.clone(), session.clone());
|
|
||||||
|
|
||||||
let session_for_reader = session.clone();
|
|
||||||
let child_for_reader = session_for_reader.lock().unwrap().child.clone();
|
|
||||||
tokio::task::spawn_blocking(move || {
|
|
||||||
let mut reader = reader;
|
|
||||||
let mut buf = [0u8; 4096];
|
|
||||||
let mut partial = String::new();
|
|
||||||
loop {
|
|
||||||
use std::io::Read;
|
|
||||||
match reader.read(&mut buf) {
|
|
||||||
Ok(0) => break,
|
|
||||||
Ok(n) => {
|
|
||||||
let s = String::from_utf8_lossy(&buf[..n]);
|
|
||||||
partial.push_str(&s);
|
|
||||||
let lines: Vec<&str> = partial.split('\n').collect();
|
|
||||||
if lines.len() <= 1 {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
let complete = lines.len() - 1;
|
|
||||||
let mut guard = session_for_reader.lock().unwrap();
|
|
||||||
for line in lines[..complete].iter() {
|
|
||||||
guard.push_line(line.to_string());
|
|
||||||
}
|
|
||||||
partial = lines[complete].to_string();
|
|
||||||
}
|
|
||||||
Err(_) => break,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let mut guard = session_for_reader.lock().unwrap();
|
|
||||||
if !partial.is_empty() {
|
|
||||||
guard.push_line(partial);
|
|
||||||
}
|
|
||||||
let exit_code = {
|
|
||||||
let mut cg = child_for_reader.lock().unwrap();
|
|
||||||
if let Some(ref mut c) = *cg {
|
|
||||||
c.wait().map(|s| s.exit_code() as i32).ok()
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
};
|
|
||||||
guard.status = SessionStatus::Exited(exit_code.unwrap_or(-1));
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(format!("session_id: {}, pid: {}", session_id, pid))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn write(&self, session_id: &str, data: &str) -> Result<String, String> {
|
|
||||||
let unescaped = unescape_control_chars(data);
|
|
||||||
let byte_count = unescaped.len();
|
|
||||||
|
|
||||||
let sessions = self.sessions.lock().unwrap();
|
|
||||||
let session = sessions
|
|
||||||
.get(session_id)
|
|
||||||
.ok_or_else(|| format!("Session not found: {}", session_id))?;
|
|
||||||
let mut guard = session.lock().unwrap();
|
|
||||||
if guard.status != SessionStatus::Running {
|
|
||||||
return Err("Session is not running".to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
let writer = guard.writer.clone();
|
|
||||||
drop(guard);
|
|
||||||
drop(sessions);
|
|
||||||
|
|
||||||
let mut writer_guard = writer.lock().unwrap();
|
|
||||||
match *writer_guard {
|
|
||||||
Some(ref mut w) => {
|
|
||||||
w.write_all(unescaped.as_bytes())
|
|
||||||
.map_err(|e| format!("Write error: {}", e))?;
|
|
||||||
w.flush().map_err(|e| format!("Flush error: {}", e))?;
|
|
||||||
}
|
|
||||||
None => return Err("Writer not available".to_string()),
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(format!("OK, wrote {} bytes", byte_count))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn read(
|
|
||||||
&self,
|
|
||||||
session_id: &str,
|
|
||||||
offset: usize,
|
|
||||||
limit: usize,
|
|
||||||
) -> Result<String, String> {
|
|
||||||
let sessions = self.sessions.lock().unwrap();
|
|
||||||
let session = sessions
|
|
||||||
.get(session_id)
|
|
||||||
.ok_or_else(|| format!("Session not found: {}", session_id))?;
|
|
||||||
let guard = session.lock().unwrap();
|
|
||||||
|
|
||||||
let total = guard.output_total_lines;
|
|
||||||
let buffer_len = guard.output_buffer.len();
|
|
||||||
let start = 0_usize.max(offset);
|
|
||||||
let skip_old = total.saturating_sub(buffer_len);
|
|
||||||
let view_start = start.saturating_sub(skip_old);
|
|
||||||
|
|
||||||
let lines: Vec<String> = guard
|
|
||||||
.output_buffer
|
|
||||||
.iter()
|
|
||||||
.skip(view_start)
|
|
||||||
.take(limit)
|
|
||||||
.enumerate()
|
|
||||||
.map(|(i, line)| format!("{}: {}", start + i, line))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let displayed = start + lines.len();
|
|
||||||
let has_more = displayed < total;
|
|
||||||
|
|
||||||
let mut output = format!(
|
|
||||||
"# Lines {}-{} (共 {} 行{})\n",
|
|
||||||
start,
|
|
||||||
displayed.saturating_sub(1),
|
|
||||||
total,
|
|
||||||
if has_more { ",还有更多" } else { "" }
|
|
||||||
);
|
|
||||||
output.push_str(&lines.join("\n"));
|
|
||||||
if has_more {
|
|
||||||
output.push_str(&format!(
|
|
||||||
"\n[还有 {} 行未显示,用 offset={} 继续读取]",
|
|
||||||
total.saturating_sub(displayed),
|
|
||||||
displayed
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(output)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn kill(&self, session_id: &str) -> Result<String, String> {
|
|
||||||
let mut sessions = self.sessions.lock().unwrap();
|
|
||||||
let session = sessions
|
|
||||||
.get(session_id)
|
|
||||||
.ok_or_else(|| format!("Session not found: {}", session_id))?;
|
|
||||||
let mut guard = session.lock().unwrap();
|
|
||||||
|
|
||||||
let mut child_guard = guard.child.lock().unwrap();
|
|
||||||
if let Some(ref mut child) = *child_guard {
|
|
||||||
let _ = child.kill();
|
|
||||||
let _ = child.wait();
|
|
||||||
}
|
|
||||||
*child_guard = None;
|
|
||||||
guard.status = SessionStatus::Killed;
|
|
||||||
drop(child_guard);
|
|
||||||
drop(guard);
|
|
||||||
sessions.remove(session_id);
|
|
||||||
|
|
||||||
Ok(format!("Session {} killed", session_id))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn list(&self) -> String {
|
|
||||||
let sessions = self.sessions.lock().unwrap();
|
|
||||||
if sessions.is_empty() {
|
|
||||||
return "No active PTY sessions".to_string();
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut lines: Vec<String> = sessions
|
|
||||||
.iter()
|
|
||||||
.map(|(id, session)| {
|
|
||||||
let guard = session.lock().unwrap();
|
|
||||||
let age = Instant::now().duration_since(guard.started_at);
|
|
||||||
let age_str = if age.as_secs() < 60 {
|
|
||||||
format!("{}s ago", age.as_secs())
|
|
||||||
} else {
|
|
||||||
format!("{}m ago", age.as_secs() / 60)
|
|
||||||
};
|
|
||||||
format!(
|
|
||||||
"{:<14} {:<12} {:<10} {:<6} lines {}",
|
|
||||||
id,
|
|
||||||
truncate_command(&guard.command, 10),
|
|
||||||
status_str(&guard.status),
|
|
||||||
guard.output_total_lines,
|
|
||||||
age_str,
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
lines.sort();
|
|
||||||
lines.join("\n")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Drop for PtyManager {
|
|
||||||
fn drop(&mut self) {
|
|
||||||
self.cleanup_all();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct PtyTool {
|
|
||||||
pty_manager: Arc<PtyManager>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PtyTool {
|
|
||||||
pub fn new(pty_manager: Arc<PtyManager>) -> Self {
|
|
||||||
Self { pty_manager }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl Tool for PtyTool {
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
"pty"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn description(&self) -> &str {
|
|
||||||
"管理持久伪终端(PTY)会话。用于交互式程序、长运行服务、多步骤命令等需要保持终端状态的场景。支持操作: spawn(创建)/write(写入输入)/read(读取输出)/kill(终止)/list(列出所有会话)。"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parameters_schema(&self) -> serde_json::Value {
|
|
||||||
json!({
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"action": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": ["spawn", "write", "read", "kill", "list"],
|
|
||||||
"description": "操作类型: spawn=创建会话, write=写入输入, read=读取输出, kill=终止会话, list=列出所有会话"
|
|
||||||
},
|
|
||||||
"session_id": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "会话ID (write/read/kill 需要)"
|
|
||||||
},
|
|
||||||
"command": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "要执行的命令 (spawn 需要)"
|
|
||||||
},
|
|
||||||
"data": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "写入终端的数据,支持转义序列: \\n(换行) \\x03(Ctrl+C) \\x04(Ctrl+D) \\x1a(Ctrl+Z) (write 需要)"
|
|
||||||
},
|
|
||||||
"offset": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "输出读取起始行号,从 0 开始 (read 可选,默认 0)",
|
|
||||||
"minimum": 0
|
|
||||||
},
|
|
||||||
"limit": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "读取的最大行数 (read 可选,默认 500)",
|
|
||||||
"minimum": 1
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["action"]
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn exclusive(&self) -> bool {
|
|
||||||
true
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
|
||||||
let action = match args.get("action").and_then(|v| v.as_str()) {
|
|
||||||
Some(a) => a,
|
|
||||||
None => {
|
|
||||||
return Ok(ToolResult {
|
|
||||||
success: false,
|
|
||||||
output: String::new(),
|
|
||||||
error: Some("Missing required parameter: action".to_string()),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
match action {
|
|
||||||
"spawn" => {
|
|
||||||
let command = match args.get("command").and_then(|v| v.as_str()) {
|
|
||||||
Some(c) => c,
|
|
||||||
None => {
|
|
||||||
return Ok(ToolResult {
|
|
||||||
success: false,
|
|
||||||
output: String::new(),
|
|
||||||
error: Some("Missing required parameter: command".to_string()),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
match self.pty_manager.spawn(command) {
|
|
||||||
Ok(output) => Ok(ToolResult {
|
|
||||||
success: true,
|
|
||||||
output,
|
|
||||||
error: None,
|
|
||||||
}),
|
|
||||||
Err(e) => Ok(ToolResult {
|
|
||||||
success: false,
|
|
||||||
output: String::new(),
|
|
||||||
error: Some(e),
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"write" => {
|
|
||||||
let session_id = match args.get("session_id").and_then(|v| v.as_str()) {
|
|
||||||
Some(id) => id,
|
|
||||||
None => {
|
|
||||||
return Ok(ToolResult {
|
|
||||||
success: false,
|
|
||||||
output: String::new(),
|
|
||||||
error: Some("Missing required parameter: session_id".to_string()),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let data = match args.get("data").and_then(|v| v.as_str()) {
|
|
||||||
Some(d) => d,
|
|
||||||
None => {
|
|
||||||
return Ok(ToolResult {
|
|
||||||
success: false,
|
|
||||||
output: String::new(),
|
|
||||||
error: Some("Missing required parameter: data".to_string()),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
match self.pty_manager.write(session_id, data) {
|
|
||||||
Ok(output) => Ok(ToolResult {
|
|
||||||
success: true,
|
|
||||||
output,
|
|
||||||
error: None,
|
|
||||||
}),
|
|
||||||
Err(e) => Ok(ToolResult {
|
|
||||||
success: false,
|
|
||||||
output: String::new(),
|
|
||||||
error: Some(e),
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"read" => {
|
|
||||||
let session_id = match args.get("session_id").and_then(|v| v.as_str()) {
|
|
||||||
Some(id) => id,
|
|
||||||
None => {
|
|
||||||
return Ok(ToolResult {
|
|
||||||
success: false,
|
|
||||||
output: String::new(),
|
|
||||||
error: Some("Missing required parameter: session_id".to_string()),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let offset = args
|
|
||||||
.get("offset")
|
|
||||||
.and_then(|v| v.as_u64())
|
|
||||||
.unwrap_or(0) as usize;
|
|
||||||
let limit = args
|
|
||||||
.get("limit")
|
|
||||||
.and_then(|v| v.as_u64())
|
|
||||||
.unwrap_or(500) as usize;
|
|
||||||
match self.pty_manager.read(session_id, offset, limit) {
|
|
||||||
Ok(output) => Ok(ToolResult {
|
|
||||||
success: true,
|
|
||||||
output,
|
|
||||||
error: None,
|
|
||||||
}),
|
|
||||||
Err(e) => Ok(ToolResult {
|
|
||||||
success: false,
|
|
||||||
output: String::new(),
|
|
||||||
error: Some(e),
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"kill" => {
|
|
||||||
let session_id = match args.get("session_id").and_then(|v| v.as_str()) {
|
|
||||||
Some(id) => id,
|
|
||||||
None => {
|
|
||||||
return Ok(ToolResult {
|
|
||||||
success: false,
|
|
||||||
output: String::new(),
|
|
||||||
error: Some("Missing required parameter: session_id".to_string()),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
match self.pty_manager.kill(session_id) {
|
|
||||||
Ok(output) => Ok(ToolResult {
|
|
||||||
success: true,
|
|
||||||
output,
|
|
||||||
error: None,
|
|
||||||
}),
|
|
||||||
Err(e) => Ok(ToolResult {
|
|
||||||
success: false,
|
|
||||||
output: String::new(),
|
|
||||||
error: Some(e),
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"list" => {
|
|
||||||
let output = self.pty_manager.list();
|
|
||||||
Ok(ToolResult {
|
|
||||||
success: true,
|
|
||||||
output,
|
|
||||||
error: None,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
_ => Ok(ToolResult {
|
|
||||||
success: false,
|
|
||||||
output: String::new(),
|
|
||||||
error: Some(format!("Unknown action: {}", action)),
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -17,15 +17,7 @@ impl ToolRegistry {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn register<T: ToolTrait + 'static>(&self, tool: T) {
|
pub fn register<T: ToolTrait + 'static>(&self, tool: T) {
|
||||||
self.tools
|
self.tools.lock().unwrap().insert(tool.name().to_string(), Arc::new(tool));
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.insert(tool.name().to_string(), Arc::new(tool));
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Register an existing Arc-wrapped tool by name
|
|
||||||
pub fn register_raw(&self, name: String, tool: Arc<dyn ToolTrait>) {
|
|
||||||
self.tools.lock().unwrap().insert(name, tool);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get(&self, name: &str) -> Option<Arc<dyn ToolTrait>> {
|
pub fn get(&self, name: &str) -> Option<Arc<dyn ToolTrait>> {
|
||||||
@ -70,17 +62,6 @@ impl ToolRegistry {
|
|||||||
.map(|(k, v)| (k.clone(), v.clone()))
|
.map(|(k, v)| (k.clone(), v.clone()))
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 生成工具列表描述,用于子 Agent 系统提示词
|
|
||||||
pub fn describe_for_prompt(&self) -> String {
|
|
||||||
let mut entries: Vec<String> = self
|
|
||||||
.iter()
|
|
||||||
.into_iter()
|
|
||||||
.map(|(name, tool)| format!("- {}: {}", name, tool.description()))
|
|
||||||
.collect();
|
|
||||||
entries.sort();
|
|
||||||
entries.join("\n")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for ToolRegistry {
|
impl Default for ToolRegistry {
|
||||||
|
|||||||
@ -115,9 +115,7 @@ impl SchemaCleanr {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if let Some(Value::String(t)) = obj.get("type")
|
if let Some(Value::String(t)) = obj.get("type")
|
||||||
&& t == "object"
|
&& t == "object" && !obj.contains_key("properties") {
|
||||||
&& !obj.contains_key("properties")
|
|
||||||
{
|
|
||||||
tracing::warn!("Object schema without 'properties' field may cause issues");
|
tracing::warn!("Object schema without 'properties' field may cause issues");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -175,8 +173,7 @@ impl SchemaCleanr {
|
|||||||
|
|
||||||
// Handle anyOf/oneOf simplification
|
// Handle anyOf/oneOf simplification
|
||||||
if (obj.contains_key("anyOf") || obj.contains_key("oneOf"))
|
if (obj.contains_key("anyOf") || obj.contains_key("oneOf"))
|
||||||
&& let Some(simplified) = Self::try_simplify_union(&obj, defs, strategy, ref_stack)
|
&& let Some(simplified) = Self::try_simplify_union(&obj, defs, strategy, ref_stack) {
|
||||||
{
|
|
||||||
return simplified;
|
return simplified;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -246,8 +243,7 @@ impl SchemaCleanr {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if let Some(def_name) = Self::parse_local_ref(ref_value)
|
if let Some(def_name) = Self::parse_local_ref(ref_value)
|
||||||
&& let Some(definition) = defs.get(def_name.as_str())
|
&& let Some(definition) = defs.get(def_name.as_str()) {
|
||||||
{
|
|
||||||
ref_stack.insert(ref_value.to_string());
|
ref_stack.insert(ref_value.to_string());
|
||||||
let cleaned = Self::clean_with_defs(definition.clone(), defs, strategy, ref_stack);
|
let cleaned = Self::clean_with_defs(definition.clone(), defs, strategy, ref_stack);
|
||||||
ref_stack.remove(ref_value);
|
ref_stack.remove(ref_value);
|
||||||
@ -344,14 +340,11 @@ impl SchemaCleanr {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if let Some(Value::Array(arr)) = obj.get("enum")
|
if let Some(Value::Array(arr)) = obj.get("enum")
|
||||||
&& arr.len() == 1
|
&& arr.len() == 1 && matches!(arr[0], Value::Null) {
|
||||||
&& matches!(arr[0], Value::Null)
|
|
||||||
{
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if let Some(Value::String(t)) = obj.get("type")
|
if let Some(Value::String(t)) = obj.get("type")
|
||||||
&& t == "null"
|
&& t == "null" {
|
||||||
{
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -410,10 +403,7 @@ impl SchemaCleanr {
|
|||||||
|
|
||||||
match non_null.len() {
|
match non_null.len() {
|
||||||
0 => Value::String("null".to_string()),
|
0 => Value::String("null".to_string()),
|
||||||
1 => non_null
|
1 => non_null.into_iter().next().unwrap_or(Value::String("null".to_string())),
|
||||||
.into_iter()
|
|
||||||
.next()
|
|
||||||
.unwrap_or(Value::String("null".to_string())),
|
|
||||||
_ => Value::Array(non_null),
|
_ => Value::Array(non_null),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
use std::collections::HashSet;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use mime_guess::mime;
|
use mime_guess::mime;
|
||||||
@ -31,20 +31,14 @@ fn parse_target_chat_id(raw: &str) -> Result<(&str, &str, Option<&str>), String>
|
|||||||
match parts.len() {
|
match parts.len() {
|
||||||
2 => {
|
2 => {
|
||||||
if parts[0].is_empty() || parts[1].is_empty() {
|
if parts[0].is_empty() || parts[1].is_empty() {
|
||||||
Err(format!(
|
Err(format!("Invalid target_chat_id format '{}': channel and chat_id must not be empty", raw))
|
||||||
"Invalid target_chat_id format '{}': channel and chat_id must not be empty",
|
|
||||||
raw
|
|
||||||
))
|
|
||||||
} else {
|
} else {
|
||||||
Ok((parts[0], parts[1], None))
|
Ok((parts[0], parts[1], None))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
3 => {
|
3 => {
|
||||||
if parts[0].is_empty() || parts[1].is_empty() || parts[2].is_empty() {
|
if parts[0].is_empty() || parts[1].is_empty() || parts[2].is_empty() {
|
||||||
Err(format!(
|
Err(format!("Invalid target_chat_id format '{}': all three parts must not be empty", raw))
|
||||||
"Invalid target_chat_id format '{}': all three parts must not be empty",
|
|
||||||
raw
|
|
||||||
))
|
|
||||||
} else {
|
} else {
|
||||||
Ok((parts[0], parts[1], Some(parts[2])))
|
Ok((parts[0], parts[1], Some(parts[2])))
|
||||||
}
|
}
|
||||||
@ -104,8 +98,8 @@ target_chat_id 支持两种格式:<channel>:<chat_id>(发送到该聊天下
|
|||||||
.ok_or_else(|| anyhow::anyhow!("missing content"))?;
|
.ok_or_else(|| anyhow::anyhow!("missing content"))?;
|
||||||
|
|
||||||
// 1. Parse target_chat_id
|
// 1. Parse target_chat_id
|
||||||
let (channel, chat_id, dialog_id) =
|
let (channel, chat_id, dialog_id) = parse_target_chat_id(raw_id)
|
||||||
parse_target_chat_id(raw_id).map_err(|e| anyhow::anyhow!(e))?;
|
.map_err(|e| anyhow::anyhow!(e))?;
|
||||||
|
|
||||||
// 2. Validate channel
|
// 2. Validate channel
|
||||||
if !self.available_channels.contains(channel) {
|
if !self.available_channels.contains(channel) {
|
||||||
@ -115,11 +109,7 @@ target_chat_id 支持两种格式:<channel>:<chat_id>(发送到该聊天下
|
|||||||
error: Some(format!(
|
error: Some(format!(
|
||||||
"Channel '{}' is not available. Available channels: {}",
|
"Channel '{}' is not available. Available channels: {}",
|
||||||
channel,
|
channel,
|
||||||
self.available_channels
|
self.available_channels.iter().cloned().collect::<Vec<_>>().join(", ")
|
||||||
.iter()
|
|
||||||
.cloned()
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.join(", ")
|
|
||||||
)),
|
)),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -139,8 +129,7 @@ target_chat_id 支持两种格式:<channel>:<chat_id>(发送到该聊天下
|
|||||||
let media = parse_files_arg(&args);
|
let media = parse_files_arg(&args);
|
||||||
|
|
||||||
// 4. Send via messenger
|
// 4. Send via messenger
|
||||||
match self
|
match self.messenger
|
||||||
.messenger
|
|
||||||
.send_message(channel, chat_id, dialog_id, content, source, media)
|
.send_message(channel, chat_id, dialog_id, content, source, media)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
use crate::bus::{MediaItem, MessageSource};
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use crate::bus::{MediaItem, MessageSource};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct ToolResult {
|
pub struct ToolResult {
|
||||||
|
|||||||
@ -239,11 +239,7 @@ fn is_private_host(host: &str) -> bool {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if host
|
if host.rsplit('.').next().is_some_and(|label| label == "local") {
|
||||||
.rsplit('.')
|
|
||||||
.next()
|
|
||||||
.is_some_and(|label| label == "local")
|
|
||||||
{
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -252,9 +248,7 @@ fn is_private_host(host: &str) -> bool {
|
|||||||
std::net::IpAddr::V4(v4) => {
|
std::net::IpAddr::V4(v4) => {
|
||||||
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified()
|
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified()
|
||||||
}
|
}
|
||||||
std::net::IpAddr::V6(v6) => {
|
std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(),
|
||||||
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
use picobot::config::{Config, LLMProviderConfig};
|
|
||||||
use picobot::providers::{ChatCompletionRequest, Message, create_provider};
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use picobot::providers::{create_provider, ChatCompletionRequest, Message};
|
||||||
|
use picobot::config::{Config, LLMProviderConfig};
|
||||||
|
|
||||||
fn load_config() -> Option<LLMProviderConfig> {
|
fn load_config() -> Option<LLMProviderConfig> {
|
||||||
dotenv::from_filename("tests/test.env").ok()?;
|
dotenv::from_filename("tests/test.env").ok()?;
|
||||||
@ -42,7 +42,8 @@ fn create_request(content: &str) -> ChatCompletionRequest {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
#[ignore]
|
#[ignore]
|
||||||
async fn test_openai_simple_completion() {
|
async fn test_openai_simple_completion() {
|
||||||
let config = load_config().expect("Please configure tests/test.env with valid API keys");
|
let config = load_config()
|
||||||
|
.expect("Please configure tests/test.env with valid API keys");
|
||||||
|
|
||||||
let provider = create_provider(config).expect("Failed to create provider");
|
let provider = create_provider(config).expect("Failed to create provider");
|
||||||
let response = provider.chat(create_request("Say 'ok'")).await.unwrap();
|
let response = provider.chat(create_request("Say 'ok'")).await.unwrap();
|
||||||
@ -56,7 +57,8 @@ async fn test_openai_simple_completion() {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
#[ignore]
|
#[ignore]
|
||||||
async fn test_openai_conversation() {
|
async fn test_openai_conversation() {
|
||||||
let config = load_config().expect("Please configure tests/test.env with valid API keys");
|
let config = load_config()
|
||||||
|
.expect("Please configure tests/test.env with valid API keys");
|
||||||
|
|
||||||
let provider = create_provider(config).expect("Failed to create provider");
|
let provider = create_provider(config).expect("Failed to create provider");
|
||||||
|
|
||||||
@ -80,9 +82,7 @@ async fn test_openai_conversation() {
|
|||||||
async fn test_config_load() {
|
async fn test_config_load() {
|
||||||
// Test that config.json can be loaded and provider config created
|
// Test that config.json can be loaded and provider config created
|
||||||
let config = Config::load("config.json").expect("Failed to load config.json");
|
let config = Config::load("config.json").expect("Failed to load config.json");
|
||||||
let provider_config = config
|
let provider_config = config.get_provider_config("default").expect("Failed to get provider config");
|
||||||
.get_provider_config("default")
|
|
||||||
.expect("Failed to get provider config");
|
|
||||||
|
|
||||||
assert_eq!(provider_config.provider_type, "openai");
|
assert_eq!(provider_config.provider_type, "openai");
|
||||||
assert_eq!(provider_config.name, "aliyun");
|
assert_eq!(provider_config.name, "aliyun");
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
use picobot::protocol::{SessionSummary, WsInbound, WsOutbound};
|
|
||||||
use picobot::providers::{ChatCompletionRequest, Message};
|
use picobot::providers::{ChatCompletionRequest, Message};
|
||||||
|
use picobot::protocol::{SessionSummary, WsInbound, WsOutbound};
|
||||||
|
|
||||||
/// Test that message with special characters is properly escaped
|
/// Test that message with special characters is properly escaped
|
||||||
#[test]
|
#[test]
|
||||||
@ -19,9 +19,7 @@ fn test_message_special_characters() {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_multiline_system_prompt() {
|
fn test_multiline_system_prompt() {
|
||||||
let messages = vec![
|
let messages = vec![
|
||||||
Message::system(
|
Message::system("You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate"),
|
||||||
"You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate",
|
|
||||||
),
|
|
||||||
Message::user("Hi"),
|
Message::user("Hi"),
|
||||||
];
|
];
|
||||||
|
|
||||||
@ -35,7 +33,10 @@ fn test_multiline_system_prompt() {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_chat_request_serialization() {
|
fn test_chat_request_serialization() {
|
||||||
let request = ChatCompletionRequest {
|
let request = ChatCompletionRequest {
|
||||||
messages: vec![Message::system("You are helpful"), Message::user("Hello")],
|
messages: vec![
|
||||||
|
Message::system("You are helpful"),
|
||||||
|
Message::user("Hello"),
|
||||||
|
],
|
||||||
temperature: Some(0.7),
|
temperature: Some(0.7),
|
||||||
max_tokens: Some(100),
|
max_tokens: Some(100),
|
||||||
tools: None,
|
tools: None,
|
||||||
|
|||||||
@ -41,7 +41,7 @@ async fn test_scheduler_types_roundtrip() {
|
|||||||
/// Verify that next_run_for_schedule produces valid future timestamps.
|
/// Verify that next_run_for_schedule produces valid future timestamps.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_next_run_always_future() {
|
fn test_next_run_always_future() {
|
||||||
use picobot::scheduler::{Schedule, next_run_for_schedule};
|
use picobot::scheduler::{next_run_for_schedule, Schedule};
|
||||||
|
|
||||||
let now = 1700000000000_i64;
|
let now = 1700000000000_i64;
|
||||||
|
|
||||||
@ -56,10 +56,6 @@ fn test_next_run_always_future() {
|
|||||||
for s in &schedules {
|
for s in &schedules {
|
||||||
let next = next_run_for_schedule(s, now);
|
let next = next_run_for_schedule(s, now);
|
||||||
assert!(next.is_some(), "expected next run for {:?}", s);
|
assert!(next.is_some(), "expected next run for {:?}", s);
|
||||||
assert!(
|
assert!(next.unwrap() > now, "next run should be after now for {:?}", s);
|
||||||
next.unwrap() > now,
|
|
||||||
"next run should be after now for {:?}",
|
|
||||||
s
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
use picobot::config::LLMProviderConfig;
|
|
||||||
use picobot::providers::{ChatCompletionRequest, Message, Tool, ToolFunction, create_provider};
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use picobot::providers::{create_provider, ChatCompletionRequest, Message, Tool, ToolFunction};
|
||||||
|
use picobot::config::LLMProviderConfig;
|
||||||
|
|
||||||
fn load_openai_config() -> Option<LLMProviderConfig> {
|
fn load_openai_config() -> Option<LLMProviderConfig> {
|
||||||
dotenv::from_filename("tests/test.env").ok()?;
|
dotenv::from_filename("tests/test.env").ok()?;
|
||||||
@ -53,7 +53,8 @@ fn make_weather_tool() -> Tool {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
#[ignore]
|
#[ignore]
|
||||||
async fn test_openai_tool_call() {
|
async fn test_openai_tool_call() {
|
||||||
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
|
let config = load_openai_config()
|
||||||
|
.expect("Please configure tests/test.env with valid API keys");
|
||||||
|
|
||||||
let provider = create_provider(config).expect("Failed to create provider");
|
let provider = create_provider(config).expect("Failed to create provider");
|
||||||
|
|
||||||
@ -67,11 +68,7 @@ async fn test_openai_tool_call() {
|
|||||||
let response = provider.chat(request).await.unwrap();
|
let response = provider.chat(request).await.unwrap();
|
||||||
|
|
||||||
// Should have tool calls
|
// Should have tool calls
|
||||||
assert!(
|
assert!(!response.tool_calls.is_empty(), "Expected tool call, got: {}", response.content);
|
||||||
!response.tool_calls.is_empty(),
|
|
||||||
"Expected tool call, got: {}",
|
|
||||||
response.content
|
|
||||||
);
|
|
||||||
|
|
||||||
let tool_call = &response.tool_calls[0];
|
let tool_call = &response.tool_calls[0];
|
||||||
assert_eq!(tool_call.name, "get_weather");
|
assert_eq!(tool_call.name, "get_weather");
|
||||||
@ -81,7 +78,8 @@ async fn test_openai_tool_call() {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
#[ignore]
|
#[ignore]
|
||||||
async fn test_openai_tool_call_with_manual_execution() {
|
async fn test_openai_tool_call_with_manual_execution() {
|
||||||
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
|
let config = load_openai_config()
|
||||||
|
.expect("Please configure tests/test.env with valid API keys");
|
||||||
|
|
||||||
let provider = create_provider(config).expect("Failed to create provider");
|
let provider = create_provider(config).expect("Failed to create provider");
|
||||||
|
|
||||||
@ -94,7 +92,8 @@ async fn test_openai_tool_call_with_manual_execution() {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let response1 = provider.chat(request1).await.unwrap();
|
let response1 = provider.chat(request1).await.unwrap();
|
||||||
let tool_call = response1.tool_calls.first().expect("Expected tool call");
|
let tool_call = response1.tool_calls.first()
|
||||||
|
.expect("Expected tool call");
|
||||||
assert_eq!(tool_call.name, "get_weather");
|
assert_eq!(tool_call.name, "get_weather");
|
||||||
|
|
||||||
// Second request with tool result
|
// Second request with tool result
|
||||||
@ -117,7 +116,8 @@ async fn test_openai_tool_call_with_manual_execution() {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
#[ignore]
|
#[ignore]
|
||||||
async fn test_openai_no_tool_when_not_provided() {
|
async fn test_openai_no_tool_when_not_provided() {
|
||||||
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
|
let config = load_openai_config()
|
||||||
|
.expect("Please configure tests/test.env with valid API keys");
|
||||||
|
|
||||||
let provider = create_provider(config).expect("Failed to create provider");
|
let provider = create_provider(config).expect("Failed to create provider");
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user