Compare commits
4 Commits
b4ef56803f
...
b3fa0bb978
| Author | SHA1 | Date | |
|---|---|---|---|
| b3fa0bb978 | |||
| cbb384a4e6 | |||
| f68e915b04 | |||
| b6f2de053d |
@ -37,3 +37,12 @@ rusqlite = { version = "0.32", features = ["bundled"] }
|
||||
rustls = { version = "0.23", features = ["ring"] }
|
||||
wechatbot = { path = "vendor/wechatbot" }
|
||||
encoding_rs = "0.8"
|
||||
# MCP (Model Context Protocol) support
|
||||
rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "main", features = [
|
||||
"client",
|
||||
"transport-child-process",
|
||||
"transport-streamable-http-client-reqwest",
|
||||
"reqwest",
|
||||
] }
|
||||
schemars = "1.0"
|
||||
http = "1"
|
||||
|
||||
11
README.md
11
README.md
@ -249,13 +249,9 @@ PicoBot 会在 ~/.picobot/agent/AGENT.md 维护一份持久化 Agent 画像文
|
||||
用户请求、关键标识符、文件路径、URL、工具调用、命令、结果、错误、决策、偏好和当前任务状态。
|
||||
如果摘要调用失败,会退化为直接截断 transcript,而不会中断主流程。
|
||||
9. 摘要结果会被包装成一条新的 system 消息,并打上 SYSTEM_CONTEXT_HISTORY_COMPACTION 标记,内容前缀为 [Compressed History]。
|
||||
10. 后台提交阶段不会直接修改旧消息,而是向消息表尾部追加一段“新的活动段”:
|
||||
依次写入保留的关键 system 消息、压缩摘要消息、最近保留的消息,以及在压缩快照之后新产生的 delta 消息。
|
||||
11. 提交成功后,sessions.reset_cutoff_seq 会被推进到压缩前的最大 seq。
|
||||
这样旧消息仍然留在数据库里用于审计或全量导出,但默认恢复到运行时上下文时,只会加载新的活动段。
|
||||
12. 为避免并发覆盖,压缩提交前会检查快照是否过期:
|
||||
如果 reset_cutoff_seq 已变化,或者压缩期间又有更新导致快照不再匹配,本次压缩会跳过,不会覆盖较新的上下文。
|
||||
13. 压缩提交成功后,Session 会重新加载当前 chat 的活动历史,后续轮次看到的就是“关键 system 消息 + 压缩摘要 + 最近若干完整 turn”的新上下文。
|
||||
10. 后台提交阶段会删除旧消息,并追加新的活动段:
|
||||
依次写入保留的关键 system 消息、压缩摘要消息、最近保留的消息。
|
||||
11. 压缩提交成功后,Session 会重新加载当前 chat 的活动历史,后续轮次看到的就是"关键 system 消息 + 压缩摘要 + 最近若干完整 turn"的新上下文。
|
||||
|
||||
这套机制的目标不是简单删历史,而是把“远端历史变成可恢复摘要”,同时保证:
|
||||
|
||||
@ -695,7 +691,6 @@ cargo run -- agent
|
||||
CLI 中已实现的交互命令包括:
|
||||
|
||||
- /new [title] - 创建新会话
|
||||
- /reset - 重置当前会话上下文
|
||||
- /sessions - 列出当前通道的所有会话(支持跨通道隔离)
|
||||
- /use <session> - 切换到指定会话
|
||||
- /rename <title> - 重命名当前会话
|
||||
|
||||
@ -90,9 +90,8 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
|
||||
| `archived_at` | `INTEGER` | 归档时间 | 非空表示会话已归档 |
|
||||
| `deleted_at` | `INTEGER` | 删除时间 | 预留字段,当前读取逻辑会过滤该字段,但当前删除实现是物理删除 |
|
||||
| `message_count` | `INTEGER NOT NULL DEFAULT 0` | 消息数 | 追加消息时自增,清空历史时重置 |
|
||||
| `reset_cutoff_seq` | `INTEGER NOT NULL DEFAULT 0` | 逻辑重置切点 | `/reset` 后默认只恢复 `seq > reset_cutoff_seq` 的活动段 |
|
||||
| `user_turn_count` | `INTEGER NOT NULL DEFAULT 0` | 当前活动段用户轮次数 | 只在追加 `role = user` 消息时递增,清空历史和 `/reset` 时归零 |
|
||||
| `agent_prompt_reinjection_count` | `INTEGER NOT NULL DEFAULT 0` | AGENT.md 周期重注入次数 | 每完成一次“达到配置阈值后的下一轮前注入”就递增,清空历史和 `/reset` 时归零 |
|
||||
| `user_turn_count` | `INTEGER NOT NULL DEFAULT 0` | 当前活动段用户轮次数 | 只在追加 `role = user` 消息时递增,清空历史时归零 |
|
||||
| `agent_prompt_reinjection_count` | `INTEGER NOT NULL DEFAULT 0` | AGENT.md 周期重注入次数 | 每完成一次”达到配置阈值后的下一轮前注入”就递增,清空历史时归零 |
|
||||
|
||||
索引:
|
||||
|
||||
@ -224,12 +223,9 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
|
||||
|
||||
### 7.3 读取历史
|
||||
|
||||
`load_messages(session_id)` 会按 `seq ASC` 读取当前活动段历史,并把 JSON 字段反序列化回 `ChatMessage`。活动段的定义是:
|
||||
`load_messages(session_id)` 会按 `seq ASC` 读取当前活动段历史,并把 JSON 字段反序列化回 `ChatMessage`。
|
||||
|
||||
- 只返回 `seq > sessions.reset_cutoff_seq` 的消息
|
||||
- 因此 `/reset` 之后,旧消息仍然保留在数据库中,但不会默认回灌到运行时上下文
|
||||
|
||||
如果需要审计、导出或查看完整历史,应使用全量读取接口 `load_all_messages(session_id)`。
|
||||
如果需要审计、导出或查看历史,可使用全量读取接口 `load_all_messages(session_id)`(当前与 load_messages 相同)。
|
||||
|
||||
因此运行态恢复的是“当前活动段的逻辑顺序”,而不是简单按创建时间排序。只要 `seq` 连续,重放顺序就稳定。
|
||||
|
||||
@ -286,30 +282,14 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
|
||||
|
||||
- 删除该会话在 `messages` 中的所有记录
|
||||
- 将 `sessions.message_count` 重置为 0
|
||||
- 将 `sessions.reset_cutoff_seq` 重置为 0
|
||||
- 将 `sessions.user_turn_count` 重置为 0
|
||||
- 将 `sessions.agent_prompt_reinjection_count` 重置为 0
|
||||
- 更新 `updated_at` 和 `last_active_at`
|
||||
- 保留会话本身
|
||||
|
||||
这适合“保留会话入口,但丢弃聊天内容”的场景。
|
||||
这适合”保留会话入口,但丢弃聊天内容”的场景。
|
||||
|
||||
### 8.4 逻辑重置
|
||||
|
||||
`reset_session(session_id)`:
|
||||
|
||||
- 不删除 `messages` 中的任何记录
|
||||
- 将当前会话的 `MAX(seq)` 写入 `sessions.reset_cutoff_seq`
|
||||
- 将 `sessions.user_turn_count` 重置为 0
|
||||
- 将 `sessions.agent_prompt_reinjection_count` 重置为 0
|
||||
- 更新 `updated_at` 和 `last_active_at`
|
||||
- 后续默认恢复和发给模型的历史,只包含这次重置之后新增的消息
|
||||
|
||||
这适合“开始新对话,但保留完整历史以便审计或未来检索”的场景。
|
||||
|
||||
由于 AGENT.md 注入消息也会持久化,`/reset` 前的 Agent 设定消息仍会保留在完整历史中,但不会继续出现在新的活动段。下一次活动段首次加载时,系统会重新读取当前版本的 `~/.picobot/agent/AGENT.md`,并把它作为新的首条系统消息写入活动段。
|
||||
|
||||
### 8.5 删除会话
|
||||
### 8.4 删除会话
|
||||
|
||||
`delete_session(session_id)`:
|
||||
|
||||
@ -351,11 +331,8 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
|
||||
- `sessions.deleted_at`
|
||||
- 当前查询逻辑兼容软删除
|
||||
- 当前删除实现仍然是物理删除
|
||||
- `sessions.reset_cutoff_seq`
|
||||
- 当前已用于实现 `/reset` 的非破坏性逻辑重置
|
||||
- 只影响默认恢复的活动段,不影响数据库中的全量历史
|
||||
|
||||
这说明当前 schema 已经为“会话摘要”和“软删除”预留了演进空间,但并未完全落地。
|
||||
这说明当前 schema 已经为”会话摘要”和”软删除”预留了演进空间,但并未完全落地。
|
||||
|
||||
## 11. 给维护者的快速判断指南
|
||||
|
||||
@ -363,7 +340,6 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
|
||||
|
||||
- 会话查不到:先看 `persistent_session_id` 是否和实际 `channel_name/chat_id` 一致
|
||||
- 重启后没历史:检查 `ensure_chat_loaded()` 调用链,以及数据库文件路径是否正确
|
||||
- `/reset` 后重启又带回旧上下文:检查 `sessions.reset_cutoff_seq` 是否已写入,以及恢复路径是否走了活动段读取而不是全量读取
|
||||
- 消息顺序不对:检查 `messages.seq`
|
||||
- 工具调用上下文异常:同时检查 `tool_calls_json` 和 `tool_call_id`
|
||||
- 会话列表里看不到记录:检查 `archived_at` 和 `include_archived` 参数
|
||||
|
||||
@ -76,6 +76,7 @@ impl InitWizard {
|
||||
skills: crate::config::SkillsConfig::default(),
|
||||
tools: crate::config::ToolsConfig::default(),
|
||||
memory_maintenance: crate::config::MemoryMaintenanceConfig::default(),
|
||||
mcp: crate::mcp::McpConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -826,6 +827,7 @@ impl InitWizard {
|
||||
skills: existing.skills.clone(),
|
||||
tools: existing.tools.clone(),
|
||||
memory_maintenance: existing.memory_maintenance.clone(),
|
||||
mcp: existing.mcp.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -740,7 +740,6 @@ mod tests {
|
||||
archived_at: None,
|
||||
deleted_at: None,
|
||||
message_count: 0,
|
||||
reset_cutoff_seq: 0,
|
||||
user_turn_count: 0,
|
||||
agent_prompt_reinjection_count: 0,
|
||||
}
|
||||
|
||||
@ -29,6 +29,8 @@ pub struct Config {
|
||||
pub tools: ToolsConfig,
|
||||
#[serde(default)]
|
||||
pub memory_maintenance: MemoryMaintenanceConfig,
|
||||
#[serde(default)]
|
||||
pub mcp: crate::mcp::McpConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
@ -397,8 +399,6 @@ pub struct GatewayConfig {
|
||||
pub port: u16,
|
||||
#[serde(default)]
|
||||
pub show_tool_results: bool,
|
||||
#[serde(default, rename = "chat_history_ttl_hours")]
|
||||
pub chat_history_ttl_hours: Option<u64>,
|
||||
#[serde(
|
||||
default = "default_agent_prompt_reinject_every",
|
||||
rename = "agent_prompt_reinject_every"
|
||||
@ -714,7 +714,6 @@ impl Default for GatewayConfig {
|
||||
host: default_gateway_host(),
|
||||
port: default_gateway_port(),
|
||||
show_tool_results: false,
|
||||
chat_history_ttl_hours: Some(4),
|
||||
agent_prompt_reinject_every: default_agent_prompt_reinject_every(),
|
||||
max_concurrent_requests: default_max_concurrent_requests(),
|
||||
session_ttl_hours: Some(24),
|
||||
|
||||
@ -30,7 +30,6 @@ pub(crate) async fn schedule_background_history_compaction(
|
||||
(
|
||||
session_guard.store(),
|
||||
session_guard.persistent_session_id(&chat_id),
|
||||
session_record.reset_cutoff_seq,
|
||||
session_record.message_count,
|
||||
history,
|
||||
compressor,
|
||||
@ -41,7 +40,6 @@ pub(crate) async fn schedule_background_history_compaction(
|
||||
let (
|
||||
store,
|
||||
session_id,
|
||||
expected_reset_cutoff_seq,
|
||||
snapshot_end_seq,
|
||||
history,
|
||||
compressor,
|
||||
@ -61,7 +59,6 @@ pub(crate) async fn schedule_background_history_compaction(
|
||||
match compaction_result {
|
||||
Ok(Some(plan)) => match store.compact_active_history(
|
||||
&session_id,
|
||||
expected_reset_cutoff_seq,
|
||||
snapshot_end_seq,
|
||||
&plan.preserved_system_messages,
|
||||
&plan.summary_message,
|
||||
|
||||
@ -71,7 +71,21 @@
|
||||
- 如果任务涉及文件、命令、配置或下一步操作,优先给出最关键的那部分。
|
||||
|
||||
## PICO配置
|
||||
- Skill安装在[basedir]/skills
|
||||
|
||||
### 技能系统
|
||||
|
||||
- **技能存储路径**:
|
||||
- 项目级: `{project-root}/.picobot/skills/{skill-name}/SKILL.md`
|
||||
- 用户级: `~/.picobot/skills/{skill-name}/SKILL.md`
|
||||
|
||||
- **创建/修改技能**:
|
||||
- 必须使用 `skill_manage` 工具的 `create` 或 `update` action
|
||||
- 不要使用 `write` 工具直接写入技能文件
|
||||
- `skill_manage` 会自动创建正确的目录结构
|
||||
|
||||
- **使用技能**:
|
||||
- Skill 不是工具名,不能直接调用
|
||||
- 必须先调用 `skill_activate` 工具激活技能,再按指令执行
|
||||
|
||||
## 补充要求
|
||||
|
||||
|
||||
@ -64,7 +64,6 @@ pub(crate) struct ScheduledExecutionRequest<'a> {
|
||||
pub(crate) prompt: &'a str,
|
||||
pub(crate) sender_id: &'a str,
|
||||
pub(crate) provider_config: LLMProviderConfig,
|
||||
pub(crate) fresh_session: bool,
|
||||
pub(crate) system_prompt: Option<&'a str>,
|
||||
pub(crate) metadata: &'a HashMap<String, String>,
|
||||
}
|
||||
@ -272,10 +271,6 @@ impl AgentExecutionService {
|
||||
|
||||
session_guard.ensure_persistent_session(request.chat_id)?;
|
||||
|
||||
if request.fresh_session {
|
||||
session_guard.reset_chat_context(request.chat_id)?;
|
||||
}
|
||||
|
||||
session_guard.ensure_chat_loaded(request.chat_id)?;
|
||||
session_guard.ensure_agent_prompt_before_user_message(request.chat_id)?;
|
||||
|
||||
|
||||
@ -63,8 +63,6 @@ impl GatewayState {
|
||||
provider_configs.insert(agent_name.clone(), config.get_provider_config(agent_name)?);
|
||||
}
|
||||
|
||||
// Chat history TTL from config (default 4 hours)
|
||||
let chat_history_ttl_hours = config.gateway.chat_history_ttl_hours;
|
||||
let agent_prompt_reinject_every = config.gateway.agent_prompt_reinject_every;
|
||||
let show_tool_results = config.gateway.show_tool_results;
|
||||
|
||||
@ -85,8 +83,8 @@ impl GatewayState {
|
||||
std::collections::HashSet::new(),
|
||||
config.tools.task.clone(),
|
||||
config.memory_maintenance.clone(),
|
||||
chat_history_ttl_hours,
|
||||
session_ttl_hours,
|
||||
config.mcp.clone(),
|
||||
)?;
|
||||
|
||||
Ok(Self {
|
||||
|
||||
@ -4,6 +4,7 @@ use std::sync::Arc;
|
||||
use crate::agent::AgentError;
|
||||
use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig, TaskConfig};
|
||||
use crate::gateway::tool_registry_factory::ToolRegistryFactory;
|
||||
use crate::mcp::{McpClientManager, McpConfig};
|
||||
use crate::skills::SkillRuntime;
|
||||
use crate::storage::{
|
||||
ConversationRepository, MemoryRepository, PromptInjectionRepository, SchedulerJobRepository,
|
||||
@ -35,8 +36,8 @@ pub(crate) fn build_session_manager(
|
||||
disabled_tools: HashSet<String>,
|
||||
task_config: TaskConfig,
|
||||
maintenance_config: MemoryMaintenanceConfig,
|
||||
chat_history_ttl_hours: Option<u64>,
|
||||
session_ttl_hours: Option<u64>,
|
||||
mcp_config: McpConfig,
|
||||
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
|
||||
build_session_manager_with_sender(
|
||||
agent_prompt_reinject_every,
|
||||
@ -49,8 +50,8 @@ pub(crate) fn build_session_manager(
|
||||
disabled_tools,
|
||||
task_config,
|
||||
maintenance_config,
|
||||
chat_history_ttl_hours,
|
||||
session_ttl_hours,
|
||||
mcp_config,
|
||||
)
|
||||
}
|
||||
|
||||
@ -65,8 +66,8 @@ pub(crate) fn build_session_manager_with_sender(
|
||||
disabled_tools: HashSet<String>,
|
||||
task_config: TaskConfig,
|
||||
maintenance_config: MemoryMaintenanceConfig,
|
||||
chat_history_ttl_hours: Option<u64>,
|
||||
session_ttl_hours: Option<u64>,
|
||||
mcp_config: McpConfig,
|
||||
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
|
||||
let store = Arc::new(
|
||||
SessionStore::new()
|
||||
@ -103,6 +104,36 @@ pub(crate) fn build_session_manager_with_sender(
|
||||
task_config.clone(),
|
||||
);
|
||||
|
||||
// 创建 MCP Client Manager(如果启用)
|
||||
let mcp_manager = if mcp_config.has_enabled_servers() {
|
||||
let manager = Arc::new(McpClientManager::new());
|
||||
|
||||
// 在 tokio runtime 中连接 MCP servers
|
||||
// 使用 block_in_place 允许在同步上下文中执行异步代码
|
||||
let servers = mcp_config.enabled_servers();
|
||||
let servers_clone: Vec<_> = servers.into_iter().cloned().collect();
|
||||
|
||||
tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(async {
|
||||
tracing::info!("Connecting to MCP servers...");
|
||||
if let Err(e) = manager.connect_all(&servers_clone).await {
|
||||
tracing::error!(error = %e, "Failed to connect to some MCP servers");
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
Some(manager)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// 将 MCP manager 添加到 factory
|
||||
let factory = if let Some(ref manager) = mcp_manager {
|
||||
factory.with_mcp_manager(manager.clone())
|
||||
} else {
|
||||
factory
|
||||
};
|
||||
|
||||
// 创建 SubAgentRuntime(如果 task 工具启用)
|
||||
let (factory, task_repository): (_, Arc<dyn TaskRepository>) = if task_config.enabled {
|
||||
let task_repository = Arc::new(InMemoryTaskRepository::new());
|
||||
@ -131,7 +162,20 @@ pub(crate) fn build_session_manager_with_sender(
|
||||
(factory, Arc::new(InMemoryTaskRepository::new()))
|
||||
};
|
||||
|
||||
let tools = Arc::new(factory.build());
|
||||
let mut tools = factory.build();
|
||||
|
||||
// 注册 MCP tools(如果有 MCP manager)
|
||||
if let Some(manager) = &mcp_manager {
|
||||
tokio::task::block_in_place(|| {
|
||||
tokio::runtime::Handle::current().block_on(async {
|
||||
if let Err(e) = crate::mcp::register_mcp_tools(manager.clone(), &mut tools).await {
|
||||
tracing::error!(error = %e, "Failed to register MCP tools");
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
let tools = Arc::new(tools);
|
||||
|
||||
let prompt_repository: Arc<dyn PromptInjectionRepository> = store.clone();
|
||||
let agent_factory = AgentFactory::new(
|
||||
@ -147,7 +191,6 @@ pub(crate) fn build_session_manager_with_sender(
|
||||
conversations,
|
||||
skill_events,
|
||||
store.clone(),
|
||||
chat_history_ttl_hours,
|
||||
);
|
||||
let lifecycle = SessionLifecycleService::new(session_factory, session_ttl_hours);
|
||||
let cli_sessions = CliSessionService::new(store.clone());
|
||||
|
||||
@ -50,7 +50,6 @@ impl ScheduledAgentTaskService {
|
||||
prompt,
|
||||
sender_id: &sender_id,
|
||||
provider_config,
|
||||
fresh_session: options.fresh_session,
|
||||
system_prompt: options.system_prompt.as_deref(),
|
||||
metadata: &options.metadata,
|
||||
})
|
||||
|
||||
@ -103,7 +103,6 @@ impl Session {
|
||||
skills: Arc<SkillRuntime>,
|
||||
store: Arc<SessionStore>,
|
||||
agent_prompt_reinject_every: u64,
|
||||
chat_history_ttl_hours: Option<u64>,
|
||||
) -> Result<Self, AgentError> {
|
||||
let conversations: Arc<dyn ConversationRepository> = store.clone();
|
||||
let skill_events: Arc<dyn SkillEventRepository> = store.clone();
|
||||
@ -122,7 +121,6 @@ impl Session {
|
||||
agent_factory,
|
||||
conversations,
|
||||
skill_events,
|
||||
chat_history_ttl_hours,
|
||||
store,
|
||||
)
|
||||
.await
|
||||
@ -136,7 +134,6 @@ impl Session {
|
||||
agent_factory: AgentFactory,
|
||||
conversations: Arc<dyn ConversationRepository>,
|
||||
skill_events: Arc<dyn SkillEventRepository>,
|
||||
chat_history_ttl_hours: Option<u64>,
|
||||
store: Arc<SessionStore>,
|
||||
) -> Result<Self, AgentError> {
|
||||
Ok(Self {
|
||||
@ -151,7 +148,6 @@ impl Session {
|
||||
channel_name,
|
||||
conversations,
|
||||
skill_events,
|
||||
chat_history_ttl_hours,
|
||||
),
|
||||
store,
|
||||
})
|
||||
@ -267,10 +263,6 @@ impl Session {
|
||||
self.history.clear_chat_history(chat_id)
|
||||
}
|
||||
|
||||
pub fn reset_chat_context(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||
self.history.reset_chat_context(chat_id)
|
||||
}
|
||||
|
||||
/// 将消息写入内存与持久化层(使用当前 topic)
|
||||
pub fn append_persisted_message(
|
||||
&mut self,
|
||||
@ -502,8 +494,8 @@ impl SessionManager {
|
||||
disabled_tools: std::collections::HashSet<String>,
|
||||
task_config: crate::config::TaskConfig,
|
||||
maintenance_config: crate::config::MemoryMaintenanceConfig,
|
||||
chat_history_ttl_hours: Option<u64>,
|
||||
session_ttl_hours: Option<u64>,
|
||||
mcp_config: crate::mcp::McpConfig,
|
||||
) -> Result<Self, AgentError> {
|
||||
super::runtime::build_session_manager(
|
||||
agent_prompt_reinject_every,
|
||||
@ -515,8 +507,8 @@ impl SessionManager {
|
||||
disabled_tools,
|
||||
task_config,
|
||||
maintenance_config,
|
||||
chat_history_ttl_hours,
|
||||
session_ttl_hours,
|
||||
mcp_config,
|
||||
)
|
||||
.map(|(session_manager, _)| session_manager)
|
||||
}
|
||||
@ -731,7 +723,6 @@ mod tests {
|
||||
skills,
|
||||
store,
|
||||
100,
|
||||
Some(4),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@ -779,7 +770,6 @@ mod tests {
|
||||
skills,
|
||||
store.clone(),
|
||||
100,
|
||||
Some(4),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@ -813,7 +803,6 @@ mod tests {
|
||||
store
|
||||
.compact_active_history(
|
||||
&session_id,
|
||||
0,
|
||||
snapshot_end_seq,
|
||||
&[],
|
||||
&ChatMessage::system("[Compressed History]\n\nsummary"),
|
||||
@ -976,7 +965,6 @@ mod tests {
|
||||
HashSet::new(),
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(4),
|
||||
Some(24),
|
||||
)
|
||||
.unwrap();
|
||||
@ -1029,7 +1017,6 @@ mod tests {
|
||||
HashSet::new(),
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(4),
|
||||
Some(24),
|
||||
)
|
||||
.unwrap();
|
||||
@ -1041,7 +1028,6 @@ mod tests {
|
||||
"请规划今天工作",
|
||||
ScheduledAgentTaskOptions {
|
||||
agent: Some("planner".to_string()),
|
||||
fresh_session: true,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
@ -1057,7 +1043,6 @@ mod tests {
|
||||
"请规划今天工作",
|
||||
ScheduledAgentTaskOptions {
|
||||
agent: Some("default".to_string()),
|
||||
fresh_session: true,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
@ -1098,7 +1083,6 @@ mod tests {
|
||||
HashSet::new(),
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(4),
|
||||
Some(24),
|
||||
)
|
||||
.unwrap();
|
||||
@ -1109,7 +1093,6 @@ mod tests {
|
||||
"chat-guard",
|
||||
"每小时执行以下流程:检查邮箱并同步待办",
|
||||
ScheduledAgentTaskOptions {
|
||||
fresh_session: true,
|
||||
system_prompt: Some("你是邮箱待办同步助手。".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
@ -1184,7 +1167,6 @@ mod tests {
|
||||
HashSet::new(),
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(4),
|
||||
Some(24),
|
||||
)
|
||||
.unwrap();
|
||||
@ -1271,7 +1253,6 @@ mod tests {
|
||||
HashSet::new(),
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(4),
|
||||
Some(24),
|
||||
)
|
||||
.unwrap();
|
||||
@ -1357,7 +1338,6 @@ mod tests {
|
||||
HashSet::new(),
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(4),
|
||||
Some(24),
|
||||
)
|
||||
.unwrap();
|
||||
@ -1425,7 +1405,6 @@ mod tests {
|
||||
HashSet::new(),
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(4),
|
||||
Some(24),
|
||||
)
|
||||
.unwrap();
|
||||
@ -1502,7 +1481,6 @@ mod tests {
|
||||
HashSet::new(),
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(4),
|
||||
Some(24),
|
||||
)
|
||||
.unwrap();
|
||||
@ -1566,7 +1544,6 @@ mod tests {
|
||||
HashSet::new(),
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(4),
|
||||
Some(24),
|
||||
)
|
||||
.unwrap();
|
||||
@ -1773,7 +1750,6 @@ mod tests {
|
||||
skills,
|
||||
store.clone(),
|
||||
100,
|
||||
Some(4),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@ -1813,7 +1789,6 @@ mod tests {
|
||||
skills,
|
||||
store.clone(),
|
||||
100,
|
||||
Some(4),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@ -1886,7 +1861,6 @@ mod tests {
|
||||
skills,
|
||||
store.clone(),
|
||||
0,
|
||||
Some(4),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@ -19,7 +19,6 @@ pub(crate) struct SessionFactory {
|
||||
conversations: Arc<dyn ConversationRepository>,
|
||||
skill_events: Arc<dyn SkillEventRepository>,
|
||||
store: Arc<SessionStore>,
|
||||
chat_history_ttl_hours: Option<u64>,
|
||||
}
|
||||
|
||||
impl SessionFactory {
|
||||
@ -30,7 +29,6 @@ impl SessionFactory {
|
||||
conversations: Arc<dyn ConversationRepository>,
|
||||
skill_events: Arc<dyn SkillEventRepository>,
|
||||
store: Arc<SessionStore>,
|
||||
chat_history_ttl_hours: Option<u64>,
|
||||
) -> Self {
|
||||
Self {
|
||||
provider_config,
|
||||
@ -39,7 +37,6 @@ impl SessionFactory {
|
||||
conversations,
|
||||
skill_events,
|
||||
store,
|
||||
chat_history_ttl_hours,
|
||||
}
|
||||
}
|
||||
|
||||
@ -56,7 +53,6 @@ impl SessionFactory {
|
||||
self.agent_factory.clone(),
|
||||
self.conversations.clone(),
|
||||
self.skill_events.clone(),
|
||||
self.chat_history_ttl_hours,
|
||||
self.store.clone(),
|
||||
)
|
||||
.await
|
||||
|
||||
@ -15,13 +15,6 @@ fn preview_text(content: &str, max_chars: usize) -> String {
|
||||
preview.replace('\n', "\\n")
|
||||
}
|
||||
|
||||
fn current_timestamp() -> i64 {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs() as i64
|
||||
}
|
||||
|
||||
pub(crate) struct SessionHistory {
|
||||
channel_name: String,
|
||||
chat_histories: HashMap<String, Vec<ChatMessage>>,
|
||||
@ -30,7 +23,6 @@ pub(crate) struct SessionHistory {
|
||||
compression_in_flight: HashSet<String>,
|
||||
conversations: Arc<dyn ConversationRepository>,
|
||||
skill_events: Arc<dyn SkillEventRepository>,
|
||||
chat_history_ttl_hours: Option<u64>,
|
||||
}
|
||||
|
||||
impl SessionHistory {
|
||||
@ -38,7 +30,6 @@ impl SessionHistory {
|
||||
channel_name: impl Into<String>,
|
||||
conversations: Arc<dyn ConversationRepository>,
|
||||
skill_events: Arc<dyn SkillEventRepository>,
|
||||
chat_history_ttl_hours: Option<u64>,
|
||||
) -> Self {
|
||||
Self {
|
||||
channel_name: channel_name.into(),
|
||||
@ -48,7 +39,6 @@ impl SessionHistory {
|
||||
compression_in_flight: HashSet::new(),
|
||||
conversations,
|
||||
skill_events,
|
||||
chat_history_ttl_hours,
|
||||
}
|
||||
}
|
||||
|
||||
@ -70,29 +60,6 @@ impl SessionHistory {
|
||||
chat_id: &str,
|
||||
topic_id: Option<&str>,
|
||||
) -> Result<(), AgentError> {
|
||||
// 获取 session 记录(用于检查最后活跃时间)
|
||||
let session_record = self.ensure_persistent_session(chat_id)?;
|
||||
|
||||
// 检查是否超时
|
||||
if let Some(ttl_hours) = self.chat_history_ttl_hours {
|
||||
if ttl_hours > 0 {
|
||||
let now = current_timestamp();
|
||||
let elapsed_hours = (now - session_record.last_active_at) / 3600;
|
||||
if elapsed_hours >= ttl_hours as i64 {
|
||||
tracing::info!(
|
||||
channel = %self.channel_name,
|
||||
chat_id = %chat_id,
|
||||
elapsed_hours = elapsed_hours,
|
||||
ttl_hours = ttl_hours,
|
||||
"Chat history expired, resetting context"
|
||||
);
|
||||
// 重置会话上下文(清空内存历史,但保留数据库记录)
|
||||
self.reset_chat_context(chat_id)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 原有逻辑
|
||||
if self.chat_histories.contains_key(chat_id) {
|
||||
return Ok(());
|
||||
}
|
||||
@ -178,19 +145,6 @@ impl SessionHistory {
|
||||
.map_err(|err| AgentError::Other(format!("clear history persistence error: {}", err)))
|
||||
}
|
||||
|
||||
pub(crate) fn reset_chat_context(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||
if let Some(history) = self.chat_histories.get_mut(chat_id) {
|
||||
let len = history.len();
|
||||
history.clear();
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history reset in memory");
|
||||
}
|
||||
|
||||
self.conversations
|
||||
.reset_session(&self.persistent_session_id(chat_id))
|
||||
.map_err(|err| AgentError::Other(format!("reset history persistence error: {}", err)))
|
||||
}
|
||||
|
||||
pub(crate) fn append_persisted_message(
|
||||
&mut self,
|
||||
chat_id: &str,
|
||||
|
||||
@ -2,6 +2,7 @@ use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::config::TaskConfig;
|
||||
use crate::mcp::McpClientManager;
|
||||
use crate::skills::SkillRuntime;
|
||||
use crate::storage::{MemoryRepository, SchedulerJobRepository, SkillEventRepository};
|
||||
use crate::tools::{
|
||||
@ -23,6 +24,7 @@ pub(crate) struct ToolRegistryFactory {
|
||||
disabled_tools: HashSet<String>,
|
||||
task_config: TaskConfig,
|
||||
subagent_runtime: Option<Arc<dyn SubAgentRuntime>>,
|
||||
mcp_manager: Option<Arc<McpClientManager>>,
|
||||
}
|
||||
|
||||
impl ToolRegistryFactory {
|
||||
@ -48,6 +50,7 @@ impl ToolRegistryFactory {
|
||||
disabled_tools,
|
||||
task_config,
|
||||
subagent_runtime: None,
|
||||
mcp_manager: None,
|
||||
}
|
||||
}
|
||||
|
||||
@ -59,6 +62,14 @@ impl ToolRegistryFactory {
|
||||
self
|
||||
}
|
||||
|
||||
pub(crate) fn with_mcp_manager(
|
||||
mut self,
|
||||
manager: Arc<McpClientManager>,
|
||||
) -> Self {
|
||||
self.mcp_manager = Some(manager);
|
||||
self
|
||||
}
|
||||
|
||||
fn is_enabled(&self, tool_name: &str) -> bool {
|
||||
!self.disabled_tools.contains(tool_name)
|
||||
}
|
||||
|
||||
@ -9,6 +9,7 @@ pub mod config;
|
||||
pub mod domain;
|
||||
pub mod gateway;
|
||||
pub mod logging;
|
||||
pub mod mcp;
|
||||
pub mod observability;
|
||||
pub mod platform;
|
||||
pub mod protocol;
|
||||
|
||||
262
src/mcp/client.rs
Normal file
262
src/mcp/client.rs
Normal file
@ -0,0 +1,262 @@
|
||||
//! MCP Client Manager - manages connections to MCP servers
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use rmcp::{
|
||||
model::{CallToolRequestParams, CallToolResult, ServerInfo, Tool},
|
||||
RoleClient, ServiceExt,
|
||||
service::RunningService,
|
||||
transport::TokioChildProcess,
|
||||
transport::streamable_http_client::{StreamableHttpClientTransport, StreamableHttpClientTransportConfig},
|
||||
};
|
||||
use http::{HeaderName, HeaderValue};
|
||||
use tokio::process::Command;
|
||||
|
||||
use crate::mcp::config::{McpServerConfig, McpTransportConfig};
|
||||
|
||||
/// Type alias for the MCP client service
|
||||
pub type McpClient = RunningService<RoleClient, ()>;
|
||||
|
||||
/// Information about a connected MCP server
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct McpServerInfo {
|
||||
/// Server name
|
||||
pub name: String,
|
||||
/// Server information from MCP protocol
|
||||
pub info: Option<ServerInfo>,
|
||||
/// Available tools
|
||||
pub tools: Vec<Tool>,
|
||||
}
|
||||
|
||||
/// Manager for MCP client connections
|
||||
pub struct McpClientManager {
|
||||
/// Connected clients keyed by server name
|
||||
clients: RwLock<HashMap<String, Arc<McpClient>>>,
|
||||
/// Server information cache
|
||||
server_info: RwLock<HashMap<String, McpServerInfo>>,
|
||||
}
|
||||
|
||||
impl McpClientManager {
|
||||
/// Create a new manager
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
clients: RwLock::new(HashMap::new()),
|
||||
server_info: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Connect to all configured servers
|
||||
pub async fn connect_all(&self, servers: &[McpServerConfig]) -> anyhow::Result<()> {
|
||||
for server in servers {
|
||||
if !server.enabled {
|
||||
tracing::info!(name = %server.name, "Skipping disabled MCP server");
|
||||
continue;
|
||||
}
|
||||
|
||||
match self.connect_server(server).await {
|
||||
Ok(info) => {
|
||||
tracing::info!(
|
||||
name = %server.name,
|
||||
tools_count = info.tools.len(),
|
||||
"Connected to MCP server"
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
name = %server.name,
|
||||
error = %e,
|
||||
"Failed to connect to MCP server"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Connect to a single MCP server
|
||||
pub async fn connect_server(&self, config: &McpServerConfig) -> anyhow::Result<McpServerInfo> {
|
||||
tracing::info!(name = %config.name, "Connecting to MCP server");
|
||||
|
||||
let client = match &config.transport {
|
||||
McpTransportConfig::Stdio { command, args, env } => {
|
||||
self.connect_stdio(command, args, env).await?
|
||||
}
|
||||
McpTransportConfig::Http { url, headers } => {
|
||||
self.connect_http(url, headers).await?
|
||||
}
|
||||
};
|
||||
|
||||
// Get server info (returns Option<&ServerInfo>)
|
||||
let info = client.peer_info().cloned();
|
||||
|
||||
// List available tools
|
||||
let tools = client.list_all_tools().await?;
|
||||
|
||||
let server_info = McpServerInfo {
|
||||
name: config.name.clone(),
|
||||
info,
|
||||
tools,
|
||||
};
|
||||
|
||||
// Store the client and info
|
||||
{
|
||||
let mut clients = self.clients.write().await;
|
||||
clients.insert(config.name.clone(), Arc::new(client));
|
||||
}
|
||||
{
|
||||
let mut info_map = self.server_info.write().await;
|
||||
info_map.insert(config.name.clone(), server_info.clone());
|
||||
}
|
||||
|
||||
Ok(server_info)
|
||||
}
|
||||
|
||||
/// Connect via stdio transport (spawn child process)
|
||||
async fn connect_stdio(
|
||||
&self,
|
||||
command: &str,
|
||||
args: &[String],
|
||||
env: &HashMap<String, String>,
|
||||
) -> anyhow::Result<McpClient> {
|
||||
let mut cmd = Command::new(command);
|
||||
cmd.args(args);
|
||||
|
||||
// Set environment variables
|
||||
for (key, value) in env {
|
||||
cmd.env(key, value);
|
||||
}
|
||||
|
||||
let transport = TokioChildProcess::new(cmd)?;
|
||||
|
||||
// Use default client handler (empty tuple)
|
||||
let client = ().serve(transport).await?;
|
||||
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
/// Connect via HTTP transport (Streamable HTTP)
|
||||
async fn connect_http(
|
||||
&self,
|
||||
url: &str,
|
||||
headers: &HashMap<String, String>,
|
||||
) -> anyhow::Result<McpClient> {
|
||||
// Build custom headers
|
||||
let custom_headers: HashMap<HeaderName, HeaderValue> = headers
|
||||
.iter()
|
||||
.filter_map(|(key, value)| {
|
||||
// Try to parse header name and value
|
||||
HeaderName::try_from(key.clone())
|
||||
.ok()
|
||||
.and_then(|name| {
|
||||
HeaderValue::try_from(value.clone())
|
||||
.ok()
|
||||
.map(|val| (name, val))
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Create transport config with custom headers
|
||||
let config = StreamableHttpClientTransportConfig::with_uri(url)
|
||||
.custom_headers(custom_headers);
|
||||
|
||||
// Create transport using reqwest client (default)
|
||||
let transport = StreamableHttpClientTransport::with_client(
|
||||
reqwest::Client::default(),
|
||||
config,
|
||||
);
|
||||
|
||||
// Connect
|
||||
let client = ().serve(transport).await?;
|
||||
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
/// Get a client by server name
|
||||
pub async fn get_client(&self, name: &str) -> Option<Arc<McpClient>> {
|
||||
let clients = self.clients.read().await;
|
||||
clients.get(name).cloned()
|
||||
}
|
||||
|
||||
/// Get server info by name
|
||||
pub async fn get_server_info(&self, name: &str) -> Option<McpServerInfo> {
|
||||
let info_map = self.server_info.read().await;
|
||||
info_map.get(name).cloned()
|
||||
}
|
||||
|
||||
/// Get all connected server names
|
||||
pub async fn connected_servers(&self) -> Vec<String> {
|
||||
let clients = self.clients.read().await;
|
||||
clients.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Get all tools from all connected servers
|
||||
pub async fn all_tools(&self) -> Vec<(String, Tool)> {
|
||||
let info_map = self.server_info.read().await;
|
||||
info_map
|
||||
.values()
|
||||
.flat_map(|info| {
|
||||
info.tools.iter().map(|tool| (info.name.clone(), tool.clone()))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Call a tool on a specific server
|
||||
pub async fn call_tool(
|
||||
&self,
|
||||
server_name: impl Into<String>,
|
||||
tool_name: impl Into<String>,
|
||||
args: serde_json::Value,
|
||||
) -> anyhow::Result<CallToolResult> {
|
||||
let server_name = server_name.into();
|
||||
let tool_name = tool_name.into();
|
||||
|
||||
let client = self
|
||||
.get_client(&server_name)
|
||||
.await
|
||||
.ok_or_else(|| anyhow::anyhow!("MCP server '{}' not connected", server_name))?;
|
||||
|
||||
// Convert Value to JsonObject if it's an object
|
||||
let arguments = if args.is_object() {
|
||||
args.as_object().unwrap().clone()
|
||||
} else {
|
||||
// If not an object, wrap it or use empty object
|
||||
serde_json::Map::new()
|
||||
};
|
||||
|
||||
// Create params with owned String (converted to Cow<'static, str>)
|
||||
let params = CallToolRequestParams::new(tool_name).with_arguments(arguments);
|
||||
|
||||
let result = client.call_tool(params).await?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Disconnect from a server
|
||||
pub async fn disconnect(&self, name: impl Into<String>) -> anyhow::Result<()> {
|
||||
let name = name.into();
|
||||
let mut clients = self.clients.write().await;
|
||||
if clients.remove(&name).is_some() {
|
||||
tracing::info!(name = %name, "Disconnected MCP server");
|
||||
}
|
||||
self.server_info.write().await.remove(&name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Disconnect from all servers
|
||||
pub async fn disconnect_all(&self) -> anyhow::Result<()> {
|
||||
let mut clients = self.clients.write().await;
|
||||
for (name, _client) in clients.drain() {
|
||||
tracing::info!(name = %name, "Disconnected MCP server");
|
||||
}
|
||||
self.server_info.write().await.clear();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for McpClientManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
173
src/mcp/config.rs
Normal file
173
src/mcp/config.rs
Normal file
@ -0,0 +1,173 @@
|
||||
//! MCP Server configuration structures
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// MCP integration configuration
|
||||
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
|
||||
pub struct McpConfig {
|
||||
/// Whether MCP integration is enabled
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
|
||||
/// List of MCP servers to connect
|
||||
#[serde(default)]
|
||||
pub servers: Vec<McpServerConfig>,
|
||||
}
|
||||
|
||||
/// Configuration for a single MCP server
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct McpServerConfig {
|
||||
/// Unique name for this server (used in tool naming)
|
||||
pub name: String,
|
||||
|
||||
/// Transport configuration
|
||||
pub transport: McpTransportConfig,
|
||||
|
||||
/// Whether this server is enabled
|
||||
#[serde(default = "default_server_enabled")]
|
||||
pub enabled: bool,
|
||||
|
||||
/// Optional description for the server
|
||||
#[serde(default)]
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
fn default_server_enabled() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// Transport configuration for connecting to MCP servers
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum McpTransportConfig {
|
||||
/// Stdio transport: spawn a child process
|
||||
Stdio {
|
||||
/// Command to execute (e.g., "npx", "cargo")
|
||||
command: String,
|
||||
/// Arguments to pass to the command
|
||||
#[serde(default)]
|
||||
args: Vec<String>,
|
||||
/// Optional environment variables to set
|
||||
#[serde(default)]
|
||||
env: HashMap<String, String>,
|
||||
},
|
||||
|
||||
/// HTTP transport: connect to a remote server
|
||||
Http {
|
||||
/// URL of the MCP server endpoint
|
||||
url: String,
|
||||
/// Optional headers to include in requests
|
||||
#[serde(default)]
|
||||
headers: HashMap<String, String>,
|
||||
},
|
||||
}
|
||||
|
||||
impl McpServerConfig {
|
||||
/// Create a stdio server config
|
||||
pub fn stdio(name: impl Into<String>, command: impl Into<String>, args: Vec<String>) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
transport: McpTransportConfig::Stdio {
|
||||
command: command.into(),
|
||||
args,
|
||||
env: HashMap::new(),
|
||||
},
|
||||
enabled: true,
|
||||
description: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an HTTP server config
|
||||
pub fn http(name: impl Into<String>, url: impl Into<String>) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
transport: McpTransportConfig::Http {
|
||||
url: url.into(),
|
||||
headers: HashMap::new(),
|
||||
},
|
||||
enabled: true,
|
||||
description: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl McpConfig {
|
||||
/// Get enabled servers
|
||||
pub fn enabled_servers(&self) -> Vec<&McpServerConfig> {
|
||||
self.servers.iter().filter(|s| s.enabled).collect()
|
||||
}
|
||||
|
||||
/// Check if there are any enabled servers
|
||||
pub fn has_enabled_servers(&self) -> bool {
|
||||
self.enabled && self.servers.iter().any(|s| s.enabled)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_stdio_config_creation() {
|
||||
let config = McpServerConfig::stdio(
|
||||
"filesystem",
|
||||
"npx",
|
||||
vec!["-y", "@modelcontextprotocol/server-filesystem", "/tmp"],
|
||||
);
|
||||
|
||||
assert_eq!(config.name, "filesystem");
|
||||
assert!(config.enabled);
|
||||
assert!(matches!(config.transport, McpTransportConfig::Stdio { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_http_config_creation() {
|
||||
let config = McpServerConfig::http("custom", "http://localhost:8000/mcp");
|
||||
|
||||
assert_eq!(config.name, "custom");
|
||||
assert!(config.enabled);
|
||||
assert!(matches!(config.transport, McpTransportConfig::Http { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_deserialization() {
|
||||
let json = r#"{
|
||||
"enabled": true,
|
||||
"servers": [
|
||||
{
|
||||
"name": "filesystem",
|
||||
"transport": {
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "http-server",
|
||||
"enabled": false,
|
||||
"transport": {
|
||||
"type": "http",
|
||||
"url": "http://localhost:8000/mcp",
|
||||
"headers": {
|
||||
"Authorization": "Bearer token"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}"#;
|
||||
|
||||
let config: McpConfig = serde_json::from_str(json).unwrap();
|
||||
assert!(config.enabled);
|
||||
assert_eq!(config.servers.len(), 2);
|
||||
assert_eq!(config.enabled_servers().len(), 1);
|
||||
|
||||
let fs_server = &config.servers[0];
|
||||
assert_eq!(fs_server.name, "filesystem");
|
||||
assert!(fs_server.enabled);
|
||||
|
||||
let http_server = &config.servers[1];
|
||||
assert_eq!(http_server.name, "http-server");
|
||||
assert!(!http_server.enabled);
|
||||
}
|
||||
}
|
||||
12
src/mcp/mod.rs
Normal file
12
src/mcp/mod.rs
Normal file
@ -0,0 +1,12 @@
|
||||
//! MCP (Model Context Protocol) integration module
|
||||
//!
|
||||
//! This module provides MCP client functionality to connect to external MCP servers
|
||||
//! and expose their tools through PicoBot's Tool system.
|
||||
|
||||
pub mod config;
|
||||
pub mod client;
|
||||
pub mod tool_adapter;
|
||||
|
||||
pub use config::{McpConfig, McpServerConfig, McpTransportConfig};
|
||||
pub use client::{McpClientManager, McpClient, McpServerInfo};
|
||||
pub use tool_adapter::{McpToolWrapper, register_mcp_tools};
|
||||
186
src/mcp/tool_adapter.rs
Normal file
186
src/mcp/tool_adapter.rs
Normal file
@ -0,0 +1,186 @@
|
||||
//! MCP Tool Adapter - wraps MCP tools as PicoBot tools
|
||||
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Arc;
|
||||
|
||||
use rmcp::model::Tool;
|
||||
|
||||
use crate::mcp::client::McpClientManager;
|
||||
use crate::tools::traits::{Tool as PicoBotTool, ToolResult};
|
||||
|
||||
/// Wrapper that adapts an MCP tool to PicoBot's Tool trait
|
||||
pub struct McpToolWrapper {
|
||||
/// The MCP client manager
|
||||
manager: Arc<McpClientManager>,
|
||||
/// The server name this tool belongs to
|
||||
server_name: String,
|
||||
/// The original tool name on the MCP server
|
||||
tool_name: String,
|
||||
/// The full tool name with namespace (mcp_{server}_{tool})
|
||||
full_name: String,
|
||||
/// Tool information from MCP server
|
||||
tool_info: Tool,
|
||||
}
|
||||
|
||||
impl McpToolWrapper {
|
||||
/// Create a new tool wrapper
|
||||
pub fn new(
|
||||
manager: Arc<McpClientManager>,
|
||||
server_name: String,
|
||||
tool_info: Tool,
|
||||
) -> Self {
|
||||
let tool_name = tool_info.name.clone().into_owned();
|
||||
let full_name = format!("mcp_{}_{}", server_name, tool_name);
|
||||
Self {
|
||||
manager,
|
||||
server_name,
|
||||
tool_name,
|
||||
full_name,
|
||||
tool_info,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the server name
|
||||
pub fn server_name(&self) -> &str {
|
||||
&self.server_name
|
||||
}
|
||||
|
||||
/// Get the original tool name
|
||||
pub fn original_name(&self) -> &str {
|
||||
&self.tool_name
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl PicoBotTool for McpToolWrapper {
|
||||
fn name(&self) -> &str {
|
||||
&self.full_name
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
self.tool_info.description.as_deref().unwrap_or("MCP tool")
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
// Convert Arc<JsonObject> to serde_json::Value
|
||||
let schema = (*self.tool_info.input_schema).clone();
|
||||
serde_json::Value::Object(schema)
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
tracing::debug!(
|
||||
server = %self.server_name,
|
||||
tool = %self.tool_name,
|
||||
"Calling MCP tool"
|
||||
);
|
||||
|
||||
let result = self
|
||||
.manager
|
||||
.call_tool(&self.server_name, &self.tool_name, args)
|
||||
.await?;
|
||||
|
||||
// Convert MCP CallToolResult to PicoBot ToolResult
|
||||
let output = extract_text_content(&result);
|
||||
let is_error = result.is_error.unwrap_or(false);
|
||||
|
||||
Ok(ToolResult {
|
||||
success: !is_error,
|
||||
output,
|
||||
error: if is_error {
|
||||
Some("MCP tool returned error".to_string())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn read_only(&self) -> bool {
|
||||
// MCP tools may or may not be read-only; we default to false
|
||||
// This could be enhanced if MCP servers provide this info via annotations
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract text content from MCP CallToolResult
|
||||
fn extract_text_content(result: &rmcp::model::CallToolResult) -> String {
|
||||
let mut text_parts = Vec::new();
|
||||
|
||||
for content in &result.content {
|
||||
if let Some(text) = content.as_text() {
|
||||
text_parts.push(text.text.clone());
|
||||
}
|
||||
}
|
||||
|
||||
if text_parts.is_empty() {
|
||||
// No text content found, try to serialize the whole result
|
||||
serde_json::to_string_pretty(&result).unwrap_or_else(|_| "Empty result".to_string())
|
||||
} else {
|
||||
text_parts.join("\n")
|
||||
}
|
||||
}
|
||||
|
||||
/// Register all MCP tools from connected servers into a tool registry
|
||||
pub async fn register_mcp_tools(
|
||||
manager: Arc<McpClientManager>,
|
||||
registry: &mut crate::tools::registry::ToolRegistry,
|
||||
) -> anyhow::Result<()> {
|
||||
let all_tools = manager.all_tools().await;
|
||||
|
||||
for (server_name, tool_info) in all_tools {
|
||||
let wrapper = McpToolWrapper::new(
|
||||
manager.clone(),
|
||||
server_name.clone(),
|
||||
tool_info,
|
||||
);
|
||||
|
||||
tracing::info!(
|
||||
name = %wrapper.name(),
|
||||
server = %server_name,
|
||||
"Registering MCP tool"
|
||||
);
|
||||
|
||||
registry.register(wrapper);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rmcp::model::{CallToolResult, Content};
|
||||
|
||||
#[test]
|
||||
fn test_extract_text_content_from_text() {
|
||||
let result = CallToolResult::success(vec![
|
||||
Content::text("Hello"),
|
||||
Content::text("World"),
|
||||
]);
|
||||
|
||||
let text = extract_text_content(&result);
|
||||
assert_eq!(text, "Hello\nWorld");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_text_content_empty() {
|
||||
let result = CallToolResult::success(vec![]);
|
||||
let text = extract_text_content(&result);
|
||||
assert!(text.contains("Empty result"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mcp_tool_wrapper_name() {
|
||||
let manager = Arc::new(McpClientManager::new());
|
||||
let tool_info = Tool {
|
||||
name: "echo".into(),
|
||||
description: Some("Echo tool".into()),
|
||||
input_schema: serde_json::json!({"type": "object"}).as_object().unwrap().clone(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let wrapper = McpToolWrapper::new(manager, "filesystem".to_string(), tool_info);
|
||||
assert_eq!(wrapper.name(), "mcp_filesystem_echo");
|
||||
assert_eq!(wrapper.original_name(), "echo");
|
||||
assert_eq!(wrapper.server_name(), "filesystem");
|
||||
}
|
||||
}
|
||||
@ -20,7 +20,6 @@ use crate::storage::{
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ScheduledAgentTaskOptions {
|
||||
pub sender_id: Option<String>,
|
||||
pub fresh_session: bool,
|
||||
pub system_prompt: Option<String>,
|
||||
pub metadata: HashMap<String, String>,
|
||||
pub agent: Option<String>,
|
||||
@ -1019,11 +1018,6 @@ fn parse_scheduled_agent_task_options(
|
||||
.get("sender_id")
|
||||
.and_then(|value| value.as_str())
|
||||
.map(ToString::to_string);
|
||||
let fresh_session = job
|
||||
.payload
|
||||
.get("fresh_session")
|
||||
.and_then(|value| value.as_bool())
|
||||
.unwrap_or(false);
|
||||
let system_prompt = job
|
||||
.payload
|
||||
.get("system_prompt")
|
||||
@ -1038,7 +1032,6 @@ fn parse_scheduled_agent_task_options(
|
||||
|
||||
Ok(ScheduledAgentTaskOptions {
|
||||
sender_id,
|
||||
fresh_session,
|
||||
system_prompt,
|
||||
metadata,
|
||||
agent,
|
||||
@ -1219,7 +1212,6 @@ mod agent_task_tests {
|
||||
let options = parse_scheduled_agent_task_options(&job).unwrap();
|
||||
assert_eq!(options.agent.as_deref(), Some("planner"));
|
||||
assert_eq!(options.sender_id.as_deref(), Some("scheduler-bot"));
|
||||
assert!(options.fresh_session);
|
||||
assert_eq!(options.system_prompt.as_deref(), Some("你是日报助手"));
|
||||
assert_eq!(
|
||||
options.metadata.get("job_type").map(String::as_str),
|
||||
|
||||
@ -66,7 +66,6 @@ impl SessionStore {
|
||||
archived_at INTEGER,
|
||||
deleted_at INTEGER,
|
||||
message_count INTEGER NOT NULL DEFAULT 0,
|
||||
reset_cutoff_seq INTEGER NOT NULL DEFAULT 0,
|
||||
user_turn_count INTEGER NOT NULL DEFAULT 0,
|
||||
agent_prompt_reinjection_count INTEGER NOT NULL DEFAULT 0
|
||||
);
|
||||
@ -248,8 +247,8 @@ impl SessionStore {
|
||||
INSERT INTO sessions (
|
||||
id, title, channel_name, chat_id, summary,
|
||||
created_at, updated_at, last_active_at, archived_at, deleted_at, message_count,
|
||||
reset_cutoff_seq, user_turn_count, agent_prompt_reinjection_count
|
||||
) VALUES (?1, ?2, ?3, ?4, NULL, ?5, ?5, ?5, NULL, NULL, 0, 0, 0, 0)
|
||||
user_turn_count, agent_prompt_reinjection_count
|
||||
) VALUES (?1, ?2, ?3, ?4, NULL, ?5, ?5, ?5, NULL, NULL, 0, 0, 0)
|
||||
",
|
||||
params![&session_id, title, channel_name, id, now],
|
||||
)?;
|
||||
@ -291,8 +290,8 @@ impl SessionStore {
|
||||
INSERT INTO sessions (
|
||||
id, title, channel_name, chat_id, summary,
|
||||
created_at, updated_at, last_active_at, archived_at, deleted_at, message_count,
|
||||
reset_cutoff_seq, user_turn_count, agent_prompt_reinjection_count
|
||||
) VALUES (?1, ?2, ?3, ?4, NULL, ?5, ?5, ?5, NULL, NULL, 0, 0, 0, 0)
|
||||
user_turn_count, agent_prompt_reinjection_count
|
||||
) VALUES (?1, ?2, ?3, ?4, NULL, ?5, ?5, ?5, NULL, NULL, 0, 0, 0)
|
||||
",
|
||||
params![session_id, title, channel_name, chat_id, now],
|
||||
)?;
|
||||
@ -308,7 +307,7 @@ impl SessionStore {
|
||||
"
|
||||
SELECT id, title, channel_name, chat_id, summary,
|
||||
created_at, updated_at, last_active_at,
|
||||
archived_at, deleted_at, message_count, reset_cutoff_seq,
|
||||
archived_at, deleted_at, message_count,
|
||||
user_turn_count, agent_prompt_reinjection_count
|
||||
FROM sessions
|
||||
WHERE id = ?1 AND deleted_at IS NULL
|
||||
@ -330,7 +329,7 @@ impl SessionStore {
|
||||
"
|
||||
SELECT id, title, channel_name, chat_id, summary,
|
||||
created_at, updated_at, last_active_at,
|
||||
archived_at, deleted_at, message_count, reset_cutoff_seq,
|
||||
archived_at, deleted_at, message_count,
|
||||
user_turn_count, agent_prompt_reinjection_count
|
||||
FROM sessions
|
||||
WHERE channel_name = ?1
|
||||
@ -493,7 +492,6 @@ impl SessionStore {
|
||||
SET message_count = 0,
|
||||
updated_at = ?2,
|
||||
last_active_at = ?2,
|
||||
reset_cutoff_seq = 0,
|
||||
user_turn_count = 0,
|
||||
agent_prompt_reinjection_count = 0
|
||||
WHERE id = ?1 AND deleted_at IS NULL
|
||||
@ -503,35 +501,6 @@ impl SessionStore {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn reset_session(&self, session_id: &str) -> Result<(), StorageError> {
|
||||
let now = current_timestamp();
|
||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||
let tx = conn.unchecked_transaction()?;
|
||||
|
||||
let cutoff_seq: i64 = tx.query_row(
|
||||
"SELECT COALESCE(MAX(seq), 0) FROM messages WHERE session_id = ?1",
|
||||
params![session_id],
|
||||
|row| row.get(0),
|
||||
)?;
|
||||
|
||||
tx.execute(
|
||||
"
|
||||
UPDATE sessions
|
||||
SET reset_cutoff_seq = ?2,
|
||||
updated_at = ?3,
|
||||
last_active_at = ?3,
|
||||
archived_at = NULL,
|
||||
user_turn_count = 0,
|
||||
agent_prompt_reinjection_count = 0
|
||||
WHERE id = ?1 AND deleted_at IS NULL
|
||||
",
|
||||
params![session_id, cutoff_seq, now],
|
||||
)?;
|
||||
|
||||
tx.commit()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn append_message(
|
||||
&self,
|
||||
session_id: &str,
|
||||
@ -607,7 +576,6 @@ impl SessionStore {
|
||||
pub fn compact_active_history(
|
||||
&self,
|
||||
session_id: &str,
|
||||
expected_reset_cutoff_seq: i64,
|
||||
snapshot_end_seq: i64,
|
||||
preserved_system_messages: &[ChatMessage],
|
||||
summary_message: &ChatMessage,
|
||||
@ -616,18 +584,13 @@ impl SessionStore {
|
||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||
let tx = conn.unchecked_transaction()?;
|
||||
|
||||
let current_cutoff = active_reset_cutoff(&tx, session_id)?;
|
||||
if current_cutoff != expected_reset_cutoff_seq {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let current_max_seq: i64 = tx.query_row(
|
||||
"SELECT COALESCE(MAX(seq), 0) FROM messages WHERE session_id = ?1",
|
||||
params![session_id],
|
||||
|row| row.get(0),
|
||||
)?;
|
||||
|
||||
if snapshot_end_seq <= current_cutoff || snapshot_end_seq > current_max_seq {
|
||||
if snapshot_end_seq > current_max_seq {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
@ -660,20 +623,24 @@ impl SessionStore {
|
||||
inserted_count += 1;
|
||||
}
|
||||
|
||||
// Delete all old messages (including delta messages that were just re-inserted)
|
||||
tx.execute(
|
||||
"DELETE FROM messages WHERE session_id = ?1 AND seq <= ?2",
|
||||
params![session_id, current_max_seq],
|
||||
)?;
|
||||
|
||||
tx.execute(
|
||||
"
|
||||
UPDATE sessions
|
||||
SET reset_cutoff_seq = ?2,
|
||||
message_count = message_count + ?3,
|
||||
user_turn_count = ?4,
|
||||
updated_at = ?5,
|
||||
last_active_at = ?5,
|
||||
SET message_count = ?2,
|
||||
user_turn_count = ?3,
|
||||
updated_at = ?4,
|
||||
last_active_at = ?4,
|
||||
archived_at = NULL
|
||||
WHERE id = ?1 AND deleted_at IS NULL
|
||||
",
|
||||
params![
|
||||
session_id,
|
||||
current_max_seq,
|
||||
inserted_count,
|
||||
active_user_turn_count,
|
||||
now,
|
||||
@ -1309,8 +1276,7 @@ impl SessionStore {
|
||||
|
||||
pub fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError> {
|
||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||
let cutoff_seq = active_reset_cutoff(&conn, session_id)?;
|
||||
load_messages_after(&conn, session_id, cutoff_seq)
|
||||
load_messages_after(&conn, session_id, 0)
|
||||
}
|
||||
|
||||
pub fn load_messages_for_topic(&self, topic_id: &str) -> Result<Vec<ChatMessage>, StorageError> {
|
||||
@ -1381,14 +1347,13 @@ impl SessionStore {
|
||||
|
||||
pub fn count_active_user_messages(&self, session_id: &str) -> Result<i64, StorageError> {
|
||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||
let cutoff_seq = active_reset_cutoff(&conn, session_id)?;
|
||||
conn.query_row(
|
||||
"
|
||||
SELECT COUNT(*)
|
||||
FROM messages
|
||||
WHERE session_id = ?1 AND seq > ?2 AND role = 'user'
|
||||
WHERE session_id = ?1 AND role = 'user'
|
||||
",
|
||||
params![session_id, cutoff_seq],
|
||||
params![session_id],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.map_err(StorageError::from)
|
||||
@ -1422,9 +1387,8 @@ fn map_session_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<SessionRecord
|
||||
archived_at: row.get(8)?,
|
||||
deleted_at: row.get(9)?,
|
||||
message_count: row.get(10)?,
|
||||
reset_cutoff_seq: row.get(11)?,
|
||||
user_turn_count: row.get(12)?,
|
||||
agent_prompt_reinjection_count: row.get(13)?,
|
||||
user_turn_count: row.get(11)?,
|
||||
agent_prompt_reinjection_count: row.get(12)?,
|
||||
})
|
||||
}
|
||||
|
||||
@ -1510,13 +1474,6 @@ fn map_scheduler_job_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<Schedul
|
||||
}
|
||||
|
||||
fn ensure_sessions_schema(conn: &Connection) -> Result<(), StorageError> {
|
||||
if !has_column(conn, "sessions", "reset_cutoff_seq")? {
|
||||
add_column_if_missing(
|
||||
conn,
|
||||
"ALTER TABLE sessions ADD COLUMN reset_cutoff_seq INTEGER NOT NULL DEFAULT 0",
|
||||
)?;
|
||||
}
|
||||
|
||||
if !has_column(conn, "sessions", "user_turn_count")? {
|
||||
add_column_if_missing(
|
||||
conn,
|
||||
@ -1643,18 +1600,6 @@ fn add_column_if_missing(conn: &Connection, sql: &str) -> Result<(), StorageErro
|
||||
}
|
||||
}
|
||||
|
||||
fn active_reset_cutoff(conn: &Connection, session_id: &str) -> Result<i64, StorageError> {
|
||||
let cutoff = conn
|
||||
.query_row(
|
||||
"SELECT reset_cutoff_seq FROM sessions WHERE id = ?1 AND deleted_at IS NULL",
|
||||
params![session_id],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.optional()?;
|
||||
|
||||
Ok(cutoff.unwrap_or(0))
|
||||
}
|
||||
|
||||
fn insert_message_with_seq(
|
||||
conn: &rusqlite::Transaction<'_>,
|
||||
session_id: &str,
|
||||
@ -1875,7 +1820,6 @@ mod tests {
|
||||
assert_eq!(session.channel_name, "cli");
|
||||
assert_eq!(session.chat_id, session.id);
|
||||
assert_eq!(session.message_count, 0);
|
||||
assert_eq!(session.reset_cutoff_seq, 0);
|
||||
assert_eq!(session.user_turn_count, 0);
|
||||
assert_eq!(session.agent_prompt_reinjection_count, 0);
|
||||
|
||||
@ -1887,7 +1831,6 @@ mod tests {
|
||||
let stored = store.get_session(&session.id).unwrap().unwrap();
|
||||
assert_eq!(stored.message_count, 2);
|
||||
assert!(stored.archived_at.is_none());
|
||||
assert_eq!(stored.reset_cutoff_seq, 0);
|
||||
assert_eq!(stored.user_turn_count, 1);
|
||||
assert_eq!(stored.agent_prompt_reinjection_count, 0);
|
||||
|
||||
@ -1982,44 +1925,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reset_session_preserves_full_history_and_hides_active_history() {
|
||||
let store = SessionStore::in_memory().unwrap();
|
||||
let session = store.create_cli_session(Some("reset")).unwrap();
|
||||
|
||||
store
|
||||
.append_message(&session.id, &ChatMessage::user("before"))
|
||||
.unwrap();
|
||||
store
|
||||
.append_message(&session.id, &ChatMessage::assistant("context"))
|
||||
.unwrap();
|
||||
store.reset_session(&session.id).unwrap();
|
||||
|
||||
let stored = store.get_session(&session.id).unwrap().unwrap();
|
||||
assert_eq!(stored.reset_cutoff_seq, 2);
|
||||
assert_eq!(stored.user_turn_count, 0);
|
||||
assert_eq!(stored.agent_prompt_reinjection_count, 0);
|
||||
|
||||
let active_messages = store.load_messages(&session.id).unwrap();
|
||||
assert!(active_messages.is_empty());
|
||||
|
||||
let all_messages = store.load_all_messages(&session.id).unwrap();
|
||||
assert_eq!(all_messages.len(), 2);
|
||||
assert_eq!(all_messages[0].content, "before");
|
||||
assert_eq!(all_messages[1].content, "context");
|
||||
|
||||
store
|
||||
.append_message(&session.id, &ChatMessage::user("after"))
|
||||
.unwrap();
|
||||
let active_messages = store.load_messages(&session.id).unwrap();
|
||||
assert_eq!(active_messages.len(), 1);
|
||||
assert_eq!(active_messages[0].content, "after");
|
||||
|
||||
let stored = store.get_session(&session.id).unwrap().unwrap();
|
||||
assert_eq!(stored.user_turn_count, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_schema_migration_adds_reset_cutoff_column() {
|
||||
fn test_schema_migration_adds_user_turn_and_reinjection_columns() {
|
||||
let conn = Connection::open_in_memory().unwrap();
|
||||
conn.execute_batch(
|
||||
"
|
||||
@ -2057,7 +1963,6 @@ mod tests {
|
||||
|
||||
let store = SessionStore::from_connection(conn).unwrap();
|
||||
let session = store.create_cli_session(Some("migrated")).unwrap();
|
||||
assert_eq!(session.reset_cutoff_seq, 0);
|
||||
assert_eq!(session.user_turn_count, 0);
|
||||
assert_eq!(session.agent_prompt_reinjection_count, 0);
|
||||
}
|
||||
@ -2105,42 +2010,6 @@ mod tests {
|
||||
assert!(has_column(&conn, "messages", "reasoning_content").unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_count_active_user_messages_respects_reset_cutoff_seq() {
|
||||
let store = SessionStore::in_memory().unwrap();
|
||||
let session = store.create_cli_session(Some("count-users")).unwrap();
|
||||
|
||||
store
|
||||
.append_message(&session.id, &ChatMessage::system("agent"))
|
||||
.unwrap();
|
||||
store
|
||||
.append_message(&session.id, &ChatMessage::user("u1"))
|
||||
.unwrap();
|
||||
store
|
||||
.append_message(&session.id, &ChatMessage::assistant("a1"))
|
||||
.unwrap();
|
||||
store
|
||||
.append_message(&session.id, &ChatMessage::user("u2"))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 2);
|
||||
|
||||
store.reset_session(&session.id).unwrap();
|
||||
assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 0);
|
||||
|
||||
store
|
||||
.append_message(&session.id, &ChatMessage::system("agent-again"))
|
||||
.unwrap();
|
||||
store
|
||||
.append_message(&session.id, &ChatMessage::user("u3"))
|
||||
.unwrap();
|
||||
store
|
||||
.append_message(&session.id, &ChatMessage::user("u4"))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compact_active_history_rebuilds_active_segment_with_delta_messages() {
|
||||
let store = SessionStore::in_memory().unwrap();
|
||||
@ -2185,7 +2054,6 @@ mod tests {
|
||||
let compacted = store
|
||||
.compact_active_history(
|
||||
&session.id,
|
||||
0,
|
||||
snapshot_end_seq,
|
||||
&preserved_system_messages,
|
||||
&summary_message,
|
||||
@ -2214,11 +2082,10 @@ mod tests {
|
||||
assert_eq!(active_messages[9].content, "a5");
|
||||
|
||||
let stored = store.get_session(&session.id).unwrap().unwrap();
|
||||
assert_eq!(stored.reset_cutoff_seq, 11);
|
||||
assert_eq!(stored.user_turn_count, 4);
|
||||
|
||||
let all_messages = store.load_all_messages(&session.id).unwrap();
|
||||
assert_eq!(all_messages.len(), 21);
|
||||
assert_eq!(all_messages.len(), 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@ -35,12 +35,9 @@ pub trait ConversationRepository: Send + Sync + 'static {
|
||||
|
||||
fn clear_messages(&self, session_id: &str) -> Result<(), StorageError>;
|
||||
|
||||
fn reset_session(&self, session_id: &str) -> Result<(), StorageError>;
|
||||
|
||||
fn compact_active_history(
|
||||
&self,
|
||||
session_id: &str,
|
||||
expected_reset_cutoff_seq: i64,
|
||||
snapshot_end_seq: i64,
|
||||
preserved_system_messages: &[ChatMessage],
|
||||
summary_message: &ChatMessage,
|
||||
@ -185,14 +182,9 @@ impl ConversationRepository for super::SessionStore {
|
||||
super::SessionStore::clear_messages(self, session_id)
|
||||
}
|
||||
|
||||
fn reset_session(&self, session_id: &str) -> Result<(), StorageError> {
|
||||
super::SessionStore::reset_session(self, session_id)
|
||||
}
|
||||
|
||||
fn compact_active_history(
|
||||
&self,
|
||||
session_id: &str,
|
||||
expected_reset_cutoff_seq: i64,
|
||||
snapshot_end_seq: i64,
|
||||
preserved_system_messages: &[ChatMessage],
|
||||
summary_message: &ChatMessage,
|
||||
@ -201,7 +193,6 @@ impl ConversationRepository for super::SessionStore {
|
||||
super::SessionStore::compact_active_history(
|
||||
self,
|
||||
session_id,
|
||||
expected_reset_cutoff_seq,
|
||||
snapshot_end_seq,
|
||||
preserved_system_messages,
|
||||
summary_message,
|
||||
|
||||
@ -23,7 +23,6 @@ pub struct SessionRecord {
|
||||
pub archived_at: Option<i64>,
|
||||
pub deleted_at: Option<i64>,
|
||||
pub message_count: i64,
|
||||
pub reset_cutoff_seq: i64,
|
||||
pub user_turn_count: i64,
|
||||
pub agent_prompt_reinjection_count: i64,
|
||||
}
|
||||
|
||||
@ -24,15 +24,18 @@ impl Tool for SkillManageTool {
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Manage PicoBot skills. Actions: list, get, create, update, delete, disable, reload.\n\n\
|
||||
IMPORTANT: To create or modify skills, ALWAYS use this tool (skill_manage), NOT the write tool.\n\n\
|
||||
Skill Structure:\n\
|
||||
- Folder name: kebab-case (lowercase with hyphens, e.g., 'my-cool-skill')\n\
|
||||
- Required: SKILL.md with YAML frontmatter + Markdown body\n\
|
||||
- Optional folders: scripts/, references/, assets/\n\
|
||||
- Storage: .picobot/skills/{name}/SKILL.md or ~/.picobot/skills/{name}/SKILL.md\n\n\
|
||||
- Storage paths (created automatically by this tool):\n\
|
||||
- Project scope: {current-dir}/.picobot/skills/{name}/SKILL.md\n\
|
||||
- User scope: ~/.picobot/skills/{name}/SKILL.md\n\n\
|
||||
Installing from Zip:\n\
|
||||
- Extract skill folders to skills/ directory\n\
|
||||
- Extract skill folders to .picobot/skills/ directory (NOT skills/)\n\
|
||||
- If zip contains multiple skills, extract each subfolder separately\n\
|
||||
- Final structure: skills/{skill-name}/SKILL.md"
|
||||
- Final structure: .picobot/skills/{skill-name}/SKILL.md"
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
@ -51,7 +54,7 @@ impl Tool for SkillManageTool {
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Skill folder name in kebab-case (e.g., 'my-cool-skill', 'code-review'). Must match the folder name under .picobot/skills/ or ~/.picobot/skills/"
|
||||
"description": "Skill folder name in kebab-case (e.g., 'my-cool-skill', 'code-review'). The skill_manage tool automatically creates files at .picobot/skills/{name}/SKILL.md (project scope) or ~/.picobot/skills/{name}/SKILL.md (user scope)."
|
||||
},
|
||||
"names": {
|
||||
"type": "array",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user