Compare commits
4 Commits
c36650c9aa
...
b4ef56803f
| Author | SHA1 | Date | |
|---|---|---|---|
| b4ef56803f | |||
| 44d9171b86 | |||
| a74c801945 | |||
| 4a24758262 |
38
README.md
38
README.md
@ -488,46 +488,52 @@ tools 配置示例:
|
||||
可用工具名称:
|
||||
- calculator - 数学计算器
|
||||
- get_time - 获取当前时间
|
||||
- file_read - 读取文件
|
||||
- file_write - 写入文件
|
||||
- file_edit - 编辑文件
|
||||
- read - 读取文件
|
||||
- write - 写入文件
|
||||
- edit - 编辑文件
|
||||
- memory_search - 搜索长期记忆
|
||||
- memory_manage - 管理长期记忆
|
||||
- session_send - 发送会话消息
|
||||
- send_session_message - 发送会话消息
|
||||
- scheduler_manage - 管理定时任务
|
||||
- skill_activate - 激活技能
|
||||
- skill_list - 列出技能
|
||||
- skill_manage - 管理技能
|
||||
- bash - 执行 shell 命令
|
||||
- skill_manage - 管理技能(含 list 功能)
|
||||
- bash - 执行 shell 命令(Unix/Linux/macOS)
|
||||
- shell - 执行 shell 命令(Windows PowerShell/Cmd)
|
||||
- http_request - HTTP 请求
|
||||
- web_fetch - 网页抓取
|
||||
- task - 创建和管理子代理
|
||||
|
||||
注意:bash 和 shell 是同一个工具在不同平台上的名称,运行时自动检测。
|
||||
|
||||
## 8. 工具机制
|
||||
|
||||
PicoBot 的 Agent 是围绕工具调用构建的。当前默认注册的工具包括:
|
||||
|
||||
- calculator:简单数学计算
|
||||
- time:获取当前时间与时区上下文
|
||||
- file_read:读取文件
|
||||
- file_write:写文件
|
||||
- file_edit:编辑文件
|
||||
- get_time:获取当前时间与时区上下文
|
||||
- read:读取文件
|
||||
- write:写文件
|
||||
- edit:编辑文件
|
||||
- memory_search:读取长期记忆
|
||||
- memory_manage:写入 / 更新 / 删除长期记忆
|
||||
- send_session_message:发送会话消息
|
||||
- scheduler_manage:管理调度任务
|
||||
- skill_activate:读取并激活某个技能内容
|
||||
- skill_list:列出技能
|
||||
- skill_manage:管理技能
|
||||
- bash:执行 shell 命令
|
||||
- skill_manage:管理技能(支持 list, get, create, update, delete, disable, reload)
|
||||
- bash / shell:执行 shell 命令(同一工具,Unix 下名称为 bash,Windows 下名称为 shell)
|
||||
- http_request:发起 HTTP 请求
|
||||
- web_fetch:抓取网页正文
|
||||
- task:创建和管理子代理
|
||||
|
||||
其中:
|
||||
|
||||
- 文件工具适合做代码库和文档操作
|
||||
- read / write / edit 文件工具适合做代码库和文档操作
|
||||
- 记忆工具适合维持长期用户画像
|
||||
- scheduler_manage 允许 Agent 自主创建后续计划任务
|
||||
- skill_activate 负责把具体技能正文注入当前任务上下文
|
||||
- bash / http_request / web_fetch 让 Agent 具备更强的外部交互能力
|
||||
- skill_manage 整合了技能列出与管理功能,支持运行时创建、更新、删除和批量禁用
|
||||
- bash / shell / http_request / web_fetch 让 Agent 具备更强的外部交互能力(bash 和 shell 是同一工具在不同平台的名称)
|
||||
- task 允许 Agent 创建独立上下文的子代理来处理复杂多步骤任务,支持 general 和 explore 两种类型
|
||||
|
||||
## 9. 调度器机制
|
||||
|
||||
|
||||
919
src/cli/init.rs
Normal file
919
src/cli/init.rs
Normal file
@ -0,0 +1,919 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
|
||||
use crate::config::{
|
||||
AgentConfig, ChannelConfig, Config, FeishuChannelConfig, GatewayConfig, ModelConfig,
|
||||
ProviderConfig, SchedulerConfig, TaggedChannelConfig, WechatChannelConfig,
|
||||
};
|
||||
|
||||
/// Interactive configuration wizard for PicoBot
|
||||
pub struct InitWizard {
|
||||
read: BufReader<tokio::io::Stdin>,
|
||||
write: tokio::io::Stdout,
|
||||
config_path: PathBuf,
|
||||
}
|
||||
|
||||
impl InitWizard {
|
||||
pub fn new() -> Self {
|
||||
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
|
||||
Self {
|
||||
read: BufReader::new(tokio::io::stdin()),
|
||||
write: tokio::io::stdout(),
|
||||
config_path: home.join(".picobot").join("config.json"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(&mut self, force: bool, skip_channels: bool) -> Result<(), InitError> {
|
||||
let existing = self.load_existing_config(force)?;
|
||||
|
||||
self.show_welcome_message();
|
||||
|
||||
// Step 1: Provider
|
||||
let providers = self.configure_provider(&existing).await?;
|
||||
|
||||
// Step 2: Model
|
||||
let models = self.configure_model(&existing).await?;
|
||||
|
||||
// Step 3: Agent
|
||||
let agents = self.configure_agent(&existing, &providers, &models).await?;
|
||||
|
||||
// Step 4: Channels
|
||||
let channels = if !skip_channels {
|
||||
self.configure_channels(&existing).await?
|
||||
} else {
|
||||
existing.channels.clone()
|
||||
};
|
||||
|
||||
let config = self.build_config(providers, models, agents, channels, &existing);
|
||||
|
||||
self.save_config(&config)?;
|
||||
|
||||
self.show_completion_message(&config);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_existing_config(&self, force: bool) -> Result<Config, InitError> {
|
||||
if self.config_path.exists() && !force {
|
||||
Config::load_default().or_else(|_| Ok(Self::empty_config()))
|
||||
} else {
|
||||
Ok(Self::empty_config())
|
||||
}
|
||||
}
|
||||
|
||||
fn empty_config() -> Config {
|
||||
Config {
|
||||
providers: HashMap::new(),
|
||||
models: HashMap::new(),
|
||||
agents: HashMap::new(),
|
||||
time: crate::config::TimeConfig::default(),
|
||||
gateway: GatewayConfig::default(),
|
||||
scheduler: SchedulerConfig::default(),
|
||||
client: crate::config::ClientConfig::default(),
|
||||
channels: HashMap::new(),
|
||||
skills: crate::config::SkillsConfig::default(),
|
||||
tools: crate::config::ToolsConfig::default(),
|
||||
memory_maintenance: crate::config::MemoryMaintenanceConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn show_welcome_message(&mut self) {
|
||||
println!();
|
||||
println!("╔══════════════════════════════════════════════════════╗");
|
||||
println!("║ PicoBot Configuration Wizard ║");
|
||||
println!("╚══════════════════════════════════════════════════════╝");
|
||||
println!();
|
||||
println!("This wizard will help you configure PicoBot.");
|
||||
println!("Press Enter to use default values where shown.");
|
||||
println!();
|
||||
}
|
||||
|
||||
// ==================== Prompt Helpers ====================
|
||||
|
||||
async fn prompt_with_default(
|
||||
&mut self,
|
||||
label: &str,
|
||||
default: &str,
|
||||
) -> Result<String, InitError> {
|
||||
if !default.is_empty() {
|
||||
println!("{} [default: {}]: ", label, default);
|
||||
} else {
|
||||
println!("{}: ", label);
|
||||
}
|
||||
self.write.flush().await?;
|
||||
|
||||
let mut line = String::new();
|
||||
let bytes_read = self.read.read_line(&mut line).await?;
|
||||
|
||||
if bytes_read == 0 {
|
||||
return Err(InitError::InputError("EOF reached".to_string()));
|
||||
}
|
||||
|
||||
let input = line.trim().to_string();
|
||||
Ok(if input.is_empty() { default.to_string() } else { input })
|
||||
}
|
||||
|
||||
async fn prompt_required(&mut self, label: &str) -> Result<String, InitError> {
|
||||
loop {
|
||||
println!("{}: ", label);
|
||||
self.write.flush().await?;
|
||||
|
||||
let mut line = String::new();
|
||||
let bytes_read = self.read.read_line(&mut line).await?;
|
||||
|
||||
if bytes_read == 0 {
|
||||
return Err(InitError::InputError("EOF reached".to_string()));
|
||||
}
|
||||
|
||||
let input = line.trim().to_string();
|
||||
if !input.is_empty() {
|
||||
return Ok(input);
|
||||
}
|
||||
println!("{} is required. Please enter a value.", label);
|
||||
}
|
||||
}
|
||||
|
||||
async fn prompt_select(
|
||||
&mut self,
|
||||
label: &str,
|
||||
options: &[String],
|
||||
default: usize,
|
||||
) -> Result<usize, InitError> {
|
||||
for (i, opt) in options.iter().enumerate() {
|
||||
println!(" {}. {}", i + 1, opt);
|
||||
}
|
||||
println!();
|
||||
|
||||
let default_str = (default + 1).to_string();
|
||||
let input = self.prompt_with_default(label, &default_str).await?;
|
||||
|
||||
let selected: usize = input.parse().map_err(|_| {
|
||||
InitError::InputError(format!("Invalid selection: {}", input))
|
||||
})?;
|
||||
|
||||
if selected == 0 || selected > options.len() {
|
||||
return Err(InitError::InputError(format!(
|
||||
"Selection must be between 1 and {}",
|
||||
options.len()
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(selected - 1)
|
||||
}
|
||||
|
||||
// ==================== Step 1: Provider ====================
|
||||
|
||||
async fn configure_provider(
|
||||
&mut self,
|
||||
existing: &Config,
|
||||
) -> Result<HashMap<String, ProviderConfig>, InitError> {
|
||||
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
||||
println!("Step 1: Configure Provider (API Endpoint)");
|
||||
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
||||
println!();
|
||||
|
||||
let provider_names: Vec<String> = existing.providers.keys().cloned().collect();
|
||||
|
||||
if !provider_names.is_empty() {
|
||||
println!("Existing providers:");
|
||||
for (i, name) in provider_names.iter().enumerate() {
|
||||
println!(" {}. {}", i + 1, name);
|
||||
}
|
||||
println!();
|
||||
|
||||
println!("Options:");
|
||||
println!(" 1. Add new provider");
|
||||
println!(" 2. Modify existing provider");
|
||||
println!(" 3. Use existing provider");
|
||||
println!(" 4. Skip");
|
||||
println!();
|
||||
|
||||
let choice = self
|
||||
.prompt_with_default("Select option", "1")
|
||||
.await?;
|
||||
|
||||
match choice.as_str() {
|
||||
"1" => return self.add_provider(existing).await,
|
||||
"2" => return self.modify_provider(existing).await,
|
||||
"3" => {
|
||||
println!("Keeping existing providers.");
|
||||
return Ok(existing.providers.clone());
|
||||
}
|
||||
"4" => {
|
||||
println!("Skipping provider configuration.");
|
||||
return Ok(existing.providers.clone());
|
||||
}
|
||||
_ => {
|
||||
println!("Invalid option, adding new provider.");
|
||||
return self.add_provider(existing).await;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
println!("No existing providers found.");
|
||||
println!();
|
||||
println!("Options:");
|
||||
println!(" 1. Add new provider");
|
||||
println!(" 2. Skip");
|
||||
println!();
|
||||
|
||||
let choice = self.prompt_with_default("Select option", "1").await?;
|
||||
|
||||
match choice.as_str() {
|
||||
"1" => self.add_provider(existing).await,
|
||||
"2" => {
|
||||
println!("Skipping provider configuration.");
|
||||
Ok(existing.providers.clone())
|
||||
}
|
||||
_ => {
|
||||
println!("Invalid option, adding new provider.");
|
||||
self.add_provider(existing).await
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn add_provider(
|
||||
&mut self,
|
||||
existing: &Config,
|
||||
) -> Result<HashMap<String, ProviderConfig>, InitError> {
|
||||
let provider_name = self
|
||||
.prompt_with_default("Provider name", "default")
|
||||
.await?;
|
||||
|
||||
println!("Provider type:");
|
||||
println!(" 1. openai");
|
||||
println!(" 2. anthropic");
|
||||
println!();
|
||||
let type_choice = self.prompt_with_default("Select type", "1").await?;
|
||||
let provider_type = match type_choice.as_str() {
|
||||
"1" => "openai",
|
||||
"2" => "anthropic",
|
||||
_ => "openai",
|
||||
};
|
||||
|
||||
let base_url = self.prompt_required("API base URL").await?;
|
||||
|
||||
println!();
|
||||
println!("API Key is required and will be stored in config.json.");
|
||||
let api_key = self.prompt_required("API Key").await?;
|
||||
|
||||
let provider = ProviderConfig {
|
||||
provider_type: provider_type.to_string(),
|
||||
base_url,
|
||||
api_key,
|
||||
extra_headers: HashMap::new(),
|
||||
llm_timeout_secs: 120,
|
||||
memory_maintenance_timeout_secs: 600,
|
||||
};
|
||||
|
||||
let mut providers = existing.providers.clone();
|
||||
providers.insert(provider_name.clone(), provider);
|
||||
|
||||
println!();
|
||||
println!("Provider '{}' configured.", provider_name);
|
||||
|
||||
Ok(providers)
|
||||
}
|
||||
|
||||
async fn modify_provider(
|
||||
&mut self,
|
||||
existing: &Config,
|
||||
) -> Result<HashMap<String, ProviderConfig>, InitError> {
|
||||
// Select which provider to modify
|
||||
let provider_names: Vec<String> = existing.providers.keys().cloned().collect();
|
||||
println!("Select provider to modify:");
|
||||
let provider_idx = self.prompt_select("", &provider_names, 0).await?;
|
||||
let selected_provider_name = &provider_names[provider_idx];
|
||||
let current_provider = existing.providers.get(selected_provider_name).unwrap();
|
||||
|
||||
println!();
|
||||
println!(
|
||||
"Current config: type={}, base_url={}",
|
||||
current_provider.provider_type, current_provider.base_url
|
||||
);
|
||||
println!();
|
||||
|
||||
let current_type_idx = if current_provider.provider_type == "anthropic" {
|
||||
1
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let type_options = vec!["openai".to_string(), "anthropic".to_string()];
|
||||
println!("Provider type:");
|
||||
let type_idx = self.prompt_select("", &type_options, current_type_idx).await?;
|
||||
let provider_type = &type_options[type_idx];
|
||||
|
||||
let base_url = self
|
||||
.prompt_with_default("API base URL", ¤t_provider.base_url)
|
||||
.await?;
|
||||
|
||||
println!();
|
||||
println!("API Key (press Enter to keep current key):");
|
||||
let api_key = self
|
||||
.prompt_with_default("API Key", ¤t_provider.api_key)
|
||||
.await?;
|
||||
|
||||
let provider = ProviderConfig {
|
||||
provider_type: provider_type.clone(),
|
||||
base_url,
|
||||
api_key,
|
||||
extra_headers: current_provider.extra_headers.clone(),
|
||||
llm_timeout_secs: current_provider.llm_timeout_secs,
|
||||
memory_maintenance_timeout_secs: current_provider.memory_maintenance_timeout_secs,
|
||||
};
|
||||
|
||||
let mut providers = existing.providers.clone();
|
||||
providers.insert(selected_provider_name.clone(), provider);
|
||||
|
||||
println!();
|
||||
println!("Provider '{}' modified.", selected_provider_name);
|
||||
|
||||
Ok(providers)
|
||||
}
|
||||
|
||||
// ==================== Step 2: Model ====================
|
||||
|
||||
async fn configure_model(
|
||||
&mut self,
|
||||
existing: &Config,
|
||||
) -> Result<HashMap<String, ModelConfig>, InitError> {
|
||||
println!();
|
||||
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
||||
println!("Step 2: Configure Model");
|
||||
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
||||
println!();
|
||||
|
||||
let model_names: Vec<String> = existing.models.keys().cloned().collect();
|
||||
|
||||
if !model_names.is_empty() {
|
||||
println!("Existing models:");
|
||||
for (i, name) in model_names.iter().enumerate() {
|
||||
let model = existing.models.get(name).unwrap();
|
||||
println!(" {}. {} (model_id: {})", i + 1, name, model.model_id);
|
||||
}
|
||||
println!();
|
||||
|
||||
println!("Options:");
|
||||
println!(" 1. Use existing model");
|
||||
println!(" 2. Add new model");
|
||||
println!(" 3. Skip");
|
||||
println!();
|
||||
|
||||
let choice = self.prompt_with_default("Select option", "1").await?;
|
||||
|
||||
match choice.as_str() {
|
||||
"1" => {
|
||||
println!("Keeping existing models.");
|
||||
return Ok(existing.models.clone());
|
||||
}
|
||||
"2" => return self.add_model(existing).await,
|
||||
"3" => {
|
||||
println!("Skipping model configuration.");
|
||||
return Ok(existing.models.clone());
|
||||
}
|
||||
_ => {
|
||||
println!("Invalid option, keeping existing models.");
|
||||
return Ok(existing.models.clone());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
println!("No existing models found.");
|
||||
println!();
|
||||
println!("Options:");
|
||||
println!(" 1. Add new model");
|
||||
println!(" 2. Skip");
|
||||
println!();
|
||||
|
||||
let choice = self.prompt_with_default("Select option", "1").await?;
|
||||
|
||||
match choice.as_str() {
|
||||
"1" => self.add_model(existing).await,
|
||||
"2" => {
|
||||
println!("Skipping model configuration.");
|
||||
Ok(existing.models.clone())
|
||||
}
|
||||
_ => self.add_model(existing).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn add_model(
|
||||
&mut self,
|
||||
existing: &Config,
|
||||
) -> Result<HashMap<String, ModelConfig>, InitError> {
|
||||
let model_name = self
|
||||
.prompt_with_default("Model name (reference key)", "default")
|
||||
.await?;
|
||||
|
||||
let model_id = self.prompt_required("Model ID (actual identifier)").await?;
|
||||
|
||||
let temperature_str = self
|
||||
.prompt_with_default("Temperature (0.0-1.0)", "0.7")
|
||||
.await?;
|
||||
let temperature: Option<f32> = temperature_str
|
||||
.parse()
|
||||
.ok()
|
||||
.filter(|v| (0.0..=1.0).contains(v));
|
||||
|
||||
let max_tokens_str = self
|
||||
.prompt_with_default("Max tokens (optional)", "4096")
|
||||
.await?;
|
||||
let max_tokens: Option<u32> = max_tokens_str.parse().ok();
|
||||
|
||||
let model = ModelConfig {
|
||||
model_id,
|
||||
temperature,
|
||||
max_tokens,
|
||||
context_window_tokens: Some(128000),
|
||||
extra: HashMap::new(),
|
||||
};
|
||||
|
||||
let mut models = existing.models.clone();
|
||||
models.insert(model_name.clone(), model);
|
||||
|
||||
println!();
|
||||
println!("Model '{}' configured.", model_name);
|
||||
|
||||
Ok(models)
|
||||
}
|
||||
|
||||
// ==================== Step 3: Agent ====================
|
||||
|
||||
async fn configure_agent(
|
||||
&mut self,
|
||||
existing: &Config,
|
||||
providers: &HashMap<String, ProviderConfig>,
|
||||
models: &HashMap<String, ModelConfig>,
|
||||
) -> Result<HashMap<String, AgentConfig>, InitError> {
|
||||
println!();
|
||||
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
||||
println!("Step 3: Configure Agent");
|
||||
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
||||
println!();
|
||||
|
||||
// Check if we have providers and models to select from
|
||||
if providers.is_empty() {
|
||||
println!("Warning: No providers configured. Please configure a provider first.");
|
||||
return Ok(existing.agents.clone());
|
||||
}
|
||||
if models.is_empty() {
|
||||
println!("Warning: No models configured. Please configure a model first.");
|
||||
return Ok(existing.agents.clone());
|
||||
}
|
||||
|
||||
let agent_names: Vec<String> = existing.agents.keys().cloned().collect();
|
||||
|
||||
if !agent_names.is_empty() {
|
||||
println!("Existing agents:");
|
||||
for (i, name) in agent_names.iter().enumerate() {
|
||||
let agent = existing.agents.get(name).unwrap();
|
||||
println!(
|
||||
" {}. {} (provider: {}, model: {})",
|
||||
i + 1,
|
||||
name,
|
||||
agent.provider,
|
||||
agent.model
|
||||
);
|
||||
}
|
||||
println!();
|
||||
|
||||
println!("Options:");
|
||||
println!(" 1. Add new agent");
|
||||
println!(" 2. Modify existing agent");
|
||||
println!(" 3. Use existing agent");
|
||||
println!(" 4. Skip");
|
||||
println!();
|
||||
|
||||
let choice = self.prompt_with_default("Select option", "1").await?;
|
||||
|
||||
match choice.as_str() {
|
||||
"1" => return self.add_agent(existing, providers, models).await,
|
||||
"2" => return self.modify_agent(existing, providers, models).await,
|
||||
"3" => {
|
||||
println!("Keeping existing agents.");
|
||||
return Ok(existing.agents.clone());
|
||||
}
|
||||
"4" => {
|
||||
println!("Skipping agent configuration.");
|
||||
return Ok(existing.agents.clone());
|
||||
}
|
||||
_ => {
|
||||
println!("Invalid option, adding new agent.");
|
||||
return self.add_agent(existing, providers, models).await;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
println!("No existing agents found.");
|
||||
println!();
|
||||
println!("Options:");
|
||||
println!(" 1. Add new agent");
|
||||
println!(" 2. Skip");
|
||||
println!();
|
||||
|
||||
let choice = self.prompt_with_default("Select option", "1").await?;
|
||||
|
||||
match choice.as_str() {
|
||||
"1" => self.add_agent(existing, providers, models).await,
|
||||
"2" => {
|
||||
println!("Skipping agent configuration.");
|
||||
Ok(existing.agents.clone())
|
||||
}
|
||||
_ => self.add_agent(existing, providers, models).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn add_agent(
|
||||
&mut self,
|
||||
existing: &Config,
|
||||
providers: &HashMap<String, ProviderConfig>,
|
||||
models: &HashMap<String, ModelConfig>,
|
||||
) -> Result<HashMap<String, AgentConfig>, InitError> {
|
||||
let agent_name = self
|
||||
.prompt_with_default("Agent name", "default")
|
||||
.await?;
|
||||
|
||||
// Select provider
|
||||
let provider_names: Vec<String> = providers.keys().cloned().collect();
|
||||
println!("Select provider:");
|
||||
let provider_idx = self.prompt_select("", &provider_names, 0).await?;
|
||||
let selected_provider = &provider_names[provider_idx];
|
||||
|
||||
// Select model (independently)
|
||||
let model_names: Vec<String> = models.keys().cloned().collect();
|
||||
println!();
|
||||
println!("Select model:");
|
||||
let model_idx = self.prompt_select("", &model_names, 0).await?;
|
||||
let selected_model = &model_names[model_idx];
|
||||
|
||||
let agent = AgentConfig {
|
||||
provider: selected_provider.clone(),
|
||||
model: selected_model.clone(),
|
||||
max_tool_iterations: 100,
|
||||
tool_result_max_chars: 20000,
|
||||
context_tool_result_trim_chars: 2000,
|
||||
};
|
||||
|
||||
let mut agents = existing.agents.clone();
|
||||
agents.insert(agent_name.clone(), agent);
|
||||
|
||||
println!();
|
||||
println!(
|
||||
"Agent '{}' configured (provider: {}, model: {}).",
|
||||
agent_name, selected_provider, selected_model
|
||||
);
|
||||
|
||||
Ok(agents)
|
||||
}
|
||||
|
||||
async fn modify_agent(
|
||||
&mut self,
|
||||
existing: &Config,
|
||||
providers: &HashMap<String, ProviderConfig>,
|
||||
models: &HashMap<String, ModelConfig>,
|
||||
) -> Result<HashMap<String, AgentConfig>, InitError> {
|
||||
// Select which agent to modify
|
||||
let agent_names: Vec<String> = existing.agents.keys().cloned().collect();
|
||||
println!("Select agent to modify:");
|
||||
let agent_idx = self.prompt_select("", &agent_names, 0).await?;
|
||||
let selected_agent_name = &agent_names[agent_idx];
|
||||
let current_agent = existing.agents.get(selected_agent_name).unwrap();
|
||||
|
||||
println!();
|
||||
println!(
|
||||
"Current config: provider={}, model={}",
|
||||
current_agent.provider, current_agent.model
|
||||
);
|
||||
println!();
|
||||
|
||||
// Select new provider
|
||||
let provider_names: Vec<String> = providers.keys().cloned().collect();
|
||||
let current_provider_idx = provider_names
|
||||
.iter()
|
||||
.position(|p| p == ¤t_agent.provider)
|
||||
.unwrap_or(0);
|
||||
println!("Select provider:");
|
||||
let provider_idx = self.prompt_select("", &provider_names, current_provider_idx).await?;
|
||||
let selected_provider = &provider_names[provider_idx];
|
||||
|
||||
// Select new model
|
||||
let model_names: Vec<String> = models.keys().cloned().collect();
|
||||
let current_model_idx = model_names
|
||||
.iter()
|
||||
.position(|m| m == ¤t_agent.model)
|
||||
.unwrap_or(0);
|
||||
println!();
|
||||
println!("Select model:");
|
||||
let model_idx = self.prompt_select("", &model_names, current_model_idx).await?;
|
||||
let selected_model = &model_names[model_idx];
|
||||
|
||||
let agent = AgentConfig {
|
||||
provider: selected_provider.clone(),
|
||||
model: selected_model.clone(),
|
||||
max_tool_iterations: current_agent.max_tool_iterations,
|
||||
tool_result_max_chars: current_agent.tool_result_max_chars,
|
||||
context_tool_result_trim_chars: current_agent.context_tool_result_trim_chars,
|
||||
};
|
||||
|
||||
let mut agents = existing.agents.clone();
|
||||
agents.insert(selected_agent_name.clone(), agent);
|
||||
|
||||
println!();
|
||||
println!(
|
||||
"Agent '{}' modified (provider: {}, model: {}).",
|
||||
selected_agent_name, selected_provider, selected_model
|
||||
);
|
||||
|
||||
Ok(agents)
|
||||
}
|
||||
|
||||
// ==================== Step 4: Channels ====================
|
||||
|
||||
async fn configure_channels(
|
||||
&mut self,
|
||||
existing: &Config,
|
||||
) -> Result<HashMap<String, ChannelConfig>, InitError> {
|
||||
println!();
|
||||
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
||||
println!("Step 4: Configure Channels");
|
||||
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
||||
println!();
|
||||
|
||||
// Show existing channels
|
||||
if !existing.channels.is_empty() {
|
||||
println!("Existing channels:");
|
||||
for (name, config) in &existing.channels {
|
||||
let status = if config.enabled() { "enabled" } else { "disabled" };
|
||||
println!(" - {} ({})", name, status);
|
||||
}
|
||||
println!();
|
||||
}
|
||||
|
||||
let mut channels = existing.channels.clone();
|
||||
|
||||
// Loop for configuring multiple channels
|
||||
loop {
|
||||
println!("Options:");
|
||||
println!(" 1. Add Feishu channel");
|
||||
println!(" 2. Add WeChat channel");
|
||||
println!(" 3. Skip / Done");
|
||||
println!();
|
||||
|
||||
let choice = self.prompt_with_default("Select option", "3").await?;
|
||||
|
||||
match choice.as_str() {
|
||||
"1" => {
|
||||
let channel_config = self.configure_feishu_channel(&channels).await?;
|
||||
for (name, config) in channel_config {
|
||||
channels.insert(name, config);
|
||||
}
|
||||
}
|
||||
"2" => {
|
||||
let channel_config = self.configure_wechat_channel(&channels).await?;
|
||||
for (name, config) in channel_config {
|
||||
channels.insert(name, config);
|
||||
}
|
||||
}
|
||||
"3" => break,
|
||||
_ => {
|
||||
println!("Invalid selection.");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(channels)
|
||||
}
|
||||
|
||||
async fn configure_feishu_channel(
|
||||
&mut self,
|
||||
existing: &HashMap<String, ChannelConfig>,
|
||||
) -> Result<HashMap<String, ChannelConfig>, InitError> {
|
||||
println!();
|
||||
println!("Configuring Feishu channel...");
|
||||
println!();
|
||||
|
||||
let channel_name = self
|
||||
.prompt_with_default("Channel name", "feishu")
|
||||
.await?;
|
||||
|
||||
let _existing_config = existing.get(&channel_name).and_then(|c| c.as_feishu());
|
||||
|
||||
let app_id = self.prompt_required("Feishu App ID").await?;
|
||||
let app_secret = self.prompt_required("Feishu App Secret").await?;
|
||||
|
||||
let config = ChannelConfig::Tagged(TaggedChannelConfig::Feishu(FeishuChannelConfig {
|
||||
enabled: true,
|
||||
app_id,
|
||||
app_secret,
|
||||
allow_from: vec!["*".to_string()],
|
||||
agent: "default".to_string(),
|
||||
media_dir: Self::default_feishu_media_dir(),
|
||||
reaction_emoji: "Typing".to_string(),
|
||||
max_message_chars: 20000,
|
||||
reply_context_max_chars: 20000,
|
||||
}));
|
||||
|
||||
let mut result = HashMap::new();
|
||||
result.insert(channel_name.clone(), config);
|
||||
|
||||
println!();
|
||||
println!("Feishu channel '{}' configured.", channel_name);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn configure_wechat_channel(
|
||||
&mut self,
|
||||
_existing: &HashMap<String, ChannelConfig>,
|
||||
) -> Result<HashMap<String, ChannelConfig>, InitError> {
|
||||
println!();
|
||||
println!("Configuring WeChat channel...");
|
||||
println!();
|
||||
println!("WeChat login requires scanning a QR code.");
|
||||
println!("The QR code URL will be displayed after configuration.");
|
||||
println!();
|
||||
|
||||
// Use default values directly
|
||||
let channel_name = "wechat";
|
||||
let base_url = "https://ilinkai.weixin.qq.com";
|
||||
let cred_path = Self::default_wechat_cred_path();
|
||||
let force_login = false;
|
||||
|
||||
let config = ChannelConfig::Tagged(TaggedChannelConfig::Wechat(WechatChannelConfig {
|
||||
enabled: true,
|
||||
allow_from: vec!["*".to_string()],
|
||||
agent: "default".to_string(),
|
||||
base_url: base_url.to_string(),
|
||||
cred_path,
|
||||
force_login,
|
||||
}));
|
||||
|
||||
let mut result = HashMap::new();
|
||||
result.insert(channel_name.to_string(), config);
|
||||
|
||||
println!();
|
||||
println!("WeChat channel '{}' configured.", channel_name);
|
||||
|
||||
// Auto login after configuration
|
||||
println!();
|
||||
println!("Starting WeChat login...");
|
||||
|
||||
self.do_wechat_login(base_url, &Self::default_wechat_cred_path()).await?;
|
||||
|
||||
println!();
|
||||
println!("WeChat login successful! Credentials saved.");
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn do_wechat_login(&mut self, base_url: &str, cred_path: &str) -> Result<(), InitError> {
|
||||
use wechatbot::{BotOptions, WeChatBot};
|
||||
|
||||
let bot = WeChatBot::new(BotOptions {
|
||||
base_url: Some(base_url.to_string()),
|
||||
cred_path: Some(cred_path.to_string()),
|
||||
on_qr_url: Some(Box::new(|url| {
|
||||
println!();
|
||||
println!("┌──────────────────────────────────────────────────────┐");
|
||||
println!("│ WeChat QR Code │");
|
||||
println!("└──────────────────────────────────────────────────────┘");
|
||||
println!();
|
||||
println!("Scan this URL in WeChat:");
|
||||
println!("{}", url);
|
||||
println!();
|
||||
println!("Waiting for confirmation...");
|
||||
})),
|
||||
on_error: Some(Box::new(|error| {
|
||||
eprintln!("WeChat login error: {}", error);
|
||||
})),
|
||||
});
|
||||
|
||||
let creds = bot.login(true).await.map_err(|e| {
|
||||
InitError::WeChatError(format!("WeChat login failed: {}", e))
|
||||
})?;
|
||||
|
||||
println!();
|
||||
println!(
|
||||
"Logged in as: {} (account: {})",
|
||||
creds.user_id, creds.account_id
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ==================== Build & Save ====================
|
||||
|
||||
fn build_config(
|
||||
&self,
|
||||
providers: HashMap<String, ProviderConfig>,
|
||||
models: HashMap<String, ModelConfig>,
|
||||
agents: HashMap<String, AgentConfig>,
|
||||
channels: HashMap<String, ChannelConfig>,
|
||||
existing: &Config,
|
||||
) -> Config {
|
||||
Config {
|
||||
providers,
|
||||
models,
|
||||
agents,
|
||||
channels,
|
||||
time: existing.time.clone(),
|
||||
gateway: existing.gateway.clone(),
|
||||
scheduler: existing.scheduler.clone(),
|
||||
client: existing.client.clone(),
|
||||
skills: existing.skills.clone(),
|
||||
tools: existing.tools.clone(),
|
||||
memory_maintenance: existing.memory_maintenance.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn save_config(&self, config: &Config) -> Result<(), InitError> {
|
||||
let dir = self
|
||||
.config_path
|
||||
.parent()
|
||||
.ok_or_else(|| InitError::IoError("Invalid config path".to_string()))?;
|
||||
std::fs::create_dir_all(dir)
|
||||
.map_err(|e| InitError::IoError(format!("Failed to create directory: {}", e)))?;
|
||||
|
||||
let content = serde_json::to_string_pretty(config)
|
||||
.map_err(|e| InitError::SerializeError(e.to_string()))?;
|
||||
|
||||
std::fs::write(&self.config_path, content)
|
||||
.map_err(|e| InitError::IoError(format!("Failed to write config: {}", e)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn show_completion_message(&mut self, config: &Config) {
|
||||
println!();
|
||||
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
||||
println!("Configuration complete!");
|
||||
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
||||
println!();
|
||||
println!("Config saved to: {}", self.config_path.display());
|
||||
println!();
|
||||
|
||||
println!("Summary:");
|
||||
println!(" Providers: {}", config.providers.len());
|
||||
println!(" Models: {}", config.models.len());
|
||||
println!(" Agents: {}", config.agents.len());
|
||||
println!(" Channels: {}", config.channels.len());
|
||||
println!();
|
||||
|
||||
println!("Next steps:");
|
||||
println!(" 1. Start the gateway: picobot gateway");
|
||||
println!(" 2. Connect with CLI: picobot agent");
|
||||
println!();
|
||||
println!("For more options, run: picobot --help");
|
||||
println!();
|
||||
}
|
||||
|
||||
fn default_feishu_media_dir() -> String {
|
||||
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
|
||||
home.join(".picobot/media/feishu")
|
||||
.to_string_lossy()
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn default_wechat_cred_path() -> String {
|
||||
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
|
||||
home.join(".picobot/wechat/credentials.json")
|
||||
.to_string_lossy()
|
||||
.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for InitWizard {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum InitError {
|
||||
IoError(String),
|
||||
SerializeError(String),
|
||||
InputError(String),
|
||||
WeChatError(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for InitError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
InitError::IoError(msg) => write!(f, "IO error: {}", msg),
|
||||
InitError::SerializeError(msg) => write!(f, "Serialization error: {}", msg),
|
||||
InitError::InputError(msg) => write!(f, "Input error: {}", msg),
|
||||
InitError::WeChatError(msg) => write!(f, "WeChat error: {}", msg),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for InitError {}
|
||||
|
||||
impl From<std::io::Error> for InitError {
|
||||
fn from(e: std::io::Error) -> Self {
|
||||
InitError::IoError(e.to_string())
|
||||
}
|
||||
}
|
||||
@ -1,5 +1,7 @@
|
||||
pub mod channel;
|
||||
pub mod input;
|
||||
pub mod init;
|
||||
|
||||
pub use channel::CliChannel;
|
||||
pub use input::{InputCommand, InputEvent, InputHandler};
|
||||
pub use init::InitWizard;
|
||||
|
||||
@ -27,6 +27,8 @@ pub struct Config {
|
||||
pub skills: SkillsConfig,
|
||||
#[serde(default)]
|
||||
pub tools: ToolsConfig,
|
||||
#[serde(default)]
|
||||
pub memory_maintenance: MemoryMaintenanceConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
@ -108,6 +110,41 @@ pub struct ToolsConfig {
|
||||
pub task: TaskConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct MemoryMaintenanceConfig {
|
||||
/// 单次最大合并/删除比例 (0.0-1.0),默认 0.3 (30%)
|
||||
#[serde(default = "default_max_merge_ratio")]
|
||||
pub max_merge_ratio: f32,
|
||||
/// 最小保留记忆数量,默认 5
|
||||
#[serde(default = "default_min_memories_to_keep")]
|
||||
pub min_memories_to_keep: usize,
|
||||
/// 单次合并最大源记忆数,默认 3
|
||||
#[serde(default = "default_max_merge_per_group")]
|
||||
pub max_merge_per_group: usize,
|
||||
}
|
||||
|
||||
fn default_max_merge_ratio() -> f32 {
|
||||
0.3
|
||||
}
|
||||
|
||||
fn default_min_memories_to_keep() -> usize {
|
||||
5
|
||||
}
|
||||
|
||||
fn default_max_merge_per_group() -> usize {
|
||||
3
|
||||
}
|
||||
|
||||
impl Default for MemoryMaintenanceConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_merge_ratio: default_max_merge_ratio(),
|
||||
min_memories_to_keep: default_min_memories_to_keep(),
|
||||
max_merge_per_group: default_max_merge_per_group(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct TaskConfig {
|
||||
#[serde(default = "default_task_enabled")]
|
||||
|
||||
@ -5,7 +5,7 @@ use std::time::Duration;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::agent::AgentError;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig};
|
||||
use crate::providers::{ChatCompletionRequest, Message, ProviderRuntimeConfig, create_provider};
|
||||
use crate::storage::{MemoryRecord, SessionStore};
|
||||
|
||||
@ -17,12 +17,16 @@ const MEMORY_MAINTENANCE_STEP2_SYSTEM_PROMPT: &str =
|
||||
include_str!("memory_maintenance_step2_system_prompt.md");
|
||||
const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000];
|
||||
|
||||
const META_NAMESPACE: &str = "_meta";
|
||||
const LAST_MAINTENANCE_KEY: &str = "last_maintenance_at";
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub(crate) struct MemoryMaintenanceCandidate {
|
||||
pub(crate) id: String,
|
||||
pub(crate) namespace: String,
|
||||
pub(crate) key: String,
|
||||
pub(crate) content: String,
|
||||
pub(crate) updated_at: i64, // 记忆更新时间(Unix timestamp)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
|
||||
@ -73,13 +77,19 @@ pub(crate) struct MemoryMaintenanceScopeResult {
|
||||
pub(crate) struct MemoryMaintenanceService {
|
||||
store: Arc<SessionStore>,
|
||||
provider_config: LLMProviderConfig,
|
||||
maintenance_config: MemoryMaintenanceConfig,
|
||||
}
|
||||
|
||||
impl MemoryMaintenanceService {
|
||||
pub(crate) fn new(store: Arc<SessionStore>, provider_config: LLMProviderConfig) -> Self {
|
||||
pub(crate) fn new(
|
||||
store: Arc<SessionStore>,
|
||||
provider_config: LLMProviderConfig,
|
||||
maintenance_config: MemoryMaintenanceConfig,
|
||||
) -> Self {
|
||||
Self {
|
||||
store,
|
||||
provider_config,
|
||||
maintenance_config,
|
||||
}
|
||||
}
|
||||
|
||||
@ -107,6 +117,12 @@ impl MemoryMaintenanceService {
|
||||
&self,
|
||||
scope_key: &str,
|
||||
) -> Result<Option<MemoryMaintenancePlan>, AgentError> {
|
||||
// 新增:检查是否有新记忆需要整理
|
||||
if !has_new_memories_since_last_maintenance(&self.store, scope_key)? {
|
||||
tracing::info!(scope_key = %scope_key, "No new memories since last maintenance, skipping");
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let memories = self
|
||||
.store
|
||||
.list_memories_for_scope("user", scope_key)
|
||||
@ -255,7 +271,15 @@ impl MemoryMaintenanceService {
|
||||
let organize_output = self.organize_plan(scope_key, &plan).await?;
|
||||
|
||||
// 应用整理结果(merge和delete)
|
||||
apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &organize_output)?;
|
||||
apply_memory_maintenance_output(
|
||||
self.store.as_ref(),
|
||||
scope_key,
|
||||
&plan,
|
||||
&organize_output,
|
||||
self.maintenance_config.max_merge_ratio,
|
||||
self.maintenance_config.min_memories_to_keep,
|
||||
self.maintenance_config.max_merge_per_group,
|
||||
)?;
|
||||
|
||||
// 步骤2:从数据库重新读取剩余的记忆
|
||||
let remaining_memories = self
|
||||
@ -470,17 +494,94 @@ impl MemoryMaintenanceService {
|
||||
let organize_output = self.organize_plan(scope_key, &plan).await?;
|
||||
|
||||
// 应用整理结果(merge和delete)
|
||||
apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &organize_output)?;
|
||||
apply_memory_maintenance_output(
|
||||
self.store.as_ref(),
|
||||
scope_key,
|
||||
&plan,
|
||||
&organize_output,
|
||||
self.maintenance_config.max_merge_ratio,
|
||||
self.maintenance_config.min_memories_to_keep,
|
||||
self.maintenance_config.max_merge_per_group,
|
||||
)?;
|
||||
|
||||
Ok(Some(organize_output))
|
||||
}
|
||||
}
|
||||
|
||||
/// 获取上次整理时间
|
||||
fn get_last_maintenance_time(
|
||||
store: &SessionStore,
|
||||
scope_key: &str,
|
||||
) -> Result<Option<i64>, crate::storage::StorageError> {
|
||||
let meta = store.get_memory("user", scope_key, META_NAMESPACE, LAST_MAINTENANCE_KEY)?;
|
||||
Ok(meta.and_then(|m| m.content.parse::<i64>().ok()))
|
||||
}
|
||||
|
||||
/// 记录本次整理时间
|
||||
fn set_last_maintenance_time(
|
||||
store: &SessionStore,
|
||||
scope_key: &str,
|
||||
time: i64,
|
||||
) -> Result<(), crate::storage::StorageError> {
|
||||
store.put_memory(&crate::storage::MemoryUpsert {
|
||||
scope_kind: "user".to_string(),
|
||||
scope_key: scope_key.to_string(),
|
||||
namespace: META_NAMESPACE.to_string(),
|
||||
memory_key: LAST_MAINTENANCE_KEY.to_string(),
|
||||
content: time.to_string(),
|
||||
source_type: "memory_maintenance".to_string(),
|
||||
source_session_id: None,
|
||||
source_message_id: None,
|
||||
source_message_seq: None,
|
||||
source_channel_name: None,
|
||||
source_chat_id: None,
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 检查是否有需要整理的新记忆(过滤掉 _meta namespace)
|
||||
fn has_new_memories_since_last_maintenance(
|
||||
store: &SessionStore,
|
||||
scope_key: &str,
|
||||
) -> Result<bool, AgentError> {
|
||||
let memories = store
|
||||
.list_memories_for_scope("user", scope_key)
|
||||
.map_err(|e| AgentError::Other(format!("list memories error: {}", e)))?;
|
||||
|
||||
// 过滤掉 _meta namespace 的记忆
|
||||
let user_memories: Vec<_> = memories
|
||||
.iter()
|
||||
.filter(|m| m.namespace != META_NAMESPACE)
|
||||
.collect();
|
||||
|
||||
if user_memories.is_empty() {
|
||||
return Ok(false); // 没有记忆,跳过
|
||||
}
|
||||
|
||||
// 获取上次整理时间
|
||||
let last_time = get_last_maintenance_time(store, scope_key)
|
||||
.map_err(|e| AgentError::Other(format!("get last maintenance time error: {}", e)))?;
|
||||
|
||||
match last_time {
|
||||
None => Ok(true), // 从未整理过,需要整理
|
||||
Some(last) => {
|
||||
// 检查是否有记忆在上次整理后更新
|
||||
let has_new = user_memories.iter().any(|m| m.updated_at > last);
|
||||
Ok(has_new)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> MemoryMaintenancePlan {
|
||||
let mut plan = MemoryMaintenancePlan::default();
|
||||
let mut seen = HashSet::new();
|
||||
|
||||
for memory in memories {
|
||||
// 过滤掉 _meta namespace 的记忆
|
||||
if memory.namespace == META_NAMESPACE {
|
||||
continue;
|
||||
}
|
||||
|
||||
let normalized_content = memory.content.trim();
|
||||
if normalized_content.is_empty() {
|
||||
continue;
|
||||
@ -501,6 +602,7 @@ pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> Memory
|
||||
namespace: memory.namespace.clone(),
|
||||
key: memory.memory_key.clone(),
|
||||
content: normalized_content.to_string(),
|
||||
updated_at: memory.updated_at,
|
||||
};
|
||||
|
||||
plan.candidates.push(candidate);
|
||||
@ -580,12 +682,115 @@ pub(crate) fn extract_json_object(content: &str) -> Option<&str> {
|
||||
None
|
||||
}
|
||||
|
||||
/// 验证记忆整理输出是否符合限制
|
||||
pub(crate) fn validate_memory_maintenance_output(
|
||||
plan: &MemoryMaintenancePlan,
|
||||
output: &MemoryOrganizationOutput,
|
||||
max_merge_ratio: f32,
|
||||
min_memories_to_keep: usize,
|
||||
max_merge_per_group: usize,
|
||||
) -> Result<(), String> {
|
||||
let total = plan.candidates.len();
|
||||
|
||||
if total == 0 {
|
||||
return Ok(()); // 没有候选,无需验证
|
||||
}
|
||||
|
||||
// 验证 1: 单次合并数量限制
|
||||
for merge in &output.merges {
|
||||
if merge.source_ids.len() > max_merge_per_group {
|
||||
return Err(format!(
|
||||
"合并组过大: {} 条源记忆超过上限 {}",
|
||||
merge.source_ids.len(),
|
||||
max_merge_per_group
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// 验证 2: 跨 namespace 合并检测(完全禁止)
|
||||
let candidates_by_id: HashMap<&str, &MemoryMaintenanceCandidate> = plan
|
||||
.candidates
|
||||
.iter()
|
||||
.map(|c| (c.id.as_str(), c))
|
||||
.collect();
|
||||
|
||||
for merge in &output.merges {
|
||||
let source_namespaces: HashSet<&str> = merge
|
||||
.source_ids
|
||||
.iter()
|
||||
.filter_map(|id| candidates_by_id.get(id.as_str()).map(|c| c.namespace.as_str()))
|
||||
.collect();
|
||||
|
||||
// 检查是否跨越多个 namespace
|
||||
if source_namespaces.len() > 1 {
|
||||
return Err(format!(
|
||||
"跨 namespace 合并被禁止: 源来自 {}",
|
||||
source_namespaces.iter().cloned().collect::<Vec<_>>().join(", ")
|
||||
));
|
||||
}
|
||||
|
||||
// 检查目标 namespace 是否与源一致
|
||||
if let Some(src_ns) = source_namespaces.iter().next() {
|
||||
if *src_ns != merge.namespace {
|
||||
return Err(format!(
|
||||
"跨 namespace 合并被禁止: {} → {}",
|
||||
src_ns, merge.namespace
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 验证 3: 总体合并比例
|
||||
let merged_ids: HashSet<&str> = output
|
||||
.merges
|
||||
.iter()
|
||||
.flat_map(|m| m.source_ids.iter())
|
||||
.map(|s| s.as_str())
|
||||
.collect();
|
||||
|
||||
let deleted_ids: HashSet<&str> = output
|
||||
.low_value_ids
|
||||
.iter()
|
||||
.map(|s| s.as_str())
|
||||
.collect();
|
||||
|
||||
let affected = merged_ids.len() + deleted_ids.len();
|
||||
let max_allowed = (total as f32 * max_merge_ratio).ceil() as usize;
|
||||
|
||||
if affected > max_allowed {
|
||||
return Err(format!(
|
||||
"合并比例超限: {} / {} > {:.0}%",
|
||||
affected,
|
||||
total,
|
||||
max_merge_ratio * 100.0
|
||||
));
|
||||
}
|
||||
|
||||
// 验证 4: 最小保留数
|
||||
let remaining = total - affected + output.merges.len();
|
||||
if remaining < min_memories_to_keep {
|
||||
return Err(format!(
|
||||
"保留数不足: {} < {}",
|
||||
remaining, min_memories_to_keep
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn apply_memory_maintenance_output(
|
||||
store: &SessionStore,
|
||||
scope_key: &str,
|
||||
plan: &MemoryMaintenancePlan,
|
||||
output: &MemoryOrganizationOutput,
|
||||
max_merge_ratio: f32,
|
||||
min_memories_to_keep: usize,
|
||||
max_merge_per_group: usize,
|
||||
) -> Result<(), AgentError> {
|
||||
// 新增: 验证合并输出
|
||||
validate_memory_maintenance_output(plan, output, max_merge_ratio, min_memories_to_keep, max_merge_per_group)
|
||||
.map_err(|e| AgentError::Other(e))?;
|
||||
|
||||
let all_candidates = plan.candidates.clone();
|
||||
|
||||
let candidates_by_id = all_candidates
|
||||
@ -661,6 +866,11 @@ pub(crate) fn apply_memory_maintenance_output(
|
||||
}
|
||||
}
|
||||
|
||||
// 新增:记录整理完成时间
|
||||
let now = chrono::Utc::now().timestamp();
|
||||
set_last_maintenance_time(store, scope_key, now)
|
||||
.map_err(|err| AgentError::Other(format!("set last maintenance time error: {}", err)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@ -40,6 +40,7 @@ impl MemoryMaintenanceCoordinator {
|
||||
Ok(MemoryMaintenanceService::new(
|
||||
self.store.clone(),
|
||||
self.provider_configs.default_provider_config(),
|
||||
self.provider_configs.default_maintenance_config(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
@ -27,15 +27,32 @@
|
||||
- note: 冲突说明
|
||||
- low_value_ids:需要删除的低价值候选记忆 ID 数组
|
||||
|
||||
组织原则(由你自主决定):
|
||||
组织原则:
|
||||
|
||||
- 根据记忆的语义内容自然分组,不必拘泥于预定义分类
|
||||
- 相似的、互补的记忆可以合并
|
||||
- 过期、重复、过细的记忆可以标记为低价值
|
||||
- 根据记忆的语义内容自然分组
|
||||
- **每次合并最多只能合并 2-3 条源记忆**
|
||||
- **禁止跨 namespace 合并**(不同 namespace 代表不同信息维度)
|
||||
- 过期、重复、过细的记忆可以标记为低值
|
||||
- namespace 和 memory_key 的命名应当简洁、有意义
|
||||
- 可以自由创建新的 namespace 来组织相关记忆
|
||||
- **保守原则:宁可保留稍多,不可过度合并**
|
||||
- **必须保留足够数量的记忆,确保信息多样性**
|
||||
|
||||
时间权重原则(关键):
|
||||
|
||||
- 每个候选记忆包含 `updated_at` 时间戳(Unix timestamp,秒)
|
||||
- **当多条记忆存在重复或冲突时,时间越新的权重越高**
|
||||
- 合并时优先采用新记忆的内容,旧记忆作为补充或背景
|
||||
- 如果新旧记忆内容完全相同,保留新的,删除旧的
|
||||
- 时间戳数值越大表示越新(离当前时间越近)
|
||||
|
||||
合并限制(硬性约束,由系统强制检查):
|
||||
|
||||
- 单次合并最多来自 3 条源记忆
|
||||
- 整理后保留的记忆数不得少于 5 条
|
||||
- 单次整理最多影响 30% 的记忆
|
||||
- 不同 namespace 的记忆不允许互相合并
|
||||
|
||||
额外约束:
|
||||
|
||||
- 只能引用输入里出现过的候选 id。
|
||||
- 不输出 user_facts、preferences、behavior_patterns、managed_markdown 等摘要字段。
|
||||
- 不输出 user_facts、preferences、behavior_patterns、managed_markdown 等摘要字段。
|
||||
@ -84,6 +84,7 @@ impl GatewayState {
|
||||
Arc::new(BusSessionMessageSender::new(bus.clone())),
|
||||
std::collections::HashSet::new(),
|
||||
config.tools.task.clone(),
|
||||
config.memory_maintenance.clone(),
|
||||
chat_history_ttl_hours,
|
||||
session_ttl_hours,
|
||||
)?;
|
||||
|
||||
@ -2,22 +2,25 @@ use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::agent::AgentError;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ProviderConfigService {
|
||||
default_provider_config: LLMProviderConfig,
|
||||
provider_configs: Arc<HashMap<String, LLMProviderConfig>>,
|
||||
maintenance_config: MemoryMaintenanceConfig,
|
||||
}
|
||||
|
||||
impl ProviderConfigService {
|
||||
pub(crate) fn new(
|
||||
default_provider_config: LLMProviderConfig,
|
||||
provider_configs: HashMap<String, LLMProviderConfig>,
|
||||
maintenance_config: MemoryMaintenanceConfig,
|
||||
) -> Self {
|
||||
Self {
|
||||
default_provider_config,
|
||||
provider_configs: Arc::new(provider_configs),
|
||||
maintenance_config,
|
||||
}
|
||||
}
|
||||
|
||||
@ -37,6 +40,10 @@ impl ProviderConfigService {
|
||||
pub(crate) fn default_provider_config(&self) -> LLMProviderConfig {
|
||||
self.default_provider_config.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn default_maintenance_config(&self) -> MemoryMaintenanceConfig {
|
||||
self.maintenance_config.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@ -72,6 +79,7 @@ mod tests {
|
||||
"planner".to_string(),
|
||||
test_provider_config_named("planner-provider", "planner-model"),
|
||||
)]),
|
||||
MemoryMaintenanceConfig::default(),
|
||||
);
|
||||
|
||||
let selected = service.select(Some("planner")).unwrap();
|
||||
@ -82,7 +90,11 @@ mod tests {
|
||||
#[test]
|
||||
fn test_select_falls_back_to_default() {
|
||||
let default_provider = test_provider_config_named("default-provider", "default-model");
|
||||
let service = ProviderConfigService::new(default_provider, HashMap::new());
|
||||
let service = ProviderConfigService::new(
|
||||
default_provider,
|
||||
HashMap::new(),
|
||||
MemoryMaintenanceConfig::default(),
|
||||
);
|
||||
|
||||
let selected = service.select(Some("default")).unwrap();
|
||||
assert_eq!(selected.name, "default-provider");
|
||||
|
||||
@ -2,7 +2,7 @@ use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::agent::AgentError;
|
||||
use crate::config::{LLMProviderConfig, TaskConfig};
|
||||
use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig, TaskConfig};
|
||||
use crate::gateway::tool_registry_factory::ToolRegistryFactory;
|
||||
use crate::skills::SkillRuntime;
|
||||
use crate::storage::{
|
||||
@ -34,6 +34,7 @@ pub(crate) fn build_session_manager(
|
||||
skills: Arc<SkillRuntime>,
|
||||
disabled_tools: HashSet<String>,
|
||||
task_config: TaskConfig,
|
||||
maintenance_config: MemoryMaintenanceConfig,
|
||||
chat_history_ttl_hours: Option<u64>,
|
||||
session_ttl_hours: Option<u64>,
|
||||
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
|
||||
@ -47,6 +48,7 @@ pub(crate) fn build_session_manager(
|
||||
Arc::new(NoopSessionMessageSender),
|
||||
disabled_tools,
|
||||
task_config,
|
||||
maintenance_config,
|
||||
chat_history_ttl_hours,
|
||||
session_ttl_hours,
|
||||
)
|
||||
@ -62,6 +64,7 @@ pub(crate) fn build_session_manager_with_sender(
|
||||
session_message_sender: Arc<dyn SessionMessageSender>,
|
||||
disabled_tools: HashSet<String>,
|
||||
task_config: TaskConfig,
|
||||
maintenance_config: MemoryMaintenanceConfig,
|
||||
chat_history_ttl_hours: Option<u64>,
|
||||
session_ttl_hours: Option<u64>,
|
||||
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
|
||||
@ -70,7 +73,11 @@ pub(crate) fn build_session_manager_with_sender(
|
||||
.map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?,
|
||||
);
|
||||
let known_agents = provider_configs.keys().cloned().collect::<HashSet<_>>();
|
||||
let provider_configs = ProviderConfigService::new(provider_config.clone(), provider_configs);
|
||||
let provider_configs = ProviderConfigService::new(
|
||||
provider_config.clone(),
|
||||
provider_configs,
|
||||
maintenance_config,
|
||||
);
|
||||
|
||||
if let Err(err) =
|
||||
store.append_skill_event(None, "discovered", None, &skills.discovery_event_payload())
|
||||
|
||||
@ -501,6 +501,7 @@ impl SessionManager {
|
||||
skills: Arc<SkillRuntime>,
|
||||
disabled_tools: std::collections::HashSet<String>,
|
||||
task_config: crate::config::TaskConfig,
|
||||
maintenance_config: crate::config::MemoryMaintenanceConfig,
|
||||
chat_history_ttl_hours: Option<u64>,
|
||||
session_ttl_hours: Option<u64>,
|
||||
) -> Result<Self, AgentError> {
|
||||
@ -513,6 +514,7 @@ impl SessionManager {
|
||||
skills,
|
||||
disabled_tools,
|
||||
task_config,
|
||||
maintenance_config,
|
||||
chat_history_ttl_hours,
|
||||
session_ttl_hours,
|
||||
)
|
||||
@ -973,6 +975,7 @@ mod tests {
|
||||
Arc::new(SkillRuntime::default()),
|
||||
HashSet::new(),
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(4),
|
||||
Some(24),
|
||||
)
|
||||
@ -1025,6 +1028,7 @@ mod tests {
|
||||
Arc::new(SkillRuntime::default()),
|
||||
HashSet::new(),
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(4),
|
||||
Some(24),
|
||||
)
|
||||
@ -1093,6 +1097,7 @@ mod tests {
|
||||
Arc::new(SkillRuntime::default()),
|
||||
HashSet::new(),
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(4),
|
||||
Some(24),
|
||||
)
|
||||
@ -1178,6 +1183,7 @@ mod tests {
|
||||
Arc::new(SkillRuntime::default()),
|
||||
HashSet::new(),
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(4),
|
||||
Some(24),
|
||||
)
|
||||
@ -1264,6 +1270,7 @@ mod tests {
|
||||
Arc::new(SkillRuntime::default()),
|
||||
HashSet::new(),
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(4),
|
||||
Some(24),
|
||||
)
|
||||
@ -1349,6 +1356,7 @@ mod tests {
|
||||
Arc::new(SkillRuntime::default()),
|
||||
HashSet::new(),
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(4),
|
||||
Some(24),
|
||||
)
|
||||
@ -1416,6 +1424,7 @@ mod tests {
|
||||
Arc::new(SkillRuntime::default()),
|
||||
HashSet::new(),
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(4),
|
||||
Some(24),
|
||||
)
|
||||
@ -1492,6 +1501,7 @@ mod tests {
|
||||
Arc::new(SkillRuntime::default()),
|
||||
HashSet::new(),
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(4),
|
||||
Some(24),
|
||||
)
|
||||
@ -1555,6 +1565,7 @@ mod tests {
|
||||
Arc::new(SkillRuntime::default()),
|
||||
HashSet::new(),
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(4),
|
||||
Some(24),
|
||||
)
|
||||
@ -1593,6 +1604,8 @@ mod tests {
|
||||
let store = SessionStore::in_memory().unwrap();
|
||||
let scope_key = "feishu:user-1";
|
||||
|
||||
// 创建足够的记忆(7条),让合并操作满足保护限制
|
||||
// 合并后需要保留至少 5 条(min_memories_to_keep)
|
||||
let work = store
|
||||
.put_memory(&crate::storage::MemoryUpsert {
|
||||
scope_kind: "user".to_string(),
|
||||
@ -1639,9 +1652,30 @@ mod tests {
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// 添加额外的记忆以满足 min_memories_to_keep = 5 的要求
|
||||
for i in 0..4 {
|
||||
store
|
||||
.put_memory(&crate::storage::MemoryUpsert {
|
||||
scope_kind: "user".to_string(),
|
||||
scope_key: scope_key.to_string(),
|
||||
namespace: "profile".to_string(),
|
||||
memory_key: format!("extra_{}", i),
|
||||
content: format!("额外记忆 {}", i),
|
||||
source_type: "message".to_string(),
|
||||
source_session_id: None,
|
||||
source_message_id: None,
|
||||
source_message_seq: None,
|
||||
source_channel_name: None,
|
||||
source_chat_id: None,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let plan = build_memory_maintenance_plan(
|
||||
&store.list_memories_for_scope("user", scope_key).unwrap(),
|
||||
);
|
||||
assert_eq!(plan.candidates.len(), 7); // 7 条候选记忆
|
||||
|
||||
let output = MemoryOrganizationOutput {
|
||||
merges: vec![MemoryMaintenanceMerge {
|
||||
source_ids: vec![work.id.clone(), role.id.clone()],
|
||||
@ -1653,13 +1687,25 @@ mod tests {
|
||||
low_value_ids: vec![noise.id.clone()],
|
||||
};
|
||||
|
||||
apply_memory_maintenance_output(&store, scope_key, &plan, &output).unwrap();
|
||||
// 使用默认配置进行验证
|
||||
apply_memory_maintenance_output(
|
||||
&store,
|
||||
scope_key,
|
||||
&plan,
|
||||
&output,
|
||||
crate::config::MemoryMaintenanceConfig::default().max_merge_ratio,
|
||||
crate::config::MemoryMaintenanceConfig::default().min_memories_to_keep,
|
||||
crate::config::MemoryMaintenanceConfig::default().max_merge_per_group,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let all_memories = store.list_memories_for_scope("user", scope_key).unwrap();
|
||||
assert_eq!(all_memories.len(), 1);
|
||||
assert_eq!(all_memories[0].namespace, "profile");
|
||||
assert_eq!(all_memories[0].memory_key, "work");
|
||||
assert_eq!(all_memories[0].content, "用户主要在做AI产品设计与实现");
|
||||
// 过滤掉 _meta 记录
|
||||
let user_memories: Vec<_> = all_memories.iter().filter(|m| m.namespace != "_meta").collect();
|
||||
// 合并 2 条为 1 条,删除 1 条,7 - 2 + 1 = 6 条(加上 _meta 记录)
|
||||
assert_eq!(user_memories.len(), 6);
|
||||
// 验证合并后的记忆存在
|
||||
assert!(user_memories.iter().any(|m| m.namespace == "profile" && m.memory_key == "work"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
15
src/main.rs
15
src/main.rs
@ -4,6 +4,15 @@ use clap::{CommandFactory, Parser};
|
||||
#[command(name = "picobot")]
|
||||
#[command(about = "A CLI chatbot", long_about = None)]
|
||||
enum Command {
|
||||
/// Interactive configuration wizard
|
||||
Init {
|
||||
/// Force overwrite existing config
|
||||
#[arg(short, long)]
|
||||
force: bool,
|
||||
/// Only configure provider, skip channels
|
||||
#[arg(long)]
|
||||
skip_channels: bool,
|
||||
},
|
||||
/// Connect to gateway
|
||||
Agent {
|
||||
/// Gateway WebSocket URL (e.g., ws://127.0.0.1:19876/ws)
|
||||
@ -31,10 +40,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
if std::env::args().len() <= 1 {
|
||||
cmd.print_help()?;
|
||||
println!();
|
||||
return Ok(());
|
||||
return Ok(())
|
||||
}
|
||||
|
||||
match Command::parse() {
|
||||
Command::Init { force, skip_channels } => {
|
||||
let mut wizard = picobot::cli::InitWizard::new();
|
||||
wizard.run(force, skip_channels).await?;
|
||||
}
|
||||
Command::Agent { gateway_url } => {
|
||||
let config = picobot::config::Config::load_default().ok();
|
||||
let url = gateway_url
|
||||
|
||||
@ -8,6 +8,8 @@ use super::traits::Usage;
|
||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
||||
use crate::domain::messages::ContentBlock;
|
||||
|
||||
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["supported_content_types"];
|
||||
|
||||
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
|
||||
let mut details = vec![error.to_string()];
|
||||
let mut current = error.source();
|
||||
@ -30,7 +32,65 @@ where
|
||||
serializer.serialize_str(&serde_json::to_string(blocks).unwrap_or_else(|_| "[]".to_string()))
|
||||
}
|
||||
|
||||
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> {
|
||||
fn convert_content_blocks(
|
||||
supports_images: bool,
|
||||
provider_name: &str,
|
||||
model_id: &str,
|
||||
blocks: &[ContentBlock],
|
||||
message_idx: usize,
|
||||
) -> Vec<serde_json::Value> {
|
||||
// 检查是否有图片且模型不支持
|
||||
if !supports_images {
|
||||
let has_images = blocks.iter().any(|b| matches!(b, ContentBlock::ImageUrl { .. }));
|
||||
|
||||
if has_images {
|
||||
let image_count = blocks
|
||||
.iter()
|
||||
.filter(|b| matches!(b, ContentBlock::ImageUrl { .. }))
|
||||
.count();
|
||||
|
||||
tracing::warn!(
|
||||
provider = %provider_name,
|
||||
model = %model_id,
|
||||
filtered_images = image_count,
|
||||
message_idx,
|
||||
"模型不支持图片;将图片转换为通知文本"
|
||||
);
|
||||
|
||||
// 复用通知格式,将图片转换为文本通知
|
||||
let mut converted_blocks: Vec<serde_json::Value> = Vec::new();
|
||||
let mut notices: Vec<String> = Vec::new();
|
||||
let mut image_idx = 0;
|
||||
|
||||
for block in blocks.iter() {
|
||||
match block {
|
||||
ContentBlock::Text { text } => {
|
||||
converted_blocks.push(serde_json::json!({ "type": "text", "text": text }));
|
||||
}
|
||||
ContentBlock::ImageUrl { .. } => {
|
||||
image_idx += 1;
|
||||
notices.push(format!(
|
||||
"- 第 {} 张图片:当前模型不支持图片输入,该图片未能成功入模,请直接告知用户。",
|
||||
image_idx
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 添加通知文本块
|
||||
if !notices.is_empty() {
|
||||
let notice_text = format!(
|
||||
"[系统提示] 以下图片未能成功入模:\n{}",
|
||||
notices.join("\n")
|
||||
);
|
||||
converted_blocks.push(serde_json::json!({ "type": "text", "text": notice_text }));
|
||||
}
|
||||
|
||||
return converted_blocks;
|
||||
}
|
||||
}
|
||||
|
||||
// 原有逻辑 - 模型支持图片,正常转换
|
||||
blocks
|
||||
.iter()
|
||||
.map(|b| match b {
|
||||
@ -112,6 +172,32 @@ impl AnthropicProvider {
|
||||
model_extra,
|
||||
}
|
||||
}
|
||||
|
||||
/// 检查模型是否支持指定内容类型
|
||||
/// 默认支持所有类型(text, image)
|
||||
fn supports_content_type(&self, content_type: &str) -> bool {
|
||||
self.model_extra
|
||||
.get("supported_content_types")
|
||||
.and_then(|value| value.as_array())
|
||||
.map(|types| {
|
||||
types.iter().any(|t| t.as_str() == Some(content_type))
|
||||
})
|
||||
.unwrap_or(true)
|
||||
}
|
||||
|
||||
/// 检查模型是否支持图片
|
||||
fn supports_images(&self) -> bool {
|
||||
self.supports_content_type("image")
|
||||
}
|
||||
|
||||
/// 过滤掉内部字段,只返回需要发送到 API 的 extra 字段
|
||||
fn request_model_extra(&self) -> HashMap<String, serde_json::Value> {
|
||||
self.model_extra
|
||||
.iter()
|
||||
.filter(|(key, _)| !INTERNAL_MODEL_EXTRA_KEYS.contains(&key.as_str()))
|
||||
.map(|(k, v)| (k.clone(), v.clone()))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
@ -197,15 +283,22 @@ impl LLMProvider for AnthropicProvider {
|
||||
messages: request
|
||||
.messages
|
||||
.iter()
|
||||
.map(|m| AnthropicMessage {
|
||||
.enumerate()
|
||||
.map(|(i, m)| AnthropicMessage {
|
||||
role: m.role.clone(),
|
||||
content: convert_content_blocks(&m.content),
|
||||
content: convert_content_blocks(
|
||||
self.supports_images(),
|
||||
&self.name,
|
||||
&self.model_id,
|
||||
&m.content,
|
||||
i,
|
||||
),
|
||||
})
|
||||
.collect(),
|
||||
max_tokens,
|
||||
temperature: request.temperature.or(self.temperature),
|
||||
tools,
|
||||
extra: self.model_extra.clone(),
|
||||
extra: self.request_model_extra(),
|
||||
};
|
||||
|
||||
let mut req_builder = self
|
||||
|
||||
@ -10,7 +10,7 @@ use super::traits::Usage;
|
||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
||||
use crate::domain::messages::ContentBlock;
|
||||
|
||||
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content"];
|
||||
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content", "supported_content_types"];
|
||||
|
||||
/// 流式响应中的工具调用增量
|
||||
#[derive(Debug, Default)]
|
||||
@ -139,7 +139,75 @@ fn format_transport_error_context(
|
||||
)
|
||||
}
|
||||
|
||||
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
|
||||
fn convert_content_blocks(
|
||||
supports_images: bool,
|
||||
provider_name: &str,
|
||||
model_id: &str,
|
||||
blocks: &[ContentBlock],
|
||||
message_idx: usize,
|
||||
) -> Value {
|
||||
// 检查是否有图片且模型不支持
|
||||
if !supports_images {
|
||||
let has_images = blocks.iter().any(|b| matches!(b, ContentBlock::ImageUrl { .. }));
|
||||
|
||||
if has_images {
|
||||
let image_count = blocks.iter()
|
||||
.filter(|b| matches!(b, ContentBlock::ImageUrl { .. }))
|
||||
.count();
|
||||
|
||||
tracing::warn!(
|
||||
provider = %provider_name,
|
||||
model = %model_id,
|
||||
filtered_images = image_count,
|
||||
message_idx,
|
||||
"模型不支持图片;将图片转换为通知文本"
|
||||
);
|
||||
|
||||
// 复用通知格式,将图片转换为文本通知
|
||||
let mut converted_blocks: Vec<Value> = Vec::new();
|
||||
let mut notices: Vec<String> = Vec::new();
|
||||
let mut image_idx = 0;
|
||||
|
||||
for block in blocks.iter() {
|
||||
match block {
|
||||
ContentBlock::Text { text } => {
|
||||
converted_blocks.push(json!({ "type": "text", "text": text }));
|
||||
}
|
||||
ContentBlock::ImageUrl { .. } => {
|
||||
image_idx += 1;
|
||||
notices.push(format!(
|
||||
"- 第 {} 张图片:当前模型不支持图片输入,该图片未能成功入模,请直接告知用户。",
|
||||
image_idx
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 添加通知文本块
|
||||
if !notices.is_empty() {
|
||||
let notice_text = format!(
|
||||
"[系统提示] 以下图片未能成功入模:\n{}",
|
||||
notices.join("\n")
|
||||
);
|
||||
converted_blocks.push(json!({ "type": "text", "text": notice_text }));
|
||||
}
|
||||
|
||||
// 如果只有一个文本块且没有通知,返回字符串形式
|
||||
if converted_blocks.len() == 1 {
|
||||
if let Some(block) = converted_blocks.first() {
|
||||
if block.get("type").and_then(|t| t.as_str()) == Some("text") {
|
||||
if let Some(text) = block.get("text").and_then(|t| t.as_str()) {
|
||||
return Value::String(text.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Value::Array(converted_blocks);
|
||||
}
|
||||
}
|
||||
|
||||
// 原有逻辑 - 模型支持图片,正常转换
|
||||
if blocks.len() == 1 {
|
||||
if let ContentBlock::Text { text } = &blocks[0] {
|
||||
return Value::String(text.clone());
|
||||
@ -224,6 +292,23 @@ impl OpenAIProvider {
|
||||
.unwrap_or(true)
|
||||
}
|
||||
|
||||
/// 检查模型是否支持指定内容类型
|
||||
/// 默认支持所有类型(text, image)
|
||||
fn supports_content_type(&self, content_type: &str) -> bool {
|
||||
self.model_extra
|
||||
.get("supported_content_types")
|
||||
.and_then(|value| value.as_array())
|
||||
.map(|types| {
|
||||
types.iter().any(|t| t.as_str() == Some(content_type))
|
||||
})
|
||||
.unwrap_or(true)
|
||||
}
|
||||
|
||||
/// 检查模型是否支持图片
|
||||
fn supports_images(&self) -> bool {
|
||||
self.supports_content_type("image")
|
||||
}
|
||||
|
||||
fn normalize_tool_arguments(&self, arguments: &Value) -> Value {
|
||||
match arguments {
|
||||
Value::String(raw) => serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone()),
|
||||
@ -480,20 +565,21 @@ impl OpenAIProvider {
|
||||
}
|
||||
|
||||
fn build_request_body(&self, request: &ChatCompletionRequest) -> Value {
|
||||
let supports_images = self.supports_images();
|
||||
let mut body = json!({
|
||||
"model": self.model_id,
|
||||
"messages": request.messages.iter().map(|m| {
|
||||
"messages": request.messages.iter().enumerate().map(|(i, m)| {
|
||||
if m.role == "tool" {
|
||||
json!({
|
||||
"role": m.role,
|
||||
"content": convert_content_blocks(&m.content),
|
||||
"content": convert_content_blocks(supports_images, &self.name, &self.model_id, &m.content, i),
|
||||
"tool_call_id": m.tool_call_id,
|
||||
"name": m.name,
|
||||
})
|
||||
} else if m.role == "assistant" && m.tool_calls.is_some() {
|
||||
let mut message = json!({
|
||||
"role": m.role,
|
||||
"content": convert_content_blocks(&m.content),
|
||||
"content": convert_content_blocks(supports_images, &self.name, &self.model_id, &m.content, i),
|
||||
"tool_calls": m.tool_calls.as_ref().map(|calls| {
|
||||
calls.iter().map(|call| json!({
|
||||
"id": call.id,
|
||||
@ -514,7 +600,7 @@ impl OpenAIProvider {
|
||||
} else {
|
||||
let mut message = json!({
|
||||
"role": m.role,
|
||||
"content": convert_content_blocks(&m.content)
|
||||
"content": convert_content_blocks(supports_images, &self.name, &self.model_id, &m.content, i)
|
||||
});
|
||||
|
||||
if m.role == "assistant" {
|
||||
@ -1076,4 +1162,115 @@ mod tests {
|
||||
assert_eq!(response.tool_calls[1].id, "call_2");
|
||||
assert_eq!(response.tool_calls[1].name, "get_time");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_supports_images_default_true() {
|
||||
let provider = OpenAIProvider::new(
|
||||
"test".to_string(),
|
||||
"key".to_string(),
|
||||
"https://example.com/v1".to_string(),
|
||||
HashMap::new(),
|
||||
120,
|
||||
"gpt-test".to_string(),
|
||||
None,
|
||||
None,
|
||||
HashMap::new(),
|
||||
);
|
||||
|
||||
assert!(provider.supports_images());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_supports_images_disabled_via_config() {
|
||||
let provider = OpenAIProvider::new(
|
||||
"test".to_string(),
|
||||
"key".to_string(),
|
||||
"https://example.com/v1".to_string(),
|
||||
HashMap::new(),
|
||||
120,
|
||||
"gpt-test".to_string(),
|
||||
None,
|
||||
None,
|
||||
HashMap::from([(
|
||||
"supported_content_types".to_string(),
|
||||
Value::Array(vec![Value::String("text".to_string())]),
|
||||
)]),
|
||||
);
|
||||
|
||||
assert!(!provider.supports_images());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_content_blocks_converts_images_to_notice_when_disabled() {
|
||||
let blocks = vec![
|
||||
ContentBlock::text("hello"),
|
||||
ContentBlock::image_url("data:image/png;base64,abc123"),
|
||||
ContentBlock::text("world"),
|
||||
];
|
||||
|
||||
let result = convert_content_blocks(false, "test", "test-model", &blocks, 0);
|
||||
|
||||
// 应该是数组形式
|
||||
let arr = result.as_array().unwrap();
|
||||
assert_eq!(arr.len(), 3); // 两个文本块 + 一个通知块
|
||||
|
||||
// 检查通知内容
|
||||
let notice_block = arr[2].as_object().unwrap();
|
||||
assert_eq!(notice_block["type"], "text");
|
||||
let notice_text = notice_block["text"].as_str().unwrap();
|
||||
assert!(notice_text.contains("[系统提示] 以下图片未能成功入模"));
|
||||
assert!(notice_text.contains("第 1 张图片"));
|
||||
assert!(notice_text.contains("当前模型不支持图片输入"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_content_blocks_keeps_images_when_enabled() {
|
||||
let blocks = vec![
|
||||
ContentBlock::text("hello"),
|
||||
ContentBlock::image_url("data:image/png;base64,abc123"),
|
||||
];
|
||||
|
||||
let result = convert_content_blocks(true, "test", "test-model", &blocks, 0);
|
||||
|
||||
// 应该是数组形式,包含文本和图片
|
||||
let arr = result.as_array().unwrap();
|
||||
assert_eq!(arr.len(), 2);
|
||||
assert_eq!(arr[0]["type"], "text");
|
||||
assert_eq!(arr[1]["type"], "image_url");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_request_body_omits_supported_content_types_from_api() {
|
||||
let provider = OpenAIProvider::new(
|
||||
"test".to_string(),
|
||||
"key".to_string(),
|
||||
"https://example.com/v1".to_string(),
|
||||
HashMap::new(),
|
||||
120,
|
||||
"gpt-test".to_string(),
|
||||
None,
|
||||
None,
|
||||
HashMap::from([
|
||||
(
|
||||
"supported_content_types".to_string(),
|
||||
Value::Array(vec![Value::String("text".to_string())]),
|
||||
),
|
||||
("custom_param".to_string(), Value::String("value".to_string())),
|
||||
]),
|
||||
);
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
messages: vec![Message::user("hello")],
|
||||
temperature: None,
|
||||
max_tokens: None,
|
||||
tools: None,
|
||||
};
|
||||
|
||||
let body = provider.build_request_body(&request);
|
||||
|
||||
// supported_content_types 不应该发送到 API
|
||||
assert!(body.get("supported_content_types").is_none());
|
||||
// custom_param 应该保留
|
||||
assert_eq!(body["custom_param"], Value::String("value".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
@ -341,7 +341,9 @@ impl SkillSource {
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SkillCatalog {
|
||||
skills: Vec<Skill>,
|
||||
#[allow(dead_code)]
|
||||
max_index_chars: usize,
|
||||
#[allow(dead_code)]
|
||||
max_listed_skills: usize,
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user