Compare commits
3 Commits
c971bc3639
...
eb0f6c0bc7
| Author | SHA1 | Date | |
|---|---|---|---|
| eb0f6c0bc7 | |||
| ef601107ac | |||
| 8bb32fa066 |
@ -27,3 +27,4 @@ mime_guess = "2.0"
|
|||||||
base64 = "0.22"
|
base64 = "0.22"
|
||||||
tempfile = "3"
|
tempfile = "3"
|
||||||
meval = "0.2"
|
meval = "0.2"
|
||||||
|
rusqlite = { version = "0.32", features = ["bundled"] }
|
||||||
|
|||||||
300
PERSISTENCE.md
Normal file
300
PERSISTENCE.md
Normal file
@ -0,0 +1,300 @@
|
|||||||
|
# PicoBot 持久化设计说明
|
||||||
|
|
||||||
|
本文档介绍 PicoBot 当前的会话持久化实现,目标读者是需要维护或集成该模块的技术人员。
|
||||||
|
|
||||||
|
## 1. 总览
|
||||||
|
|
||||||
|
PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据库文件:
|
||||||
|
|
||||||
|
- 默认路径:`~/.picobot/storage/sessions.db`
|
||||||
|
- 初始化入口:`SessionStore::new()`
|
||||||
|
- 核心实现:`src/storage/mod.rs`
|
||||||
|
|
||||||
|
数据库启动时会完成以下初始化:
|
||||||
|
|
||||||
|
- 打开 SQLite 连接
|
||||||
|
- 创建父目录
|
||||||
|
- 打开 WAL 模式
|
||||||
|
- 打开外键约束
|
||||||
|
- 自动建表和建索引
|
||||||
|
|
||||||
|
当前持久化只覆盖两类核心数据:
|
||||||
|
|
||||||
|
- `sessions`:会话元数据
|
||||||
|
- `messages`:会话内的消息流水
|
||||||
|
|
||||||
|
内存中的 `Session` 负责运行态处理,SQLite 负责跨进程、跨重启保留历史。整体设计是“内存缓存 + SQLite 事实来源”。
|
||||||
|
|
||||||
|
## 2. 持久化在系统中的位置
|
||||||
|
|
||||||
|
相关模块职责如下:
|
||||||
|
|
||||||
|
- `src/gateway/session.rs`
|
||||||
|
- 管理运行时 `Session`
|
||||||
|
- 在收到消息时确保持久化会话存在
|
||||||
|
- 首次访问某个 `chat_id` 时从数据库加载历史
|
||||||
|
- 在新消息产生后同时写入数据库和内存历史
|
||||||
|
- `src/storage/mod.rs`
|
||||||
|
- 封装 SQLite 访问
|
||||||
|
- 提供会话和消息的增删改查
|
||||||
|
- `src/bus/message.rs`
|
||||||
|
- 定义持久化消息结构 `ChatMessage`
|
||||||
|
- `src/providers/*`
|
||||||
|
- 将历史消息转换为不同 LLM provider 需要的格式
|
||||||
|
|
||||||
|
典型关系如下:
|
||||||
|
|
||||||
|
1. 网关收到用户消息。
|
||||||
|
2. `SessionManager` 定位到对应 channel 的运行时 `Session`。
|
||||||
|
3. `Session::ensure_persistent_session(chat_id)` 确保数据库里有对应会话记录。
|
||||||
|
4. `Session::ensure_chat_loaded(chat_id)` 在内存中没有历史时,从 `messages` 表加载该会话全部历史。
|
||||||
|
5. 新的用户消息先写入 `messages`,再放入内存历史。
|
||||||
|
6. Agent 执行后产生的 assistant/tool 消息按实际顺序继续写入 `messages`。
|
||||||
|
7. 下次进程重启或 session 过期后,可从数据库完整恢复上下文。
|
||||||
|
|
||||||
|
## 3. 会话标识规则
|
||||||
|
|
||||||
|
数据库中的会话主键并不总是随机 UUID,而是依据 channel 类型区分:
|
||||||
|
|
||||||
|
- CLI 会话:`session_id == chat_id`
|
||||||
|
- 非 CLI 会话:`session_id = "{channel_name}:{chat_id}"`
|
||||||
|
|
||||||
|
这套规则由 `persistent_session_id(channel_name, chat_id)` 统一生成,目的是:
|
||||||
|
|
||||||
|
- 对 CLI 支持显式创建、切换和管理多个会话
|
||||||
|
- 对外部渠道(例如飞书)让同一个 chat 稳定映射到同一条持久化会话
|
||||||
|
|
||||||
|
## 4. 表结构
|
||||||
|
|
||||||
|
### 4.1 sessions
|
||||||
|
|
||||||
|
保存会话级元数据,每条记录代表一个可被恢复的历史会话。
|
||||||
|
|
||||||
|
字段说明:
|
||||||
|
|
||||||
|
| 字段 | 类型 | 含义 | 当前用途 |
|
||||||
|
| --- | --- | --- | --- |
|
||||||
|
| `id` | `TEXT PRIMARY KEY` | 会话主键 | 作为会话唯一标识,被 `messages.session_id` 引用 |
|
||||||
|
| `title` | `TEXT NOT NULL` | 会话标题 | CLI 展示、重命名 |
|
||||||
|
| `channel_name` | `TEXT NOT NULL` | 来源渠道名 | 例如 `cli`、`feishu` |
|
||||||
|
| `chat_id` | `TEXT NOT NULL` | 渠道侧会话标识 | 用于恢复和路由到同一聊天 |
|
||||||
|
| `summary` | `TEXT` | 会话摘要 | 预留字段,当前 schema 中存在,但当前代码未写入实际摘要 |
|
||||||
|
| `created_at` | `INTEGER NOT NULL` | 创建时间 | 毫秒级 Unix 时间戳 |
|
||||||
|
| `updated_at` | `INTEGER NOT NULL` | 最近元数据更新时间 | 重命名、归档、追加消息时更新 |
|
||||||
|
| `last_active_at` | `INTEGER NOT NULL` | 最近活跃时间 | 追加消息、清空历史时更新,用于排序 |
|
||||||
|
| `archived_at` | `INTEGER` | 归档时间 | 非空表示会话已归档 |
|
||||||
|
| `deleted_at` | `INTEGER` | 删除时间 | 预留字段,当前读取逻辑会过滤该字段,但当前删除实现是物理删除 |
|
||||||
|
| `message_count` | `INTEGER NOT NULL DEFAULT 0` | 消息数 | 追加消息时自增,清空历史时重置 |
|
||||||
|
|
||||||
|
索引:
|
||||||
|
|
||||||
|
- `idx_sessions_channel_archived(channel_name, archived_at, last_active_at DESC)`
|
||||||
|
- 用于按渠道列出会话,并支持过滤归档态和按最近活跃排序
|
||||||
|
- `idx_sessions_updated_at(updated_at DESC)`
|
||||||
|
- 用于最近更新时间维度的查询优化
|
||||||
|
|
||||||
|
### 4.2 messages
|
||||||
|
|
||||||
|
保存会话中的消息流水。这里的“消息”不仅包括用户和助手文本,还包括工具调用结果。
|
||||||
|
|
||||||
|
字段说明:
|
||||||
|
|
||||||
|
| 字段 | 类型 | 含义 | 当前用途 |
|
||||||
|
| --- | --- | --- | --- |
|
||||||
|
| `id` | `TEXT PRIMARY KEY` | 消息唯一标识 | 对应 `ChatMessage.id` |
|
||||||
|
| `session_id` | `TEXT NOT NULL` | 所属会话 | 外键指向 `sessions.id` |
|
||||||
|
| `seq` | `INTEGER NOT NULL` | 会话内顺序号 | 保证同一会话消息顺序稳定 |
|
||||||
|
| `role` | `TEXT NOT NULL` | 消息角色 | 典型值为 `user`、`assistant`、`system`、`tool` |
|
||||||
|
| `content` | `TEXT NOT NULL` | 消息正文 | 文本内容或工具结果文本 |
|
||||||
|
| `media_refs_json` | `TEXT NOT NULL` | 媒体引用列表 JSON | 存储附件、本地文件路径等上下文引用 |
|
||||||
|
| `tool_call_id` | `TEXT` | 工具调用 ID | 仅 `role=tool` 时通常有值,用来关联某次工具结果对应哪一个 tool call |
|
||||||
|
| `tool_name` | `TEXT` | 工具名称 | 例如 `calculator`、`file_write` |
|
||||||
|
| `tool_calls_json` | `TEXT` | assistant 发起的工具调用列表 JSON | 仅 assistant 发出工具调用时有值 |
|
||||||
|
| `created_at` | `INTEGER NOT NULL` | 消息创建时间 | 毫秒级 Unix 时间戳 |
|
||||||
|
|
||||||
|
约束和索引:
|
||||||
|
|
||||||
|
- 外键:`FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE`
|
||||||
|
- 唯一约束:`UNIQUE(session_id, seq)`,确保同一会话内顺序号不重复
|
||||||
|
- 索引:
|
||||||
|
- `idx_messages_session_seq(session_id, seq)`,按顺序读取历史
|
||||||
|
- `idx_messages_session_created(session_id, created_at)`,按时间维度检索
|
||||||
|
|
||||||
|
## 5. 字段与运行时结构的映射
|
||||||
|
|
||||||
|
持久化层存储的消息对象是 `ChatMessage`,关键映射关系如下:
|
||||||
|
|
||||||
|
| `ChatMessage` 字段 | 对应数据库字段 | 说明 |
|
||||||
|
| --- | --- | --- |
|
||||||
|
| `id` | `messages.id` | 消息唯一 ID |
|
||||||
|
| `role` | `messages.role` | 消息角色 |
|
||||||
|
| `content` | `messages.content` | 文本主体 |
|
||||||
|
| `media_refs` | `messages.media_refs_json` | 序列化为 JSON 数组 |
|
||||||
|
| `timestamp` | `messages.created_at` | 时间戳 |
|
||||||
|
| `tool_call_id` | `messages.tool_call_id` | 工具结果与调用的关联 ID |
|
||||||
|
| `tool_name` | `messages.tool_name` | 工具名 |
|
||||||
|
| `tool_calls` | `messages.tool_calls_json` | assistant 发起的工具调用列表 |
|
||||||
|
|
||||||
|
设计上分成 `tool_call_id` 和 `tool_calls_json` 两种字段,是因为两者表达的是不同方向的信息:
|
||||||
|
|
||||||
|
- `tool_calls_json` 表示“assistant 想调用哪些工具”
|
||||||
|
- `tool_call_id` 表示“这一条 tool 结果是在回应哪一次工具调用”
|
||||||
|
|
||||||
|
## 6. 数据写入流程
|
||||||
|
|
||||||
|
### 6.1 创建会话
|
||||||
|
|
||||||
|
有两种进入方式:
|
||||||
|
|
||||||
|
- CLI 模式调用 `create_cli_session()` 显式创建会话
|
||||||
|
- 渠道消息进入时调用 `ensure_channel_session()` 自动创建或复用会话
|
||||||
|
|
||||||
|
创建时会写入 `sessions` 表,初始状态:
|
||||||
|
|
||||||
|
- `summary = NULL`
|
||||||
|
- `archived_at = NULL`
|
||||||
|
- `deleted_at = NULL`
|
||||||
|
- `message_count = 0`
|
||||||
|
|
||||||
|
### 6.2 追加消息
|
||||||
|
|
||||||
|
消息持久化统一走 `append_message()`,写入过程是一个 SQLite 事务:
|
||||||
|
|
||||||
|
1. 查询当前会话 `MAX(seq) + 1` 作为下一条消息顺序。
|
||||||
|
2. 将 `media_refs` 序列化为 `media_refs_json`。
|
||||||
|
3. 将 `tool_calls` 序列化为 `tool_calls_json`。
|
||||||
|
4. 插入一条 `messages` 记录。
|
||||||
|
5. 更新 `sessions.message_count`、`updated_at`、`last_active_at`。
|
||||||
|
6. 将 `sessions.archived_at` 置空。
|
||||||
|
7. 提交事务。
|
||||||
|
|
||||||
|
其中第 6 步很重要:归档会话一旦收到新消息,会自动恢复为活跃态。
|
||||||
|
|
||||||
|
### 6.3 读取历史
|
||||||
|
|
||||||
|
`load_messages(session_id)` 会按 `seq ASC` 读取整个消息历史,并把 JSON 字段反序列化回 `ChatMessage`。
|
||||||
|
|
||||||
|
因此它恢复的是“逻辑顺序”,而不是简单按创建时间排序。只要 `seq` 连续,重放顺序就稳定。
|
||||||
|
|
||||||
|
## 7. 典型时序
|
||||||
|
|
||||||
|
### 7.1 普通问答
|
||||||
|
|
||||||
|
1. 用户消息进入网关。
|
||||||
|
2. 如果数据库中没有对应会话,先插入一条 `sessions`。
|
||||||
|
3. 用户消息写入 `messages`,`role = user`。
|
||||||
|
4. Agent 基于历史生成回复。
|
||||||
|
5. assistant 回复写入 `messages`,`role = assistant`。
|
||||||
|
6. 会话的 `message_count` 增加 2,`last_active_at` 更新时间。
|
||||||
|
|
||||||
|
### 7.2 带工具调用的问答
|
||||||
|
|
||||||
|
1. assistant 先生成一条带 `tool_calls_json` 的消息,`role = assistant`。
|
||||||
|
2. 系统执行对应工具。
|
||||||
|
3. 每个工具结果作为独立消息写入 `messages`,`role = tool`。
|
||||||
|
4. 这些 `tool` 消息会带 `tool_call_id` 和 `tool_name`。
|
||||||
|
5. assistant 最终整理工具结果后再写入一条普通回复。
|
||||||
|
|
||||||
|
这样保存后,即使进程重启,后续仍能完整恢复:
|
||||||
|
|
||||||
|
- assistant 当时发起了哪些工具调用
|
||||||
|
- 每个工具调用返回了什么
|
||||||
|
- 最终 assistant 给了什么结论
|
||||||
|
|
||||||
|
## 8. 会话生命周期操作
|
||||||
|
|
||||||
|
### 8.1 重命名
|
||||||
|
|
||||||
|
`rename_session(session_id, title)`:
|
||||||
|
|
||||||
|
- 更新 `sessions.title`
|
||||||
|
- 更新 `sessions.updated_at`
|
||||||
|
|
||||||
|
### 8.2 归档
|
||||||
|
|
||||||
|
`archive_session(session_id)`:
|
||||||
|
|
||||||
|
- 将 `sessions.archived_at` 设为当前时间
|
||||||
|
- 更新 `sessions.updated_at`
|
||||||
|
- 不删除消息数据
|
||||||
|
|
||||||
|
列出会话时:
|
||||||
|
|
||||||
|
- `include_archived = false` 只返回 `archived_at IS NULL` 的会话
|
||||||
|
- `include_archived = true` 返回全部未删除会话
|
||||||
|
|
||||||
|
### 8.3 清空消息
|
||||||
|
|
||||||
|
`clear_messages(session_id)`:
|
||||||
|
|
||||||
|
- 删除该会话在 `messages` 中的所有记录
|
||||||
|
- 将 `sessions.message_count` 重置为 0
|
||||||
|
- 更新 `updated_at` 和 `last_active_at`
|
||||||
|
- 保留会话本身
|
||||||
|
|
||||||
|
这适合“保留会话入口,但丢弃聊天内容”的场景。
|
||||||
|
|
||||||
|
### 8.4 删除会话
|
||||||
|
|
||||||
|
`delete_session(session_id)`:
|
||||||
|
|
||||||
|
- 显式删除 `messages`
|
||||||
|
- 再删除 `sessions`
|
||||||
|
|
||||||
|
虽然表结构中存在 `deleted_at` 字段,并且查询时也会过滤 `deleted_at IS NULL`,但当前实现并没有做软删除,而是直接物理删除。换句话说:
|
||||||
|
|
||||||
|
- `deleted_at` 当前是保留字段
|
||||||
|
- 如果后续需要回收站或审计恢复,可以基于它演进成软删除
|
||||||
|
|
||||||
|
## 9. 并发与一致性
|
||||||
|
|
||||||
|
当前 `SessionStore` 的一致性策略比较直接:
|
||||||
|
|
||||||
|
- 进程内使用 `Arc<Mutex<Connection>>` 保护单连接访问
|
||||||
|
- 追加消息时使用 SQLite 事务
|
||||||
|
- 单条消息的写入与会话计数更新在同一事务中完成
|
||||||
|
|
||||||
|
这意味着:
|
||||||
|
|
||||||
|
- 对单进程场景,消息顺序和 `message_count` 是一致的
|
||||||
|
- `seq` 通过事务内 `MAX(seq) + 1` 分配,避免同一连接并发下的顺序错乱
|
||||||
|
- WAL 模式提升读取和写入并存时的稳定性
|
||||||
|
|
||||||
|
需要注意的是,当前设计主要面向单进程本地运行。如果未来要扩展到多进程或多实例共享同一数据库,需要重新评估:
|
||||||
|
|
||||||
|
- 单连接模型
|
||||||
|
- `MAX(seq) + 1` 的扩展性
|
||||||
|
- 会话加载缓存和跨实例同步
|
||||||
|
|
||||||
|
## 10. 当前实现中的保留点
|
||||||
|
|
||||||
|
下面这些字段或能力已经在 schema 中出现,但还没有完整业务闭环:
|
||||||
|
|
||||||
|
- `sessions.summary`
|
||||||
|
- 当前代码没有把 `ContextCompressor` 产出的摘要写回数据库
|
||||||
|
- 目前摘要只参与运行时上下文压缩,不参与持久化
|
||||||
|
- `sessions.deleted_at`
|
||||||
|
- 当前查询逻辑兼容软删除
|
||||||
|
- 当前删除实现仍然是物理删除
|
||||||
|
|
||||||
|
这说明当前 schema 已经为“会话摘要”和“软删除”预留了演进空间,但并未完全落地。
|
||||||
|
|
||||||
|
## 11. 给维护者的快速判断指南
|
||||||
|
|
||||||
|
如果你要排查持久化问题,可以先按下面的思路判断:
|
||||||
|
|
||||||
|
- 会话查不到:先看 `persistent_session_id` 是否和实际 `channel_name/chat_id` 一致
|
||||||
|
- 重启后没历史:检查 `ensure_chat_loaded()` 调用链,以及数据库文件路径是否正确
|
||||||
|
- 消息顺序不对:检查 `messages.seq`
|
||||||
|
- 工具调用上下文异常:同时检查 `tool_calls_json` 和 `tool_call_id`
|
||||||
|
- 会话列表里看不到记录:检查 `archived_at` 和 `include_archived` 参数
|
||||||
|
- 清空后仍有上下文:确认是内存历史没清掉,还是数据库 `messages` 没删掉
|
||||||
|
|
||||||
|
## 12. 总结
|
||||||
|
|
||||||
|
PicoBot 当前的持久化设计比较克制,核心目标只有两个:
|
||||||
|
|
||||||
|
- 让同一会话在重启后可以恢复上下文
|
||||||
|
- 让工具调用链可以被完整回放
|
||||||
|
|
||||||
|
从实现上看,它不是通用 ORM,也不是复杂事件存储,而是一层针对聊天历史的轻量 SQLite 封装。对于当前单机、单进程、聊天驱动的运行方式,这个设计足够直接,也便于继续演进。
|
||||||
@ -211,6 +211,7 @@ fn chat_message_to_llm_message(m: &ChatMessage) -> Message {
|
|||||||
content,
|
content,
|
||||||
tool_call_id: m.tool_call_id.clone(),
|
tool_call_id: m.tool_call_id.clone(),
|
||||||
name: m.tool_name.clone(),
|
name: m.tool_name.clone(),
|
||||||
|
tool_calls: m.tool_calls.clone(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -223,6 +224,12 @@ pub struct AgentLoop {
|
|||||||
max_iterations: usize,
|
max_iterations: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct AgentProcessResult {
|
||||||
|
pub final_response: ChatMessage,
|
||||||
|
pub emitted_messages: Vec<ChatMessage>,
|
||||||
|
}
|
||||||
|
|
||||||
impl AgentLoop {
|
impl AgentLoop {
|
||||||
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
|
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
|
||||||
let max_iterations = provider_config.max_tool_iterations;
|
let max_iterations = provider_config.max_tool_iterations;
|
||||||
@ -267,12 +274,13 @@ impl AgentLoop {
|
|||||||
/// it loops back to the LLM with the tool results until either:
|
/// it loops back to the LLM with the tool results until either:
|
||||||
/// - The LLM returns no more tool calls (final response)
|
/// - The LLM returns no more tool calls (final response)
|
||||||
/// - Maximum iterations are reached
|
/// - Maximum iterations are reached
|
||||||
pub async fn process(&self, mut messages: Vec<ChatMessage>) -> Result<ChatMessage, AgentError> {
|
pub async fn process(&self, mut messages: Vec<ChatMessage>) -> Result<AgentProcessResult, AgentError> {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process");
|
tracing::debug!(history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process");
|
||||||
|
|
||||||
// Track tool calls for loop detection
|
// Track tool calls for loop detection
|
||||||
let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default());
|
let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default());
|
||||||
|
let mut emitted_messages = Vec::new();
|
||||||
|
|
||||||
for iteration in 0..self.max_iterations {
|
for iteration in 0..self.max_iterations {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
@ -316,15 +324,23 @@ impl AgentLoop {
|
|||||||
// If no tool calls, this is the final response
|
// If no tool calls, this is the final response
|
||||||
if response.tool_calls.is_empty() {
|
if response.tool_calls.is_empty() {
|
||||||
let assistant_message = ChatMessage::assistant(response.content);
|
let assistant_message = ChatMessage::assistant(response.content);
|
||||||
return Ok(assistant_message);
|
emitted_messages.push(assistant_message.clone());
|
||||||
|
return Ok(AgentProcessResult {
|
||||||
|
final_response: assistant_message,
|
||||||
|
emitted_messages,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute tool calls
|
// Execute tool calls
|
||||||
tracing::info!(iteration, count = response.tool_calls.len(), "Tool calls detected, executing tools");
|
tracing::info!(iteration, count = response.tool_calls.len(), "Tool calls detected, executing tools");
|
||||||
|
|
||||||
// Add assistant message with tool calls
|
// Add assistant message with tool calls
|
||||||
let assistant_message = ChatMessage::assistant(response.content.clone());
|
let assistant_message = ChatMessage::assistant_with_tool_calls(
|
||||||
|
response.content.clone(),
|
||||||
|
response.tool_calls.clone(),
|
||||||
|
);
|
||||||
messages.push(assistant_message.clone());
|
messages.push(assistant_message.clone());
|
||||||
|
emitted_messages.push(assistant_message);
|
||||||
|
|
||||||
// Execute tools and add results to messages
|
// Execute tools and add results to messages
|
||||||
let tool_results = self.execute_tools(&response.tool_calls).await;
|
let tool_results = self.execute_tools(&response.tool_calls).await;
|
||||||
@ -356,7 +372,8 @@ impl AgentLoop {
|
|||||||
tool_call.name.clone(),
|
tool_call.name.clone(),
|
||||||
format!("{}\n\n[上一条结果]\n{}", msg, truncated_output),
|
format!("{}\n\n[上一条结果]\n{}", msg, truncated_output),
|
||||||
);
|
);
|
||||||
messages.push(tool_message);
|
messages.push(tool_message.clone());
|
||||||
|
emitted_messages.push(tool_message);
|
||||||
}
|
}
|
||||||
LoopDetectionResult::Ok => {
|
LoopDetectionResult::Ok => {
|
||||||
let tool_message = ChatMessage::tool(
|
let tool_message = ChatMessage::tool(
|
||||||
@ -364,7 +381,8 @@ impl AgentLoop {
|
|||||||
tool_call.name.clone(),
|
tool_call.name.clone(),
|
||||||
truncated_output,
|
truncated_output,
|
||||||
);
|
);
|
||||||
messages.push(tool_message);
|
messages.push(tool_message.clone());
|
||||||
|
emitted_messages.push(tool_message);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -400,7 +418,11 @@ impl AgentLoop {
|
|||||||
match (*self.provider).chat(request).await {
|
match (*self.provider).chat(request).await {
|
||||||
Ok(response) => {
|
Ok(response) => {
|
||||||
let assistant_message = ChatMessage::assistant(response.content);
|
let assistant_message = ChatMessage::assistant(response.content);
|
||||||
Ok(assistant_message)
|
emitted_messages.push(assistant_message.clone());
|
||||||
|
Ok(AgentProcessResult {
|
||||||
|
final_response: assistant_message,
|
||||||
|
emitted_messages,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
// Fallback if summary call fails
|
// Fallback if summary call fails
|
||||||
@ -408,7 +430,11 @@ impl AgentLoop {
|
|||||||
let final_message = ChatMessage::assistant(
|
let final_message = ChatMessage::assistant(
|
||||||
format!("I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.", self.max_iterations)
|
format!("I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.", self.max_iterations)
|
||||||
);
|
);
|
||||||
Ok(final_message)
|
emitted_messages.push(final_message.clone());
|
||||||
|
Ok(AgentProcessResult {
|
||||||
|
final_response: final_message,
|
||||||
|
emitted_messages,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -593,6 +619,25 @@ mod tests {
|
|||||||
// If there's only 1 tool, should return false regardless
|
// If there's only 1 tool, should return false regardless
|
||||||
assert_eq!(calls.len() <= 1, true);
|
assert_eq!(calls.len() <= 1, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chat_message_to_llm_message_preserves_assistant_tool_calls() {
|
||||||
|
let chat_message = ChatMessage::assistant_with_tool_calls(
|
||||||
|
"calling tool",
|
||||||
|
vec![ToolCall {
|
||||||
|
id: "call_1".to_string(),
|
||||||
|
name: "calculator".to_string(),
|
||||||
|
arguments: serde_json::json!({ "expression": "2+2" }),
|
||||||
|
}],
|
||||||
|
);
|
||||||
|
|
||||||
|
let provider_message = chat_message_to_llm_message(&chat_message);
|
||||||
|
|
||||||
|
assert_eq!(provider_message.role, "assistant");
|
||||||
|
assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1);
|
||||||
|
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].id, "call_1");
|
||||||
|
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].name, "calculator");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
pub mod agent_loop;
|
pub mod agent_loop;
|
||||||
pub mod context_compressor;
|
pub mod context_compressor;
|
||||||
|
|
||||||
pub use agent_loop::{AgentLoop, AgentError};
|
pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult};
|
||||||
pub use context_compressor::ContextCompressor;
|
pub use context_compressor::ContextCompressor;
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::providers::ToolCall;
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// ContentBlock - Multimodal content representation (OpenAI-style)
|
// ContentBlock - Multimodal content representation (OpenAI-style)
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
@ -69,6 +71,8 @@ pub struct ChatMessage {
|
|||||||
pub tool_call_id: Option<String>,
|
pub tool_call_id: Option<String>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub tool_name: Option<String>,
|
pub tool_name: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_calls: Option<Vec<ToolCall>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ChatMessage {
|
impl ChatMessage {
|
||||||
@ -81,6 +85,7 @@ impl ChatMessage {
|
|||||||
timestamp: current_timestamp(),
|
timestamp: current_timestamp(),
|
||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
tool_name: None,
|
tool_name: None,
|
||||||
|
tool_calls: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -93,6 +98,7 @@ impl ChatMessage {
|
|||||||
timestamp: current_timestamp(),
|
timestamp: current_timestamp(),
|
||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
tool_name: None,
|
tool_name: None,
|
||||||
|
tool_calls: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -105,6 +111,20 @@ impl ChatMessage {
|
|||||||
timestamp: current_timestamp(),
|
timestamp: current_timestamp(),
|
||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
tool_name: None,
|
tool_name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn assistant_with_tool_calls(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
|
||||||
|
Self {
|
||||||
|
id: uuid::Uuid::new_v4().to_string(),
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: content.into(),
|
||||||
|
media_refs: Vec::new(),
|
||||||
|
timestamp: current_timestamp(),
|
||||||
|
tool_call_id: None,
|
||||||
|
tool_name: None,
|
||||||
|
tool_calls: Some(tool_calls),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -117,6 +137,7 @@ impl ChatMessage {
|
|||||||
timestamp: current_timestamp(),
|
timestamp: current_timestamp(),
|
||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
tool_name: None,
|
tool_name: None,
|
||||||
|
tool_calls: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -129,6 +150,7 @@ impl ChatMessage {
|
|||||||
timestamp: current_timestamp(),
|
timestamp: current_timestamp(),
|
||||||
tool_call_id: Some(tool_call_id.into()),
|
tool_call_id: Some(tool_call_id.into()),
|
||||||
tool_name: Some(tool_name.into()),
|
tool_name: Some(tool_name.into()),
|
||||||
|
tool_calls: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,6 +2,23 @@ use crate::bus::ChatMessage;
|
|||||||
|
|
||||||
use super::channel::CliChannel;
|
use super::channel::CliChannel;
|
||||||
|
|
||||||
|
pub enum InputEvent {
|
||||||
|
Message(ChatMessage),
|
||||||
|
Command(InputCommand),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum InputCommand {
|
||||||
|
Exit,
|
||||||
|
Clear,
|
||||||
|
New(Option<String>),
|
||||||
|
Sessions,
|
||||||
|
Use(String),
|
||||||
|
Rename(String),
|
||||||
|
Archive,
|
||||||
|
Delete,
|
||||||
|
}
|
||||||
|
|
||||||
pub struct InputHandler {
|
pub struct InputHandler {
|
||||||
channel: CliChannel,
|
channel: CliChannel,
|
||||||
}
|
}
|
||||||
@ -13,7 +30,7 @@ impl InputHandler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn read_input(&mut self, prompt: &str) -> Result<Option<ChatMessage>, InputError> {
|
pub async fn read_input(&mut self, prompt: &str) -> Result<Option<InputEvent>, InputError> {
|
||||||
match self.channel.read_line(prompt).await {
|
match self.channel.read_line(prompt).await {
|
||||||
Ok(Some(line)) => {
|
Ok(Some(line)) => {
|
||||||
if line.trim().is_empty() {
|
if line.trim().is_empty() {
|
||||||
@ -21,10 +38,10 @@ impl InputHandler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if let Some(cmd) = self.handle_special_commands(&line) {
|
if let Some(cmd) = self.handle_special_commands(&line) {
|
||||||
return Ok(Some(cmd));
|
return Ok(Some(InputEvent::Command(cmd)));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Some(ChatMessage::user(line)))
|
Ok(Some(InputEvent::Message(ChatMessage::user(line))))
|
||||||
}
|
}
|
||||||
Ok(None) => Ok(None),
|
Ok(None) => Ok(None),
|
||||||
Err(e) => Err(InputError::IoError(e)),
|
Err(e) => Err(InputError::IoError(e)),
|
||||||
@ -39,10 +56,21 @@ impl InputHandler {
|
|||||||
self.channel.write_response(content).await.map_err(InputError::IoError)
|
self.channel.write_response(content).await.map_err(InputError::IoError)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_special_commands(&self, line: &str) -> Option<ChatMessage> {
|
fn handle_special_commands(&self, line: &str) -> Option<InputCommand> {
|
||||||
match line.trim() {
|
let trimmed = line.trim();
|
||||||
"/quit" | "/exit" | "/q" => Some(ChatMessage::system("__EXIT__")),
|
let mut parts = trimmed.splitn(2, char::is_whitespace);
|
||||||
"/clear" => Some(ChatMessage::system("__CLEAR__")),
|
let command = parts.next()?;
|
||||||
|
let arg = parts.next().map(str::trim).filter(|value| !value.is_empty());
|
||||||
|
|
||||||
|
match command {
|
||||||
|
"/quit" | "/exit" | "/q" => Some(InputCommand::Exit),
|
||||||
|
"/clear" => Some(InputCommand::Clear),
|
||||||
|
"/new" => Some(InputCommand::New(arg.map(ToOwned::to_owned))),
|
||||||
|
"/sessions" => Some(InputCommand::Sessions),
|
||||||
|
"/use" => arg.map(|value| InputCommand::Use(value.to_string())),
|
||||||
|
"/rename" => arg.map(|value| InputCommand::Rename(value.to_string())),
|
||||||
|
"/archive" => Some(InputCommand::Archive),
|
||||||
|
"/delete" => Some(InputCommand::Delete),
|
||||||
_ => None,
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -68,3 +96,34 @@ impl std::fmt::Display for InputError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl std::error::Error for InputError {}
|
impl std::error::Error for InputError {}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_special_command_parsing() {
|
||||||
|
let handler = InputHandler::new();
|
||||||
|
|
||||||
|
assert_eq!(handler.handle_special_commands("/quit"), Some(InputCommand::Exit));
|
||||||
|
assert_eq!(handler.handle_special_commands("/clear"), Some(InputCommand::Clear));
|
||||||
|
assert_eq!(handler.handle_special_commands("/new"), Some(InputCommand::New(None)));
|
||||||
|
assert_eq!(
|
||||||
|
handler.handle_special_commands("/new planning"),
|
||||||
|
Some(InputCommand::New(Some("planning".to_string())))
|
||||||
|
);
|
||||||
|
assert_eq!(handler.handle_special_commands("/sessions"), Some(InputCommand::Sessions));
|
||||||
|
assert_eq!(
|
||||||
|
handler.handle_special_commands("/use abc123"),
|
||||||
|
Some(InputCommand::Use("abc123".to_string()))
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
handler.handle_special_commands("/rename project alpha"),
|
||||||
|
Some(InputCommand::Rename("project alpha".to_string()))
|
||||||
|
);
|
||||||
|
assert_eq!(handler.handle_special_commands("/archive"), Some(InputCommand::Archive));
|
||||||
|
assert_eq!(handler.handle_special_commands("/delete"), Some(InputCommand::Delete));
|
||||||
|
assert_eq!(handler.handle_special_commands("/unknown"), None);
|
||||||
|
assert_eq!(handler.handle_special_commands("/use"), None);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -2,4 +2,4 @@ pub mod channel;
|
|||||||
pub mod input;
|
pub mod input;
|
||||||
|
|
||||||
pub use channel::CliChannel;
|
pub use channel::CliChannel;
|
||||||
pub use input::InputHandler;
|
pub use input::{InputCommand, InputEvent, InputHandler};
|
||||||
|
|||||||
@ -3,7 +3,38 @@ pub use crate::protocol::{WsInbound, WsOutbound, serialize_inbound, serialize_ou
|
|||||||
use futures_util::{SinkExt, StreamExt};
|
use futures_util::{SinkExt, StreamExt};
|
||||||
use tokio_tungstenite::{connect_async, tungstenite::Message};
|
use tokio_tungstenite::{connect_async, tungstenite::Message};
|
||||||
|
|
||||||
use crate::cli::InputHandler;
|
use crate::cli::{InputCommand, InputEvent, InputHandler};
|
||||||
|
|
||||||
|
fn format_session_list(sessions: &[crate::protocol::SessionSummary], current_session_id: Option<&str>) -> String {
|
||||||
|
if sessions.is_empty() {
|
||||||
|
return "No sessions found.".to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut lines = Vec::with_capacity(sessions.len() + 1);
|
||||||
|
lines.push("Sessions:".to_string());
|
||||||
|
for session in sessions {
|
||||||
|
let marker = if current_session_id == Some(session.session_id.as_str()) {
|
||||||
|
"*"
|
||||||
|
} else {
|
||||||
|
"-"
|
||||||
|
};
|
||||||
|
let archived = if session.archived_at.is_some() {
|
||||||
|
" [archived]"
|
||||||
|
} else {
|
||||||
|
""
|
||||||
|
};
|
||||||
|
lines.push(format!(
|
||||||
|
"{} {} | {} | {} messages{}",
|
||||||
|
marker,
|
||||||
|
session.session_id,
|
||||||
|
session.title,
|
||||||
|
session.message_count,
|
||||||
|
archived,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
lines.join("\n")
|
||||||
|
}
|
||||||
|
|
||||||
fn parse_message(raw: &str) -> Result<WsOutbound, serde_json::Error> {
|
fn parse_message(raw: &str) -> Result<WsOutbound, serde_json::Error> {
|
||||||
serde_json::from_str(raw)
|
serde_json::from_str(raw)
|
||||||
@ -16,7 +47,8 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
let (mut sender, mut receiver) = ws_stream.split();
|
let (mut sender, mut receiver) = ws_stream.split();
|
||||||
|
|
||||||
let mut input = InputHandler::new();
|
let mut input = InputHandler::new();
|
||||||
input.write_output("picobot CLI - Type /quit to exit, /clear to clear history\n").await?;
|
let mut current_session_id: Option<String> = None;
|
||||||
|
input.write_output("picobot CLI - Commands: /new [title], /sessions, /use <session>, /rename <title>, /archive, /delete, /clear, /quit\n").await?;
|
||||||
|
|
||||||
// Main loop: poll both stdin and WebSocket
|
// Main loop: poll both stdin and WebSocket
|
||||||
loop {
|
loop {
|
||||||
@ -35,10 +67,38 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
input.write_output(&format!("Error: {}", message)).await?;
|
input.write_output(&format!("Error: {}", message)).await?;
|
||||||
}
|
}
|
||||||
WsOutbound::SessionEstablished { session_id } => {
|
WsOutbound::SessionEstablished { session_id } => {
|
||||||
|
current_session_id = Some(session_id.clone());
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(session_id = %session_id, "Session established");
|
tracing::debug!(session_id = %session_id, "Session established");
|
||||||
input.write_output(&format!("Session: {}\n", session_id)).await?;
|
input.write_output(&format!("Session: {}\n", session_id)).await?;
|
||||||
}
|
}
|
||||||
|
WsOutbound::SessionCreated { session_id, title } => {
|
||||||
|
current_session_id = Some(session_id.clone());
|
||||||
|
input.write_output(&format!("Created session: {} ({})\n", session_id, title)).await?;
|
||||||
|
}
|
||||||
|
WsOutbound::SessionList { sessions, current_session_id: listed_current } => {
|
||||||
|
let display = format_session_list(&sessions, listed_current.as_deref());
|
||||||
|
input.write_output(&format!("{}\n", display)).await?;
|
||||||
|
}
|
||||||
|
WsOutbound::SessionLoaded { session_id, title, message_count } => {
|
||||||
|
current_session_id = Some(session_id.clone());
|
||||||
|
input.write_output(&format!("Loaded session: {} ({}, {} messages)\n", session_id, title, message_count)).await?;
|
||||||
|
}
|
||||||
|
WsOutbound::SessionRenamed { session_id, title } => {
|
||||||
|
input.write_output(&format!("Renamed session: {} -> {}\n", session_id, title)).await?;
|
||||||
|
}
|
||||||
|
WsOutbound::SessionArchived { session_id } => {
|
||||||
|
input.write_output(&format!("Archived session: {}\n", session_id)).await?;
|
||||||
|
}
|
||||||
|
WsOutbound::SessionDeleted { session_id } => {
|
||||||
|
if current_session_id.as_deref() == Some(session_id.as_str()) {
|
||||||
|
current_session_id = None;
|
||||||
|
}
|
||||||
|
input.write_output(&format!("Deleted session: {}\n", session_id)).await?;
|
||||||
|
}
|
||||||
|
WsOutbound::HistoryCleared { session_id } => {
|
||||||
|
input.write_output(&format!("Cleared history for session: {}\n", session_id)).await?;
|
||||||
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -54,26 +114,78 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
// Handle stdin input
|
// Handle stdin input
|
||||||
result = input.read_input("> ") => {
|
result = input.read_input("> ") => {
|
||||||
match result {
|
match result {
|
||||||
Ok(Some(msg)) => {
|
Ok(Some(event)) => {
|
||||||
match msg.content.as_str() {
|
match event {
|
||||||
"__EXIT__" => {
|
InputEvent::Command(InputCommand::Exit) => {
|
||||||
input.write_output("Goodbye!").await?;
|
input.write_output("Goodbye!").await?;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
"__CLEAR__" => {
|
InputEvent::Command(InputCommand::Clear) => {
|
||||||
let inbound = WsInbound::ClearHistory { chat_id: None };
|
let inbound = WsInbound::ClearHistory {
|
||||||
|
chat_id: None,
|
||||||
|
session_id: current_session_id.clone(),
|
||||||
|
};
|
||||||
if let Ok(text) = serialize_inbound(&inbound) {
|
if let Ok(text) = serialize_inbound(&inbound) {
|
||||||
let _ = sender.send(Message::Text(text.into())).await;
|
let _ = sender.send(Message::Text(text.into())).await;
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
_ => {}
|
InputEvent::Command(InputCommand::New(title)) => {
|
||||||
|
let inbound = WsInbound::CreateSession { title };
|
||||||
|
if let Ok(text) = serialize_inbound(&inbound) {
|
||||||
|
let _ = sender.send(Message::Text(text.into())).await;
|
||||||
}
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
InputEvent::Command(InputCommand::Sessions) => {
|
||||||
|
let inbound = WsInbound::ListSessions {
|
||||||
|
include_archived: true,
|
||||||
|
};
|
||||||
|
if let Ok(text) = serialize_inbound(&inbound) {
|
||||||
|
let _ = sender.send(Message::Text(text.into())).await;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
InputEvent::Command(InputCommand::Use(session_id)) => {
|
||||||
|
let inbound = WsInbound::LoadSession { session_id };
|
||||||
|
if let Ok(text) = serialize_inbound(&inbound) {
|
||||||
|
let _ = sender.send(Message::Text(text.into())).await;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
InputEvent::Command(InputCommand::Rename(title)) => {
|
||||||
|
let inbound = WsInbound::RenameSession {
|
||||||
|
session_id: current_session_id.clone(),
|
||||||
|
title,
|
||||||
|
};
|
||||||
|
if let Ok(text) = serialize_inbound(&inbound) {
|
||||||
|
let _ = sender.send(Message::Text(text.into())).await;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
InputEvent::Command(InputCommand::Archive) => {
|
||||||
|
let inbound = WsInbound::ArchiveSession {
|
||||||
|
session_id: current_session_id.clone(),
|
||||||
|
};
|
||||||
|
if let Ok(text) = serialize_inbound(&inbound) {
|
||||||
|
let _ = sender.send(Message::Text(text.into())).await;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
InputEvent::Command(InputCommand::Delete) => {
|
||||||
|
let inbound = WsInbound::DeleteSession {
|
||||||
|
session_id: current_session_id.clone(),
|
||||||
|
};
|
||||||
|
if let Ok(text) = serialize_inbound(&inbound) {
|
||||||
|
let _ = sender.send(Message::Text(text.into())).await;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
InputEvent::Message(msg) => {
|
||||||
let inbound = WsInbound::UserInput {
|
let inbound = WsInbound::UserInput {
|
||||||
content: msg.content,
|
content: msg.content,
|
||||||
channel: None,
|
channel: None,
|
||||||
chat_id: None,
|
chat_id: current_session_id.clone(),
|
||||||
sender_id: None,
|
sender_id: None,
|
||||||
};
|
};
|
||||||
if let Ok(text) = serialize_inbound(&inbound) {
|
if let Ok(text) = serialize_inbound(&inbound) {
|
||||||
@ -83,6 +195,8 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Ok(None) => break,
|
Ok(None) => break,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!(error = %e, "Input error");
|
tracing::error!(error = %e, "Input error");
|
||||||
|
|||||||
@ -267,9 +267,54 @@ fn resolve_env_placeholders(content: &str) -> String {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
fn write_test_config() -> tempfile::NamedTempFile {
|
||||||
|
let file = tempfile::NamedTempFile::new().unwrap();
|
||||||
|
std::fs::write(
|
||||||
|
file.path(),
|
||||||
|
r#"{
|
||||||
|
"providers": {
|
||||||
|
"aliyun": {
|
||||||
|
"type": "openai",
|
||||||
|
"base_url": "https://example.invalid/v1",
|
||||||
|
"api_key": "test-key",
|
||||||
|
"extra_headers": {}
|
||||||
|
},
|
||||||
|
"volcengine": {
|
||||||
|
"type": "openai",
|
||||||
|
"base_url": "https://example.invalid/volc",
|
||||||
|
"api_key": "test-key-2",
|
||||||
|
"extra_headers": {}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"models": {
|
||||||
|
"qwen-plus": {
|
||||||
|
"model_id": "qwen-plus",
|
||||||
|
"temperature": 0.0
|
||||||
|
},
|
||||||
|
"doubao-seed-2-0-lite-260215": {
|
||||||
|
"model_id": "doubao-seed-2-0-lite-260215"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"agents": {
|
||||||
|
"default": {
|
||||||
|
"provider": "aliyun",
|
||||||
|
"model": "qwen-plus"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"gateway": {
|
||||||
|
"host": "0.0.0.0",
|
||||||
|
"port": 19876
|
||||||
|
}
|
||||||
|
}"#,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
file
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_config_load() {
|
fn test_config_load() {
|
||||||
let config = Config::load("config.json").unwrap();
|
let file = write_test_config();
|
||||||
|
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||||
|
|
||||||
// Check providers
|
// Check providers
|
||||||
assert!(config.providers.contains_key("volcengine"));
|
assert!(config.providers.contains_key("volcengine"));
|
||||||
@ -285,7 +330,8 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_get_provider_config() {
|
fn test_get_provider_config() {
|
||||||
let config = Config::load("config.json").unwrap();
|
let file = write_test_config();
|
||||||
|
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||||
let provider_config = config.get_provider_config("default").unwrap();
|
let provider_config = config.get_provider_config("default").unwrap();
|
||||||
|
|
||||||
assert_eq!(provider_config.provider_type, "openai");
|
assert_eq!(provider_config.provider_type, "openai");
|
||||||
@ -296,7 +342,8 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_default_gateway_config() {
|
fn test_default_gateway_config() {
|
||||||
let config = Config::load("config.json").unwrap();
|
let file = write_test_config();
|
||||||
|
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||||
assert_eq!(config.gateway.host, "0.0.0.0");
|
assert_eq!(config.gateway.host, "0.0.0.0");
|
||||||
assert_eq!(config.gateway.port, 19876);
|
assert_eq!(config.gateway.port, 19876);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -29,7 +29,7 @@ impl GatewayState {
|
|||||||
// Session TTL from config (default 4 hours)
|
// Session TTL from config (default 4 hours)
|
||||||
let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4);
|
let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4);
|
||||||
|
|
||||||
let session_manager = SessionManager::new(session_ttl_hours, provider_config);
|
let session_manager = SessionManager::new(session_ttl_hours, provider_config)?;
|
||||||
let channel_manager = ChannelManager::new();
|
let channel_manager = ChannelManager::new();
|
||||||
let bus = channel_manager.bus();
|
let bus = channel_manager.bus();
|
||||||
|
|
||||||
|
|||||||
@ -7,6 +7,7 @@ use crate::bus::ChatMessage;
|
|||||||
use crate::config::LLMProviderConfig;
|
use crate::config::LLMProviderConfig;
|
||||||
use crate::agent::{AgentLoop, AgentError, ContextCompressor};
|
use crate::agent::{AgentLoop, AgentError, ContextCompressor};
|
||||||
use crate::protocol::WsOutbound;
|
use crate::protocol::WsOutbound;
|
||||||
|
use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
|
||||||
use crate::tools::{
|
use crate::tools::{
|
||||||
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
|
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
|
||||||
HttpRequestTool, ToolRegistry, WebFetchTool,
|
HttpRequestTool, ToolRegistry, WebFetchTool,
|
||||||
@ -23,6 +24,7 @@ pub struct Session {
|
|||||||
provider_config: LLMProviderConfig,
|
provider_config: LLMProviderConfig,
|
||||||
tools: Arc<ToolRegistry>,
|
tools: Arc<ToolRegistry>,
|
||||||
compressor: ContextCompressor,
|
compressor: ContextCompressor,
|
||||||
|
store: Arc<SessionStore>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Session {
|
impl Session {
|
||||||
@ -31,6 +33,7 @@ impl Session {
|
|||||||
provider_config: LLMProviderConfig,
|
provider_config: LLMProviderConfig,
|
||||||
user_tx: mpsc::Sender<WsOutbound>,
|
user_tx: mpsc::Sender<WsOutbound>,
|
||||||
tools: Arc<ToolRegistry>,
|
tools: Arc<ToolRegistry>,
|
||||||
|
store: Arc<SessionStore>,
|
||||||
) -> Result<Self, AgentError> {
|
) -> Result<Self, AgentError> {
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
id: Uuid::new_v4(),
|
id: Uuid::new_v4(),
|
||||||
@ -40,9 +43,33 @@ impl Session {
|
|||||||
provider_config: provider_config.clone(),
|
provider_config: provider_config.clone(),
|
||||||
tools,
|
tools,
|
||||||
compressor: ContextCompressor::new(provider_config.token_limit),
|
compressor: ContextCompressor::new(provider_config.token_limit),
|
||||||
|
store,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn persistent_session_id(&self, chat_id: &str) -> String {
|
||||||
|
persistent_session_id(&self.channel_name, chat_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ensure_persistent_session(&self, chat_id: &str) -> Result<SessionRecord, AgentError> {
|
||||||
|
self.store
|
||||||
|
.ensure_channel_session(&self.channel_name, chat_id)
|
||||||
|
.map_err(|err| AgentError::Other(format!("session persistence error: {}", err)))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ensure_chat_loaded(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||||
|
if self.chat_histories.contains_key(chat_id) {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let history = self
|
||||||
|
.store
|
||||||
|
.load_messages(&self.persistent_session_id(chat_id))
|
||||||
|
.map_err(|err| AgentError::Other(format!("session history load error: {}", err)))?;
|
||||||
|
self.chat_histories.insert(chat_id.to_string(), history);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// 获取或创建指定 chat_id 的会话历史
|
/// 获取或创建指定 chat_id 的会话历史
|
||||||
pub fn get_or_create_history(&mut self, chat_id: &str) -> &mut Vec<ChatMessage> {
|
pub fn get_or_create_history(&mut self, chat_id: &str) -> &mut Vec<ChatMessage> {
|
||||||
self.chat_histories
|
self.chat_histories
|
||||||
@ -55,41 +82,72 @@ impl Session {
|
|||||||
self.chat_histories.get(chat_id)
|
self.chat_histories.get(chat_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 添加用户消息到指定 chat_id 的历史
|
/// 使用完整消息追加到历史
|
||||||
pub fn add_user_message(&mut self, chat_id: &str, content: &str) {
|
pub fn add_message(&mut self, chat_id: &str, message: ChatMessage) {
|
||||||
let history = self.get_or_create_history(chat_id);
|
let history = self.get_or_create_history(chat_id);
|
||||||
history.push(ChatMessage::user(content));
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 添加带媒体的用户消息到指定 chat_id 的历史
|
|
||||||
pub fn add_user_message_with_media(&mut self, chat_id: &str, content: &str, media_refs: Vec<String>) {
|
|
||||||
let history = self.get_or_create_history(chat_id);
|
|
||||||
history.push(ChatMessage::user_with_media(content, media_refs));
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 添加助手响应到指定 chat_id 的历史
|
|
||||||
pub fn add_assistant_message(&mut self, chat_id: &str, message: ChatMessage) {
|
|
||||||
if let Some(history) = self.chat_histories.get_mut(chat_id) {
|
|
||||||
history.push(message);
|
history.push(message);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn remove_history(&mut self, chat_id: &str) {
|
||||||
|
self.chat_histories.remove(chat_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 清除指定 chat_id 的历史
|
pub fn clear_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||||
pub fn clear_chat_history(&mut self, chat_id: &str) {
|
|
||||||
if let Some(history) = self.chat_histories.get_mut(chat_id) {
|
if let Some(history) = self.chat_histories.get_mut(chat_id) {
|
||||||
let len = history.len();
|
let len = history.len();
|
||||||
history.clear();
|
history.clear();
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history cleared");
|
tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history cleared");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
self.store
|
||||||
|
.clear_messages(&self.persistent_session_id(chat_id))
|
||||||
|
.map_err(|err| AgentError::Other(format!("clear history persistence error: {}", err)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 将消息写入内存与持久化层
|
||||||
|
pub fn append_persisted_message(&mut self, chat_id: &str, message: ChatMessage) -> Result<(), AgentError> {
|
||||||
|
let session_id = self.persistent_session_id(chat_id);
|
||||||
|
self.store
|
||||||
|
.append_message(&session_id, &message)
|
||||||
|
.map_err(|err| AgentError::Other(format!("append message persistence error: {}", err)))?;
|
||||||
|
self.add_message(chat_id, message);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn append_persisted_messages<I>(&mut self, chat_id: &str, messages: I) -> Result<(), AgentError>
|
||||||
|
where
|
||||||
|
I: IntoIterator<Item = ChatMessage>,
|
||||||
|
{
|
||||||
|
for message in messages {
|
||||||
|
self.append_persisted_message(chat_id, message)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn create_user_message(&self, content: &str, media_refs: Vec<String>) -> ChatMessage {
|
||||||
|
if media_refs.is_empty() {
|
||||||
|
ChatMessage::user(content)
|
||||||
|
} else {
|
||||||
|
ChatMessage::user_with_media(content, media_refs)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 清除所有历史
|
/// 清除所有历史
|
||||||
pub fn clear_all_history(&mut self) {
|
pub fn clear_all_history(&mut self) -> Result<(), AgentError> {
|
||||||
|
let chat_ids: Vec<String> = self.chat_histories.keys().cloned().collect();
|
||||||
let total: usize = self.chat_histories.values().map(|h| h.len()).sum();
|
let total: usize = self.chat_histories.values().map(|h| h.len()).sum();
|
||||||
self.chat_histories.clear();
|
self.chat_histories.clear();
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(previous_total = total, "All chat histories cleared");
|
tracing::debug!(previous_total = total, "All chat histories cleared");
|
||||||
|
|
||||||
|
for chat_id in chat_ids {
|
||||||
|
self.store
|
||||||
|
.clear_messages(&self.persistent_session_id(&chat_id))
|
||||||
|
.map_err(|err| AgentError::Other(format!("clear history persistence error: {}", err)))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn send(&self, msg: WsOutbound) {
|
pub async fn send(&self, msg: WsOutbound) {
|
||||||
@ -118,6 +176,7 @@ pub struct SessionManager {
|
|||||||
inner: Arc<Mutex<SessionManagerInner>>,
|
inner: Arc<Mutex<SessionManagerInner>>,
|
||||||
provider_config: LLMProviderConfig,
|
provider_config: LLMProviderConfig,
|
||||||
tools: Arc<ToolRegistry>,
|
tools: Arc<ToolRegistry>,
|
||||||
|
store: Arc<SessionStore>,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct SessionManagerInner {
|
struct SessionManagerInner {
|
||||||
@ -144,8 +203,13 @@ fn default_tools() -> ToolRegistry {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl SessionManager {
|
impl SessionManager {
|
||||||
pub fn new(session_ttl_hours: u64, provider_config: LLMProviderConfig) -> Self {
|
pub fn new(session_ttl_hours: u64, provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
|
||||||
Self {
|
let store = Arc::new(
|
||||||
|
SessionStore::new()
|
||||||
|
.map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?,
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
inner: Arc::new(Mutex::new(SessionManagerInner {
|
inner: Arc::new(Mutex::new(SessionManagerInner {
|
||||||
sessions: HashMap::new(),
|
sessions: HashMap::new(),
|
||||||
session_timestamps: HashMap::new(),
|
session_timestamps: HashMap::new(),
|
||||||
@ -153,13 +217,66 @@ impl SessionManager {
|
|||||||
})),
|
})),
|
||||||
provider_config,
|
provider_config,
|
||||||
tools: Arc::new(default_tools()),
|
tools: Arc::new(default_tools()),
|
||||||
}
|
store,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tools(&self) -> Arc<ToolRegistry> {
|
pub fn tools(&self) -> Arc<ToolRegistry> {
|
||||||
self.tools.clone()
|
self.tools.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn store(&self) -> Arc<SessionStore> {
|
||||||
|
self.store.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn create_cli_session(&self, title: Option<&str>) -> Result<SessionRecord, AgentError> {
|
||||||
|
self.store
|
||||||
|
.create_cli_session(title)
|
||||||
|
.map_err(|err| AgentError::Other(format!("create session error: {}", err)))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_session_record(&self, session_id: &str) -> Result<Option<SessionRecord>, AgentError> {
|
||||||
|
self.store
|
||||||
|
.get_session(session_id)
|
||||||
|
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn list_cli_sessions(&self, include_archived: bool) -> Result<Vec<SessionRecord>, AgentError> {
|
||||||
|
self.store
|
||||||
|
.list_sessions("cli", include_archived)
|
||||||
|
.map_err(|err| AgentError::Other(format!("list sessions error: {}", err)))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rename_session(&self, session_id: &str, title: &str) -> Result<(), AgentError> {
|
||||||
|
self.store
|
||||||
|
.rename_session(session_id, title)
|
||||||
|
.map_err(|err| AgentError::Other(format!("rename session error: {}", err)))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn archive_session(&self, session_id: &str) -> Result<(), AgentError> {
|
||||||
|
self.store
|
||||||
|
.archive_session(session_id)
|
||||||
|
.map_err(|err| AgentError::Other(format!("archive session error: {}", err)))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn delete_session(&self, session_id: &str) -> Result<(), AgentError> {
|
||||||
|
self.store
|
||||||
|
.delete_session(session_id)
|
||||||
|
.map_err(|err| AgentError::Other(format!("delete session error: {}", err)))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_session_messages(&self, session_id: &str) -> Result<(), AgentError> {
|
||||||
|
self.store
|
||||||
|
.clear_messages(session_id)
|
||||||
|
.map_err(|err| AgentError::Other(format!("clear session error: {}", err)))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_session_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, AgentError> {
|
||||||
|
self.store
|
||||||
|
.load_messages(session_id)
|
||||||
|
.map_err(|err| AgentError::Other(format!("load messages error: {}", err)))
|
||||||
|
}
|
||||||
|
|
||||||
/// 确保 session 存在且未超时,超时则重建
|
/// 确保 session 存在且未超时,超时则重建
|
||||||
pub async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> {
|
pub async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> {
|
||||||
let mut inner = self.inner.lock().await;
|
let mut inner = self.inner.lock().await;
|
||||||
@ -189,6 +306,7 @@ impl SessionManager {
|
|||||||
self.provider_config.clone(),
|
self.provider_config.clone(),
|
||||||
user_tx,
|
user_tx,
|
||||||
self.tools.clone(),
|
self.tools.clone(),
|
||||||
|
self.store.clone(),
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
let arc = Arc::new(Mutex::new(session));
|
let arc = Arc::new(Mutex::new(session));
|
||||||
@ -251,15 +369,17 @@ impl SessionManager {
|
|||||||
let response = {
|
let response = {
|
||||||
let mut session_guard = session.lock().await;
|
let mut session_guard = session.lock().await;
|
||||||
|
|
||||||
|
session_guard.ensure_persistent_session(chat_id)?;
|
||||||
|
session_guard.ensure_chat_loaded(chat_id)?;
|
||||||
|
|
||||||
// 添加用户消息到历史
|
// 添加用户消息到历史
|
||||||
if media.is_empty() {
|
|
||||||
session_guard.add_user_message(chat_id, content);
|
|
||||||
} else {
|
|
||||||
let media_refs: Vec<String> = media.iter().map(|m| m.path.clone()).collect();
|
let media_refs: Vec<String> = media.iter().map(|m| m.path.clone()).collect();
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
|
if !media_refs.is_empty() {
|
||||||
tracing::debug!(media_count = %media.len(), media_refs = ?media_refs, "Adding user message with media");
|
tracing::debug!(media_count = %media.len(), media_refs = ?media_refs, "Adding user message with media");
|
||||||
session_guard.add_user_message_with_media(chat_id, content, media_refs);
|
|
||||||
}
|
}
|
||||||
|
let user_message = session_guard.create_user_message(content, media_refs);
|
||||||
|
session_guard.append_persisted_message(chat_id, user_message)?;
|
||||||
|
|
||||||
// 获取完整历史
|
// 获取完整历史
|
||||||
let history = session_guard.get_or_create_history(chat_id).clone();
|
let history = session_guard.get_or_create_history(chat_id).clone();
|
||||||
@ -271,12 +391,12 @@ impl SessionManager {
|
|||||||
|
|
||||||
// 创建 agent 并处理
|
// 创建 agent 并处理
|
||||||
let agent = session_guard.create_agent()?;
|
let agent = session_guard.create_agent()?;
|
||||||
let response = agent.process(history).await?;
|
let result = agent.process(history).await?;
|
||||||
|
|
||||||
// 添加助手响应到历史
|
// 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复
|
||||||
session_guard.add_assistant_message(chat_id, response.clone());
|
session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?;
|
||||||
|
|
||||||
response
|
result.final_response
|
||||||
};
|
};
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
@ -294,7 +414,7 @@ impl SessionManager {
|
|||||||
pub async fn clear_session_history(&self, channel_name: &str) -> Result<(), AgentError> {
|
pub async fn clear_session_history(&self, channel_name: &str) -> Result<(), AgentError> {
|
||||||
if let Some(session) = self.get(channel_name).await {
|
if let Some(session) = self.get(channel_name).await {
|
||||||
let mut session_guard = session.lock().await;
|
let mut session_guard = session.lock().await;
|
||||||
session_guard.clear_all_history();
|
session_guard.clear_all_history()?;
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@ -4,7 +4,7 @@ use axum::extract::State;
|
|||||||
use axum::response::Response;
|
use axum::response::Response;
|
||||||
use futures_util::{SinkExt, StreamExt};
|
use futures_util::{SinkExt, StreamExt};
|
||||||
use tokio::sync::{mpsc, Mutex};
|
use tokio::sync::{mpsc, Mutex};
|
||||||
use crate::protocol::{parse_inbound, serialize_outbound, WsInbound, WsOutbound};
|
use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound};
|
||||||
use super::{GatewayState, session::Session};
|
use super::{GatewayState, session::Session};
|
||||||
|
|
||||||
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
|
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
|
||||||
@ -24,8 +24,15 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// CLI 使用独立的 session,channel_name = "cli-{uuid}"
|
let initial_record = match state.session_manager.create_cli_session(None) {
|
||||||
let channel_name = format!("cli-{}", uuid::Uuid::new_v4());
|
Ok(record) => record,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!(error = %e, "Failed to create initial CLI session");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let channel_name = "cli".to_string();
|
||||||
|
|
||||||
// 创建 CLI session
|
// 创建 CLI session
|
||||||
let session = match Session::new(
|
let session = match Session::new(
|
||||||
@ -33,6 +40,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
|||||||
provider_config,
|
provider_config,
|
||||||
sender,
|
sender,
|
||||||
state.session_manager.tools(),
|
state.session_manager.tools(),
|
||||||
|
state.session_manager.store(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
@ -43,21 +51,27 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let session_id = session.lock().await.id;
|
if let Err(e) = session.lock().await.ensure_chat_loaded(&initial_record.id) {
|
||||||
tracing::info!(session_id = %session_id, "CLI session established");
|
tracing::error!(error = %e, session_id = %initial_record.id, "Failed to load initial CLI session history");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let runtime_session_id = session.lock().await.id;
|
||||||
|
let mut current_session_id = initial_record.id.clone();
|
||||||
|
tracing::info!(runtime_session_id = %runtime_session_id, session_id = %current_session_id, "CLI session established");
|
||||||
|
|
||||||
let _ = session
|
let _ = session
|
||||||
.lock()
|
.lock()
|
||||||
.await
|
.await
|
||||||
.send(WsOutbound::SessionEstablished {
|
.send(WsOutbound::SessionEstablished {
|
||||||
session_id: session_id.to_string(),
|
session_id: current_session_id.clone(),
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let (mut ws_sender, mut ws_receiver) = ws.split();
|
let (mut ws_sender, mut ws_receiver) = ws.split();
|
||||||
|
|
||||||
let mut receiver = receiver;
|
let mut receiver = receiver;
|
||||||
let session_id_for_sender = session_id;
|
let session_id_for_sender = runtime_session_id;
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
while let Some(msg) = receiver.recv().await {
|
while let Some(msg) = receiver.recv().await {
|
||||||
if let Ok(text) = serialize_outbound(&msg) {
|
if let Ok(text) = serialize_outbound(&msg) {
|
||||||
@ -76,7 +90,17 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
|||||||
let text = text.to_string();
|
let text = text.to_string();
|
||||||
match parse_inbound(&text) {
|
match parse_inbound(&text) {
|
||||||
Ok(inbound) => {
|
Ok(inbound) => {
|
||||||
handle_inbound(&session, inbound).await;
|
if let Err(e) = handle_inbound(&state, &session, &mut current_session_id, inbound).await {
|
||||||
|
tracing::warn!(error = %e, session_id = %current_session_id, "Failed to handle inbound message");
|
||||||
|
let _ = session
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.send(WsOutbound::Error {
|
||||||
|
code: "SESSION_ERROR".to_string(),
|
||||||
|
message: e.to_string(),
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::warn!(error = %e, "Failed to parse inbound message");
|
tracing::warn!(error = %e, "Failed to parse inbound message");
|
||||||
@ -93,92 +117,203 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
|||||||
}
|
}
|
||||||
Ok(WsMessage::Close(_)) | Err(_) => {
|
Ok(WsMessage::Close(_)) | Err(_) => {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(session_id = %session_id, "WebSocket closed");
|
tracing::debug!(session_id = %runtime_session_id, "WebSocket closed");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tracing::info!(session_id = %session_id, "CLI session ended");
|
tracing::info!(session_id = %runtime_session_id, current_session_id = %current_session_id, "CLI session ended");
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_inbound(session: &Arc<Mutex<Session>>, inbound: WsInbound) {
|
fn to_session_summary(record: crate::storage::SessionRecord) -> SessionSummary {
|
||||||
let inbound_clone = inbound.clone();
|
SessionSummary {
|
||||||
|
session_id: record.id,
|
||||||
// 提取 content 和 chat_id(CLI 使用 session id 作为 chat_id)
|
title: record.title,
|
||||||
let (content, chat_id) = match inbound_clone {
|
channel_name: record.channel_name,
|
||||||
WsInbound::UserInput {
|
chat_id: record.chat_id,
|
||||||
content,
|
message_count: record.message_count,
|
||||||
channel: _,
|
last_active_at: record.last_active_at,
|
||||||
chat_id,
|
archived_at: record.archived_at,
|
||||||
sender_id: _,
|
}
|
||||||
} => {
|
|
||||||
// CLI 使用 session 中的 channel_name 作为标识
|
|
||||||
// chat_id 使用传入的或使用默认
|
|
||||||
let chat_id = chat_id.unwrap_or_else(|| "default".to_string());
|
|
||||||
(content, chat_id)
|
|
||||||
}
|
}
|
||||||
_ => return,
|
|
||||||
};
|
|
||||||
|
|
||||||
|
async fn handle_inbound(
|
||||||
|
state: &Arc<GatewayState>,
|
||||||
|
session: &Arc<Mutex<Session>>,
|
||||||
|
current_session_id: &mut String,
|
||||||
|
inbound: WsInbound,
|
||||||
|
) -> Result<(), crate::agent::AgentError> {
|
||||||
|
match inbound {
|
||||||
|
WsInbound::UserInput { content, chat_id, .. } => {
|
||||||
|
let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone());
|
||||||
let mut session_guard = session.lock().await;
|
let mut session_guard = session.lock().await;
|
||||||
|
|
||||||
// 添加用户消息到历史
|
session_guard.ensure_persistent_session(&chat_id)?;
|
||||||
session_guard.add_user_message(&chat_id, &content);
|
session_guard.ensure_chat_loaded(&chat_id)?;
|
||||||
|
|
||||||
// 获取完整历史
|
let user_message = session_guard.create_user_message(&content, Vec::new());
|
||||||
let history = session_guard.get_or_create_history(&chat_id).clone();
|
session_guard.append_persisted_message(&chat_id, user_message)?;
|
||||||
|
|
||||||
// 压缩历史(如果需要)
|
let raw_history = session_guard.get_or_create_history(&chat_id).clone();
|
||||||
let history = match session_guard.compressor()
|
let history = match session_guard
|
||||||
.compress_if_needed(history, session_guard.provider_config())
|
.compressor()
|
||||||
|
.compress_if_needed(raw_history, session_guard.provider_config())
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(h) => h,
|
Ok(history) => history,
|
||||||
Err(e) => {
|
Err(error) => {
|
||||||
tracing::warn!(chat_id = %chat_id, error = %e, "Compression failed, using original history");
|
tracing::warn!(chat_id = %chat_id, error = %error, "Compression failed, using original history");
|
||||||
session_guard.get_or_create_history(&chat_id).clone()
|
session_guard.get_or_create_history(&chat_id).clone()
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// 创建 agent 并处理
|
let agent = session_guard.create_agent()?;
|
||||||
let agent = match session_guard.create_agent() {
|
|
||||||
Ok(a) => a,
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!(chat_id = %chat_id, error = %e, "Failed to create agent");
|
|
||||||
let _ = session_guard
|
|
||||||
.send(WsOutbound::Error {
|
|
||||||
code: "AGENT_ERROR".to_string(),
|
|
||||||
message: e.to_string(),
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
match agent.process(history).await {
|
match agent.process(history).await {
|
||||||
Ok(response) => {
|
Ok(result) => {
|
||||||
#[cfg(debug_assertions)]
|
session_guard.append_persisted_messages(&chat_id, result.emitted_messages.clone())?;
|
||||||
tracing::debug!(chat_id = %chat_id, "Agent response sent");
|
|
||||||
// 添加助手响应到历史
|
|
||||||
session_guard.add_assistant_message(&chat_id, response.clone());
|
|
||||||
let _ = session_guard
|
let _ = session_guard
|
||||||
.send(WsOutbound::AssistantResponse {
|
.send(WsOutbound::AssistantResponse {
|
||||||
id: response.id,
|
id: result.final_response.id,
|
||||||
content: response.content,
|
content: result.final_response.content,
|
||||||
role: response.role,
|
role: result.final_response.role,
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(error) => {
|
||||||
tracing::error!(chat_id = %chat_id, error = %e, "Agent process error");
|
tracing::error!(chat_id = %chat_id, error = %error, "Agent process error");
|
||||||
let _ = session_guard
|
let _ = session_guard
|
||||||
.send(WsOutbound::Error {
|
.send(WsOutbound::Error {
|
||||||
code: "LLM_ERROR".to_string(),
|
code: "LLM_ERROR".to_string(),
|
||||||
message: e.to_string(),
|
message: error.to_string(),
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
WsInbound::ClearHistory { session_id, chat_id } => {
|
||||||
|
let target = session_id.or(chat_id).unwrap_or_else(|| current_session_id.clone());
|
||||||
|
state.session_manager.clear_session_messages(&target)?;
|
||||||
|
|
||||||
|
let mut session_guard = session.lock().await;
|
||||||
|
session_guard.remove_history(&target);
|
||||||
|
let _ = session_guard
|
||||||
|
.send(WsOutbound::HistoryCleared {
|
||||||
|
session_id: target,
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
WsInbound::CreateSession { title } => {
|
||||||
|
let record = state.session_manager.create_cli_session(title.as_deref())?;
|
||||||
|
*current_session_id = record.id.clone();
|
||||||
|
|
||||||
|
let mut session_guard = session.lock().await;
|
||||||
|
session_guard.ensure_chat_loaded(&record.id)?;
|
||||||
|
let _ = session_guard
|
||||||
|
.send(WsOutbound::SessionCreated {
|
||||||
|
session_id: record.id,
|
||||||
|
title: record.title,
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
WsInbound::ListSessions { include_archived } => {
|
||||||
|
let records = state.session_manager.list_cli_sessions(include_archived)?;
|
||||||
|
let summaries = records.into_iter().map(to_session_summary).collect();
|
||||||
|
|
||||||
|
let session_guard = session.lock().await;
|
||||||
|
let _ = session_guard
|
||||||
|
.send(WsOutbound::SessionList {
|
||||||
|
sessions: summaries,
|
||||||
|
current_session_id: Some(current_session_id.clone()),
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
WsInbound::LoadSession { session_id } => {
|
||||||
|
let Some(record) = state.session_manager.get_session_record(&session_id)? else {
|
||||||
|
let session_guard = session.lock().await;
|
||||||
|
let _ = session_guard
|
||||||
|
.send(WsOutbound::Error {
|
||||||
|
code: "SESSION_NOT_FOUND".to_string(),
|
||||||
|
message: format!("Session not found: {}", session_id),
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
return Ok(());
|
||||||
|
};
|
||||||
|
|
||||||
|
*current_session_id = record.id.clone();
|
||||||
|
let mut session_guard = session.lock().await;
|
||||||
|
session_guard.ensure_chat_loaded(&record.id)?;
|
||||||
|
let _ = session_guard
|
||||||
|
.send(WsOutbound::SessionLoaded {
|
||||||
|
session_id: record.id,
|
||||||
|
title: record.title,
|
||||||
|
message_count: record.message_count,
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
WsInbound::RenameSession { session_id, title } => {
|
||||||
|
let target = session_id.unwrap_or_else(|| current_session_id.clone());
|
||||||
|
state.session_manager.rename_session(&target, &title)?;
|
||||||
|
let session_guard = session.lock().await;
|
||||||
|
let _ = session_guard
|
||||||
|
.send(WsOutbound::SessionRenamed {
|
||||||
|
session_id: target,
|
||||||
|
title,
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
WsInbound::ArchiveSession { session_id } => {
|
||||||
|
let target = session_id.unwrap_or_else(|| current_session_id.clone());
|
||||||
|
state.session_manager.archive_session(&target)?;
|
||||||
|
let session_guard = session.lock().await;
|
||||||
|
let _ = session_guard
|
||||||
|
.send(WsOutbound::SessionArchived { session_id: target })
|
||||||
|
.await;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
WsInbound::DeleteSession { session_id } => {
|
||||||
|
let target = session_id.unwrap_or_else(|| current_session_id.clone());
|
||||||
|
state.session_manager.delete_session(&target)?;
|
||||||
|
|
||||||
|
let replacement = if target == *current_session_id {
|
||||||
|
Some(state.session_manager.create_cli_session(None)?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut session_guard = session.lock().await;
|
||||||
|
session_guard.remove_history(&target);
|
||||||
|
let _ = session_guard
|
||||||
|
.send(WsOutbound::SessionDeleted {
|
||||||
|
session_id: target.clone(),
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
if let Some(record) = replacement {
|
||||||
|
*current_session_id = record.id.clone();
|
||||||
|
session_guard.ensure_chat_loaded(&record.id)?;
|
||||||
|
let _ = session_guard
|
||||||
|
.send(WsOutbound::SessionCreated {
|
||||||
|
session_id: record.id,
|
||||||
|
title: record.title,
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
WsInbound::Ping => {
|
||||||
|
let session_guard = session.lock().await;
|
||||||
|
let _ = session_guard.send(WsOutbound::Pong).await;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -9,4 +9,5 @@ pub mod protocol;
|
|||||||
pub mod channels;
|
pub mod channels;
|
||||||
pub mod logging;
|
pub mod logging;
|
||||||
pub mod observability;
|
pub mod observability;
|
||||||
|
pub mod storage;
|
||||||
pub mod tools;
|
pub mod tools;
|
||||||
|
|||||||
@ -1,5 +1,17 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct SessionSummary {
|
||||||
|
pub session_id: String,
|
||||||
|
pub title: String,
|
||||||
|
pub channel_name: String,
|
||||||
|
pub chat_id: String,
|
||||||
|
pub message_count: i64,
|
||||||
|
pub last_active_at: i64,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub archived_at: Option<i64>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
pub enum WsInbound {
|
pub enum WsInbound {
|
||||||
@ -17,6 +29,38 @@ pub enum WsInbound {
|
|||||||
ClearHistory {
|
ClearHistory {
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
chat_id: Option<String>,
|
chat_id: Option<String>,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
session_id: Option<String>,
|
||||||
|
},
|
||||||
|
#[serde(rename = "create_session")]
|
||||||
|
CreateSession {
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
title: Option<String>,
|
||||||
|
},
|
||||||
|
#[serde(rename = "list_sessions")]
|
||||||
|
ListSessions {
|
||||||
|
#[serde(default)]
|
||||||
|
include_archived: bool,
|
||||||
|
},
|
||||||
|
#[serde(rename = "load_session")]
|
||||||
|
LoadSession {
|
||||||
|
session_id: String,
|
||||||
|
},
|
||||||
|
#[serde(rename = "rename_session")]
|
||||||
|
RenameSession {
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
session_id: Option<String>,
|
||||||
|
title: String,
|
||||||
|
},
|
||||||
|
#[serde(rename = "archive_session")]
|
||||||
|
ArchiveSession {
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
session_id: Option<String>,
|
||||||
|
},
|
||||||
|
#[serde(rename = "delete_session")]
|
||||||
|
DeleteSession {
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
session_id: Option<String>,
|
||||||
},
|
},
|
||||||
#[serde(rename = "ping")]
|
#[serde(rename = "ping")]
|
||||||
Ping,
|
Ping,
|
||||||
@ -31,6 +75,28 @@ pub enum WsOutbound {
|
|||||||
Error { code: String, message: String },
|
Error { code: String, message: String },
|
||||||
#[serde(rename = "session_established")]
|
#[serde(rename = "session_established")]
|
||||||
SessionEstablished { session_id: String },
|
SessionEstablished { session_id: String },
|
||||||
|
#[serde(rename = "session_created")]
|
||||||
|
SessionCreated { session_id: String, title: String },
|
||||||
|
#[serde(rename = "session_list")]
|
||||||
|
SessionList {
|
||||||
|
sessions: Vec<SessionSummary>,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
current_session_id: Option<String>,
|
||||||
|
},
|
||||||
|
#[serde(rename = "session_loaded")]
|
||||||
|
SessionLoaded {
|
||||||
|
session_id: String,
|
||||||
|
title: String,
|
||||||
|
message_count: i64,
|
||||||
|
},
|
||||||
|
#[serde(rename = "session_renamed")]
|
||||||
|
SessionRenamed { session_id: String, title: String },
|
||||||
|
#[serde(rename = "session_archived")]
|
||||||
|
SessionArchived { session_id: String },
|
||||||
|
#[serde(rename = "session_deleted")]
|
||||||
|
SessionDeleted { session_id: String },
|
||||||
|
#[serde(rename = "history_cleared")]
|
||||||
|
HistoryCleared { session_id: String },
|
||||||
#[serde(rename = "pong")]
|
#[serde(rename = "pong")]
|
||||||
Pong,
|
Pong,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -57,6 +57,54 @@ impl OpenAIProvider {
|
|||||||
model_extra,
|
model_extra,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn build_request_body(&self, request: &ChatCompletionRequest) -> Value {
|
||||||
|
let mut body = json!({
|
||||||
|
"model": self.model_id,
|
||||||
|
"messages": request.messages.iter().map(|m| {
|
||||||
|
if m.role == "tool" {
|
||||||
|
json!({
|
||||||
|
"role": m.role,
|
||||||
|
"content": convert_content_blocks(&m.content),
|
||||||
|
"tool_call_id": m.tool_call_id,
|
||||||
|
"name": m.name,
|
||||||
|
})
|
||||||
|
} else if m.role == "assistant" && m.tool_calls.is_some() {
|
||||||
|
json!({
|
||||||
|
"role": m.role,
|
||||||
|
"content": convert_content_blocks(&m.content),
|
||||||
|
"tool_calls": m.tool_calls.as_ref().map(|calls| {
|
||||||
|
calls.iter().map(|call| json!({
|
||||||
|
"id": call.id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": call.name,
|
||||||
|
"arguments": serde_json::to_string(&call.arguments).unwrap_or_else(|_| "null".to_string())
|
||||||
|
}
|
||||||
|
})).collect::<Vec<_>>()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
json!({
|
||||||
|
"role": m.role,
|
||||||
|
"content": convert_content_blocks(&m.content)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}).collect::<Vec<_>>(),
|
||||||
|
"temperature": request.temperature.or(self.temperature).unwrap_or(0.7),
|
||||||
|
"max_tokens": request.max_tokens.or(self.max_tokens),
|
||||||
|
});
|
||||||
|
|
||||||
|
for (key, value) in &self.model_extra {
|
||||||
|
body[key] = value.clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(tools) = &request.tools {
|
||||||
|
body["tools"] = json!(tools);
|
||||||
|
}
|
||||||
|
|
||||||
|
body
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
@ -116,35 +164,7 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let url = format!("{}/chat/completions", self.base_url);
|
let url = format!("{}/chat/completions", self.base_url);
|
||||||
|
|
||||||
let mut body = json!({
|
let body = self.build_request_body(&request);
|
||||||
"model": self.model_id,
|
|
||||||
"messages": request.messages.iter().map(|m| {
|
|
||||||
if m.role == "tool" {
|
|
||||||
json!({
|
|
||||||
"role": m.role,
|
|
||||||
"content": convert_content_blocks(&m.content),
|
|
||||||
"tool_call_id": m.tool_call_id,
|
|
||||||
"name": m.name,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
json!({
|
|
||||||
"role": m.role,
|
|
||||||
"content": convert_content_blocks(&m.content)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}).collect::<Vec<_>>(),
|
|
||||||
"temperature": request.temperature.or(self.temperature).unwrap_or(0.7),
|
|
||||||
"max_tokens": request.max_tokens.or(self.max_tokens),
|
|
||||||
});
|
|
||||||
|
|
||||||
// Add model extra fields
|
|
||||||
for (key, value) in &self.model_extra {
|
|
||||||
body[key] = value.clone();
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(tools) = &request.tools {
|
|
||||||
body["tools"] = json!(tools);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Debug: Log LLM request summary (only in debug builds)
|
// Debug: Log LLM request summary (only in debug builds)
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
@ -242,3 +262,50 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
&self.model_id
|
&self.model_id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::providers::Message;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_build_request_body_includes_assistant_tool_calls() {
|
||||||
|
let provider = OpenAIProvider::new(
|
||||||
|
"test".to_string(),
|
||||||
|
"key".to_string(),
|
||||||
|
"https://example.com/v1".to_string(),
|
||||||
|
HashMap::new(),
|
||||||
|
"gpt-test".to_string(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
HashMap::new(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let request = ChatCompletionRequest {
|
||||||
|
messages: vec![Message {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: vec![ContentBlock::text("calling tool")],
|
||||||
|
tool_call_id: None,
|
||||||
|
name: None,
|
||||||
|
tool_calls: Some(vec![ToolCall {
|
||||||
|
id: "call_1".to_string(),
|
||||||
|
name: "calculator".to_string(),
|
||||||
|
arguments: json!({"expression": "1+1"}),
|
||||||
|
}]),
|
||||||
|
}],
|
||||||
|
temperature: None,
|
||||||
|
max_tokens: None,
|
||||||
|
tools: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let body = provider.build_request_body(&request);
|
||||||
|
let messages = body["messages"].as_array().unwrap();
|
||||||
|
let tool_calls = messages[0]["tool_calls"].as_array().unwrap();
|
||||||
|
|
||||||
|
assert_eq!(tool_calls.len(), 1);
|
||||||
|
assert_eq!(tool_calls[0]["id"], "call_1");
|
||||||
|
assert_eq!(tool_calls[0]["type"], "function");
|
||||||
|
assert_eq!(tool_calls[0]["function"]["name"], "calculator");
|
||||||
|
assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -10,6 +10,8 @@ pub struct Message {
|
|||||||
pub tool_call_id: Option<String>,
|
pub tool_call_id: Option<String>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub name: Option<String>,
|
pub name: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_calls: Option<Vec<ToolCall>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Message {
|
impl Message {
|
||||||
@ -19,6 +21,7 @@ impl Message {
|
|||||||
content: vec![ContentBlock::text(content)],
|
content: vec![ContentBlock::text(content)],
|
||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -28,6 +31,7 @@ impl Message {
|
|||||||
content,
|
content,
|
||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -37,6 +41,7 @@ impl Message {
|
|||||||
content: vec![ContentBlock::text(content)],
|
content: vec![ContentBlock::text(content)],
|
||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -46,6 +51,7 @@ impl Message {
|
|||||||
content: vec![ContentBlock::text(content)],
|
content: vec![ContentBlock::text(content)],
|
||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -55,6 +61,7 @@ impl Message {
|
|||||||
content: vec![ContentBlock::text(content)],
|
content: vec![ContentBlock::text(content)],
|
||||||
tool_call_id: Some(tool_call_id.into()),
|
tool_call_id: Some(tool_call_id.into()),
|
||||||
name: Some(tool_name.into()),
|
name: Some(tool_name.into()),
|
||||||
|
tool_calls: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
506
src/storage/mod.rs
Normal file
506
src/storage/mod.rs
Normal file
@ -0,0 +1,506 @@
|
|||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
use rusqlite::{Connection, OptionalExtension, params};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::bus::ChatMessage;
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum StorageError {
|
||||||
|
#[error("database error: {0}")]
|
||||||
|
Database(#[from] rusqlite::Error),
|
||||||
|
#[error("io error: {0}")]
|
||||||
|
Io(#[from] std::io::Error),
|
||||||
|
#[error("serialization error: {0}")]
|
||||||
|
Serialization(#[from] serde_json::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct SessionRecord {
|
||||||
|
pub id: String,
|
||||||
|
pub title: String,
|
||||||
|
pub channel_name: String,
|
||||||
|
pub chat_id: String,
|
||||||
|
pub summary: Option<String>,
|
||||||
|
pub created_at: i64,
|
||||||
|
pub updated_at: i64,
|
||||||
|
pub last_active_at: i64,
|
||||||
|
pub archived_at: Option<i64>,
|
||||||
|
pub deleted_at: Option<i64>,
|
||||||
|
pub message_count: i64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct SessionStore {
|
||||||
|
conn: Arc<Mutex<Connection>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SessionStore {
|
||||||
|
pub fn new() -> Result<Self, StorageError> {
|
||||||
|
let db_path = default_session_db_path()?;
|
||||||
|
Self::open_at_path(&db_path)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn open_at_path(path: &Path) -> Result<Self, StorageError> {
|
||||||
|
if let Some(parent) = path.parent() {
|
||||||
|
std::fs::create_dir_all(parent)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let conn = Connection::open(path)?;
|
||||||
|
Self::from_connection(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn from_connection(conn: Connection) -> Result<Self, StorageError> {
|
||||||
|
conn.execute_batch(
|
||||||
|
"
|
||||||
|
PRAGMA journal_mode = WAL;
|
||||||
|
PRAGMA foreign_keys = ON;
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS sessions (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
title TEXT NOT NULL,
|
||||||
|
channel_name TEXT NOT NULL,
|
||||||
|
chat_id TEXT NOT NULL,
|
||||||
|
summary TEXT,
|
||||||
|
created_at INTEGER NOT NULL,
|
||||||
|
updated_at INTEGER NOT NULL,
|
||||||
|
last_active_at INTEGER NOT NULL,
|
||||||
|
archived_at INTEGER,
|
||||||
|
deleted_at INTEGER,
|
||||||
|
message_count INTEGER NOT NULL DEFAULT 0
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_sessions_channel_archived
|
||||||
|
ON sessions(channel_name, archived_at, last_active_at DESC);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_sessions_updated_at
|
||||||
|
ON sessions(updated_at DESC);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS messages (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
session_id TEXT NOT NULL,
|
||||||
|
seq INTEGER NOT NULL,
|
||||||
|
role TEXT NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
media_refs_json TEXT NOT NULL,
|
||||||
|
tool_call_id TEXT,
|
||||||
|
tool_name TEXT,
|
||||||
|
tool_calls_json TEXT,
|
||||||
|
created_at INTEGER NOT NULL,
|
||||||
|
FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE,
|
||||||
|
UNIQUE(session_id, seq)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_messages_session_seq
|
||||||
|
ON messages(session_id, seq);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_messages_session_created
|
||||||
|
ON messages(session_id, created_at);
|
||||||
|
",
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
conn: Arc::new(Mutex::new(conn)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
fn in_memory() -> Result<Self, StorageError> {
|
||||||
|
Self::from_connection(Connection::open_in_memory()?)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn create_cli_session(&self, title: Option<&str>) -> Result<SessionRecord, StorageError> {
|
||||||
|
let now = current_timestamp();
|
||||||
|
let id = uuid::Uuid::new_v4().to_string();
|
||||||
|
let title = title
|
||||||
|
.map(str::trim)
|
||||||
|
.filter(|value| !value.is_empty())
|
||||||
|
.map(ToOwned::to_owned)
|
||||||
|
.unwrap_or_else(|| format!("CLI Session {}", &id[..8]));
|
||||||
|
|
||||||
|
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||||
|
conn.execute(
|
||||||
|
"
|
||||||
|
INSERT INTO sessions (
|
||||||
|
id, title, channel_name, chat_id, summary,
|
||||||
|
created_at, updated_at, last_active_at, archived_at, deleted_at, message_count
|
||||||
|
) VALUES (?1, ?2, 'cli', ?3, NULL, ?4, ?4, ?4, NULL, NULL, 0)
|
||||||
|
",
|
||||||
|
params![id, title, id, now],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
drop(conn);
|
||||||
|
self.get_session(&id)?.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ensure_channel_session(
|
||||||
|
&self,
|
||||||
|
channel_name: &str,
|
||||||
|
chat_id: &str,
|
||||||
|
) -> Result<SessionRecord, StorageError> {
|
||||||
|
let session_id = persistent_session_id(channel_name, chat_id);
|
||||||
|
if let Some(record) = self.get_session(&session_id)? {
|
||||||
|
return Ok(record);
|
||||||
|
}
|
||||||
|
|
||||||
|
let now = current_timestamp();
|
||||||
|
let title = format!("{}:{}", channel_name, chat_id);
|
||||||
|
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||||
|
conn.execute(
|
||||||
|
"
|
||||||
|
INSERT INTO sessions (
|
||||||
|
id, title, channel_name, chat_id, summary,
|
||||||
|
created_at, updated_at, last_active_at, archived_at, deleted_at, message_count
|
||||||
|
) VALUES (?1, ?2, ?3, ?4, NULL, ?5, ?5, ?5, NULL, NULL, 0)
|
||||||
|
",
|
||||||
|
params![session_id, title, channel_name, chat_id, now],
|
||||||
|
)?;
|
||||||
|
drop(conn);
|
||||||
|
|
||||||
|
self.get_session(&session_id)?.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_session(&self, session_id: &str) -> Result<Option<SessionRecord>, StorageError> {
|
||||||
|
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||||
|
let mut stmt = conn.prepare(
|
||||||
|
"
|
||||||
|
SELECT id, title, channel_name, chat_id, summary,
|
||||||
|
created_at, updated_at, last_active_at,
|
||||||
|
archived_at, deleted_at, message_count
|
||||||
|
FROM sessions
|
||||||
|
WHERE id = ?1 AND deleted_at IS NULL
|
||||||
|
",
|
||||||
|
)?;
|
||||||
|
|
||||||
|
stmt.query_row(params![session_id], map_session_record)
|
||||||
|
.optional()
|
||||||
|
.map_err(StorageError::from)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn list_sessions(
|
||||||
|
&self,
|
||||||
|
channel_name: &str,
|
||||||
|
include_archived: bool,
|
||||||
|
) -> Result<Vec<SessionRecord>, StorageError> {
|
||||||
|
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||||
|
let mut sql = String::from(
|
||||||
|
"
|
||||||
|
SELECT id, title, channel_name, chat_id, summary,
|
||||||
|
created_at, updated_at, last_active_at,
|
||||||
|
archived_at, deleted_at, message_count
|
||||||
|
FROM sessions
|
||||||
|
WHERE channel_name = ?1
|
||||||
|
AND deleted_at IS NULL
|
||||||
|
",
|
||||||
|
);
|
||||||
|
|
||||||
|
if !include_archived {
|
||||||
|
sql.push_str(" AND archived_at IS NULL");
|
||||||
|
}
|
||||||
|
|
||||||
|
sql.push_str(" ORDER BY last_active_at DESC, created_at DESC");
|
||||||
|
|
||||||
|
let mut stmt = conn.prepare(&sql)?;
|
||||||
|
let rows = stmt.query_map(params![channel_name], map_session_record)?;
|
||||||
|
let mut sessions = Vec::new();
|
||||||
|
for row in rows {
|
||||||
|
sessions.push(row?);
|
||||||
|
}
|
||||||
|
Ok(sessions)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rename_session(&self, session_id: &str, title: &str) -> Result<(), StorageError> {
|
||||||
|
let now = current_timestamp();
|
||||||
|
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE sessions SET title = ?2, updated_at = ?3 WHERE id = ?1 AND deleted_at IS NULL",
|
||||||
|
params![session_id, title.trim(), now],
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn archive_session(&self, session_id: &str) -> Result<(), StorageError> {
|
||||||
|
let now = current_timestamp();
|
||||||
|
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE sessions SET archived_at = ?2, updated_at = ?2 WHERE id = ?1 AND deleted_at IS NULL",
|
||||||
|
params![session_id, now],
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn delete_session(&self, session_id: &str) -> Result<(), StorageError> {
|
||||||
|
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||||
|
conn.execute("DELETE FROM messages WHERE session_id = ?1", params![session_id])?;
|
||||||
|
conn.execute("DELETE FROM sessions WHERE id = ?1", params![session_id])?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> {
|
||||||
|
let now = current_timestamp();
|
||||||
|
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||||
|
conn.execute("DELETE FROM messages WHERE session_id = ?1", params![session_id])?;
|
||||||
|
conn.execute(
|
||||||
|
"
|
||||||
|
UPDATE sessions
|
||||||
|
SET message_count = 0, updated_at = ?2, last_active_at = ?2
|
||||||
|
WHERE id = ?1 AND deleted_at IS NULL
|
||||||
|
",
|
||||||
|
params![session_id, now],
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError> {
|
||||||
|
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||||
|
let tx = conn.unchecked_transaction()?;
|
||||||
|
|
||||||
|
let seq: i64 = tx.query_row(
|
||||||
|
"SELECT COALESCE(MAX(seq), 0) + 1 FROM messages WHERE session_id = ?1",
|
||||||
|
params![session_id],
|
||||||
|
|row| row.get(0),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let media_refs_json = serde_json::to_string(&message.media_refs)?;
|
||||||
|
let tool_calls_json = message.tool_calls.as_ref().map(serde_json::to_string).transpose()?;
|
||||||
|
tx.execute(
|
||||||
|
"
|
||||||
|
INSERT INTO messages (
|
||||||
|
id, session_id, seq, role, content,
|
||||||
|
media_refs_json, tool_call_id, tool_name, tool_calls_json, created_at
|
||||||
|
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
|
||||||
|
",
|
||||||
|
params![
|
||||||
|
message.id,
|
||||||
|
session_id,
|
||||||
|
seq,
|
||||||
|
message.role,
|
||||||
|
message.content,
|
||||||
|
media_refs_json,
|
||||||
|
message.tool_call_id,
|
||||||
|
message.tool_name,
|
||||||
|
tool_calls_json,
|
||||||
|
message.timestamp,
|
||||||
|
],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let now = current_timestamp();
|
||||||
|
tx.execute(
|
||||||
|
"
|
||||||
|
UPDATE sessions
|
||||||
|
SET message_count = message_count + 1,
|
||||||
|
updated_at = ?2,
|
||||||
|
last_active_at = ?2,
|
||||||
|
archived_at = NULL
|
||||||
|
WHERE id = ?1 AND deleted_at IS NULL
|
||||||
|
",
|
||||||
|
params![session_id, now],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
tx.commit()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError> {
|
||||||
|
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||||
|
let mut stmt = conn.prepare(
|
||||||
|
"
|
||||||
|
SELECT id, role, content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json
|
||||||
|
FROM messages
|
||||||
|
WHERE session_id = ?1
|
||||||
|
ORDER BY seq ASC
|
||||||
|
",
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let rows = stmt.query_map(params![session_id], |row| {
|
||||||
|
let media_refs_json: String = row.get(3)?;
|
||||||
|
let media_refs: Vec<String> = serde_json::from_str(&media_refs_json).map_err(|err| {
|
||||||
|
rusqlite::Error::FromSqlConversionFailure(
|
||||||
|
media_refs_json.len(),
|
||||||
|
rusqlite::types::Type::Text,
|
||||||
|
Box::new(err),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let tool_calls_json: Option<String> = row.get(7)?;
|
||||||
|
let tool_calls = tool_calls_json
|
||||||
|
.as_deref()
|
||||||
|
.map(serde_json::from_str)
|
||||||
|
.transpose()
|
||||||
|
.map_err(|err| {
|
||||||
|
rusqlite::Error::FromSqlConversionFailure(
|
||||||
|
7,
|
||||||
|
rusqlite::types::Type::Text,
|
||||||
|
Box::new(err),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(ChatMessage {
|
||||||
|
id: row.get(0)?,
|
||||||
|
role: row.get(1)?,
|
||||||
|
content: row.get(2)?,
|
||||||
|
media_refs,
|
||||||
|
timestamp: row.get(4)?,
|
||||||
|
tool_call_id: row.get(5)?,
|
||||||
|
tool_name: row.get(6)?,
|
||||||
|
tool_calls,
|
||||||
|
})
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let mut messages = Vec::new();
|
||||||
|
for row in rows {
|
||||||
|
messages.push(row?);
|
||||||
|
}
|
||||||
|
Ok(messages)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn persistent_session_id(channel_name: &str, chat_id: &str) -> String {
|
||||||
|
if channel_name == "cli" {
|
||||||
|
chat_id.to_string()
|
||||||
|
} else {
|
||||||
|
format!("{}:{}", channel_name, chat_id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_session_db_path() -> Result<PathBuf, std::io::Error> {
|
||||||
|
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
|
||||||
|
Ok(home.join(".picobot").join("storage").join("sessions.db"))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn map_session_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<SessionRecord> {
|
||||||
|
Ok(SessionRecord {
|
||||||
|
id: row.get(0)?,
|
||||||
|
title: row.get(1)?,
|
||||||
|
channel_name: row.get(2)?,
|
||||||
|
chat_id: row.get(3)?,
|
||||||
|
summary: row.get(4)?,
|
||||||
|
created_at: row.get(5)?,
|
||||||
|
updated_at: row.get(6)?,
|
||||||
|
last_active_at: row.get(7)?,
|
||||||
|
archived_at: row.get(8)?,
|
||||||
|
deleted_at: row.get(9)?,
|
||||||
|
message_count: row.get(10)?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn current_timestamp() -> i64 {
|
||||||
|
std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.expect("system clock before unix epoch")
|
||||||
|
.as_millis() as i64
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::providers::ToolCall;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_persistent_session_id_for_cli_and_channel() {
|
||||||
|
assert_eq!(persistent_session_id("cli", "abc"), "abc");
|
||||||
|
assert_eq!(persistent_session_id("feishu", "abc"), "feishu:abc");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_session_store_roundtrip_and_lifecycle() {
|
||||||
|
let store = SessionStore::in_memory().unwrap();
|
||||||
|
|
||||||
|
let session = store.create_cli_session(Some("demo")).unwrap();
|
||||||
|
assert_eq!(session.title, "demo");
|
||||||
|
assert_eq!(session.channel_name, "cli");
|
||||||
|
assert_eq!(session.chat_id, session.id);
|
||||||
|
assert_eq!(session.message_count, 0);
|
||||||
|
|
||||||
|
let first = ChatMessage::user("hello");
|
||||||
|
let second = ChatMessage::assistant("world");
|
||||||
|
store.append_message(&session.id, &first).unwrap();
|
||||||
|
store.append_message(&session.id, &second).unwrap();
|
||||||
|
|
||||||
|
let stored = store.get_session(&session.id).unwrap().unwrap();
|
||||||
|
assert_eq!(stored.message_count, 2);
|
||||||
|
assert!(stored.archived_at.is_none());
|
||||||
|
|
||||||
|
let messages = store.load_messages(&session.id).unwrap();
|
||||||
|
assert_eq!(messages.len(), 2);
|
||||||
|
assert_eq!(messages[0].role, "user");
|
||||||
|
assert_eq!(messages[0].content, "hello");
|
||||||
|
assert_eq!(messages[1].role, "assistant");
|
||||||
|
assert_eq!(messages[1].content, "world");
|
||||||
|
|
||||||
|
store.rename_session(&session.id, "renamed").unwrap();
|
||||||
|
let renamed = store.get_session(&session.id).unwrap().unwrap();
|
||||||
|
assert_eq!(renamed.title, "renamed");
|
||||||
|
|
||||||
|
store.archive_session(&session.id).unwrap();
|
||||||
|
let archived = store.get_session(&session.id).unwrap().unwrap();
|
||||||
|
assert!(archived.archived_at.is_some());
|
||||||
|
|
||||||
|
let active_only = store.list_sessions("cli", false).unwrap();
|
||||||
|
assert!(active_only.is_empty());
|
||||||
|
|
||||||
|
let including_archived = store.list_sessions("cli", true).unwrap();
|
||||||
|
assert_eq!(including_archived.len(), 1);
|
||||||
|
|
||||||
|
store.clear_messages(&session.id).unwrap();
|
||||||
|
let cleared = store.load_messages(&session.id).unwrap();
|
||||||
|
assert!(cleared.is_empty());
|
||||||
|
let cleared_session = store.get_session(&session.id).unwrap().unwrap();
|
||||||
|
assert_eq!(cleared_session.message_count, 0);
|
||||||
|
|
||||||
|
store.delete_session(&session.id).unwrap();
|
||||||
|
assert!(store.get_session(&session.id).unwrap().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ensure_channel_session_is_stable() {
|
||||||
|
let store = SessionStore::in_memory().unwrap();
|
||||||
|
|
||||||
|
let first = store.ensure_channel_session("feishu", "chat-1").unwrap();
|
||||||
|
let second = store.ensure_channel_session("feishu", "chat-1").unwrap();
|
||||||
|
|
||||||
|
assert_eq!(first.id, second.id);
|
||||||
|
assert_eq!(first.chat_id, "chat-1");
|
||||||
|
assert_eq!(second.channel_name, "feishu");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_assistant_tool_calls_roundtrip() {
|
||||||
|
let store = SessionStore::in_memory().unwrap();
|
||||||
|
let session = store.create_cli_session(Some("tools")).unwrap();
|
||||||
|
|
||||||
|
let assistant = ChatMessage::assistant_with_tool_calls(
|
||||||
|
"calling tool",
|
||||||
|
vec![ToolCall {
|
||||||
|
id: "call_1".to_string(),
|
||||||
|
name: "calculator".to_string(),
|
||||||
|
arguments: serde_json::json!({ "expression": "3*7" }),
|
||||||
|
}],
|
||||||
|
);
|
||||||
|
|
||||||
|
store.append_message(&session.id, &assistant).unwrap();
|
||||||
|
|
||||||
|
let messages = store.load_messages(&session.id).unwrap();
|
||||||
|
assert_eq!(messages.len(), 1);
|
||||||
|
assert_eq!(messages[0].role, "assistant");
|
||||||
|
assert_eq!(messages[0].tool_calls.as_ref().unwrap().len(), 1);
|
||||||
|
assert_eq!(messages[0].tool_calls.as_ref().unwrap()[0].id, "call_1");
|
||||||
|
assert_eq!(messages[0].tool_calls.as_ref().unwrap()[0].name, "calculator");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tool_result_roundtrip() {
|
||||||
|
let store = SessionStore::in_memory().unwrap();
|
||||||
|
let session = store.create_cli_session(Some("tool-result")).unwrap();
|
||||||
|
|
||||||
|
let tool_message = ChatMessage::tool("call_9", "file_write", "saved to /tmp/output.txt");
|
||||||
|
store.append_message(&session.id, &tool_message).unwrap();
|
||||||
|
|
||||||
|
let messages = store.load_messages(&session.id).unwrap();
|
||||||
|
assert_eq!(messages.len(), 1);
|
||||||
|
assert_eq!(messages[0].role, "tool");
|
||||||
|
assert_eq!(messages[0].content, "saved to /tmp/output.txt");
|
||||||
|
assert_eq!(messages[0].tool_call_id.as_deref(), Some("call_9"));
|
||||||
|
assert_eq!(messages[0].tool_name.as_deref(), Some("file_write"));
|
||||||
|
assert!(messages[0].tool_calls.is_none());
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,6 +1,6 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use PicoBot::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message};
|
use picobot::providers::{create_provider, ChatCompletionRequest, Message};
|
||||||
use PicoBot::config::{Config, LLMProviderConfig};
|
use picobot::config::{Config, LLMProviderConfig};
|
||||||
|
|
||||||
fn load_config() -> Option<LLMProviderConfig> {
|
fn load_config() -> Option<LLMProviderConfig> {
|
||||||
dotenv::from_filename("tests/test.env").ok()?;
|
dotenv::from_filename("tests/test.env").ok()?;
|
||||||
@ -24,15 +24,13 @@ fn load_config() -> Option<LLMProviderConfig> {
|
|||||||
max_tokens: Some(100),
|
max_tokens: Some(100),
|
||||||
model_extra: HashMap::new(),
|
model_extra: HashMap::new(),
|
||||||
max_tool_iterations: 20,
|
max_tool_iterations: 20,
|
||||||
|
token_limit: 128_000,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_request(content: &str) -> ChatCompletionRequest {
|
fn create_request(content: &str) -> ChatCompletionRequest {
|
||||||
ChatCompletionRequest {
|
ChatCompletionRequest {
|
||||||
messages: vec![Message {
|
messages: vec![Message::user(content)],
|
||||||
role: "user".to_string(),
|
|
||||||
content: content.to_string(),
|
|
||||||
}],
|
|
||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
max_tokens: Some(100),
|
max_tokens: Some(100),
|
||||||
tools: None,
|
tools: None,
|
||||||
@ -64,9 +62,9 @@ async fn test_openai_conversation() {
|
|||||||
|
|
||||||
let request = ChatCompletionRequest {
|
let request = ChatCompletionRequest {
|
||||||
messages: vec![
|
messages: vec![
|
||||||
Message { role: "user".to_string(), content: "My name is Alice".to_string() },
|
Message::user("My name is Alice"),
|
||||||
Message { role: "assistant".to_string(), content: "Hello Alice!".to_string() },
|
Message::assistant("Hello Alice!"),
|
||||||
Message { role: "user".to_string(), content: "What is my name?".to_string() },
|
Message::user("What is my name?"),
|
||||||
],
|
],
|
||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
max_tokens: Some(50),
|
max_tokens: Some(50),
|
||||||
|
|||||||
@ -1,31 +1,26 @@
|
|||||||
use PicoBot::providers::{ChatCompletionRequest, Message};
|
use picobot::providers::{ChatCompletionRequest, Message};
|
||||||
|
use picobot::protocol::{SessionSummary, WsInbound, WsOutbound};
|
||||||
|
|
||||||
/// Test that message with special characters is properly escaped
|
/// Test that message with special characters is properly escaped
|
||||||
#[test]
|
#[test]
|
||||||
fn test_message_special_characters() {
|
fn test_message_special_characters() {
|
||||||
let msg = Message {
|
let msg = Message::user("Hello \"world\"\nNew line\tTab");
|
||||||
role: "user".to_string(),
|
|
||||||
content: "Hello \"world\"\nNew line\tTab".to_string(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let json = serde_json::to_string(&msg).unwrap();
|
let json = serde_json::to_string(&msg).unwrap();
|
||||||
let deserialized: Message = serde_json::from_str(&json).unwrap();
|
let deserialized: Message = serde_json::from_str(&json).unwrap();
|
||||||
|
|
||||||
assert_eq!(deserialized.content, "Hello \"world\"\nNew line\tTab");
|
assert_eq!(deserialized.role, "user");
|
||||||
|
assert_eq!(deserialized.content.len(), 1);
|
||||||
|
let encoded = serde_json::to_string(&deserialized.content).unwrap();
|
||||||
|
assert!(encoded.contains("Hello \\\"world\\\"\\nNew line\\tTab"));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Test that multi-line system prompt is preserved
|
/// Test that multi-line system prompt is preserved
|
||||||
#[test]
|
#[test]
|
||||||
fn test_multiline_system_prompt() {
|
fn test_multiline_system_prompt() {
|
||||||
let messages = vec![
|
let messages = vec![
|
||||||
Message {
|
Message::system("You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate"),
|
||||||
role: "system".to_string(),
|
Message::user("Hi"),
|
||||||
content: "You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate".to_string(),
|
|
||||||
},
|
|
||||||
Message {
|
|
||||||
role: "user".to_string(),
|
|
||||||
content: "Hi".to_string(),
|
|
||||||
},
|
|
||||||
];
|
];
|
||||||
|
|
||||||
let json = serde_json::to_string(&messages[0]).unwrap();
|
let json = serde_json::to_string(&messages[0]).unwrap();
|
||||||
@ -39,14 +34,8 @@ fn test_multiline_system_prompt() {
|
|||||||
fn test_chat_request_serialization() {
|
fn test_chat_request_serialization() {
|
||||||
let request = ChatCompletionRequest {
|
let request = ChatCompletionRequest {
|
||||||
messages: vec![
|
messages: vec![
|
||||||
Message {
|
Message::system("You are helpful"),
|
||||||
role: "system".to_string(),
|
Message::user("Hello"),
|
||||||
content: "You are helpful".to_string(),
|
|
||||||
},
|
|
||||||
Message {
|
|
||||||
role: "user".to_string(),
|
|
||||||
content: "Hello".to_string(),
|
|
||||||
},
|
|
||||||
],
|
],
|
||||||
temperature: Some(0.7),
|
temperature: Some(0.7),
|
||||||
max_tokens: Some(100),
|
max_tokens: Some(100),
|
||||||
@ -58,8 +47,73 @@ fn test_chat_request_serialization() {
|
|||||||
// Verify structure
|
// Verify structure
|
||||||
assert!(json.contains(r#""role":"system""#));
|
assert!(json.contains(r#""role":"system""#));
|
||||||
assert!(json.contains(r#""role":"user""#));
|
assert!(json.contains(r#""role":"user""#));
|
||||||
assert!(json.contains(r#""content":"You are helpful""#));
|
assert!(json.contains("You are helpful"));
|
||||||
assert!(json.contains(r#""content":"Hello""#));
|
assert!(json.contains("Hello"));
|
||||||
assert!(json.contains(r#""temperature":0.7"#));
|
assert!(json.contains(r#""temperature":0.7"#));
|
||||||
assert!(json.contains(r#""max_tokens":100"#));
|
assert!(json.contains(r#""max_tokens":100"#));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_session_inbound_serialization() {
|
||||||
|
let msg = WsInbound::CreateSession {
|
||||||
|
title: Some("demo".to_string()),
|
||||||
|
};
|
||||||
|
|
||||||
|
let json = serde_json::to_string(&msg).unwrap();
|
||||||
|
assert!(json.contains(r#""type":"create_session""#));
|
||||||
|
assert!(json.contains(r#""title":"demo""#));
|
||||||
|
|
||||||
|
let decoded: WsInbound = serde_json::from_str(&json).unwrap();
|
||||||
|
match decoded {
|
||||||
|
WsInbound::CreateSession { title } => {
|
||||||
|
assert_eq!(title.as_deref(), Some("demo"));
|
||||||
|
}
|
||||||
|
other => panic!("unexpected decoded variant: {:?}", other),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_session_list_outbound_serialization() {
|
||||||
|
let msg = WsOutbound::SessionList {
|
||||||
|
sessions: vec![SessionSummary {
|
||||||
|
session_id: "session-1".to_string(),
|
||||||
|
title: "demo".to_string(),
|
||||||
|
channel_name: "cli".to_string(),
|
||||||
|
chat_id: "session-1".to_string(),
|
||||||
|
message_count: 2,
|
||||||
|
last_active_at: 123,
|
||||||
|
archived_at: None,
|
||||||
|
}],
|
||||||
|
current_session_id: Some("session-1".to_string()),
|
||||||
|
};
|
||||||
|
|
||||||
|
let json = serde_json::to_string(&msg).unwrap();
|
||||||
|
assert!(json.contains(r#""type":"session_list""#));
|
||||||
|
assert!(json.contains(r#""session_id":"session-1""#));
|
||||||
|
assert!(json.contains(r#""message_count":2"#));
|
||||||
|
|
||||||
|
let decoded: WsOutbound = serde_json::from_str(&json).unwrap();
|
||||||
|
match decoded {
|
||||||
|
WsOutbound::SessionList {
|
||||||
|
sessions,
|
||||||
|
current_session_id,
|
||||||
|
} => {
|
||||||
|
assert_eq!(sessions.len(), 1);
|
||||||
|
assert_eq!(sessions[0].title, "demo");
|
||||||
|
assert_eq!(current_session_id.as_deref(), Some("session-1"));
|
||||||
|
}
|
||||||
|
other => panic!("unexpected decoded variant: {:?}", other),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_clear_history_with_session_id_serialization() {
|
||||||
|
let msg = WsInbound::ClearHistory {
|
||||||
|
chat_id: None,
|
||||||
|
session_id: Some("session-1".to_string()),
|
||||||
|
};
|
||||||
|
|
||||||
|
let json = serde_json::to_string(&msg).unwrap();
|
||||||
|
assert!(json.contains(r#""type":"clear_history""#));
|
||||||
|
assert!(json.contains(r#""session_id":"session-1""#));
|
||||||
|
}
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use PicoBot::providers::{create_provider, ChatCompletionRequest, Message, Tool, ToolFunction};
|
use picobot::providers::{create_provider, ChatCompletionRequest, Message, Tool, ToolFunction};
|
||||||
use PicoBot::config::LLMProviderConfig;
|
use picobot::config::LLMProviderConfig;
|
||||||
|
|
||||||
fn load_openai_config() -> Option<LLMProviderConfig> {
|
fn load_openai_config() -> Option<LLMProviderConfig> {
|
||||||
dotenv::from_filename("tests/test.env").ok()?;
|
dotenv::from_filename("tests/test.env").ok()?;
|
||||||
@ -24,6 +24,7 @@ fn load_openai_config() -> Option<LLMProviderConfig> {
|
|||||||
max_tokens: Some(100),
|
max_tokens: Some(100),
|
||||||
model_extra: HashMap::new(),
|
model_extra: HashMap::new(),
|
||||||
max_tool_iterations: 20,
|
max_tool_iterations: 20,
|
||||||
|
token_limit: 128_000,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -56,10 +57,7 @@ async fn test_openai_tool_call() {
|
|||||||
let provider = create_provider(config).expect("Failed to create provider");
|
let provider = create_provider(config).expect("Failed to create provider");
|
||||||
|
|
||||||
let request = ChatCompletionRequest {
|
let request = ChatCompletionRequest {
|
||||||
messages: vec![Message {
|
messages: vec![Message::user("What is the weather in Tokyo?")],
|
||||||
role: "user".to_string(),
|
|
||||||
content: "What is the weather in Tokyo?".to_string(),
|
|
||||||
}],
|
|
||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
max_tokens: Some(200),
|
max_tokens: Some(200),
|
||||||
tools: Some(vec![make_weather_tool()]),
|
tools: Some(vec![make_weather_tool()]),
|
||||||
@ -85,10 +83,7 @@ async fn test_openai_tool_call_with_manual_execution() {
|
|||||||
|
|
||||||
// First request with tool
|
// First request with tool
|
||||||
let request1 = ChatCompletionRequest {
|
let request1 = ChatCompletionRequest {
|
||||||
messages: vec![Message {
|
messages: vec![Message::user("What is the weather in Tokyo?")],
|
||||||
role: "user".to_string(),
|
|
||||||
content: "What is the weather in Tokyo?".to_string(),
|
|
||||||
}],
|
|
||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
max_tokens: Some(200),
|
max_tokens: Some(200),
|
||||||
tools: Some(vec![make_weather_tool()]),
|
tools: Some(vec![make_weather_tool()]),
|
||||||
@ -102,14 +97,8 @@ async fn test_openai_tool_call_with_manual_execution() {
|
|||||||
// Second request with tool result
|
// Second request with tool result
|
||||||
let request2 = ChatCompletionRequest {
|
let request2 = ChatCompletionRequest {
|
||||||
messages: vec![
|
messages: vec![
|
||||||
Message {
|
Message::user("What is the weather in Tokyo?"),
|
||||||
role: "user".to_string(),
|
Message::assistant(r#"I'll check the weather for you using the get_weather tool."#),
|
||||||
content: "What is the weather in Tokyo?".to_string(),
|
|
||||||
},
|
|
||||||
Message {
|
|
||||||
role: "assistant".to_string(),
|
|
||||||
content: r#"I'll check the weather for you using the get_weather tool."#.to_string(),
|
|
||||||
},
|
|
||||||
],
|
],
|
||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
max_tokens: Some(200),
|
max_tokens: Some(200),
|
||||||
@ -131,10 +120,7 @@ async fn test_openai_no_tool_when_not_provided() {
|
|||||||
let provider = create_provider(config).expect("Failed to create provider");
|
let provider = create_provider(config).expect("Failed to create provider");
|
||||||
|
|
||||||
let request = ChatCompletionRequest {
|
let request = ChatCompletionRequest {
|
||||||
messages: vec![Message {
|
messages: vec![Message::user("Say hello in one word.")],
|
||||||
role: "user".to_string(),
|
|
||||||
content: "Say hello in one word.".to_string(),
|
|
||||||
}],
|
|
||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
max_tokens: Some(10),
|
max_tokens: Some(10),
|
||||||
tools: None,
|
tools: None,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user