Compare commits

..

No commits in common. "b4ef56803f03e99a3fc1f857de88d1ed8d44c2b0" and "c36650c9aa66c67f4ee0dd5580f49b752dae8671" have entirely different histories.

15 changed files with 46 additions and 1609 deletions

View File

@ -488,52 +488,46 @@ tools 配置示例:
可用工具名称:
- calculator - 数学计算器
- get_time - 获取当前时间
- read - 读取文件
- write - 写入文件
- edit - 编辑文件
- file_read - 读取文件
- file_write - 写入文件
- file_edit - 编辑文件
- memory_search - 搜索长期记忆
- memory_manage - 管理长期记忆
- send_session_message - 发送会话消息
- session_send - 发送会话消息
- scheduler_manage - 管理定时任务
- skill_activate - 激活技能
- skill_manage - 管理技能(含 list 功能)
- bash - 执行 shell 命令Unix/Linux/macOS
- shell - 执行 shell 命令Windows PowerShell/Cmd
- skill_list - 列出技能
- skill_manage - 管理技能
- bash - 执行 shell 命令
- http_request - HTTP 请求
- web_fetch - 网页抓取
- task - 创建和管理子代理
注意bash 和 shell 是同一个工具在不同平台上的名称,运行时自动检测。
## 8. 工具机制
PicoBot 的 Agent 是围绕工具调用构建的。当前默认注册的工具包括:
- calculator简单数学计算
- get_time获取当前时间与时区上下文
- read读取文件
- write写文件
- edit编辑文件
- time获取当前时间与时区上下文
- file_read读取文件
- file_write写文件
- file_edit编辑文件
- memory_search读取长期记忆
- memory_manage写入 / 更新 / 删除长期记忆
- send_session_message发送会话消息
- scheduler_manage管理调度任务
- skill_activate读取并激活某个技能内容
- skill_manage管理技能支持 list, get, create, update, delete, disable, reload
- bash / shell执行 shell 命令同一工具Unix 下名称为 bashWindows 下名称为 shell
- skill_list列出技能
- skill_manage管理技能
- bash执行 shell 命令
- http_request发起 HTTP 请求
- web_fetch抓取网页正文
- task创建和管理子代理
其中:
- read / write / edit 文件工具适合做代码库和文档操作
- 文件工具适合做代码库和文档操作
- 记忆工具适合维持长期用户画像
- scheduler_manage 允许 Agent 自主创建后续计划任务
- skill_activate 负责把具体技能正文注入当前任务上下文
- skill_manage 整合了技能列出与管理功能,支持运行时创建、更新、删除和批量禁用
- bash / shell / http_request / web_fetch 让 Agent 具备更强的外部交互能力bash 和 shell 是同一工具在不同平台的名称)
- task 允许 Agent 创建独立上下文的子代理来处理复杂多步骤任务,支持 general 和 explore 两种类型
- bash / http_request / web_fetch 让 Agent 具备更强的外部交互能力
## 9. 调度器机制

View File

@ -1,919 +0,0 @@
use std::collections::HashMap;
use std::path::PathBuf;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use crate::config::{
AgentConfig, ChannelConfig, Config, FeishuChannelConfig, GatewayConfig, ModelConfig,
ProviderConfig, SchedulerConfig, TaggedChannelConfig, WechatChannelConfig,
};
/// Interactive configuration wizard for PicoBot
pub struct InitWizard {
read: BufReader<tokio::io::Stdin>,
write: tokio::io::Stdout,
config_path: PathBuf,
}
impl InitWizard {
pub fn new() -> Self {
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
Self {
read: BufReader::new(tokio::io::stdin()),
write: tokio::io::stdout(),
config_path: home.join(".picobot").join("config.json"),
}
}
pub async fn run(&mut self, force: bool, skip_channels: bool) -> Result<(), InitError> {
let existing = self.load_existing_config(force)?;
self.show_welcome_message();
// Step 1: Provider
let providers = self.configure_provider(&existing).await?;
// Step 2: Model
let models = self.configure_model(&existing).await?;
// Step 3: Agent
let agents = self.configure_agent(&existing, &providers, &models).await?;
// Step 4: Channels
let channels = if !skip_channels {
self.configure_channels(&existing).await?
} else {
existing.channels.clone()
};
let config = self.build_config(providers, models, agents, channels, &existing);
self.save_config(&config)?;
self.show_completion_message(&config);
Ok(())
}
fn load_existing_config(&self, force: bool) -> Result<Config, InitError> {
if self.config_path.exists() && !force {
Config::load_default().or_else(|_| Ok(Self::empty_config()))
} else {
Ok(Self::empty_config())
}
}
fn empty_config() -> Config {
Config {
providers: HashMap::new(),
models: HashMap::new(),
agents: HashMap::new(),
time: crate::config::TimeConfig::default(),
gateway: GatewayConfig::default(),
scheduler: SchedulerConfig::default(),
client: crate::config::ClientConfig::default(),
channels: HashMap::new(),
skills: crate::config::SkillsConfig::default(),
tools: crate::config::ToolsConfig::default(),
memory_maintenance: crate::config::MemoryMaintenanceConfig::default(),
}
}
fn show_welcome_message(&mut self) {
println!();
println!("╔══════════════════════════════════════════════════════╗");
println!("║ PicoBot Configuration Wizard ║");
println!("╚══════════════════════════════════════════════════════╝");
println!();
println!("This wizard will help you configure PicoBot.");
println!("Press Enter to use default values where shown.");
println!();
}
// ==================== Prompt Helpers ====================
async fn prompt_with_default(
&mut self,
label: &str,
default: &str,
) -> Result<String, InitError> {
if !default.is_empty() {
println!("{} [default: {}]: ", label, default);
} else {
println!("{}: ", label);
}
self.write.flush().await?;
let mut line = String::new();
let bytes_read = self.read.read_line(&mut line).await?;
if bytes_read == 0 {
return Err(InitError::InputError("EOF reached".to_string()));
}
let input = line.trim().to_string();
Ok(if input.is_empty() { default.to_string() } else { input })
}
async fn prompt_required(&mut self, label: &str) -> Result<String, InitError> {
loop {
println!("{}: ", label);
self.write.flush().await?;
let mut line = String::new();
let bytes_read = self.read.read_line(&mut line).await?;
if bytes_read == 0 {
return Err(InitError::InputError("EOF reached".to_string()));
}
let input = line.trim().to_string();
if !input.is_empty() {
return Ok(input);
}
println!("{} is required. Please enter a value.", label);
}
}
async fn prompt_select(
&mut self,
label: &str,
options: &[String],
default: usize,
) -> Result<usize, InitError> {
for (i, opt) in options.iter().enumerate() {
println!(" {}. {}", i + 1, opt);
}
println!();
let default_str = (default + 1).to_string();
let input = self.prompt_with_default(label, &default_str).await?;
let selected: usize = input.parse().map_err(|_| {
InitError::InputError(format!("Invalid selection: {}", input))
})?;
if selected == 0 || selected > options.len() {
return Err(InitError::InputError(format!(
"Selection must be between 1 and {}",
options.len()
)));
}
Ok(selected - 1)
}
// ==================== Step 1: Provider ====================
async fn configure_provider(
&mut self,
existing: &Config,
) -> Result<HashMap<String, ProviderConfig>, InitError> {
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
println!("Step 1: Configure Provider (API Endpoint)");
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
println!();
let provider_names: Vec<String> = existing.providers.keys().cloned().collect();
if !provider_names.is_empty() {
println!("Existing providers:");
for (i, name) in provider_names.iter().enumerate() {
println!(" {}. {}", i + 1, name);
}
println!();
println!("Options:");
println!(" 1. Add new provider");
println!(" 2. Modify existing provider");
println!(" 3. Use existing provider");
println!(" 4. Skip");
println!();
let choice = self
.prompt_with_default("Select option", "1")
.await?;
match choice.as_str() {
"1" => return self.add_provider(existing).await,
"2" => return self.modify_provider(existing).await,
"3" => {
println!("Keeping existing providers.");
return Ok(existing.providers.clone());
}
"4" => {
println!("Skipping provider configuration.");
return Ok(existing.providers.clone());
}
_ => {
println!("Invalid option, adding new provider.");
return self.add_provider(existing).await;
}
}
} else {
println!("No existing providers found.");
println!();
println!("Options:");
println!(" 1. Add new provider");
println!(" 2. Skip");
println!();
let choice = self.prompt_with_default("Select option", "1").await?;
match choice.as_str() {
"1" => self.add_provider(existing).await,
"2" => {
println!("Skipping provider configuration.");
Ok(existing.providers.clone())
}
_ => {
println!("Invalid option, adding new provider.");
self.add_provider(existing).await
}
}
}
}
async fn add_provider(
&mut self,
existing: &Config,
) -> Result<HashMap<String, ProviderConfig>, InitError> {
let provider_name = self
.prompt_with_default("Provider name", "default")
.await?;
println!("Provider type:");
println!(" 1. openai");
println!(" 2. anthropic");
println!();
let type_choice = self.prompt_with_default("Select type", "1").await?;
let provider_type = match type_choice.as_str() {
"1" => "openai",
"2" => "anthropic",
_ => "openai",
};
let base_url = self.prompt_required("API base URL").await?;
println!();
println!("API Key is required and will be stored in config.json.");
let api_key = self.prompt_required("API Key").await?;
let provider = ProviderConfig {
provider_type: provider_type.to_string(),
base_url,
api_key,
extra_headers: HashMap::new(),
llm_timeout_secs: 120,
memory_maintenance_timeout_secs: 600,
};
let mut providers = existing.providers.clone();
providers.insert(provider_name.clone(), provider);
println!();
println!("Provider '{}' configured.", provider_name);
Ok(providers)
}
async fn modify_provider(
&mut self,
existing: &Config,
) -> Result<HashMap<String, ProviderConfig>, InitError> {
// Select which provider to modify
let provider_names: Vec<String> = existing.providers.keys().cloned().collect();
println!("Select provider to modify:");
let provider_idx = self.prompt_select("", &provider_names, 0).await?;
let selected_provider_name = &provider_names[provider_idx];
let current_provider = existing.providers.get(selected_provider_name).unwrap();
println!();
println!(
"Current config: type={}, base_url={}",
current_provider.provider_type, current_provider.base_url
);
println!();
let current_type_idx = if current_provider.provider_type == "anthropic" {
1
} else {
0
};
let type_options = vec!["openai".to_string(), "anthropic".to_string()];
println!("Provider type:");
let type_idx = self.prompt_select("", &type_options, current_type_idx).await?;
let provider_type = &type_options[type_idx];
let base_url = self
.prompt_with_default("API base URL", &current_provider.base_url)
.await?;
println!();
println!("API Key (press Enter to keep current key):");
let api_key = self
.prompt_with_default("API Key", &current_provider.api_key)
.await?;
let provider = ProviderConfig {
provider_type: provider_type.clone(),
base_url,
api_key,
extra_headers: current_provider.extra_headers.clone(),
llm_timeout_secs: current_provider.llm_timeout_secs,
memory_maintenance_timeout_secs: current_provider.memory_maintenance_timeout_secs,
};
let mut providers = existing.providers.clone();
providers.insert(selected_provider_name.clone(), provider);
println!();
println!("Provider '{}' modified.", selected_provider_name);
Ok(providers)
}
// ==================== Step 2: Model ====================
async fn configure_model(
&mut self,
existing: &Config,
) -> Result<HashMap<String, ModelConfig>, InitError> {
println!();
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
println!("Step 2: Configure Model");
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
println!();
let model_names: Vec<String> = existing.models.keys().cloned().collect();
if !model_names.is_empty() {
println!("Existing models:");
for (i, name) in model_names.iter().enumerate() {
let model = existing.models.get(name).unwrap();
println!(" {}. {} (model_id: {})", i + 1, name, model.model_id);
}
println!();
println!("Options:");
println!(" 1. Use existing model");
println!(" 2. Add new model");
println!(" 3. Skip");
println!();
let choice = self.prompt_with_default("Select option", "1").await?;
match choice.as_str() {
"1" => {
println!("Keeping existing models.");
return Ok(existing.models.clone());
}
"2" => return self.add_model(existing).await,
"3" => {
println!("Skipping model configuration.");
return Ok(existing.models.clone());
}
_ => {
println!("Invalid option, keeping existing models.");
return Ok(existing.models.clone());
}
}
} else {
println!("No existing models found.");
println!();
println!("Options:");
println!(" 1. Add new model");
println!(" 2. Skip");
println!();
let choice = self.prompt_with_default("Select option", "1").await?;
match choice.as_str() {
"1" => self.add_model(existing).await,
"2" => {
println!("Skipping model configuration.");
Ok(existing.models.clone())
}
_ => self.add_model(existing).await,
}
}
}
async fn add_model(
&mut self,
existing: &Config,
) -> Result<HashMap<String, ModelConfig>, InitError> {
let model_name = self
.prompt_with_default("Model name (reference key)", "default")
.await?;
let model_id = self.prompt_required("Model ID (actual identifier)").await?;
let temperature_str = self
.prompt_with_default("Temperature (0.0-1.0)", "0.7")
.await?;
let temperature: Option<f32> = temperature_str
.parse()
.ok()
.filter(|v| (0.0..=1.0).contains(v));
let max_tokens_str = self
.prompt_with_default("Max tokens (optional)", "4096")
.await?;
let max_tokens: Option<u32> = max_tokens_str.parse().ok();
let model = ModelConfig {
model_id,
temperature,
max_tokens,
context_window_tokens: Some(128000),
extra: HashMap::new(),
};
let mut models = existing.models.clone();
models.insert(model_name.clone(), model);
println!();
println!("Model '{}' configured.", model_name);
Ok(models)
}
// ==================== Step 3: Agent ====================
async fn configure_agent(
&mut self,
existing: &Config,
providers: &HashMap<String, ProviderConfig>,
models: &HashMap<String, ModelConfig>,
) -> Result<HashMap<String, AgentConfig>, InitError> {
println!();
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
println!("Step 3: Configure Agent");
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
println!();
// Check if we have providers and models to select from
if providers.is_empty() {
println!("Warning: No providers configured. Please configure a provider first.");
return Ok(existing.agents.clone());
}
if models.is_empty() {
println!("Warning: No models configured. Please configure a model first.");
return Ok(existing.agents.clone());
}
let agent_names: Vec<String> = existing.agents.keys().cloned().collect();
if !agent_names.is_empty() {
println!("Existing agents:");
for (i, name) in agent_names.iter().enumerate() {
let agent = existing.agents.get(name).unwrap();
println!(
" {}. {} (provider: {}, model: {})",
i + 1,
name,
agent.provider,
agent.model
);
}
println!();
println!("Options:");
println!(" 1. Add new agent");
println!(" 2. Modify existing agent");
println!(" 3. Use existing agent");
println!(" 4. Skip");
println!();
let choice = self.prompt_with_default("Select option", "1").await?;
match choice.as_str() {
"1" => return self.add_agent(existing, providers, models).await,
"2" => return self.modify_agent(existing, providers, models).await,
"3" => {
println!("Keeping existing agents.");
return Ok(existing.agents.clone());
}
"4" => {
println!("Skipping agent configuration.");
return Ok(existing.agents.clone());
}
_ => {
println!("Invalid option, adding new agent.");
return self.add_agent(existing, providers, models).await;
}
}
} else {
println!("No existing agents found.");
println!();
println!("Options:");
println!(" 1. Add new agent");
println!(" 2. Skip");
println!();
let choice = self.prompt_with_default("Select option", "1").await?;
match choice.as_str() {
"1" => self.add_agent(existing, providers, models).await,
"2" => {
println!("Skipping agent configuration.");
Ok(existing.agents.clone())
}
_ => self.add_agent(existing, providers, models).await,
}
}
}
async fn add_agent(
&mut self,
existing: &Config,
providers: &HashMap<String, ProviderConfig>,
models: &HashMap<String, ModelConfig>,
) -> Result<HashMap<String, AgentConfig>, InitError> {
let agent_name = self
.prompt_with_default("Agent name", "default")
.await?;
// Select provider
let provider_names: Vec<String> = providers.keys().cloned().collect();
println!("Select provider:");
let provider_idx = self.prompt_select("", &provider_names, 0).await?;
let selected_provider = &provider_names[provider_idx];
// Select model (independently)
let model_names: Vec<String> = models.keys().cloned().collect();
println!();
println!("Select model:");
let model_idx = self.prompt_select("", &model_names, 0).await?;
let selected_model = &model_names[model_idx];
let agent = AgentConfig {
provider: selected_provider.clone(),
model: selected_model.clone(),
max_tool_iterations: 100,
tool_result_max_chars: 20000,
context_tool_result_trim_chars: 2000,
};
let mut agents = existing.agents.clone();
agents.insert(agent_name.clone(), agent);
println!();
println!(
"Agent '{}' configured (provider: {}, model: {}).",
agent_name, selected_provider, selected_model
);
Ok(agents)
}
async fn modify_agent(
&mut self,
existing: &Config,
providers: &HashMap<String, ProviderConfig>,
models: &HashMap<String, ModelConfig>,
) -> Result<HashMap<String, AgentConfig>, InitError> {
// Select which agent to modify
let agent_names: Vec<String> = existing.agents.keys().cloned().collect();
println!("Select agent to modify:");
let agent_idx = self.prompt_select("", &agent_names, 0).await?;
let selected_agent_name = &agent_names[agent_idx];
let current_agent = existing.agents.get(selected_agent_name).unwrap();
println!();
println!(
"Current config: provider={}, model={}",
current_agent.provider, current_agent.model
);
println!();
// Select new provider
let provider_names: Vec<String> = providers.keys().cloned().collect();
let current_provider_idx = provider_names
.iter()
.position(|p| p == &current_agent.provider)
.unwrap_or(0);
println!("Select provider:");
let provider_idx = self.prompt_select("", &provider_names, current_provider_idx).await?;
let selected_provider = &provider_names[provider_idx];
// Select new model
let model_names: Vec<String> = models.keys().cloned().collect();
let current_model_idx = model_names
.iter()
.position(|m| m == &current_agent.model)
.unwrap_or(0);
println!();
println!("Select model:");
let model_idx = self.prompt_select("", &model_names, current_model_idx).await?;
let selected_model = &model_names[model_idx];
let agent = AgentConfig {
provider: selected_provider.clone(),
model: selected_model.clone(),
max_tool_iterations: current_agent.max_tool_iterations,
tool_result_max_chars: current_agent.tool_result_max_chars,
context_tool_result_trim_chars: current_agent.context_tool_result_trim_chars,
};
let mut agents = existing.agents.clone();
agents.insert(selected_agent_name.clone(), agent);
println!();
println!(
"Agent '{}' modified (provider: {}, model: {}).",
selected_agent_name, selected_provider, selected_model
);
Ok(agents)
}
// ==================== Step 4: Channels ====================
async fn configure_channels(
&mut self,
existing: &Config,
) -> Result<HashMap<String, ChannelConfig>, InitError> {
println!();
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
println!("Step 4: Configure Channels");
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
println!();
// Show existing channels
if !existing.channels.is_empty() {
println!("Existing channels:");
for (name, config) in &existing.channels {
let status = if config.enabled() { "enabled" } else { "disabled" };
println!(" - {} ({})", name, status);
}
println!();
}
let mut channels = existing.channels.clone();
// Loop for configuring multiple channels
loop {
println!("Options:");
println!(" 1. Add Feishu channel");
println!(" 2. Add WeChat channel");
println!(" 3. Skip / Done");
println!();
let choice = self.prompt_with_default("Select option", "3").await?;
match choice.as_str() {
"1" => {
let channel_config = self.configure_feishu_channel(&channels).await?;
for (name, config) in channel_config {
channels.insert(name, config);
}
}
"2" => {
let channel_config = self.configure_wechat_channel(&channels).await?;
for (name, config) in channel_config {
channels.insert(name, config);
}
}
"3" => break,
_ => {
println!("Invalid selection.");
continue;
}
}
}
Ok(channels)
}
async fn configure_feishu_channel(
&mut self,
existing: &HashMap<String, ChannelConfig>,
) -> Result<HashMap<String, ChannelConfig>, InitError> {
println!();
println!("Configuring Feishu channel...");
println!();
let channel_name = self
.prompt_with_default("Channel name", "feishu")
.await?;
let _existing_config = existing.get(&channel_name).and_then(|c| c.as_feishu());
let app_id = self.prompt_required("Feishu App ID").await?;
let app_secret = self.prompt_required("Feishu App Secret").await?;
let config = ChannelConfig::Tagged(TaggedChannelConfig::Feishu(FeishuChannelConfig {
enabled: true,
app_id,
app_secret,
allow_from: vec!["*".to_string()],
agent: "default".to_string(),
media_dir: Self::default_feishu_media_dir(),
reaction_emoji: "Typing".to_string(),
max_message_chars: 20000,
reply_context_max_chars: 20000,
}));
let mut result = HashMap::new();
result.insert(channel_name.clone(), config);
println!();
println!("Feishu channel '{}' configured.", channel_name);
Ok(result)
}
async fn configure_wechat_channel(
&mut self,
_existing: &HashMap<String, ChannelConfig>,
) -> Result<HashMap<String, ChannelConfig>, InitError> {
println!();
println!("Configuring WeChat channel...");
println!();
println!("WeChat login requires scanning a QR code.");
println!("The QR code URL will be displayed after configuration.");
println!();
// Use default values directly
let channel_name = "wechat";
let base_url = "https://ilinkai.weixin.qq.com";
let cred_path = Self::default_wechat_cred_path();
let force_login = false;
let config = ChannelConfig::Tagged(TaggedChannelConfig::Wechat(WechatChannelConfig {
enabled: true,
allow_from: vec!["*".to_string()],
agent: "default".to_string(),
base_url: base_url.to_string(),
cred_path,
force_login,
}));
let mut result = HashMap::new();
result.insert(channel_name.to_string(), config);
println!();
println!("WeChat channel '{}' configured.", channel_name);
// Auto login after configuration
println!();
println!("Starting WeChat login...");
self.do_wechat_login(base_url, &Self::default_wechat_cred_path()).await?;
println!();
println!("WeChat login successful! Credentials saved.");
Ok(result)
}
async fn do_wechat_login(&mut self, base_url: &str, cred_path: &str) -> Result<(), InitError> {
use wechatbot::{BotOptions, WeChatBot};
let bot = WeChatBot::new(BotOptions {
base_url: Some(base_url.to_string()),
cred_path: Some(cred_path.to_string()),
on_qr_url: Some(Box::new(|url| {
println!();
println!("┌──────────────────────────────────────────────────────┐");
println!("│ WeChat QR Code │");
println!("└──────────────────────────────────────────────────────┘");
println!();
println!("Scan this URL in WeChat:");
println!("{}", url);
println!();
println!("Waiting for confirmation...");
})),
on_error: Some(Box::new(|error| {
eprintln!("WeChat login error: {}", error);
})),
});
let creds = bot.login(true).await.map_err(|e| {
InitError::WeChatError(format!("WeChat login failed: {}", e))
})?;
println!();
println!(
"Logged in as: {} (account: {})",
creds.user_id, creds.account_id
);
Ok(())
}
// ==================== Build & Save ====================
fn build_config(
&self,
providers: HashMap<String, ProviderConfig>,
models: HashMap<String, ModelConfig>,
agents: HashMap<String, AgentConfig>,
channels: HashMap<String, ChannelConfig>,
existing: &Config,
) -> Config {
Config {
providers,
models,
agents,
channels,
time: existing.time.clone(),
gateway: existing.gateway.clone(),
scheduler: existing.scheduler.clone(),
client: existing.client.clone(),
skills: existing.skills.clone(),
tools: existing.tools.clone(),
memory_maintenance: existing.memory_maintenance.clone(),
}
}
fn save_config(&self, config: &Config) -> Result<(), InitError> {
let dir = self
.config_path
.parent()
.ok_or_else(|| InitError::IoError("Invalid config path".to_string()))?;
std::fs::create_dir_all(dir)
.map_err(|e| InitError::IoError(format!("Failed to create directory: {}", e)))?;
let content = serde_json::to_string_pretty(config)
.map_err(|e| InitError::SerializeError(e.to_string()))?;
std::fs::write(&self.config_path, content)
.map_err(|e| InitError::IoError(format!("Failed to write config: {}", e)))?;
Ok(())
}
fn show_completion_message(&mut self, config: &Config) {
println!();
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
println!("Configuration complete!");
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
println!();
println!("Config saved to: {}", self.config_path.display());
println!();
println!("Summary:");
println!(" Providers: {}", config.providers.len());
println!(" Models: {}", config.models.len());
println!(" Agents: {}", config.agents.len());
println!(" Channels: {}", config.channels.len());
println!();
println!("Next steps:");
println!(" 1. Start the gateway: picobot gateway");
println!(" 2. Connect with CLI: picobot agent");
println!();
println!("For more options, run: picobot --help");
println!();
}
fn default_feishu_media_dir() -> String {
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
home.join(".picobot/media/feishu")
.to_string_lossy()
.to_string()
}
fn default_wechat_cred_path() -> String {
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
home.join(".picobot/wechat/credentials.json")
.to_string_lossy()
.to_string()
}
}
impl Default for InitWizard {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub enum InitError {
IoError(String),
SerializeError(String),
InputError(String),
WeChatError(String),
}
impl std::fmt::Display for InitError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
InitError::IoError(msg) => write!(f, "IO error: {}", msg),
InitError::SerializeError(msg) => write!(f, "Serialization error: {}", msg),
InitError::InputError(msg) => write!(f, "Input error: {}", msg),
InitError::WeChatError(msg) => write!(f, "WeChat error: {}", msg),
}
}
}
impl std::error::Error for InitError {}
impl From<std::io::Error> for InitError {
fn from(e: std::io::Error) -> Self {
InitError::IoError(e.to_string())
}
}

View File

@ -1,7 +1,5 @@
pub mod channel;
pub mod input;
pub mod init;
pub use channel::CliChannel;
pub use input::{InputCommand, InputEvent, InputHandler};
pub use init::InitWizard;

View File

@ -27,8 +27,6 @@ pub struct Config {
pub skills: SkillsConfig,
#[serde(default)]
pub tools: ToolsConfig,
#[serde(default)]
pub memory_maintenance: MemoryMaintenanceConfig,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
@ -110,41 +108,6 @@ pub struct ToolsConfig {
pub task: TaskConfig,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MemoryMaintenanceConfig {
/// 单次最大合并/删除比例 (0.0-1.0),默认 0.3 (30%)
#[serde(default = "default_max_merge_ratio")]
pub max_merge_ratio: f32,
/// 最小保留记忆数量,默认 5
#[serde(default = "default_min_memories_to_keep")]
pub min_memories_to_keep: usize,
/// 单次合并最大源记忆数,默认 3
#[serde(default = "default_max_merge_per_group")]
pub max_merge_per_group: usize,
}
fn default_max_merge_ratio() -> f32 {
0.3
}
fn default_min_memories_to_keep() -> usize {
5
}
fn default_max_merge_per_group() -> usize {
3
}
impl Default for MemoryMaintenanceConfig {
fn default() -> Self {
Self {
max_merge_ratio: default_max_merge_ratio(),
min_memories_to_keep: default_min_memories_to_keep(),
max_merge_per_group: default_max_merge_per_group(),
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct TaskConfig {
#[serde(default = "default_task_enabled")]

View File

@ -5,7 +5,7 @@ use std::time::Duration;
use serde::{Deserialize, Serialize};
use crate::agent::AgentError;
use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig};
use crate::config::LLMProviderConfig;
use crate::providers::{ChatCompletionRequest, Message, ProviderRuntimeConfig, create_provider};
use crate::storage::{MemoryRecord, SessionStore};
@ -17,16 +17,12 @@ const MEMORY_MAINTENANCE_STEP2_SYSTEM_PROMPT: &str =
include_str!("memory_maintenance_step2_system_prompt.md");
const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000];
const META_NAMESPACE: &str = "_meta";
const LAST_MAINTENANCE_KEY: &str = "last_maintenance_at";
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct MemoryMaintenanceCandidate {
pub(crate) id: String,
pub(crate) namespace: String,
pub(crate) key: String,
pub(crate) content: String,
pub(crate) updated_at: i64, // 记忆更新时间Unix timestamp
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
@ -77,19 +73,13 @@ pub(crate) struct MemoryMaintenanceScopeResult {
pub(crate) struct MemoryMaintenanceService {
store: Arc<SessionStore>,
provider_config: LLMProviderConfig,
maintenance_config: MemoryMaintenanceConfig,
}
impl MemoryMaintenanceService {
pub(crate) fn new(
store: Arc<SessionStore>,
provider_config: LLMProviderConfig,
maintenance_config: MemoryMaintenanceConfig,
) -> Self {
pub(crate) fn new(store: Arc<SessionStore>, provider_config: LLMProviderConfig) -> Self {
Self {
store,
provider_config,
maintenance_config,
}
}
@ -117,12 +107,6 @@ impl MemoryMaintenanceService {
&self,
scope_key: &str,
) -> Result<Option<MemoryMaintenancePlan>, AgentError> {
// 新增:检查是否有新记忆需要整理
if !has_new_memories_since_last_maintenance(&self.store, scope_key)? {
tracing::info!(scope_key = %scope_key, "No new memories since last maintenance, skipping");
return Ok(None);
}
let memories = self
.store
.list_memories_for_scope("user", scope_key)
@ -271,15 +255,7 @@ impl MemoryMaintenanceService {
let organize_output = self.organize_plan(scope_key, &plan).await?;
// 应用整理结果merge和delete
apply_memory_maintenance_output(
self.store.as_ref(),
scope_key,
&plan,
&organize_output,
self.maintenance_config.max_merge_ratio,
self.maintenance_config.min_memories_to_keep,
self.maintenance_config.max_merge_per_group,
)?;
apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &organize_output)?;
// 步骤2从数据库重新读取剩余的记忆
let remaining_memories = self
@ -494,94 +470,17 @@ impl MemoryMaintenanceService {
let organize_output = self.organize_plan(scope_key, &plan).await?;
// 应用整理结果merge和delete
apply_memory_maintenance_output(
self.store.as_ref(),
scope_key,
&plan,
&organize_output,
self.maintenance_config.max_merge_ratio,
self.maintenance_config.min_memories_to_keep,
self.maintenance_config.max_merge_per_group,
)?;
apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &organize_output)?;
Ok(Some(organize_output))
}
}
/// 获取上次整理时间
fn get_last_maintenance_time(
store: &SessionStore,
scope_key: &str,
) -> Result<Option<i64>, crate::storage::StorageError> {
let meta = store.get_memory("user", scope_key, META_NAMESPACE, LAST_MAINTENANCE_KEY)?;
Ok(meta.and_then(|m| m.content.parse::<i64>().ok()))
}
/// 记录本次整理时间
fn set_last_maintenance_time(
store: &SessionStore,
scope_key: &str,
time: i64,
) -> Result<(), crate::storage::StorageError> {
store.put_memory(&crate::storage::MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: scope_key.to_string(),
namespace: META_NAMESPACE.to_string(),
memory_key: LAST_MAINTENANCE_KEY.to_string(),
content: time.to_string(),
source_type: "memory_maintenance".to_string(),
source_session_id: None,
source_message_id: None,
source_message_seq: None,
source_channel_name: None,
source_chat_id: None,
})?;
Ok(())
}
/// 检查是否有需要整理的新记忆(过滤掉 _meta namespace
fn has_new_memories_since_last_maintenance(
store: &SessionStore,
scope_key: &str,
) -> Result<bool, AgentError> {
let memories = store
.list_memories_for_scope("user", scope_key)
.map_err(|e| AgentError::Other(format!("list memories error: {}", e)))?;
// 过滤掉 _meta namespace 的记忆
let user_memories: Vec<_> = memories
.iter()
.filter(|m| m.namespace != META_NAMESPACE)
.collect();
if user_memories.is_empty() {
return Ok(false); // 没有记忆,跳过
}
// 获取上次整理时间
let last_time = get_last_maintenance_time(store, scope_key)
.map_err(|e| AgentError::Other(format!("get last maintenance time error: {}", e)))?;
match last_time {
None => Ok(true), // 从未整理过,需要整理
Some(last) => {
// 检查是否有记忆在上次整理后更新
let has_new = user_memories.iter().any(|m| m.updated_at > last);
Ok(has_new)
}
}
}
pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> MemoryMaintenancePlan {
let mut plan = MemoryMaintenancePlan::default();
let mut seen = HashSet::new();
for memory in memories {
// 过滤掉 _meta namespace 的记忆
if memory.namespace == META_NAMESPACE {
continue;
}
let normalized_content = memory.content.trim();
if normalized_content.is_empty() {
continue;
@ -602,7 +501,6 @@ pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> Memory
namespace: memory.namespace.clone(),
key: memory.memory_key.clone(),
content: normalized_content.to_string(),
updated_at: memory.updated_at,
};
plan.candidates.push(candidate);
@ -682,115 +580,12 @@ pub(crate) fn extract_json_object(content: &str) -> Option<&str> {
None
}
/// 验证记忆整理输出是否符合限制
pub(crate) fn validate_memory_maintenance_output(
plan: &MemoryMaintenancePlan,
output: &MemoryOrganizationOutput,
max_merge_ratio: f32,
min_memories_to_keep: usize,
max_merge_per_group: usize,
) -> Result<(), String> {
let total = plan.candidates.len();
if total == 0 {
return Ok(()); // 没有候选,无需验证
}
// 验证 1: 单次合并数量限制
for merge in &output.merges {
if merge.source_ids.len() > max_merge_per_group {
return Err(format!(
"合并组过大: {} 条源记忆超过上限 {}",
merge.source_ids.len(),
max_merge_per_group
));
}
}
// 验证 2: 跨 namespace 合并检测(完全禁止)
let candidates_by_id: HashMap<&str, &MemoryMaintenanceCandidate> = plan
.candidates
.iter()
.map(|c| (c.id.as_str(), c))
.collect();
for merge in &output.merges {
let source_namespaces: HashSet<&str> = merge
.source_ids
.iter()
.filter_map(|id| candidates_by_id.get(id.as_str()).map(|c| c.namespace.as_str()))
.collect();
// 检查是否跨越多个 namespace
if source_namespaces.len() > 1 {
return Err(format!(
"跨 namespace 合并被禁止: 源来自 {}",
source_namespaces.iter().cloned().collect::<Vec<_>>().join(", ")
));
}
// 检查目标 namespace 是否与源一致
if let Some(src_ns) = source_namespaces.iter().next() {
if *src_ns != merge.namespace {
return Err(format!(
"跨 namespace 合并被禁止: {} → {}",
src_ns, merge.namespace
));
}
}
}
// 验证 3: 总体合并比例
let merged_ids: HashSet<&str> = output
.merges
.iter()
.flat_map(|m| m.source_ids.iter())
.map(|s| s.as_str())
.collect();
let deleted_ids: HashSet<&str> = output
.low_value_ids
.iter()
.map(|s| s.as_str())
.collect();
let affected = merged_ids.len() + deleted_ids.len();
let max_allowed = (total as f32 * max_merge_ratio).ceil() as usize;
if affected > max_allowed {
return Err(format!(
"合并比例超限: {} / {} > {:.0}%",
affected,
total,
max_merge_ratio * 100.0
));
}
// 验证 4: 最小保留数
let remaining = total - affected + output.merges.len();
if remaining < min_memories_to_keep {
return Err(format!(
"保留数不足: {} < {}",
remaining, min_memories_to_keep
));
}
Ok(())
}
pub(crate) fn apply_memory_maintenance_output(
store: &SessionStore,
scope_key: &str,
plan: &MemoryMaintenancePlan,
output: &MemoryOrganizationOutput,
max_merge_ratio: f32,
min_memories_to_keep: usize,
max_merge_per_group: usize,
) -> Result<(), AgentError> {
// 新增: 验证合并输出
validate_memory_maintenance_output(plan, output, max_merge_ratio, min_memories_to_keep, max_merge_per_group)
.map_err(|e| AgentError::Other(e))?;
let all_candidates = plan.candidates.clone();
let candidates_by_id = all_candidates
@ -866,11 +661,6 @@ pub(crate) fn apply_memory_maintenance_output(
}
}
// 新增:记录整理完成时间
let now = chrono::Utc::now().timestamp();
set_last_maintenance_time(store, scope_key, now)
.map_err(|err| AgentError::Other(format!("set last maintenance time error: {}", err)))?;
Ok(())
}

View File

@ -40,7 +40,6 @@ impl MemoryMaintenanceCoordinator {
Ok(MemoryMaintenanceService::new(
self.store.clone(),
self.provider_configs.default_provider_config(),
self.provider_configs.default_maintenance_config(),
))
}
}

View File

@ -27,32 +27,15 @@
- note: 冲突说明
- low_value_ids需要删除的低价值候选记忆 ID 数组
组织原则:
组织原则(由你自主决定)
- 根据记忆的语义内容自然分组
- **每次合并最多只能合并 2-3 条源记忆**
- **禁止跨 namespace 合并**(不同 namespace 代表不同信息维度)
- 过期、重复、过细的记忆可以标记为低值
- 根据记忆的语义内容自然分组,不必拘泥于预定义分类
- 相似的、互补的记忆可以合并
- 过期、重复、过细的记忆可以标记为低价值
- namespace 和 memory_key 的命名应当简洁、有意义
- **保守原则:宁可保留稍多,不可过度合并**
- **必须保留足够数量的记忆,确保信息多样性**
时间权重原则(关键):
- 每个候选记忆包含 `updated_at` 时间戳Unix timestamp
- **当多条记忆存在重复或冲突时,时间越新的权重越高**
- 合并时优先采用新记忆的内容,旧记忆作为补充或背景
- 如果新旧记忆内容完全相同,保留新的,删除旧的
- 时间戳数值越大表示越新(离当前时间越近)
合并限制(硬性约束,由系统强制检查):
- 单次合并最多来自 3 条源记忆
- 整理后保留的记忆数不得少于 5 条
- 单次整理最多影响 30% 的记忆
- 不同 namespace 的记忆不允许互相合并
- 可以自由创建新的 namespace 来组织相关记忆
额外约束:
- 只能引用输入里出现过的候选 id。
- 不输出 user_facts、preferences、behavior_patterns、managed_markdown 等摘要字段。
- 不输出 user_facts、preferences、behavior_patterns、managed_markdown 等摘要字段。

View File

@ -84,7 +84,6 @@ impl GatewayState {
Arc::new(BusSessionMessageSender::new(bus.clone())),
std::collections::HashSet::new(),
config.tools.task.clone(),
config.memory_maintenance.clone(),
chat_history_ttl_hours,
session_ttl_hours,
)?;

View File

@ -2,25 +2,22 @@ use std::collections::HashMap;
use std::sync::Arc;
use crate::agent::AgentError;
use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig};
use crate::config::LLMProviderConfig;
#[derive(Clone)]
pub(crate) struct ProviderConfigService {
default_provider_config: LLMProviderConfig,
provider_configs: Arc<HashMap<String, LLMProviderConfig>>,
maintenance_config: MemoryMaintenanceConfig,
}
impl ProviderConfigService {
pub(crate) fn new(
default_provider_config: LLMProviderConfig,
provider_configs: HashMap<String, LLMProviderConfig>,
maintenance_config: MemoryMaintenanceConfig,
) -> Self {
Self {
default_provider_config,
provider_configs: Arc::new(provider_configs),
maintenance_config,
}
}
@ -40,10 +37,6 @@ impl ProviderConfigService {
pub(crate) fn default_provider_config(&self) -> LLMProviderConfig {
self.default_provider_config.clone()
}
pub(crate) fn default_maintenance_config(&self) -> MemoryMaintenanceConfig {
self.maintenance_config.clone()
}
}
#[cfg(test)]
@ -79,7 +72,6 @@ mod tests {
"planner".to_string(),
test_provider_config_named("planner-provider", "planner-model"),
)]),
MemoryMaintenanceConfig::default(),
);
let selected = service.select(Some("planner")).unwrap();
@ -90,11 +82,7 @@ mod tests {
#[test]
fn test_select_falls_back_to_default() {
let default_provider = test_provider_config_named("default-provider", "default-model");
let service = ProviderConfigService::new(
default_provider,
HashMap::new(),
MemoryMaintenanceConfig::default(),
);
let service = ProviderConfigService::new(default_provider, HashMap::new());
let selected = service.select(Some("default")).unwrap();
assert_eq!(selected.name, "default-provider");

View File

@ -2,7 +2,7 @@ use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use crate::agent::AgentError;
use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig, TaskConfig};
use crate::config::{LLMProviderConfig, TaskConfig};
use crate::gateway::tool_registry_factory::ToolRegistryFactory;
use crate::skills::SkillRuntime;
use crate::storage::{
@ -34,7 +34,6 @@ pub(crate) fn build_session_manager(
skills: Arc<SkillRuntime>,
disabled_tools: HashSet<String>,
task_config: TaskConfig,
maintenance_config: MemoryMaintenanceConfig,
chat_history_ttl_hours: Option<u64>,
session_ttl_hours: Option<u64>,
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
@ -48,7 +47,6 @@ pub(crate) fn build_session_manager(
Arc::new(NoopSessionMessageSender),
disabled_tools,
task_config,
maintenance_config,
chat_history_ttl_hours,
session_ttl_hours,
)
@ -64,7 +62,6 @@ pub(crate) fn build_session_manager_with_sender(
session_message_sender: Arc<dyn SessionMessageSender>,
disabled_tools: HashSet<String>,
task_config: TaskConfig,
maintenance_config: MemoryMaintenanceConfig,
chat_history_ttl_hours: Option<u64>,
session_ttl_hours: Option<u64>,
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
@ -73,11 +70,7 @@ pub(crate) fn build_session_manager_with_sender(
.map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?,
);
let known_agents = provider_configs.keys().cloned().collect::<HashSet<_>>();
let provider_configs = ProviderConfigService::new(
provider_config.clone(),
provider_configs,
maintenance_config,
);
let provider_configs = ProviderConfigService::new(provider_config.clone(), provider_configs);
if let Err(err) =
store.append_skill_event(None, "discovered", None, &skills.discovery_event_payload())

View File

@ -501,7 +501,6 @@ impl SessionManager {
skills: Arc<SkillRuntime>,
disabled_tools: std::collections::HashSet<String>,
task_config: crate::config::TaskConfig,
maintenance_config: crate::config::MemoryMaintenanceConfig,
chat_history_ttl_hours: Option<u64>,
session_ttl_hours: Option<u64>,
) -> Result<Self, AgentError> {
@ -514,7 +513,6 @@ impl SessionManager {
skills,
disabled_tools,
task_config,
maintenance_config,
chat_history_ttl_hours,
session_ttl_hours,
)
@ -975,7 +973,6 @@ mod tests {
Arc::new(SkillRuntime::default()),
HashSet::new(),
crate::config::TaskConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(4),
Some(24),
)
@ -1028,7 +1025,6 @@ mod tests {
Arc::new(SkillRuntime::default()),
HashSet::new(),
crate::config::TaskConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(4),
Some(24),
)
@ -1097,7 +1093,6 @@ mod tests {
Arc::new(SkillRuntime::default()),
HashSet::new(),
crate::config::TaskConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(4),
Some(24),
)
@ -1183,7 +1178,6 @@ mod tests {
Arc::new(SkillRuntime::default()),
HashSet::new(),
crate::config::TaskConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(4),
Some(24),
)
@ -1270,7 +1264,6 @@ mod tests {
Arc::new(SkillRuntime::default()),
HashSet::new(),
crate::config::TaskConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(4),
Some(24),
)
@ -1356,7 +1349,6 @@ mod tests {
Arc::new(SkillRuntime::default()),
HashSet::new(),
crate::config::TaskConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(4),
Some(24),
)
@ -1424,7 +1416,6 @@ mod tests {
Arc::new(SkillRuntime::default()),
HashSet::new(),
crate::config::TaskConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(4),
Some(24),
)
@ -1501,7 +1492,6 @@ mod tests {
Arc::new(SkillRuntime::default()),
HashSet::new(),
crate::config::TaskConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(4),
Some(24),
)
@ -1565,7 +1555,6 @@ mod tests {
Arc::new(SkillRuntime::default()),
HashSet::new(),
crate::config::TaskConfig::default(),
crate::config::MemoryMaintenanceConfig::default(),
Some(4),
Some(24),
)
@ -1604,8 +1593,6 @@ mod tests {
let store = SessionStore::in_memory().unwrap();
let scope_key = "feishu:user-1";
// 创建足够的记忆7条让合并操作满足保护限制
// 合并后需要保留至少 5 条min_memories_to_keep
let work = store
.put_memory(&crate::storage::MemoryUpsert {
scope_kind: "user".to_string(),
@ -1652,30 +1639,9 @@ mod tests {
})
.unwrap();
// 添加额外的记忆以满足 min_memories_to_keep = 5 的要求
for i in 0..4 {
store
.put_memory(&crate::storage::MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: scope_key.to_string(),
namespace: "profile".to_string(),
memory_key: format!("extra_{}", i),
content: format!("额外记忆 {}", i),
source_type: "message".to_string(),
source_session_id: None,
source_message_id: None,
source_message_seq: None,
source_channel_name: None,
source_chat_id: None,
})
.unwrap();
}
let plan = build_memory_maintenance_plan(
&store.list_memories_for_scope("user", scope_key).unwrap(),
);
assert_eq!(plan.candidates.len(), 7); // 7 条候选记忆
let output = MemoryOrganizationOutput {
merges: vec![MemoryMaintenanceMerge {
source_ids: vec![work.id.clone(), role.id.clone()],
@ -1687,25 +1653,13 @@ mod tests {
low_value_ids: vec![noise.id.clone()],
};
// 使用默认配置进行验证
apply_memory_maintenance_output(
&store,
scope_key,
&plan,
&output,
crate::config::MemoryMaintenanceConfig::default().max_merge_ratio,
crate::config::MemoryMaintenanceConfig::default().min_memories_to_keep,
crate::config::MemoryMaintenanceConfig::default().max_merge_per_group,
)
.unwrap();
apply_memory_maintenance_output(&store, scope_key, &plan, &output).unwrap();
let all_memories = store.list_memories_for_scope("user", scope_key).unwrap();
// 过滤掉 _meta 记录
let user_memories: Vec<_> = all_memories.iter().filter(|m| m.namespace != "_meta").collect();
// 合并 2 条为 1 条,删除 1 条7 - 2 + 1 = 6 条(加上 _meta 记录)
assert_eq!(user_memories.len(), 6);
// 验证合并后的记忆存在
assert!(user_memories.iter().any(|m| m.namespace == "profile" && m.memory_key == "work"));
assert_eq!(all_memories.len(), 1);
assert_eq!(all_memories[0].namespace, "profile");
assert_eq!(all_memories[0].memory_key, "work");
assert_eq!(all_memories[0].content, "用户主要在做AI产品设计与实现");
}
#[test]

View File

@ -4,15 +4,6 @@ use clap::{CommandFactory, Parser};
#[command(name = "picobot")]
#[command(about = "A CLI chatbot", long_about = None)]
enum Command {
/// Interactive configuration wizard
Init {
/// Force overwrite existing config
#[arg(short, long)]
force: bool,
/// Only configure provider, skip channels
#[arg(long)]
skip_channels: bool,
},
/// Connect to gateway
Agent {
/// Gateway WebSocket URL (e.g., ws://127.0.0.1:19876/ws)
@ -40,14 +31,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
if std::env::args().len() <= 1 {
cmd.print_help()?;
println!();
return Ok(())
return Ok(());
}
match Command::parse() {
Command::Init { force, skip_channels } => {
let mut wizard = picobot::cli::InitWizard::new();
wizard.run(force, skip_channels).await?;
}
Command::Agent { gateway_url } => {
let config = picobot::config::Config::load_default().ok();
let url = gateway_url

View File

@ -8,8 +8,6 @@ use super::traits::Usage;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
use crate::domain::messages::ContentBlock;
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["supported_content_types"];
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
let mut details = vec![error.to_string()];
let mut current = error.source();
@ -32,65 +30,7 @@ where
serializer.serialize_str(&serde_json::to_string(blocks).unwrap_or_else(|_| "[]".to_string()))
}
fn convert_content_blocks(
supports_images: bool,
provider_name: &str,
model_id: &str,
blocks: &[ContentBlock],
message_idx: usize,
) -> Vec<serde_json::Value> {
// 检查是否有图片且模型不支持
if !supports_images {
let has_images = blocks.iter().any(|b| matches!(b, ContentBlock::ImageUrl { .. }));
if has_images {
let image_count = blocks
.iter()
.filter(|b| matches!(b, ContentBlock::ImageUrl { .. }))
.count();
tracing::warn!(
provider = %provider_name,
model = %model_id,
filtered_images = image_count,
message_idx,
"模型不支持图片;将图片转换为通知文本"
);
// 复用通知格式,将图片转换为文本通知
let mut converted_blocks: Vec<serde_json::Value> = Vec::new();
let mut notices: Vec<String> = Vec::new();
let mut image_idx = 0;
for block in blocks.iter() {
match block {
ContentBlock::Text { text } => {
converted_blocks.push(serde_json::json!({ "type": "text", "text": text }));
}
ContentBlock::ImageUrl { .. } => {
image_idx += 1;
notices.push(format!(
"- 第 {} 张图片:当前模型不支持图片输入,该图片未能成功入模,请直接告知用户。",
image_idx
));
}
}
}
// 添加通知文本块
if !notices.is_empty() {
let notice_text = format!(
"[系统提示] 以下图片未能成功入模:\n{}",
notices.join("\n")
);
converted_blocks.push(serde_json::json!({ "type": "text", "text": notice_text }));
}
return converted_blocks;
}
}
// 原有逻辑 - 模型支持图片,正常转换
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> {
blocks
.iter()
.map(|b| match b {
@ -172,32 +112,6 @@ impl AnthropicProvider {
model_extra,
}
}
/// 检查模型是否支持指定内容类型
/// 默认支持所有类型text, image
fn supports_content_type(&self, content_type: &str) -> bool {
self.model_extra
.get("supported_content_types")
.and_then(|value| value.as_array())
.map(|types| {
types.iter().any(|t| t.as_str() == Some(content_type))
})
.unwrap_or(true)
}
/// 检查模型是否支持图片
fn supports_images(&self) -> bool {
self.supports_content_type("image")
}
/// 过滤掉内部字段,只返回需要发送到 API 的 extra 字段
fn request_model_extra(&self) -> HashMap<String, serde_json::Value> {
self.model_extra
.iter()
.filter(|(key, _)| !INTERNAL_MODEL_EXTRA_KEYS.contains(&key.as_str()))
.map(|(k, v)| (k.clone(), v.clone()))
.collect()
}
}
#[derive(Serialize)]
@ -283,22 +197,15 @@ impl LLMProvider for AnthropicProvider {
messages: request
.messages
.iter()
.enumerate()
.map(|(i, m)| AnthropicMessage {
.map(|m| AnthropicMessage {
role: m.role.clone(),
content: convert_content_blocks(
self.supports_images(),
&self.name,
&self.model_id,
&m.content,
i,
),
content: convert_content_blocks(&m.content),
})
.collect(),
max_tokens,
temperature: request.temperature.or(self.temperature),
tools,
extra: self.request_model_extra(),
extra: self.model_extra.clone(),
};
let mut req_builder = self

View File

@ -10,7 +10,7 @@ use super::traits::Usage;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
use crate::domain::messages::ContentBlock;
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content", "supported_content_types"];
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content"];
/// 流式响应中的工具调用增量
#[derive(Debug, Default)]
@ -139,75 +139,7 @@ fn format_transport_error_context(
)
}
fn convert_content_blocks(
supports_images: bool,
provider_name: &str,
model_id: &str,
blocks: &[ContentBlock],
message_idx: usize,
) -> Value {
// 检查是否有图片且模型不支持
if !supports_images {
let has_images = blocks.iter().any(|b| matches!(b, ContentBlock::ImageUrl { .. }));
if has_images {
let image_count = blocks.iter()
.filter(|b| matches!(b, ContentBlock::ImageUrl { .. }))
.count();
tracing::warn!(
provider = %provider_name,
model = %model_id,
filtered_images = image_count,
message_idx,
"模型不支持图片;将图片转换为通知文本"
);
// 复用通知格式,将图片转换为文本通知
let mut converted_blocks: Vec<Value> = Vec::new();
let mut notices: Vec<String> = Vec::new();
let mut image_idx = 0;
for block in blocks.iter() {
match block {
ContentBlock::Text { text } => {
converted_blocks.push(json!({ "type": "text", "text": text }));
}
ContentBlock::ImageUrl { .. } => {
image_idx += 1;
notices.push(format!(
"- 第 {} 张图片:当前模型不支持图片输入,该图片未能成功入模,请直接告知用户。",
image_idx
));
}
}
}
// 添加通知文本块
if !notices.is_empty() {
let notice_text = format!(
"[系统提示] 以下图片未能成功入模:\n{}",
notices.join("\n")
);
converted_blocks.push(json!({ "type": "text", "text": notice_text }));
}
// 如果只有一个文本块且没有通知,返回字符串形式
if converted_blocks.len() == 1 {
if let Some(block) = converted_blocks.first() {
if block.get("type").and_then(|t| t.as_str()) == Some("text") {
if let Some(text) = block.get("text").and_then(|t| t.as_str()) {
return Value::String(text.to_string());
}
}
}
}
return Value::Array(converted_blocks);
}
}
// 原有逻辑 - 模型支持图片,正常转换
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
if blocks.len() == 1 {
if let ContentBlock::Text { text } = &blocks[0] {
return Value::String(text.clone());
@ -292,23 +224,6 @@ impl OpenAIProvider {
.unwrap_or(true)
}
/// 检查模型是否支持指定内容类型
/// 默认支持所有类型text, image
fn supports_content_type(&self, content_type: &str) -> bool {
self.model_extra
.get("supported_content_types")
.and_then(|value| value.as_array())
.map(|types| {
types.iter().any(|t| t.as_str() == Some(content_type))
})
.unwrap_or(true)
}
/// 检查模型是否支持图片
fn supports_images(&self) -> bool {
self.supports_content_type("image")
}
fn normalize_tool_arguments(&self, arguments: &Value) -> Value {
match arguments {
Value::String(raw) => serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone()),
@ -565,21 +480,20 @@ impl OpenAIProvider {
}
fn build_request_body(&self, request: &ChatCompletionRequest) -> Value {
let supports_images = self.supports_images();
let mut body = json!({
"model": self.model_id,
"messages": request.messages.iter().enumerate().map(|(i, m)| {
"messages": request.messages.iter().map(|m| {
if m.role == "tool" {
json!({
"role": m.role,
"content": convert_content_blocks(supports_images, &self.name, &self.model_id, &m.content, i),
"content": convert_content_blocks(&m.content),
"tool_call_id": m.tool_call_id,
"name": m.name,
})
} else if m.role == "assistant" && m.tool_calls.is_some() {
let mut message = json!({
"role": m.role,
"content": convert_content_blocks(supports_images, &self.name, &self.model_id, &m.content, i),
"content": convert_content_blocks(&m.content),
"tool_calls": m.tool_calls.as_ref().map(|calls| {
calls.iter().map(|call| json!({
"id": call.id,
@ -600,7 +514,7 @@ impl OpenAIProvider {
} else {
let mut message = json!({
"role": m.role,
"content": convert_content_blocks(supports_images, &self.name, &self.model_id, &m.content, i)
"content": convert_content_blocks(&m.content)
});
if m.role == "assistant" {
@ -1162,115 +1076,4 @@ mod tests {
assert_eq!(response.tool_calls[1].id, "call_2");
assert_eq!(response.tool_calls[1].name, "get_time");
}
#[test]
fn test_supports_images_default_true() {
let provider = OpenAIProvider::new(
"test".to_string(),
"key".to_string(),
"https://example.com/v1".to_string(),
HashMap::new(),
120,
"gpt-test".to_string(),
None,
None,
HashMap::new(),
);
assert!(provider.supports_images());
}
#[test]
fn test_supports_images_disabled_via_config() {
let provider = OpenAIProvider::new(
"test".to_string(),
"key".to_string(),
"https://example.com/v1".to_string(),
HashMap::new(),
120,
"gpt-test".to_string(),
None,
None,
HashMap::from([(
"supported_content_types".to_string(),
Value::Array(vec![Value::String("text".to_string())]),
)]),
);
assert!(!provider.supports_images());
}
#[test]
fn test_convert_content_blocks_converts_images_to_notice_when_disabled() {
let blocks = vec![
ContentBlock::text("hello"),
ContentBlock::image_url("data:image/png;base64,abc123"),
ContentBlock::text("world"),
];
let result = convert_content_blocks(false, "test", "test-model", &blocks, 0);
// 应该是数组形式
let arr = result.as_array().unwrap();
assert_eq!(arr.len(), 3); // 两个文本块 + 一个通知块
// 检查通知内容
let notice_block = arr[2].as_object().unwrap();
assert_eq!(notice_block["type"], "text");
let notice_text = notice_block["text"].as_str().unwrap();
assert!(notice_text.contains("[系统提示] 以下图片未能成功入模"));
assert!(notice_text.contains("第 1 张图片"));
assert!(notice_text.contains("当前模型不支持图片输入"));
}
#[test]
fn test_convert_content_blocks_keeps_images_when_enabled() {
let blocks = vec![
ContentBlock::text("hello"),
ContentBlock::image_url("data:image/png;base64,abc123"),
];
let result = convert_content_blocks(true, "test", "test-model", &blocks, 0);
// 应该是数组形式,包含文本和图片
let arr = result.as_array().unwrap();
assert_eq!(arr.len(), 2);
assert_eq!(arr[0]["type"], "text");
assert_eq!(arr[1]["type"], "image_url");
}
#[test]
fn test_build_request_body_omits_supported_content_types_from_api() {
let provider = OpenAIProvider::new(
"test".to_string(),
"key".to_string(),
"https://example.com/v1".to_string(),
HashMap::new(),
120,
"gpt-test".to_string(),
None,
None,
HashMap::from([
(
"supported_content_types".to_string(),
Value::Array(vec![Value::String("text".to_string())]),
),
("custom_param".to_string(), Value::String("value".to_string())),
]),
);
let request = ChatCompletionRequest {
messages: vec![Message::user("hello")],
temperature: None,
max_tokens: None,
tools: None,
};
let body = provider.build_request_body(&request);
// supported_content_types 不应该发送到 API
assert!(body.get("supported_content_types").is_none());
// custom_param 应该保留
assert_eq!(body["custom_param"], Value::String("value".to_string()));
}
}

View File

@ -341,9 +341,7 @@ impl SkillSource {
#[derive(Debug, Clone)]
pub struct SkillCatalog {
skills: Vec<Skill>,
#[allow(dead_code)]
max_index_chars: usize,
#[allow(dead_code)]
max_listed_skills: usize,
}