Compare commits
No commits in common. "b4ef56803f03e99a3fc1f857de88d1ed8d44c2b0" and "c36650c9aa66c67f4ee0dd5580f49b752dae8671" have entirely different histories.
b4ef56803f
...
c36650c9aa
38
README.md
38
README.md
@ -488,52 +488,46 @@ tools 配置示例:
|
|||||||
可用工具名称:
|
可用工具名称:
|
||||||
- calculator - 数学计算器
|
- calculator - 数学计算器
|
||||||
- get_time - 获取当前时间
|
- get_time - 获取当前时间
|
||||||
- read - 读取文件
|
- file_read - 读取文件
|
||||||
- write - 写入文件
|
- file_write - 写入文件
|
||||||
- edit - 编辑文件
|
- file_edit - 编辑文件
|
||||||
- memory_search - 搜索长期记忆
|
- memory_search - 搜索长期记忆
|
||||||
- memory_manage - 管理长期记忆
|
- memory_manage - 管理长期记忆
|
||||||
- send_session_message - 发送会话消息
|
- session_send - 发送会话消息
|
||||||
- scheduler_manage - 管理定时任务
|
- scheduler_manage - 管理定时任务
|
||||||
- skill_activate - 激活技能
|
- skill_activate - 激活技能
|
||||||
- skill_manage - 管理技能(含 list 功能)
|
- skill_list - 列出技能
|
||||||
- bash - 执行 shell 命令(Unix/Linux/macOS)
|
- skill_manage - 管理技能
|
||||||
- shell - 执行 shell 命令(Windows PowerShell/Cmd)
|
- bash - 执行 shell 命令
|
||||||
- http_request - HTTP 请求
|
- http_request - HTTP 请求
|
||||||
- web_fetch - 网页抓取
|
- web_fetch - 网页抓取
|
||||||
- task - 创建和管理子代理
|
|
||||||
|
|
||||||
注意:bash 和 shell 是同一个工具在不同平台上的名称,运行时自动检测。
|
|
||||||
|
|
||||||
## 8. 工具机制
|
## 8. 工具机制
|
||||||
|
|
||||||
PicoBot 的 Agent 是围绕工具调用构建的。当前默认注册的工具包括:
|
PicoBot 的 Agent 是围绕工具调用构建的。当前默认注册的工具包括:
|
||||||
|
|
||||||
- calculator:简单数学计算
|
- calculator:简单数学计算
|
||||||
- get_time:获取当前时间与时区上下文
|
- time:获取当前时间与时区上下文
|
||||||
- read:读取文件
|
- file_read:读取文件
|
||||||
- write:写文件
|
- file_write:写文件
|
||||||
- edit:编辑文件
|
- file_edit:编辑文件
|
||||||
- memory_search:读取长期记忆
|
- memory_search:读取长期记忆
|
||||||
- memory_manage:写入 / 更新 / 删除长期记忆
|
- memory_manage:写入 / 更新 / 删除长期记忆
|
||||||
- send_session_message:发送会话消息
|
|
||||||
- scheduler_manage:管理调度任务
|
- scheduler_manage:管理调度任务
|
||||||
- skill_activate:读取并激活某个技能内容
|
- skill_activate:读取并激活某个技能内容
|
||||||
- skill_manage:管理技能(支持 list, get, create, update, delete, disable, reload)
|
- skill_list:列出技能
|
||||||
- bash / shell:执行 shell 命令(同一工具,Unix 下名称为 bash,Windows 下名称为 shell)
|
- skill_manage:管理技能
|
||||||
|
- bash:执行 shell 命令
|
||||||
- http_request:发起 HTTP 请求
|
- http_request:发起 HTTP 请求
|
||||||
- web_fetch:抓取网页正文
|
- web_fetch:抓取网页正文
|
||||||
- task:创建和管理子代理
|
|
||||||
|
|
||||||
其中:
|
其中:
|
||||||
|
|
||||||
- read / write / edit 文件工具适合做代码库和文档操作
|
- 文件工具适合做代码库和文档操作
|
||||||
- 记忆工具适合维持长期用户画像
|
- 记忆工具适合维持长期用户画像
|
||||||
- scheduler_manage 允许 Agent 自主创建后续计划任务
|
- scheduler_manage 允许 Agent 自主创建后续计划任务
|
||||||
- skill_activate 负责把具体技能正文注入当前任务上下文
|
- skill_activate 负责把具体技能正文注入当前任务上下文
|
||||||
- skill_manage 整合了技能列出与管理功能,支持运行时创建、更新、删除和批量禁用
|
- bash / http_request / web_fetch 让 Agent 具备更强的外部交互能力
|
||||||
- bash / shell / http_request / web_fetch 让 Agent 具备更强的外部交互能力(bash 和 shell 是同一工具在不同平台的名称)
|
|
||||||
- task 允许 Agent 创建独立上下文的子代理来处理复杂多步骤任务,支持 general 和 explore 两种类型
|
|
||||||
|
|
||||||
## 9. 调度器机制
|
## 9. 调度器机制
|
||||||
|
|
||||||
|
|||||||
919
src/cli/init.rs
919
src/cli/init.rs
@ -1,919 +0,0 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
use std::path::PathBuf;
|
|
||||||
|
|
||||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
|
||||||
|
|
||||||
use crate::config::{
|
|
||||||
AgentConfig, ChannelConfig, Config, FeishuChannelConfig, GatewayConfig, ModelConfig,
|
|
||||||
ProviderConfig, SchedulerConfig, TaggedChannelConfig, WechatChannelConfig,
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Interactive configuration wizard for PicoBot
|
|
||||||
pub struct InitWizard {
|
|
||||||
read: BufReader<tokio::io::Stdin>,
|
|
||||||
write: tokio::io::Stdout,
|
|
||||||
config_path: PathBuf,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl InitWizard {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
|
|
||||||
Self {
|
|
||||||
read: BufReader::new(tokio::io::stdin()),
|
|
||||||
write: tokio::io::stdout(),
|
|
||||||
config_path: home.join(".picobot").join("config.json"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn run(&mut self, force: bool, skip_channels: bool) -> Result<(), InitError> {
|
|
||||||
let existing = self.load_existing_config(force)?;
|
|
||||||
|
|
||||||
self.show_welcome_message();
|
|
||||||
|
|
||||||
// Step 1: Provider
|
|
||||||
let providers = self.configure_provider(&existing).await?;
|
|
||||||
|
|
||||||
// Step 2: Model
|
|
||||||
let models = self.configure_model(&existing).await?;
|
|
||||||
|
|
||||||
// Step 3: Agent
|
|
||||||
let agents = self.configure_agent(&existing, &providers, &models).await?;
|
|
||||||
|
|
||||||
// Step 4: Channels
|
|
||||||
let channels = if !skip_channels {
|
|
||||||
self.configure_channels(&existing).await?
|
|
||||||
} else {
|
|
||||||
existing.channels.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
let config = self.build_config(providers, models, agents, channels, &existing);
|
|
||||||
|
|
||||||
self.save_config(&config)?;
|
|
||||||
|
|
||||||
self.show_completion_message(&config);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn load_existing_config(&self, force: bool) -> Result<Config, InitError> {
|
|
||||||
if self.config_path.exists() && !force {
|
|
||||||
Config::load_default().or_else(|_| Ok(Self::empty_config()))
|
|
||||||
} else {
|
|
||||||
Ok(Self::empty_config())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn empty_config() -> Config {
|
|
||||||
Config {
|
|
||||||
providers: HashMap::new(),
|
|
||||||
models: HashMap::new(),
|
|
||||||
agents: HashMap::new(),
|
|
||||||
time: crate::config::TimeConfig::default(),
|
|
||||||
gateway: GatewayConfig::default(),
|
|
||||||
scheduler: SchedulerConfig::default(),
|
|
||||||
client: crate::config::ClientConfig::default(),
|
|
||||||
channels: HashMap::new(),
|
|
||||||
skills: crate::config::SkillsConfig::default(),
|
|
||||||
tools: crate::config::ToolsConfig::default(),
|
|
||||||
memory_maintenance: crate::config::MemoryMaintenanceConfig::default(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn show_welcome_message(&mut self) {
|
|
||||||
println!();
|
|
||||||
println!("╔══════════════════════════════════════════════════════╗");
|
|
||||||
println!("║ PicoBot Configuration Wizard ║");
|
|
||||||
println!("╚══════════════════════════════════════════════════════╝");
|
|
||||||
println!();
|
|
||||||
println!("This wizard will help you configure PicoBot.");
|
|
||||||
println!("Press Enter to use default values where shown.");
|
|
||||||
println!();
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==================== Prompt Helpers ====================
|
|
||||||
|
|
||||||
async fn prompt_with_default(
|
|
||||||
&mut self,
|
|
||||||
label: &str,
|
|
||||||
default: &str,
|
|
||||||
) -> Result<String, InitError> {
|
|
||||||
if !default.is_empty() {
|
|
||||||
println!("{} [default: {}]: ", label, default);
|
|
||||||
} else {
|
|
||||||
println!("{}: ", label);
|
|
||||||
}
|
|
||||||
self.write.flush().await?;
|
|
||||||
|
|
||||||
let mut line = String::new();
|
|
||||||
let bytes_read = self.read.read_line(&mut line).await?;
|
|
||||||
|
|
||||||
if bytes_read == 0 {
|
|
||||||
return Err(InitError::InputError("EOF reached".to_string()));
|
|
||||||
}
|
|
||||||
|
|
||||||
let input = line.trim().to_string();
|
|
||||||
Ok(if input.is_empty() { default.to_string() } else { input })
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn prompt_required(&mut self, label: &str) -> Result<String, InitError> {
|
|
||||||
loop {
|
|
||||||
println!("{}: ", label);
|
|
||||||
self.write.flush().await?;
|
|
||||||
|
|
||||||
let mut line = String::new();
|
|
||||||
let bytes_read = self.read.read_line(&mut line).await?;
|
|
||||||
|
|
||||||
if bytes_read == 0 {
|
|
||||||
return Err(InitError::InputError("EOF reached".to_string()));
|
|
||||||
}
|
|
||||||
|
|
||||||
let input = line.trim().to_string();
|
|
||||||
if !input.is_empty() {
|
|
||||||
return Ok(input);
|
|
||||||
}
|
|
||||||
println!("{} is required. Please enter a value.", label);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn prompt_select(
|
|
||||||
&mut self,
|
|
||||||
label: &str,
|
|
||||||
options: &[String],
|
|
||||||
default: usize,
|
|
||||||
) -> Result<usize, InitError> {
|
|
||||||
for (i, opt) in options.iter().enumerate() {
|
|
||||||
println!(" {}. {}", i + 1, opt);
|
|
||||||
}
|
|
||||||
println!();
|
|
||||||
|
|
||||||
let default_str = (default + 1).to_string();
|
|
||||||
let input = self.prompt_with_default(label, &default_str).await?;
|
|
||||||
|
|
||||||
let selected: usize = input.parse().map_err(|_| {
|
|
||||||
InitError::InputError(format!("Invalid selection: {}", input))
|
|
||||||
})?;
|
|
||||||
|
|
||||||
if selected == 0 || selected > options.len() {
|
|
||||||
return Err(InitError::InputError(format!(
|
|
||||||
"Selection must be between 1 and {}",
|
|
||||||
options.len()
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(selected - 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==================== Step 1: Provider ====================
|
|
||||||
|
|
||||||
async fn configure_provider(
|
|
||||||
&mut self,
|
|
||||||
existing: &Config,
|
|
||||||
) -> Result<HashMap<String, ProviderConfig>, InitError> {
|
|
||||||
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
|
||||||
println!("Step 1: Configure Provider (API Endpoint)");
|
|
||||||
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
|
||||||
println!();
|
|
||||||
|
|
||||||
let provider_names: Vec<String> = existing.providers.keys().cloned().collect();
|
|
||||||
|
|
||||||
if !provider_names.is_empty() {
|
|
||||||
println!("Existing providers:");
|
|
||||||
for (i, name) in provider_names.iter().enumerate() {
|
|
||||||
println!(" {}. {}", i + 1, name);
|
|
||||||
}
|
|
||||||
println!();
|
|
||||||
|
|
||||||
println!("Options:");
|
|
||||||
println!(" 1. Add new provider");
|
|
||||||
println!(" 2. Modify existing provider");
|
|
||||||
println!(" 3. Use existing provider");
|
|
||||||
println!(" 4. Skip");
|
|
||||||
println!();
|
|
||||||
|
|
||||||
let choice = self
|
|
||||||
.prompt_with_default("Select option", "1")
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
match choice.as_str() {
|
|
||||||
"1" => return self.add_provider(existing).await,
|
|
||||||
"2" => return self.modify_provider(existing).await,
|
|
||||||
"3" => {
|
|
||||||
println!("Keeping existing providers.");
|
|
||||||
return Ok(existing.providers.clone());
|
|
||||||
}
|
|
||||||
"4" => {
|
|
||||||
println!("Skipping provider configuration.");
|
|
||||||
return Ok(existing.providers.clone());
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
println!("Invalid option, adding new provider.");
|
|
||||||
return self.add_provider(existing).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
println!("No existing providers found.");
|
|
||||||
println!();
|
|
||||||
println!("Options:");
|
|
||||||
println!(" 1. Add new provider");
|
|
||||||
println!(" 2. Skip");
|
|
||||||
println!();
|
|
||||||
|
|
||||||
let choice = self.prompt_with_default("Select option", "1").await?;
|
|
||||||
|
|
||||||
match choice.as_str() {
|
|
||||||
"1" => self.add_provider(existing).await,
|
|
||||||
"2" => {
|
|
||||||
println!("Skipping provider configuration.");
|
|
||||||
Ok(existing.providers.clone())
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
println!("Invalid option, adding new provider.");
|
|
||||||
self.add_provider(existing).await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn add_provider(
|
|
||||||
&mut self,
|
|
||||||
existing: &Config,
|
|
||||||
) -> Result<HashMap<String, ProviderConfig>, InitError> {
|
|
||||||
let provider_name = self
|
|
||||||
.prompt_with_default("Provider name", "default")
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
println!("Provider type:");
|
|
||||||
println!(" 1. openai");
|
|
||||||
println!(" 2. anthropic");
|
|
||||||
println!();
|
|
||||||
let type_choice = self.prompt_with_default("Select type", "1").await?;
|
|
||||||
let provider_type = match type_choice.as_str() {
|
|
||||||
"1" => "openai",
|
|
||||||
"2" => "anthropic",
|
|
||||||
_ => "openai",
|
|
||||||
};
|
|
||||||
|
|
||||||
let base_url = self.prompt_required("API base URL").await?;
|
|
||||||
|
|
||||||
println!();
|
|
||||||
println!("API Key is required and will be stored in config.json.");
|
|
||||||
let api_key = self.prompt_required("API Key").await?;
|
|
||||||
|
|
||||||
let provider = ProviderConfig {
|
|
||||||
provider_type: provider_type.to_string(),
|
|
||||||
base_url,
|
|
||||||
api_key,
|
|
||||||
extra_headers: HashMap::new(),
|
|
||||||
llm_timeout_secs: 120,
|
|
||||||
memory_maintenance_timeout_secs: 600,
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut providers = existing.providers.clone();
|
|
||||||
providers.insert(provider_name.clone(), provider);
|
|
||||||
|
|
||||||
println!();
|
|
||||||
println!("Provider '{}' configured.", provider_name);
|
|
||||||
|
|
||||||
Ok(providers)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn modify_provider(
|
|
||||||
&mut self,
|
|
||||||
existing: &Config,
|
|
||||||
) -> Result<HashMap<String, ProviderConfig>, InitError> {
|
|
||||||
// Select which provider to modify
|
|
||||||
let provider_names: Vec<String> = existing.providers.keys().cloned().collect();
|
|
||||||
println!("Select provider to modify:");
|
|
||||||
let provider_idx = self.prompt_select("", &provider_names, 0).await?;
|
|
||||||
let selected_provider_name = &provider_names[provider_idx];
|
|
||||||
let current_provider = existing.providers.get(selected_provider_name).unwrap();
|
|
||||||
|
|
||||||
println!();
|
|
||||||
println!(
|
|
||||||
"Current config: type={}, base_url={}",
|
|
||||||
current_provider.provider_type, current_provider.base_url
|
|
||||||
);
|
|
||||||
println!();
|
|
||||||
|
|
||||||
let current_type_idx = if current_provider.provider_type == "anthropic" {
|
|
||||||
1
|
|
||||||
} else {
|
|
||||||
0
|
|
||||||
};
|
|
||||||
let type_options = vec!["openai".to_string(), "anthropic".to_string()];
|
|
||||||
println!("Provider type:");
|
|
||||||
let type_idx = self.prompt_select("", &type_options, current_type_idx).await?;
|
|
||||||
let provider_type = &type_options[type_idx];
|
|
||||||
|
|
||||||
let base_url = self
|
|
||||||
.prompt_with_default("API base URL", ¤t_provider.base_url)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
println!();
|
|
||||||
println!("API Key (press Enter to keep current key):");
|
|
||||||
let api_key = self
|
|
||||||
.prompt_with_default("API Key", ¤t_provider.api_key)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let provider = ProviderConfig {
|
|
||||||
provider_type: provider_type.clone(),
|
|
||||||
base_url,
|
|
||||||
api_key,
|
|
||||||
extra_headers: current_provider.extra_headers.clone(),
|
|
||||||
llm_timeout_secs: current_provider.llm_timeout_secs,
|
|
||||||
memory_maintenance_timeout_secs: current_provider.memory_maintenance_timeout_secs,
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut providers = existing.providers.clone();
|
|
||||||
providers.insert(selected_provider_name.clone(), provider);
|
|
||||||
|
|
||||||
println!();
|
|
||||||
println!("Provider '{}' modified.", selected_provider_name);
|
|
||||||
|
|
||||||
Ok(providers)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==================== Step 2: Model ====================
|
|
||||||
|
|
||||||
async fn configure_model(
|
|
||||||
&mut self,
|
|
||||||
existing: &Config,
|
|
||||||
) -> Result<HashMap<String, ModelConfig>, InitError> {
|
|
||||||
println!();
|
|
||||||
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
|
||||||
println!("Step 2: Configure Model");
|
|
||||||
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
|
||||||
println!();
|
|
||||||
|
|
||||||
let model_names: Vec<String> = existing.models.keys().cloned().collect();
|
|
||||||
|
|
||||||
if !model_names.is_empty() {
|
|
||||||
println!("Existing models:");
|
|
||||||
for (i, name) in model_names.iter().enumerate() {
|
|
||||||
let model = existing.models.get(name).unwrap();
|
|
||||||
println!(" {}. {} (model_id: {})", i + 1, name, model.model_id);
|
|
||||||
}
|
|
||||||
println!();
|
|
||||||
|
|
||||||
println!("Options:");
|
|
||||||
println!(" 1. Use existing model");
|
|
||||||
println!(" 2. Add new model");
|
|
||||||
println!(" 3. Skip");
|
|
||||||
println!();
|
|
||||||
|
|
||||||
let choice = self.prompt_with_default("Select option", "1").await?;
|
|
||||||
|
|
||||||
match choice.as_str() {
|
|
||||||
"1" => {
|
|
||||||
println!("Keeping existing models.");
|
|
||||||
return Ok(existing.models.clone());
|
|
||||||
}
|
|
||||||
"2" => return self.add_model(existing).await,
|
|
||||||
"3" => {
|
|
||||||
println!("Skipping model configuration.");
|
|
||||||
return Ok(existing.models.clone());
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
println!("Invalid option, keeping existing models.");
|
|
||||||
return Ok(existing.models.clone());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
println!("No existing models found.");
|
|
||||||
println!();
|
|
||||||
println!("Options:");
|
|
||||||
println!(" 1. Add new model");
|
|
||||||
println!(" 2. Skip");
|
|
||||||
println!();
|
|
||||||
|
|
||||||
let choice = self.prompt_with_default("Select option", "1").await?;
|
|
||||||
|
|
||||||
match choice.as_str() {
|
|
||||||
"1" => self.add_model(existing).await,
|
|
||||||
"2" => {
|
|
||||||
println!("Skipping model configuration.");
|
|
||||||
Ok(existing.models.clone())
|
|
||||||
}
|
|
||||||
_ => self.add_model(existing).await,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn add_model(
|
|
||||||
&mut self,
|
|
||||||
existing: &Config,
|
|
||||||
) -> Result<HashMap<String, ModelConfig>, InitError> {
|
|
||||||
let model_name = self
|
|
||||||
.prompt_with_default("Model name (reference key)", "default")
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let model_id = self.prompt_required("Model ID (actual identifier)").await?;
|
|
||||||
|
|
||||||
let temperature_str = self
|
|
||||||
.prompt_with_default("Temperature (0.0-1.0)", "0.7")
|
|
||||||
.await?;
|
|
||||||
let temperature: Option<f32> = temperature_str
|
|
||||||
.parse()
|
|
||||||
.ok()
|
|
||||||
.filter(|v| (0.0..=1.0).contains(v));
|
|
||||||
|
|
||||||
let max_tokens_str = self
|
|
||||||
.prompt_with_default("Max tokens (optional)", "4096")
|
|
||||||
.await?;
|
|
||||||
let max_tokens: Option<u32> = max_tokens_str.parse().ok();
|
|
||||||
|
|
||||||
let model = ModelConfig {
|
|
||||||
model_id,
|
|
||||||
temperature,
|
|
||||||
max_tokens,
|
|
||||||
context_window_tokens: Some(128000),
|
|
||||||
extra: HashMap::new(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut models = existing.models.clone();
|
|
||||||
models.insert(model_name.clone(), model);
|
|
||||||
|
|
||||||
println!();
|
|
||||||
println!("Model '{}' configured.", model_name);
|
|
||||||
|
|
||||||
Ok(models)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==================== Step 3: Agent ====================
|
|
||||||
|
|
||||||
async fn configure_agent(
|
|
||||||
&mut self,
|
|
||||||
existing: &Config,
|
|
||||||
providers: &HashMap<String, ProviderConfig>,
|
|
||||||
models: &HashMap<String, ModelConfig>,
|
|
||||||
) -> Result<HashMap<String, AgentConfig>, InitError> {
|
|
||||||
println!();
|
|
||||||
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
|
||||||
println!("Step 3: Configure Agent");
|
|
||||||
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
|
||||||
println!();
|
|
||||||
|
|
||||||
// Check if we have providers and models to select from
|
|
||||||
if providers.is_empty() {
|
|
||||||
println!("Warning: No providers configured. Please configure a provider first.");
|
|
||||||
return Ok(existing.agents.clone());
|
|
||||||
}
|
|
||||||
if models.is_empty() {
|
|
||||||
println!("Warning: No models configured. Please configure a model first.");
|
|
||||||
return Ok(existing.agents.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
let agent_names: Vec<String> = existing.agents.keys().cloned().collect();
|
|
||||||
|
|
||||||
if !agent_names.is_empty() {
|
|
||||||
println!("Existing agents:");
|
|
||||||
for (i, name) in agent_names.iter().enumerate() {
|
|
||||||
let agent = existing.agents.get(name).unwrap();
|
|
||||||
println!(
|
|
||||||
" {}. {} (provider: {}, model: {})",
|
|
||||||
i + 1,
|
|
||||||
name,
|
|
||||||
agent.provider,
|
|
||||||
agent.model
|
|
||||||
);
|
|
||||||
}
|
|
||||||
println!();
|
|
||||||
|
|
||||||
println!("Options:");
|
|
||||||
println!(" 1. Add new agent");
|
|
||||||
println!(" 2. Modify existing agent");
|
|
||||||
println!(" 3. Use existing agent");
|
|
||||||
println!(" 4. Skip");
|
|
||||||
println!();
|
|
||||||
|
|
||||||
let choice = self.prompt_with_default("Select option", "1").await?;
|
|
||||||
|
|
||||||
match choice.as_str() {
|
|
||||||
"1" => return self.add_agent(existing, providers, models).await,
|
|
||||||
"2" => return self.modify_agent(existing, providers, models).await,
|
|
||||||
"3" => {
|
|
||||||
println!("Keeping existing agents.");
|
|
||||||
return Ok(existing.agents.clone());
|
|
||||||
}
|
|
||||||
"4" => {
|
|
||||||
println!("Skipping agent configuration.");
|
|
||||||
return Ok(existing.agents.clone());
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
println!("Invalid option, adding new agent.");
|
|
||||||
return self.add_agent(existing, providers, models).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
println!("No existing agents found.");
|
|
||||||
println!();
|
|
||||||
println!("Options:");
|
|
||||||
println!(" 1. Add new agent");
|
|
||||||
println!(" 2. Skip");
|
|
||||||
println!();
|
|
||||||
|
|
||||||
let choice = self.prompt_with_default("Select option", "1").await?;
|
|
||||||
|
|
||||||
match choice.as_str() {
|
|
||||||
"1" => self.add_agent(existing, providers, models).await,
|
|
||||||
"2" => {
|
|
||||||
println!("Skipping agent configuration.");
|
|
||||||
Ok(existing.agents.clone())
|
|
||||||
}
|
|
||||||
_ => self.add_agent(existing, providers, models).await,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn add_agent(
|
|
||||||
&mut self,
|
|
||||||
existing: &Config,
|
|
||||||
providers: &HashMap<String, ProviderConfig>,
|
|
||||||
models: &HashMap<String, ModelConfig>,
|
|
||||||
) -> Result<HashMap<String, AgentConfig>, InitError> {
|
|
||||||
let agent_name = self
|
|
||||||
.prompt_with_default("Agent name", "default")
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Select provider
|
|
||||||
let provider_names: Vec<String> = providers.keys().cloned().collect();
|
|
||||||
println!("Select provider:");
|
|
||||||
let provider_idx = self.prompt_select("", &provider_names, 0).await?;
|
|
||||||
let selected_provider = &provider_names[provider_idx];
|
|
||||||
|
|
||||||
// Select model (independently)
|
|
||||||
let model_names: Vec<String> = models.keys().cloned().collect();
|
|
||||||
println!();
|
|
||||||
println!("Select model:");
|
|
||||||
let model_idx = self.prompt_select("", &model_names, 0).await?;
|
|
||||||
let selected_model = &model_names[model_idx];
|
|
||||||
|
|
||||||
let agent = AgentConfig {
|
|
||||||
provider: selected_provider.clone(),
|
|
||||||
model: selected_model.clone(),
|
|
||||||
max_tool_iterations: 100,
|
|
||||||
tool_result_max_chars: 20000,
|
|
||||||
context_tool_result_trim_chars: 2000,
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut agents = existing.agents.clone();
|
|
||||||
agents.insert(agent_name.clone(), agent);
|
|
||||||
|
|
||||||
println!();
|
|
||||||
println!(
|
|
||||||
"Agent '{}' configured (provider: {}, model: {}).",
|
|
||||||
agent_name, selected_provider, selected_model
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(agents)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn modify_agent(
|
|
||||||
&mut self,
|
|
||||||
existing: &Config,
|
|
||||||
providers: &HashMap<String, ProviderConfig>,
|
|
||||||
models: &HashMap<String, ModelConfig>,
|
|
||||||
) -> Result<HashMap<String, AgentConfig>, InitError> {
|
|
||||||
// Select which agent to modify
|
|
||||||
let agent_names: Vec<String> = existing.agents.keys().cloned().collect();
|
|
||||||
println!("Select agent to modify:");
|
|
||||||
let agent_idx = self.prompt_select("", &agent_names, 0).await?;
|
|
||||||
let selected_agent_name = &agent_names[agent_idx];
|
|
||||||
let current_agent = existing.agents.get(selected_agent_name).unwrap();
|
|
||||||
|
|
||||||
println!();
|
|
||||||
println!(
|
|
||||||
"Current config: provider={}, model={}",
|
|
||||||
current_agent.provider, current_agent.model
|
|
||||||
);
|
|
||||||
println!();
|
|
||||||
|
|
||||||
// Select new provider
|
|
||||||
let provider_names: Vec<String> = providers.keys().cloned().collect();
|
|
||||||
let current_provider_idx = provider_names
|
|
||||||
.iter()
|
|
||||||
.position(|p| p == ¤t_agent.provider)
|
|
||||||
.unwrap_or(0);
|
|
||||||
println!("Select provider:");
|
|
||||||
let provider_idx = self.prompt_select("", &provider_names, current_provider_idx).await?;
|
|
||||||
let selected_provider = &provider_names[provider_idx];
|
|
||||||
|
|
||||||
// Select new model
|
|
||||||
let model_names: Vec<String> = models.keys().cloned().collect();
|
|
||||||
let current_model_idx = model_names
|
|
||||||
.iter()
|
|
||||||
.position(|m| m == ¤t_agent.model)
|
|
||||||
.unwrap_or(0);
|
|
||||||
println!();
|
|
||||||
println!("Select model:");
|
|
||||||
let model_idx = self.prompt_select("", &model_names, current_model_idx).await?;
|
|
||||||
let selected_model = &model_names[model_idx];
|
|
||||||
|
|
||||||
let agent = AgentConfig {
|
|
||||||
provider: selected_provider.clone(),
|
|
||||||
model: selected_model.clone(),
|
|
||||||
max_tool_iterations: current_agent.max_tool_iterations,
|
|
||||||
tool_result_max_chars: current_agent.tool_result_max_chars,
|
|
||||||
context_tool_result_trim_chars: current_agent.context_tool_result_trim_chars,
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut agents = existing.agents.clone();
|
|
||||||
agents.insert(selected_agent_name.clone(), agent);
|
|
||||||
|
|
||||||
println!();
|
|
||||||
println!(
|
|
||||||
"Agent '{}' modified (provider: {}, model: {}).",
|
|
||||||
selected_agent_name, selected_provider, selected_model
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(agents)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==================== Step 4: Channels ====================
|
|
||||||
|
|
||||||
async fn configure_channels(
|
|
||||||
&mut self,
|
|
||||||
existing: &Config,
|
|
||||||
) -> Result<HashMap<String, ChannelConfig>, InitError> {
|
|
||||||
println!();
|
|
||||||
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
|
||||||
println!("Step 4: Configure Channels");
|
|
||||||
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
|
||||||
println!();
|
|
||||||
|
|
||||||
// Show existing channels
|
|
||||||
if !existing.channels.is_empty() {
|
|
||||||
println!("Existing channels:");
|
|
||||||
for (name, config) in &existing.channels {
|
|
||||||
let status = if config.enabled() { "enabled" } else { "disabled" };
|
|
||||||
println!(" - {} ({})", name, status);
|
|
||||||
}
|
|
||||||
println!();
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut channels = existing.channels.clone();
|
|
||||||
|
|
||||||
// Loop for configuring multiple channels
|
|
||||||
loop {
|
|
||||||
println!("Options:");
|
|
||||||
println!(" 1. Add Feishu channel");
|
|
||||||
println!(" 2. Add WeChat channel");
|
|
||||||
println!(" 3. Skip / Done");
|
|
||||||
println!();
|
|
||||||
|
|
||||||
let choice = self.prompt_with_default("Select option", "3").await?;
|
|
||||||
|
|
||||||
match choice.as_str() {
|
|
||||||
"1" => {
|
|
||||||
let channel_config = self.configure_feishu_channel(&channels).await?;
|
|
||||||
for (name, config) in channel_config {
|
|
||||||
channels.insert(name, config);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"2" => {
|
|
||||||
let channel_config = self.configure_wechat_channel(&channels).await?;
|
|
||||||
for (name, config) in channel_config {
|
|
||||||
channels.insert(name, config);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"3" => break,
|
|
||||||
_ => {
|
|
||||||
println!("Invalid selection.");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(channels)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn configure_feishu_channel(
|
|
||||||
&mut self,
|
|
||||||
existing: &HashMap<String, ChannelConfig>,
|
|
||||||
) -> Result<HashMap<String, ChannelConfig>, InitError> {
|
|
||||||
println!();
|
|
||||||
println!("Configuring Feishu channel...");
|
|
||||||
println!();
|
|
||||||
|
|
||||||
let channel_name = self
|
|
||||||
.prompt_with_default("Channel name", "feishu")
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let _existing_config = existing.get(&channel_name).and_then(|c| c.as_feishu());
|
|
||||||
|
|
||||||
let app_id = self.prompt_required("Feishu App ID").await?;
|
|
||||||
let app_secret = self.prompt_required("Feishu App Secret").await?;
|
|
||||||
|
|
||||||
let config = ChannelConfig::Tagged(TaggedChannelConfig::Feishu(FeishuChannelConfig {
|
|
||||||
enabled: true,
|
|
||||||
app_id,
|
|
||||||
app_secret,
|
|
||||||
allow_from: vec!["*".to_string()],
|
|
||||||
agent: "default".to_string(),
|
|
||||||
media_dir: Self::default_feishu_media_dir(),
|
|
||||||
reaction_emoji: "Typing".to_string(),
|
|
||||||
max_message_chars: 20000,
|
|
||||||
reply_context_max_chars: 20000,
|
|
||||||
}));
|
|
||||||
|
|
||||||
let mut result = HashMap::new();
|
|
||||||
result.insert(channel_name.clone(), config);
|
|
||||||
|
|
||||||
println!();
|
|
||||||
println!("Feishu channel '{}' configured.", channel_name);
|
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn configure_wechat_channel(
|
|
||||||
&mut self,
|
|
||||||
_existing: &HashMap<String, ChannelConfig>,
|
|
||||||
) -> Result<HashMap<String, ChannelConfig>, InitError> {
|
|
||||||
println!();
|
|
||||||
println!("Configuring WeChat channel...");
|
|
||||||
println!();
|
|
||||||
println!("WeChat login requires scanning a QR code.");
|
|
||||||
println!("The QR code URL will be displayed after configuration.");
|
|
||||||
println!();
|
|
||||||
|
|
||||||
// Use default values directly
|
|
||||||
let channel_name = "wechat";
|
|
||||||
let base_url = "https://ilinkai.weixin.qq.com";
|
|
||||||
let cred_path = Self::default_wechat_cred_path();
|
|
||||||
let force_login = false;
|
|
||||||
|
|
||||||
let config = ChannelConfig::Tagged(TaggedChannelConfig::Wechat(WechatChannelConfig {
|
|
||||||
enabled: true,
|
|
||||||
allow_from: vec!["*".to_string()],
|
|
||||||
agent: "default".to_string(),
|
|
||||||
base_url: base_url.to_string(),
|
|
||||||
cred_path,
|
|
||||||
force_login,
|
|
||||||
}));
|
|
||||||
|
|
||||||
let mut result = HashMap::new();
|
|
||||||
result.insert(channel_name.to_string(), config);
|
|
||||||
|
|
||||||
println!();
|
|
||||||
println!("WeChat channel '{}' configured.", channel_name);
|
|
||||||
|
|
||||||
// Auto login after configuration
|
|
||||||
println!();
|
|
||||||
println!("Starting WeChat login...");
|
|
||||||
|
|
||||||
self.do_wechat_login(base_url, &Self::default_wechat_cred_path()).await?;
|
|
||||||
|
|
||||||
println!();
|
|
||||||
println!("WeChat login successful! Credentials saved.");
|
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn do_wechat_login(&mut self, base_url: &str, cred_path: &str) -> Result<(), InitError> {
|
|
||||||
use wechatbot::{BotOptions, WeChatBot};
|
|
||||||
|
|
||||||
let bot = WeChatBot::new(BotOptions {
|
|
||||||
base_url: Some(base_url.to_string()),
|
|
||||||
cred_path: Some(cred_path.to_string()),
|
|
||||||
on_qr_url: Some(Box::new(|url| {
|
|
||||||
println!();
|
|
||||||
println!("┌──────────────────────────────────────────────────────┐");
|
|
||||||
println!("│ WeChat QR Code │");
|
|
||||||
println!("└──────────────────────────────────────────────────────┘");
|
|
||||||
println!();
|
|
||||||
println!("Scan this URL in WeChat:");
|
|
||||||
println!("{}", url);
|
|
||||||
println!();
|
|
||||||
println!("Waiting for confirmation...");
|
|
||||||
})),
|
|
||||||
on_error: Some(Box::new(|error| {
|
|
||||||
eprintln!("WeChat login error: {}", error);
|
|
||||||
})),
|
|
||||||
});
|
|
||||||
|
|
||||||
let creds = bot.login(true).await.map_err(|e| {
|
|
||||||
InitError::WeChatError(format!("WeChat login failed: {}", e))
|
|
||||||
})?;
|
|
||||||
|
|
||||||
println!();
|
|
||||||
println!(
|
|
||||||
"Logged in as: {} (account: {})",
|
|
||||||
creds.user_id, creds.account_id
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==================== Build & Save ====================
|
|
||||||
|
|
||||||
fn build_config(
|
|
||||||
&self,
|
|
||||||
providers: HashMap<String, ProviderConfig>,
|
|
||||||
models: HashMap<String, ModelConfig>,
|
|
||||||
agents: HashMap<String, AgentConfig>,
|
|
||||||
channels: HashMap<String, ChannelConfig>,
|
|
||||||
existing: &Config,
|
|
||||||
) -> Config {
|
|
||||||
Config {
|
|
||||||
providers,
|
|
||||||
models,
|
|
||||||
agents,
|
|
||||||
channels,
|
|
||||||
time: existing.time.clone(),
|
|
||||||
gateway: existing.gateway.clone(),
|
|
||||||
scheduler: existing.scheduler.clone(),
|
|
||||||
client: existing.client.clone(),
|
|
||||||
skills: existing.skills.clone(),
|
|
||||||
tools: existing.tools.clone(),
|
|
||||||
memory_maintenance: existing.memory_maintenance.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn save_config(&self, config: &Config) -> Result<(), InitError> {
|
|
||||||
let dir = self
|
|
||||||
.config_path
|
|
||||||
.parent()
|
|
||||||
.ok_or_else(|| InitError::IoError("Invalid config path".to_string()))?;
|
|
||||||
std::fs::create_dir_all(dir)
|
|
||||||
.map_err(|e| InitError::IoError(format!("Failed to create directory: {}", e)))?;
|
|
||||||
|
|
||||||
let content = serde_json::to_string_pretty(config)
|
|
||||||
.map_err(|e| InitError::SerializeError(e.to_string()))?;
|
|
||||||
|
|
||||||
std::fs::write(&self.config_path, content)
|
|
||||||
.map_err(|e| InitError::IoError(format!("Failed to write config: {}", e)))?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn show_completion_message(&mut self, config: &Config) {
|
|
||||||
println!();
|
|
||||||
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
|
||||||
println!("Configuration complete!");
|
|
||||||
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
|
||||||
println!();
|
|
||||||
println!("Config saved to: {}", self.config_path.display());
|
|
||||||
println!();
|
|
||||||
|
|
||||||
println!("Summary:");
|
|
||||||
println!(" Providers: {}", config.providers.len());
|
|
||||||
println!(" Models: {}", config.models.len());
|
|
||||||
println!(" Agents: {}", config.agents.len());
|
|
||||||
println!(" Channels: {}", config.channels.len());
|
|
||||||
println!();
|
|
||||||
|
|
||||||
println!("Next steps:");
|
|
||||||
println!(" 1. Start the gateway: picobot gateway");
|
|
||||||
println!(" 2. Connect with CLI: picobot agent");
|
|
||||||
println!();
|
|
||||||
println!("For more options, run: picobot --help");
|
|
||||||
println!();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_feishu_media_dir() -> String {
|
|
||||||
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
|
|
||||||
home.join(".picobot/media/feishu")
|
|
||||||
.to_string_lossy()
|
|
||||||
.to_string()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_wechat_cred_path() -> String {
|
|
||||||
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
|
|
||||||
home.join(".picobot/wechat/credentials.json")
|
|
||||||
.to_string_lossy()
|
|
||||||
.to_string()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for InitWizard {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self::new()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum InitError {
|
|
||||||
IoError(String),
|
|
||||||
SerializeError(String),
|
|
||||||
InputError(String),
|
|
||||||
WeChatError(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Display for InitError {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
InitError::IoError(msg) => write!(f, "IO error: {}", msg),
|
|
||||||
InitError::SerializeError(msg) => write!(f, "Serialization error: {}", msg),
|
|
||||||
InitError::InputError(msg) => write!(f, "Input error: {}", msg),
|
|
||||||
InitError::WeChatError(msg) => write!(f, "WeChat error: {}", msg),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::error::Error for InitError {}
|
|
||||||
|
|
||||||
impl From<std::io::Error> for InitError {
|
|
||||||
fn from(e: std::io::Error) -> Self {
|
|
||||||
InitError::IoError(e.to_string())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,7 +1,5 @@
|
|||||||
pub mod channel;
|
pub mod channel;
|
||||||
pub mod input;
|
pub mod input;
|
||||||
pub mod init;
|
|
||||||
|
|
||||||
pub use channel::CliChannel;
|
pub use channel::CliChannel;
|
||||||
pub use input::{InputCommand, InputEvent, InputHandler};
|
pub use input::{InputCommand, InputEvent, InputHandler};
|
||||||
pub use init::InitWizard;
|
|
||||||
|
|||||||
@ -27,8 +27,6 @@ pub struct Config {
|
|||||||
pub skills: SkillsConfig,
|
pub skills: SkillsConfig,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub tools: ToolsConfig,
|
pub tools: ToolsConfig,
|
||||||
#[serde(default)]
|
|
||||||
pub memory_maintenance: MemoryMaintenanceConfig,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
@ -110,41 +108,6 @@ pub struct ToolsConfig {
|
|||||||
pub task: TaskConfig,
|
pub task: TaskConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct MemoryMaintenanceConfig {
|
|
||||||
/// 单次最大合并/删除比例 (0.0-1.0),默认 0.3 (30%)
|
|
||||||
#[serde(default = "default_max_merge_ratio")]
|
|
||||||
pub max_merge_ratio: f32,
|
|
||||||
/// 最小保留记忆数量,默认 5
|
|
||||||
#[serde(default = "default_min_memories_to_keep")]
|
|
||||||
pub min_memories_to_keep: usize,
|
|
||||||
/// 单次合并最大源记忆数,默认 3
|
|
||||||
#[serde(default = "default_max_merge_per_group")]
|
|
||||||
pub max_merge_per_group: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_max_merge_ratio() -> f32 {
|
|
||||||
0.3
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_min_memories_to_keep() -> usize {
|
|
||||||
5
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_max_merge_per_group() -> usize {
|
|
||||||
3
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for MemoryMaintenanceConfig {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
max_merge_ratio: default_max_merge_ratio(),
|
|
||||||
min_memories_to_keep: default_min_memories_to_keep(),
|
|
||||||
max_merge_per_group: default_max_merge_per_group(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct TaskConfig {
|
pub struct TaskConfig {
|
||||||
#[serde(default = "default_task_enabled")]
|
#[serde(default = "default_task_enabled")]
|
||||||
|
|||||||
@ -5,7 +5,7 @@ use std::time::Duration;
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::agent::AgentError;
|
use crate::agent::AgentError;
|
||||||
use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig};
|
use crate::config::LLMProviderConfig;
|
||||||
use crate::providers::{ChatCompletionRequest, Message, ProviderRuntimeConfig, create_provider};
|
use crate::providers::{ChatCompletionRequest, Message, ProviderRuntimeConfig, create_provider};
|
||||||
use crate::storage::{MemoryRecord, SessionStore};
|
use crate::storage::{MemoryRecord, SessionStore};
|
||||||
|
|
||||||
@ -17,16 +17,12 @@ const MEMORY_MAINTENANCE_STEP2_SYSTEM_PROMPT: &str =
|
|||||||
include_str!("memory_maintenance_step2_system_prompt.md");
|
include_str!("memory_maintenance_step2_system_prompt.md");
|
||||||
const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000];
|
const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000];
|
||||||
|
|
||||||
const META_NAMESPACE: &str = "_meta";
|
|
||||||
const LAST_MAINTENANCE_KEY: &str = "last_maintenance_at";
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
pub(crate) struct MemoryMaintenanceCandidate {
|
pub(crate) struct MemoryMaintenanceCandidate {
|
||||||
pub(crate) id: String,
|
pub(crate) id: String,
|
||||||
pub(crate) namespace: String,
|
pub(crate) namespace: String,
|
||||||
pub(crate) key: String,
|
pub(crate) key: String,
|
||||||
pub(crate) content: String,
|
pub(crate) content: String,
|
||||||
pub(crate) updated_at: i64, // 记忆更新时间(Unix timestamp)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
|
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
@ -77,19 +73,13 @@ pub(crate) struct MemoryMaintenanceScopeResult {
|
|||||||
pub(crate) struct MemoryMaintenanceService {
|
pub(crate) struct MemoryMaintenanceService {
|
||||||
store: Arc<SessionStore>,
|
store: Arc<SessionStore>,
|
||||||
provider_config: LLMProviderConfig,
|
provider_config: LLMProviderConfig,
|
||||||
maintenance_config: MemoryMaintenanceConfig,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MemoryMaintenanceService {
|
impl MemoryMaintenanceService {
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(store: Arc<SessionStore>, provider_config: LLMProviderConfig) -> Self {
|
||||||
store: Arc<SessionStore>,
|
|
||||||
provider_config: LLMProviderConfig,
|
|
||||||
maintenance_config: MemoryMaintenanceConfig,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
Self {
|
||||||
store,
|
store,
|
||||||
provider_config,
|
provider_config,
|
||||||
maintenance_config,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -117,12 +107,6 @@ impl MemoryMaintenanceService {
|
|||||||
&self,
|
&self,
|
||||||
scope_key: &str,
|
scope_key: &str,
|
||||||
) -> Result<Option<MemoryMaintenancePlan>, AgentError> {
|
) -> Result<Option<MemoryMaintenancePlan>, AgentError> {
|
||||||
// 新增:检查是否有新记忆需要整理
|
|
||||||
if !has_new_memories_since_last_maintenance(&self.store, scope_key)? {
|
|
||||||
tracing::info!(scope_key = %scope_key, "No new memories since last maintenance, skipping");
|
|
||||||
return Ok(None);
|
|
||||||
}
|
|
||||||
|
|
||||||
let memories = self
|
let memories = self
|
||||||
.store
|
.store
|
||||||
.list_memories_for_scope("user", scope_key)
|
.list_memories_for_scope("user", scope_key)
|
||||||
@ -271,15 +255,7 @@ impl MemoryMaintenanceService {
|
|||||||
let organize_output = self.organize_plan(scope_key, &plan).await?;
|
let organize_output = self.organize_plan(scope_key, &plan).await?;
|
||||||
|
|
||||||
// 应用整理结果(merge和delete)
|
// 应用整理结果(merge和delete)
|
||||||
apply_memory_maintenance_output(
|
apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &organize_output)?;
|
||||||
self.store.as_ref(),
|
|
||||||
scope_key,
|
|
||||||
&plan,
|
|
||||||
&organize_output,
|
|
||||||
self.maintenance_config.max_merge_ratio,
|
|
||||||
self.maintenance_config.min_memories_to_keep,
|
|
||||||
self.maintenance_config.max_merge_per_group,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
// 步骤2:从数据库重新读取剩余的记忆
|
// 步骤2:从数据库重新读取剩余的记忆
|
||||||
let remaining_memories = self
|
let remaining_memories = self
|
||||||
@ -494,94 +470,17 @@ impl MemoryMaintenanceService {
|
|||||||
let organize_output = self.organize_plan(scope_key, &plan).await?;
|
let organize_output = self.organize_plan(scope_key, &plan).await?;
|
||||||
|
|
||||||
// 应用整理结果(merge和delete)
|
// 应用整理结果(merge和delete)
|
||||||
apply_memory_maintenance_output(
|
apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &organize_output)?;
|
||||||
self.store.as_ref(),
|
|
||||||
scope_key,
|
|
||||||
&plan,
|
|
||||||
&organize_output,
|
|
||||||
self.maintenance_config.max_merge_ratio,
|
|
||||||
self.maintenance_config.min_memories_to_keep,
|
|
||||||
self.maintenance_config.max_merge_per_group,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(Some(organize_output))
|
Ok(Some(organize_output))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 获取上次整理时间
|
|
||||||
fn get_last_maintenance_time(
|
|
||||||
store: &SessionStore,
|
|
||||||
scope_key: &str,
|
|
||||||
) -> Result<Option<i64>, crate::storage::StorageError> {
|
|
||||||
let meta = store.get_memory("user", scope_key, META_NAMESPACE, LAST_MAINTENANCE_KEY)?;
|
|
||||||
Ok(meta.and_then(|m| m.content.parse::<i64>().ok()))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 记录本次整理时间
|
|
||||||
fn set_last_maintenance_time(
|
|
||||||
store: &SessionStore,
|
|
||||||
scope_key: &str,
|
|
||||||
time: i64,
|
|
||||||
) -> Result<(), crate::storage::StorageError> {
|
|
||||||
store.put_memory(&crate::storage::MemoryUpsert {
|
|
||||||
scope_kind: "user".to_string(),
|
|
||||||
scope_key: scope_key.to_string(),
|
|
||||||
namespace: META_NAMESPACE.to_string(),
|
|
||||||
memory_key: LAST_MAINTENANCE_KEY.to_string(),
|
|
||||||
content: time.to_string(),
|
|
||||||
source_type: "memory_maintenance".to_string(),
|
|
||||||
source_session_id: None,
|
|
||||||
source_message_id: None,
|
|
||||||
source_message_seq: None,
|
|
||||||
source_channel_name: None,
|
|
||||||
source_chat_id: None,
|
|
||||||
})?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 检查是否有需要整理的新记忆(过滤掉 _meta namespace)
|
|
||||||
fn has_new_memories_since_last_maintenance(
|
|
||||||
store: &SessionStore,
|
|
||||||
scope_key: &str,
|
|
||||||
) -> Result<bool, AgentError> {
|
|
||||||
let memories = store
|
|
||||||
.list_memories_for_scope("user", scope_key)
|
|
||||||
.map_err(|e| AgentError::Other(format!("list memories error: {}", e)))?;
|
|
||||||
|
|
||||||
// 过滤掉 _meta namespace 的记忆
|
|
||||||
let user_memories: Vec<_> = memories
|
|
||||||
.iter()
|
|
||||||
.filter(|m| m.namespace != META_NAMESPACE)
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
if user_memories.is_empty() {
|
|
||||||
return Ok(false); // 没有记忆,跳过
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取上次整理时间
|
|
||||||
let last_time = get_last_maintenance_time(store, scope_key)
|
|
||||||
.map_err(|e| AgentError::Other(format!("get last maintenance time error: {}", e)))?;
|
|
||||||
|
|
||||||
match last_time {
|
|
||||||
None => Ok(true), // 从未整理过,需要整理
|
|
||||||
Some(last) => {
|
|
||||||
// 检查是否有记忆在上次整理后更新
|
|
||||||
let has_new = user_memories.iter().any(|m| m.updated_at > last);
|
|
||||||
Ok(has_new)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> MemoryMaintenancePlan {
|
pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> MemoryMaintenancePlan {
|
||||||
let mut plan = MemoryMaintenancePlan::default();
|
let mut plan = MemoryMaintenancePlan::default();
|
||||||
let mut seen = HashSet::new();
|
let mut seen = HashSet::new();
|
||||||
|
|
||||||
for memory in memories {
|
for memory in memories {
|
||||||
// 过滤掉 _meta namespace 的记忆
|
|
||||||
if memory.namespace == META_NAMESPACE {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let normalized_content = memory.content.trim();
|
let normalized_content = memory.content.trim();
|
||||||
if normalized_content.is_empty() {
|
if normalized_content.is_empty() {
|
||||||
continue;
|
continue;
|
||||||
@ -602,7 +501,6 @@ pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> Memory
|
|||||||
namespace: memory.namespace.clone(),
|
namespace: memory.namespace.clone(),
|
||||||
key: memory.memory_key.clone(),
|
key: memory.memory_key.clone(),
|
||||||
content: normalized_content.to_string(),
|
content: normalized_content.to_string(),
|
||||||
updated_at: memory.updated_at,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
plan.candidates.push(candidate);
|
plan.candidates.push(candidate);
|
||||||
@ -682,115 +580,12 @@ pub(crate) fn extract_json_object(content: &str) -> Option<&str> {
|
|||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 验证记忆整理输出是否符合限制
|
|
||||||
pub(crate) fn validate_memory_maintenance_output(
|
|
||||||
plan: &MemoryMaintenancePlan,
|
|
||||||
output: &MemoryOrganizationOutput,
|
|
||||||
max_merge_ratio: f32,
|
|
||||||
min_memories_to_keep: usize,
|
|
||||||
max_merge_per_group: usize,
|
|
||||||
) -> Result<(), String> {
|
|
||||||
let total = plan.candidates.len();
|
|
||||||
|
|
||||||
if total == 0 {
|
|
||||||
return Ok(()); // 没有候选,无需验证
|
|
||||||
}
|
|
||||||
|
|
||||||
// 验证 1: 单次合并数量限制
|
|
||||||
for merge in &output.merges {
|
|
||||||
if merge.source_ids.len() > max_merge_per_group {
|
|
||||||
return Err(format!(
|
|
||||||
"合并组过大: {} 条源记忆超过上限 {}",
|
|
||||||
merge.source_ids.len(),
|
|
||||||
max_merge_per_group
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 验证 2: 跨 namespace 合并检测(完全禁止)
|
|
||||||
let candidates_by_id: HashMap<&str, &MemoryMaintenanceCandidate> = plan
|
|
||||||
.candidates
|
|
||||||
.iter()
|
|
||||||
.map(|c| (c.id.as_str(), c))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
for merge in &output.merges {
|
|
||||||
let source_namespaces: HashSet<&str> = merge
|
|
||||||
.source_ids
|
|
||||||
.iter()
|
|
||||||
.filter_map(|id| candidates_by_id.get(id.as_str()).map(|c| c.namespace.as_str()))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
// 检查是否跨越多个 namespace
|
|
||||||
if source_namespaces.len() > 1 {
|
|
||||||
return Err(format!(
|
|
||||||
"跨 namespace 合并被禁止: 源来自 {}",
|
|
||||||
source_namespaces.iter().cloned().collect::<Vec<_>>().join(", ")
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查目标 namespace 是否与源一致
|
|
||||||
if let Some(src_ns) = source_namespaces.iter().next() {
|
|
||||||
if *src_ns != merge.namespace {
|
|
||||||
return Err(format!(
|
|
||||||
"跨 namespace 合并被禁止: {} → {}",
|
|
||||||
src_ns, merge.namespace
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 验证 3: 总体合并比例
|
|
||||||
let merged_ids: HashSet<&str> = output
|
|
||||||
.merges
|
|
||||||
.iter()
|
|
||||||
.flat_map(|m| m.source_ids.iter())
|
|
||||||
.map(|s| s.as_str())
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let deleted_ids: HashSet<&str> = output
|
|
||||||
.low_value_ids
|
|
||||||
.iter()
|
|
||||||
.map(|s| s.as_str())
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let affected = merged_ids.len() + deleted_ids.len();
|
|
||||||
let max_allowed = (total as f32 * max_merge_ratio).ceil() as usize;
|
|
||||||
|
|
||||||
if affected > max_allowed {
|
|
||||||
return Err(format!(
|
|
||||||
"合并比例超限: {} / {} > {:.0}%",
|
|
||||||
affected,
|
|
||||||
total,
|
|
||||||
max_merge_ratio * 100.0
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
// 验证 4: 最小保留数
|
|
||||||
let remaining = total - affected + output.merges.len();
|
|
||||||
if remaining < min_memories_to_keep {
|
|
||||||
return Err(format!(
|
|
||||||
"保留数不足: {} < {}",
|
|
||||||
remaining, min_memories_to_keep
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn apply_memory_maintenance_output(
|
pub(crate) fn apply_memory_maintenance_output(
|
||||||
store: &SessionStore,
|
store: &SessionStore,
|
||||||
scope_key: &str,
|
scope_key: &str,
|
||||||
plan: &MemoryMaintenancePlan,
|
plan: &MemoryMaintenancePlan,
|
||||||
output: &MemoryOrganizationOutput,
|
output: &MemoryOrganizationOutput,
|
||||||
max_merge_ratio: f32,
|
|
||||||
min_memories_to_keep: usize,
|
|
||||||
max_merge_per_group: usize,
|
|
||||||
) -> Result<(), AgentError> {
|
) -> Result<(), AgentError> {
|
||||||
// 新增: 验证合并输出
|
|
||||||
validate_memory_maintenance_output(plan, output, max_merge_ratio, min_memories_to_keep, max_merge_per_group)
|
|
||||||
.map_err(|e| AgentError::Other(e))?;
|
|
||||||
|
|
||||||
let all_candidates = plan.candidates.clone();
|
let all_candidates = plan.candidates.clone();
|
||||||
|
|
||||||
let candidates_by_id = all_candidates
|
let candidates_by_id = all_candidates
|
||||||
@ -866,11 +661,6 @@ pub(crate) fn apply_memory_maintenance_output(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 新增:记录整理完成时间
|
|
||||||
let now = chrono::Utc::now().timestamp();
|
|
||||||
set_last_maintenance_time(store, scope_key, now)
|
|
||||||
.map_err(|err| AgentError::Other(format!("set last maintenance time error: {}", err)))?;
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -40,7 +40,6 @@ impl MemoryMaintenanceCoordinator {
|
|||||||
Ok(MemoryMaintenanceService::new(
|
Ok(MemoryMaintenanceService::new(
|
||||||
self.store.clone(),
|
self.store.clone(),
|
||||||
self.provider_configs.default_provider_config(),
|
self.provider_configs.default_provider_config(),
|
||||||
self.provider_configs.default_maintenance_config(),
|
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -27,32 +27,15 @@
|
|||||||
- note: 冲突说明
|
- note: 冲突说明
|
||||||
- low_value_ids:需要删除的低价值候选记忆 ID 数组
|
- low_value_ids:需要删除的低价值候选记忆 ID 数组
|
||||||
|
|
||||||
组织原则:
|
组织原则(由你自主决定):
|
||||||
|
|
||||||
- 根据记忆的语义内容自然分组
|
- 根据记忆的语义内容自然分组,不必拘泥于预定义分类
|
||||||
- **每次合并最多只能合并 2-3 条源记忆**
|
- 相似的、互补的记忆可以合并
|
||||||
- **禁止跨 namespace 合并**(不同 namespace 代表不同信息维度)
|
- 过期、重复、过细的记忆可以标记为低价值
|
||||||
- 过期、重复、过细的记忆可以标记为低值
|
|
||||||
- namespace 和 memory_key 的命名应当简洁、有意义
|
- namespace 和 memory_key 的命名应当简洁、有意义
|
||||||
- **保守原则:宁可保留稍多,不可过度合并**
|
- 可以自由创建新的 namespace 来组织相关记忆
|
||||||
- **必须保留足够数量的记忆,确保信息多样性**
|
|
||||||
|
|
||||||
时间权重原则(关键):
|
|
||||||
|
|
||||||
- 每个候选记忆包含 `updated_at` 时间戳(Unix timestamp,秒)
|
|
||||||
- **当多条记忆存在重复或冲突时,时间越新的权重越高**
|
|
||||||
- 合并时优先采用新记忆的内容,旧记忆作为补充或背景
|
|
||||||
- 如果新旧记忆内容完全相同,保留新的,删除旧的
|
|
||||||
- 时间戳数值越大表示越新(离当前时间越近)
|
|
||||||
|
|
||||||
合并限制(硬性约束,由系统强制检查):
|
|
||||||
|
|
||||||
- 单次合并最多来自 3 条源记忆
|
|
||||||
- 整理后保留的记忆数不得少于 5 条
|
|
||||||
- 单次整理最多影响 30% 的记忆
|
|
||||||
- 不同 namespace 的记忆不允许互相合并
|
|
||||||
|
|
||||||
额外约束:
|
额外约束:
|
||||||
|
|
||||||
- 只能引用输入里出现过的候选 id。
|
- 只能引用输入里出现过的候选 id。
|
||||||
- 不输出 user_facts、preferences、behavior_patterns、managed_markdown 等摘要字段。
|
- 不输出 user_facts、preferences、behavior_patterns、managed_markdown 等摘要字段。
|
||||||
|
|||||||
@ -84,7 +84,6 @@ impl GatewayState {
|
|||||||
Arc::new(BusSessionMessageSender::new(bus.clone())),
|
Arc::new(BusSessionMessageSender::new(bus.clone())),
|
||||||
std::collections::HashSet::new(),
|
std::collections::HashSet::new(),
|
||||||
config.tools.task.clone(),
|
config.tools.task.clone(),
|
||||||
config.memory_maintenance.clone(),
|
|
||||||
chat_history_ttl_hours,
|
chat_history_ttl_hours,
|
||||||
session_ttl_hours,
|
session_ttl_hours,
|
||||||
)?;
|
)?;
|
||||||
|
|||||||
@ -2,25 +2,22 @@ use std::collections::HashMap;
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::agent::AgentError;
|
use crate::agent::AgentError;
|
||||||
use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig};
|
use crate::config::LLMProviderConfig;
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub(crate) struct ProviderConfigService {
|
pub(crate) struct ProviderConfigService {
|
||||||
default_provider_config: LLMProviderConfig,
|
default_provider_config: LLMProviderConfig,
|
||||||
provider_configs: Arc<HashMap<String, LLMProviderConfig>>,
|
provider_configs: Arc<HashMap<String, LLMProviderConfig>>,
|
||||||
maintenance_config: MemoryMaintenanceConfig,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ProviderConfigService {
|
impl ProviderConfigService {
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
default_provider_config: LLMProviderConfig,
|
default_provider_config: LLMProviderConfig,
|
||||||
provider_configs: HashMap<String, LLMProviderConfig>,
|
provider_configs: HashMap<String, LLMProviderConfig>,
|
||||||
maintenance_config: MemoryMaintenanceConfig,
|
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
default_provider_config,
|
default_provider_config,
|
||||||
provider_configs: Arc::new(provider_configs),
|
provider_configs: Arc::new(provider_configs),
|
||||||
maintenance_config,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -40,10 +37,6 @@ impl ProviderConfigService {
|
|||||||
pub(crate) fn default_provider_config(&self) -> LLMProviderConfig {
|
pub(crate) fn default_provider_config(&self) -> LLMProviderConfig {
|
||||||
self.default_provider_config.clone()
|
self.default_provider_config.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn default_maintenance_config(&self) -> MemoryMaintenanceConfig {
|
|
||||||
self.maintenance_config.clone()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@ -79,7 +72,6 @@ mod tests {
|
|||||||
"planner".to_string(),
|
"planner".to_string(),
|
||||||
test_provider_config_named("planner-provider", "planner-model"),
|
test_provider_config_named("planner-provider", "planner-model"),
|
||||||
)]),
|
)]),
|
||||||
MemoryMaintenanceConfig::default(),
|
|
||||||
);
|
);
|
||||||
|
|
||||||
let selected = service.select(Some("planner")).unwrap();
|
let selected = service.select(Some("planner")).unwrap();
|
||||||
@ -90,11 +82,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_select_falls_back_to_default() {
|
fn test_select_falls_back_to_default() {
|
||||||
let default_provider = test_provider_config_named("default-provider", "default-model");
|
let default_provider = test_provider_config_named("default-provider", "default-model");
|
||||||
let service = ProviderConfigService::new(
|
let service = ProviderConfigService::new(default_provider, HashMap::new());
|
||||||
default_provider,
|
|
||||||
HashMap::new(),
|
|
||||||
MemoryMaintenanceConfig::default(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let selected = service.select(Some("default")).unwrap();
|
let selected = service.select(Some("default")).unwrap();
|
||||||
assert_eq!(selected.name, "default-provider");
|
assert_eq!(selected.name, "default-provider");
|
||||||
|
|||||||
@ -2,7 +2,7 @@ use std::collections::{HashMap, HashSet};
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::agent::AgentError;
|
use crate::agent::AgentError;
|
||||||
use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig, TaskConfig};
|
use crate::config::{LLMProviderConfig, TaskConfig};
|
||||||
use crate::gateway::tool_registry_factory::ToolRegistryFactory;
|
use crate::gateway::tool_registry_factory::ToolRegistryFactory;
|
||||||
use crate::skills::SkillRuntime;
|
use crate::skills::SkillRuntime;
|
||||||
use crate::storage::{
|
use crate::storage::{
|
||||||
@ -34,7 +34,6 @@ pub(crate) fn build_session_manager(
|
|||||||
skills: Arc<SkillRuntime>,
|
skills: Arc<SkillRuntime>,
|
||||||
disabled_tools: HashSet<String>,
|
disabled_tools: HashSet<String>,
|
||||||
task_config: TaskConfig,
|
task_config: TaskConfig,
|
||||||
maintenance_config: MemoryMaintenanceConfig,
|
|
||||||
chat_history_ttl_hours: Option<u64>,
|
chat_history_ttl_hours: Option<u64>,
|
||||||
session_ttl_hours: Option<u64>,
|
session_ttl_hours: Option<u64>,
|
||||||
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
|
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
|
||||||
@ -48,7 +47,6 @@ pub(crate) fn build_session_manager(
|
|||||||
Arc::new(NoopSessionMessageSender),
|
Arc::new(NoopSessionMessageSender),
|
||||||
disabled_tools,
|
disabled_tools,
|
||||||
task_config,
|
task_config,
|
||||||
maintenance_config,
|
|
||||||
chat_history_ttl_hours,
|
chat_history_ttl_hours,
|
||||||
session_ttl_hours,
|
session_ttl_hours,
|
||||||
)
|
)
|
||||||
@ -64,7 +62,6 @@ pub(crate) fn build_session_manager_with_sender(
|
|||||||
session_message_sender: Arc<dyn SessionMessageSender>,
|
session_message_sender: Arc<dyn SessionMessageSender>,
|
||||||
disabled_tools: HashSet<String>,
|
disabled_tools: HashSet<String>,
|
||||||
task_config: TaskConfig,
|
task_config: TaskConfig,
|
||||||
maintenance_config: MemoryMaintenanceConfig,
|
|
||||||
chat_history_ttl_hours: Option<u64>,
|
chat_history_ttl_hours: Option<u64>,
|
||||||
session_ttl_hours: Option<u64>,
|
session_ttl_hours: Option<u64>,
|
||||||
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
|
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
|
||||||
@ -73,11 +70,7 @@ pub(crate) fn build_session_manager_with_sender(
|
|||||||
.map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?,
|
.map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?,
|
||||||
);
|
);
|
||||||
let known_agents = provider_configs.keys().cloned().collect::<HashSet<_>>();
|
let known_agents = provider_configs.keys().cloned().collect::<HashSet<_>>();
|
||||||
let provider_configs = ProviderConfigService::new(
|
let provider_configs = ProviderConfigService::new(provider_config.clone(), provider_configs);
|
||||||
provider_config.clone(),
|
|
||||||
provider_configs,
|
|
||||||
maintenance_config,
|
|
||||||
);
|
|
||||||
|
|
||||||
if let Err(err) =
|
if let Err(err) =
|
||||||
store.append_skill_event(None, "discovered", None, &skills.discovery_event_payload())
|
store.append_skill_event(None, "discovered", None, &skills.discovery_event_payload())
|
||||||
|
|||||||
@ -501,7 +501,6 @@ impl SessionManager {
|
|||||||
skills: Arc<SkillRuntime>,
|
skills: Arc<SkillRuntime>,
|
||||||
disabled_tools: std::collections::HashSet<String>,
|
disabled_tools: std::collections::HashSet<String>,
|
||||||
task_config: crate::config::TaskConfig,
|
task_config: crate::config::TaskConfig,
|
||||||
maintenance_config: crate::config::MemoryMaintenanceConfig,
|
|
||||||
chat_history_ttl_hours: Option<u64>,
|
chat_history_ttl_hours: Option<u64>,
|
||||||
session_ttl_hours: Option<u64>,
|
session_ttl_hours: Option<u64>,
|
||||||
) -> Result<Self, AgentError> {
|
) -> Result<Self, AgentError> {
|
||||||
@ -514,7 +513,6 @@ impl SessionManager {
|
|||||||
skills,
|
skills,
|
||||||
disabled_tools,
|
disabled_tools,
|
||||||
task_config,
|
task_config,
|
||||||
maintenance_config,
|
|
||||||
chat_history_ttl_hours,
|
chat_history_ttl_hours,
|
||||||
session_ttl_hours,
|
session_ttl_hours,
|
||||||
)
|
)
|
||||||
@ -975,7 +973,6 @@ mod tests {
|
|||||||
Arc::new(SkillRuntime::default()),
|
Arc::new(SkillRuntime::default()),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
crate::config::TaskConfig::default(),
|
crate::config::TaskConfig::default(),
|
||||||
crate::config::MemoryMaintenanceConfig::default(),
|
|
||||||
Some(4),
|
Some(4),
|
||||||
Some(24),
|
Some(24),
|
||||||
)
|
)
|
||||||
@ -1028,7 +1025,6 @@ mod tests {
|
|||||||
Arc::new(SkillRuntime::default()),
|
Arc::new(SkillRuntime::default()),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
crate::config::TaskConfig::default(),
|
crate::config::TaskConfig::default(),
|
||||||
crate::config::MemoryMaintenanceConfig::default(),
|
|
||||||
Some(4),
|
Some(4),
|
||||||
Some(24),
|
Some(24),
|
||||||
)
|
)
|
||||||
@ -1097,7 +1093,6 @@ mod tests {
|
|||||||
Arc::new(SkillRuntime::default()),
|
Arc::new(SkillRuntime::default()),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
crate::config::TaskConfig::default(),
|
crate::config::TaskConfig::default(),
|
||||||
crate::config::MemoryMaintenanceConfig::default(),
|
|
||||||
Some(4),
|
Some(4),
|
||||||
Some(24),
|
Some(24),
|
||||||
)
|
)
|
||||||
@ -1183,7 +1178,6 @@ mod tests {
|
|||||||
Arc::new(SkillRuntime::default()),
|
Arc::new(SkillRuntime::default()),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
crate::config::TaskConfig::default(),
|
crate::config::TaskConfig::default(),
|
||||||
crate::config::MemoryMaintenanceConfig::default(),
|
|
||||||
Some(4),
|
Some(4),
|
||||||
Some(24),
|
Some(24),
|
||||||
)
|
)
|
||||||
@ -1270,7 +1264,6 @@ mod tests {
|
|||||||
Arc::new(SkillRuntime::default()),
|
Arc::new(SkillRuntime::default()),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
crate::config::TaskConfig::default(),
|
crate::config::TaskConfig::default(),
|
||||||
crate::config::MemoryMaintenanceConfig::default(),
|
|
||||||
Some(4),
|
Some(4),
|
||||||
Some(24),
|
Some(24),
|
||||||
)
|
)
|
||||||
@ -1356,7 +1349,6 @@ mod tests {
|
|||||||
Arc::new(SkillRuntime::default()),
|
Arc::new(SkillRuntime::default()),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
crate::config::TaskConfig::default(),
|
crate::config::TaskConfig::default(),
|
||||||
crate::config::MemoryMaintenanceConfig::default(),
|
|
||||||
Some(4),
|
Some(4),
|
||||||
Some(24),
|
Some(24),
|
||||||
)
|
)
|
||||||
@ -1424,7 +1416,6 @@ mod tests {
|
|||||||
Arc::new(SkillRuntime::default()),
|
Arc::new(SkillRuntime::default()),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
crate::config::TaskConfig::default(),
|
crate::config::TaskConfig::default(),
|
||||||
crate::config::MemoryMaintenanceConfig::default(),
|
|
||||||
Some(4),
|
Some(4),
|
||||||
Some(24),
|
Some(24),
|
||||||
)
|
)
|
||||||
@ -1501,7 +1492,6 @@ mod tests {
|
|||||||
Arc::new(SkillRuntime::default()),
|
Arc::new(SkillRuntime::default()),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
crate::config::TaskConfig::default(),
|
crate::config::TaskConfig::default(),
|
||||||
crate::config::MemoryMaintenanceConfig::default(),
|
|
||||||
Some(4),
|
Some(4),
|
||||||
Some(24),
|
Some(24),
|
||||||
)
|
)
|
||||||
@ -1565,7 +1555,6 @@ mod tests {
|
|||||||
Arc::new(SkillRuntime::default()),
|
Arc::new(SkillRuntime::default()),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
crate::config::TaskConfig::default(),
|
crate::config::TaskConfig::default(),
|
||||||
crate::config::MemoryMaintenanceConfig::default(),
|
|
||||||
Some(4),
|
Some(4),
|
||||||
Some(24),
|
Some(24),
|
||||||
)
|
)
|
||||||
@ -1604,8 +1593,6 @@ mod tests {
|
|||||||
let store = SessionStore::in_memory().unwrap();
|
let store = SessionStore::in_memory().unwrap();
|
||||||
let scope_key = "feishu:user-1";
|
let scope_key = "feishu:user-1";
|
||||||
|
|
||||||
// 创建足够的记忆(7条),让合并操作满足保护限制
|
|
||||||
// 合并后需要保留至少 5 条(min_memories_to_keep)
|
|
||||||
let work = store
|
let work = store
|
||||||
.put_memory(&crate::storage::MemoryUpsert {
|
.put_memory(&crate::storage::MemoryUpsert {
|
||||||
scope_kind: "user".to_string(),
|
scope_kind: "user".to_string(),
|
||||||
@ -1652,30 +1639,9 @@ mod tests {
|
|||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
// 添加额外的记忆以满足 min_memories_to_keep = 5 的要求
|
|
||||||
for i in 0..4 {
|
|
||||||
store
|
|
||||||
.put_memory(&crate::storage::MemoryUpsert {
|
|
||||||
scope_kind: "user".to_string(),
|
|
||||||
scope_key: scope_key.to_string(),
|
|
||||||
namespace: "profile".to_string(),
|
|
||||||
memory_key: format!("extra_{}", i),
|
|
||||||
content: format!("额外记忆 {}", i),
|
|
||||||
source_type: "message".to_string(),
|
|
||||||
source_session_id: None,
|
|
||||||
source_message_id: None,
|
|
||||||
source_message_seq: None,
|
|
||||||
source_channel_name: None,
|
|
||||||
source_chat_id: None,
|
|
||||||
})
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
let plan = build_memory_maintenance_plan(
|
let plan = build_memory_maintenance_plan(
|
||||||
&store.list_memories_for_scope("user", scope_key).unwrap(),
|
&store.list_memories_for_scope("user", scope_key).unwrap(),
|
||||||
);
|
);
|
||||||
assert_eq!(plan.candidates.len(), 7); // 7 条候选记忆
|
|
||||||
|
|
||||||
let output = MemoryOrganizationOutput {
|
let output = MemoryOrganizationOutput {
|
||||||
merges: vec![MemoryMaintenanceMerge {
|
merges: vec![MemoryMaintenanceMerge {
|
||||||
source_ids: vec![work.id.clone(), role.id.clone()],
|
source_ids: vec![work.id.clone(), role.id.clone()],
|
||||||
@ -1687,25 +1653,13 @@ mod tests {
|
|||||||
low_value_ids: vec![noise.id.clone()],
|
low_value_ids: vec![noise.id.clone()],
|
||||||
};
|
};
|
||||||
|
|
||||||
// 使用默认配置进行验证
|
apply_memory_maintenance_output(&store, scope_key, &plan, &output).unwrap();
|
||||||
apply_memory_maintenance_output(
|
|
||||||
&store,
|
|
||||||
scope_key,
|
|
||||||
&plan,
|
|
||||||
&output,
|
|
||||||
crate::config::MemoryMaintenanceConfig::default().max_merge_ratio,
|
|
||||||
crate::config::MemoryMaintenanceConfig::default().min_memories_to_keep,
|
|
||||||
crate::config::MemoryMaintenanceConfig::default().max_merge_per_group,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let all_memories = store.list_memories_for_scope("user", scope_key).unwrap();
|
let all_memories = store.list_memories_for_scope("user", scope_key).unwrap();
|
||||||
// 过滤掉 _meta 记录
|
assert_eq!(all_memories.len(), 1);
|
||||||
let user_memories: Vec<_> = all_memories.iter().filter(|m| m.namespace != "_meta").collect();
|
assert_eq!(all_memories[0].namespace, "profile");
|
||||||
// 合并 2 条为 1 条,删除 1 条,7 - 2 + 1 = 6 条(加上 _meta 记录)
|
assert_eq!(all_memories[0].memory_key, "work");
|
||||||
assert_eq!(user_memories.len(), 6);
|
assert_eq!(all_memories[0].content, "用户主要在做AI产品设计与实现");
|
||||||
// 验证合并后的记忆存在
|
|
||||||
assert!(user_memories.iter().any(|m| m.namespace == "profile" && m.memory_key == "work"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
15
src/main.rs
15
src/main.rs
@ -4,15 +4,6 @@ use clap::{CommandFactory, Parser};
|
|||||||
#[command(name = "picobot")]
|
#[command(name = "picobot")]
|
||||||
#[command(about = "A CLI chatbot", long_about = None)]
|
#[command(about = "A CLI chatbot", long_about = None)]
|
||||||
enum Command {
|
enum Command {
|
||||||
/// Interactive configuration wizard
|
|
||||||
Init {
|
|
||||||
/// Force overwrite existing config
|
|
||||||
#[arg(short, long)]
|
|
||||||
force: bool,
|
|
||||||
/// Only configure provider, skip channels
|
|
||||||
#[arg(long)]
|
|
||||||
skip_channels: bool,
|
|
||||||
},
|
|
||||||
/// Connect to gateway
|
/// Connect to gateway
|
||||||
Agent {
|
Agent {
|
||||||
/// Gateway WebSocket URL (e.g., ws://127.0.0.1:19876/ws)
|
/// Gateway WebSocket URL (e.g., ws://127.0.0.1:19876/ws)
|
||||||
@ -40,14 +31,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
if std::env::args().len() <= 1 {
|
if std::env::args().len() <= 1 {
|
||||||
cmd.print_help()?;
|
cmd.print_help()?;
|
||||||
println!();
|
println!();
|
||||||
return Ok(())
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
match Command::parse() {
|
match Command::parse() {
|
||||||
Command::Init { force, skip_channels } => {
|
|
||||||
let mut wizard = picobot::cli::InitWizard::new();
|
|
||||||
wizard.run(force, skip_channels).await?;
|
|
||||||
}
|
|
||||||
Command::Agent { gateway_url } => {
|
Command::Agent { gateway_url } => {
|
||||||
let config = picobot::config::Config::load_default().ok();
|
let config = picobot::config::Config::load_default().ok();
|
||||||
let url = gateway_url
|
let url = gateway_url
|
||||||
|
|||||||
@ -8,8 +8,6 @@ use super::traits::Usage;
|
|||||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
||||||
use crate::domain::messages::ContentBlock;
|
use crate::domain::messages::ContentBlock;
|
||||||
|
|
||||||
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["supported_content_types"];
|
|
||||||
|
|
||||||
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
|
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
|
||||||
let mut details = vec![error.to_string()];
|
let mut details = vec![error.to_string()];
|
||||||
let mut current = error.source();
|
let mut current = error.source();
|
||||||
@ -32,65 +30,7 @@ where
|
|||||||
serializer.serialize_str(&serde_json::to_string(blocks).unwrap_or_else(|_| "[]".to_string()))
|
serializer.serialize_str(&serde_json::to_string(blocks).unwrap_or_else(|_| "[]".to_string()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn convert_content_blocks(
|
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> {
|
||||||
supports_images: bool,
|
|
||||||
provider_name: &str,
|
|
||||||
model_id: &str,
|
|
||||||
blocks: &[ContentBlock],
|
|
||||||
message_idx: usize,
|
|
||||||
) -> Vec<serde_json::Value> {
|
|
||||||
// 检查是否有图片且模型不支持
|
|
||||||
if !supports_images {
|
|
||||||
let has_images = blocks.iter().any(|b| matches!(b, ContentBlock::ImageUrl { .. }));
|
|
||||||
|
|
||||||
if has_images {
|
|
||||||
let image_count = blocks
|
|
||||||
.iter()
|
|
||||||
.filter(|b| matches!(b, ContentBlock::ImageUrl { .. }))
|
|
||||||
.count();
|
|
||||||
|
|
||||||
tracing::warn!(
|
|
||||||
provider = %provider_name,
|
|
||||||
model = %model_id,
|
|
||||||
filtered_images = image_count,
|
|
||||||
message_idx,
|
|
||||||
"模型不支持图片;将图片转换为通知文本"
|
|
||||||
);
|
|
||||||
|
|
||||||
// 复用通知格式,将图片转换为文本通知
|
|
||||||
let mut converted_blocks: Vec<serde_json::Value> = Vec::new();
|
|
||||||
let mut notices: Vec<String> = Vec::new();
|
|
||||||
let mut image_idx = 0;
|
|
||||||
|
|
||||||
for block in blocks.iter() {
|
|
||||||
match block {
|
|
||||||
ContentBlock::Text { text } => {
|
|
||||||
converted_blocks.push(serde_json::json!({ "type": "text", "text": text }));
|
|
||||||
}
|
|
||||||
ContentBlock::ImageUrl { .. } => {
|
|
||||||
image_idx += 1;
|
|
||||||
notices.push(format!(
|
|
||||||
"- 第 {} 张图片:当前模型不支持图片输入,该图片未能成功入模,请直接告知用户。",
|
|
||||||
image_idx
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 添加通知文本块
|
|
||||||
if !notices.is_empty() {
|
|
||||||
let notice_text = format!(
|
|
||||||
"[系统提示] 以下图片未能成功入模:\n{}",
|
|
||||||
notices.join("\n")
|
|
||||||
);
|
|
||||||
converted_blocks.push(serde_json::json!({ "type": "text", "text": notice_text }));
|
|
||||||
}
|
|
||||||
|
|
||||||
return converted_blocks;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 原有逻辑 - 模型支持图片,正常转换
|
|
||||||
blocks
|
blocks
|
||||||
.iter()
|
.iter()
|
||||||
.map(|b| match b {
|
.map(|b| match b {
|
||||||
@ -172,32 +112,6 @@ impl AnthropicProvider {
|
|||||||
model_extra,
|
model_extra,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 检查模型是否支持指定内容类型
|
|
||||||
/// 默认支持所有类型(text, image)
|
|
||||||
fn supports_content_type(&self, content_type: &str) -> bool {
|
|
||||||
self.model_extra
|
|
||||||
.get("supported_content_types")
|
|
||||||
.and_then(|value| value.as_array())
|
|
||||||
.map(|types| {
|
|
||||||
types.iter().any(|t| t.as_str() == Some(content_type))
|
|
||||||
})
|
|
||||||
.unwrap_or(true)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 检查模型是否支持图片
|
|
||||||
fn supports_images(&self) -> bool {
|
|
||||||
self.supports_content_type("image")
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 过滤掉内部字段,只返回需要发送到 API 的 extra 字段
|
|
||||||
fn request_model_extra(&self) -> HashMap<String, serde_json::Value> {
|
|
||||||
self.model_extra
|
|
||||||
.iter()
|
|
||||||
.filter(|(key, _)| !INTERNAL_MODEL_EXTRA_KEYS.contains(&key.as_str()))
|
|
||||||
.map(|(k, v)| (k.clone(), v.clone()))
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
@ -283,22 +197,15 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
messages: request
|
messages: request
|
||||||
.messages
|
.messages
|
||||||
.iter()
|
.iter()
|
||||||
.enumerate()
|
.map(|m| AnthropicMessage {
|
||||||
.map(|(i, m)| AnthropicMessage {
|
|
||||||
role: m.role.clone(),
|
role: m.role.clone(),
|
||||||
content: convert_content_blocks(
|
content: convert_content_blocks(&m.content),
|
||||||
self.supports_images(),
|
|
||||||
&self.name,
|
|
||||||
&self.model_id,
|
|
||||||
&m.content,
|
|
||||||
i,
|
|
||||||
),
|
|
||||||
})
|
})
|
||||||
.collect(),
|
.collect(),
|
||||||
max_tokens,
|
max_tokens,
|
||||||
temperature: request.temperature.or(self.temperature),
|
temperature: request.temperature.or(self.temperature),
|
||||||
tools,
|
tools,
|
||||||
extra: self.request_model_extra(),
|
extra: self.model_extra.clone(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut req_builder = self
|
let mut req_builder = self
|
||||||
|
|||||||
@ -10,7 +10,7 @@ use super::traits::Usage;
|
|||||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
||||||
use crate::domain::messages::ContentBlock;
|
use crate::domain::messages::ContentBlock;
|
||||||
|
|
||||||
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content", "supported_content_types"];
|
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content"];
|
||||||
|
|
||||||
/// 流式响应中的工具调用增量
|
/// 流式响应中的工具调用增量
|
||||||
#[derive(Debug, Default)]
|
#[derive(Debug, Default)]
|
||||||
@ -139,75 +139,7 @@ fn format_transport_error_context(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn convert_content_blocks(
|
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
|
||||||
supports_images: bool,
|
|
||||||
provider_name: &str,
|
|
||||||
model_id: &str,
|
|
||||||
blocks: &[ContentBlock],
|
|
||||||
message_idx: usize,
|
|
||||||
) -> Value {
|
|
||||||
// 检查是否有图片且模型不支持
|
|
||||||
if !supports_images {
|
|
||||||
let has_images = blocks.iter().any(|b| matches!(b, ContentBlock::ImageUrl { .. }));
|
|
||||||
|
|
||||||
if has_images {
|
|
||||||
let image_count = blocks.iter()
|
|
||||||
.filter(|b| matches!(b, ContentBlock::ImageUrl { .. }))
|
|
||||||
.count();
|
|
||||||
|
|
||||||
tracing::warn!(
|
|
||||||
provider = %provider_name,
|
|
||||||
model = %model_id,
|
|
||||||
filtered_images = image_count,
|
|
||||||
message_idx,
|
|
||||||
"模型不支持图片;将图片转换为通知文本"
|
|
||||||
);
|
|
||||||
|
|
||||||
// 复用通知格式,将图片转换为文本通知
|
|
||||||
let mut converted_blocks: Vec<Value> = Vec::new();
|
|
||||||
let mut notices: Vec<String> = Vec::new();
|
|
||||||
let mut image_idx = 0;
|
|
||||||
|
|
||||||
for block in blocks.iter() {
|
|
||||||
match block {
|
|
||||||
ContentBlock::Text { text } => {
|
|
||||||
converted_blocks.push(json!({ "type": "text", "text": text }));
|
|
||||||
}
|
|
||||||
ContentBlock::ImageUrl { .. } => {
|
|
||||||
image_idx += 1;
|
|
||||||
notices.push(format!(
|
|
||||||
"- 第 {} 张图片:当前模型不支持图片输入,该图片未能成功入模,请直接告知用户。",
|
|
||||||
image_idx
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 添加通知文本块
|
|
||||||
if !notices.is_empty() {
|
|
||||||
let notice_text = format!(
|
|
||||||
"[系统提示] 以下图片未能成功入模:\n{}",
|
|
||||||
notices.join("\n")
|
|
||||||
);
|
|
||||||
converted_blocks.push(json!({ "type": "text", "text": notice_text }));
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果只有一个文本块且没有通知,返回字符串形式
|
|
||||||
if converted_blocks.len() == 1 {
|
|
||||||
if let Some(block) = converted_blocks.first() {
|
|
||||||
if block.get("type").and_then(|t| t.as_str()) == Some("text") {
|
|
||||||
if let Some(text) = block.get("text").and_then(|t| t.as_str()) {
|
|
||||||
return Value::String(text.to_string());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return Value::Array(converted_blocks);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 原有逻辑 - 模型支持图片,正常转换
|
|
||||||
if blocks.len() == 1 {
|
if blocks.len() == 1 {
|
||||||
if let ContentBlock::Text { text } = &blocks[0] {
|
if let ContentBlock::Text { text } = &blocks[0] {
|
||||||
return Value::String(text.clone());
|
return Value::String(text.clone());
|
||||||
@ -292,23 +224,6 @@ impl OpenAIProvider {
|
|||||||
.unwrap_or(true)
|
.unwrap_or(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 检查模型是否支持指定内容类型
|
|
||||||
/// 默认支持所有类型(text, image)
|
|
||||||
fn supports_content_type(&self, content_type: &str) -> bool {
|
|
||||||
self.model_extra
|
|
||||||
.get("supported_content_types")
|
|
||||||
.and_then(|value| value.as_array())
|
|
||||||
.map(|types| {
|
|
||||||
types.iter().any(|t| t.as_str() == Some(content_type))
|
|
||||||
})
|
|
||||||
.unwrap_or(true)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 检查模型是否支持图片
|
|
||||||
fn supports_images(&self) -> bool {
|
|
||||||
self.supports_content_type("image")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn normalize_tool_arguments(&self, arguments: &Value) -> Value {
|
fn normalize_tool_arguments(&self, arguments: &Value) -> Value {
|
||||||
match arguments {
|
match arguments {
|
||||||
Value::String(raw) => serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone()),
|
Value::String(raw) => serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone()),
|
||||||
@ -565,21 +480,20 @@ impl OpenAIProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn build_request_body(&self, request: &ChatCompletionRequest) -> Value {
|
fn build_request_body(&self, request: &ChatCompletionRequest) -> Value {
|
||||||
let supports_images = self.supports_images();
|
|
||||||
let mut body = json!({
|
let mut body = json!({
|
||||||
"model": self.model_id,
|
"model": self.model_id,
|
||||||
"messages": request.messages.iter().enumerate().map(|(i, m)| {
|
"messages": request.messages.iter().map(|m| {
|
||||||
if m.role == "tool" {
|
if m.role == "tool" {
|
||||||
json!({
|
json!({
|
||||||
"role": m.role,
|
"role": m.role,
|
||||||
"content": convert_content_blocks(supports_images, &self.name, &self.model_id, &m.content, i),
|
"content": convert_content_blocks(&m.content),
|
||||||
"tool_call_id": m.tool_call_id,
|
"tool_call_id": m.tool_call_id,
|
||||||
"name": m.name,
|
"name": m.name,
|
||||||
})
|
})
|
||||||
} else if m.role == "assistant" && m.tool_calls.is_some() {
|
} else if m.role == "assistant" && m.tool_calls.is_some() {
|
||||||
let mut message = json!({
|
let mut message = json!({
|
||||||
"role": m.role,
|
"role": m.role,
|
||||||
"content": convert_content_blocks(supports_images, &self.name, &self.model_id, &m.content, i),
|
"content": convert_content_blocks(&m.content),
|
||||||
"tool_calls": m.tool_calls.as_ref().map(|calls| {
|
"tool_calls": m.tool_calls.as_ref().map(|calls| {
|
||||||
calls.iter().map(|call| json!({
|
calls.iter().map(|call| json!({
|
||||||
"id": call.id,
|
"id": call.id,
|
||||||
@ -600,7 +514,7 @@ impl OpenAIProvider {
|
|||||||
} else {
|
} else {
|
||||||
let mut message = json!({
|
let mut message = json!({
|
||||||
"role": m.role,
|
"role": m.role,
|
||||||
"content": convert_content_blocks(supports_images, &self.name, &self.model_id, &m.content, i)
|
"content": convert_content_blocks(&m.content)
|
||||||
});
|
});
|
||||||
|
|
||||||
if m.role == "assistant" {
|
if m.role == "assistant" {
|
||||||
@ -1162,115 +1076,4 @@ mod tests {
|
|||||||
assert_eq!(response.tool_calls[1].id, "call_2");
|
assert_eq!(response.tool_calls[1].id, "call_2");
|
||||||
assert_eq!(response.tool_calls[1].name, "get_time");
|
assert_eq!(response.tool_calls[1].name, "get_time");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_supports_images_default_true() {
|
|
||||||
let provider = OpenAIProvider::new(
|
|
||||||
"test".to_string(),
|
|
||||||
"key".to_string(),
|
|
||||||
"https://example.com/v1".to_string(),
|
|
||||||
HashMap::new(),
|
|
||||||
120,
|
|
||||||
"gpt-test".to_string(),
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
HashMap::new(),
|
|
||||||
);
|
|
||||||
|
|
||||||
assert!(provider.supports_images());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_supports_images_disabled_via_config() {
|
|
||||||
let provider = OpenAIProvider::new(
|
|
||||||
"test".to_string(),
|
|
||||||
"key".to_string(),
|
|
||||||
"https://example.com/v1".to_string(),
|
|
||||||
HashMap::new(),
|
|
||||||
120,
|
|
||||||
"gpt-test".to_string(),
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
HashMap::from([(
|
|
||||||
"supported_content_types".to_string(),
|
|
||||||
Value::Array(vec![Value::String("text".to_string())]),
|
|
||||||
)]),
|
|
||||||
);
|
|
||||||
|
|
||||||
assert!(!provider.supports_images());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_convert_content_blocks_converts_images_to_notice_when_disabled() {
|
|
||||||
let blocks = vec![
|
|
||||||
ContentBlock::text("hello"),
|
|
||||||
ContentBlock::image_url("data:image/png;base64,abc123"),
|
|
||||||
ContentBlock::text("world"),
|
|
||||||
];
|
|
||||||
|
|
||||||
let result = convert_content_blocks(false, "test", "test-model", &blocks, 0);
|
|
||||||
|
|
||||||
// 应该是数组形式
|
|
||||||
let arr = result.as_array().unwrap();
|
|
||||||
assert_eq!(arr.len(), 3); // 两个文本块 + 一个通知块
|
|
||||||
|
|
||||||
// 检查通知内容
|
|
||||||
let notice_block = arr[2].as_object().unwrap();
|
|
||||||
assert_eq!(notice_block["type"], "text");
|
|
||||||
let notice_text = notice_block["text"].as_str().unwrap();
|
|
||||||
assert!(notice_text.contains("[系统提示] 以下图片未能成功入模"));
|
|
||||||
assert!(notice_text.contains("第 1 张图片"));
|
|
||||||
assert!(notice_text.contains("当前模型不支持图片输入"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_convert_content_blocks_keeps_images_when_enabled() {
|
|
||||||
let blocks = vec![
|
|
||||||
ContentBlock::text("hello"),
|
|
||||||
ContentBlock::image_url("data:image/png;base64,abc123"),
|
|
||||||
];
|
|
||||||
|
|
||||||
let result = convert_content_blocks(true, "test", "test-model", &blocks, 0);
|
|
||||||
|
|
||||||
// 应该是数组形式,包含文本和图片
|
|
||||||
let arr = result.as_array().unwrap();
|
|
||||||
assert_eq!(arr.len(), 2);
|
|
||||||
assert_eq!(arr[0]["type"], "text");
|
|
||||||
assert_eq!(arr[1]["type"], "image_url");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_build_request_body_omits_supported_content_types_from_api() {
|
|
||||||
let provider = OpenAIProvider::new(
|
|
||||||
"test".to_string(),
|
|
||||||
"key".to_string(),
|
|
||||||
"https://example.com/v1".to_string(),
|
|
||||||
HashMap::new(),
|
|
||||||
120,
|
|
||||||
"gpt-test".to_string(),
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
HashMap::from([
|
|
||||||
(
|
|
||||||
"supported_content_types".to_string(),
|
|
||||||
Value::Array(vec![Value::String("text".to_string())]),
|
|
||||||
),
|
|
||||||
("custom_param".to_string(), Value::String("value".to_string())),
|
|
||||||
]),
|
|
||||||
);
|
|
||||||
|
|
||||||
let request = ChatCompletionRequest {
|
|
||||||
messages: vec![Message::user("hello")],
|
|
||||||
temperature: None,
|
|
||||||
max_tokens: None,
|
|
||||||
tools: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
let body = provider.build_request_body(&request);
|
|
||||||
|
|
||||||
// supported_content_types 不应该发送到 API
|
|
||||||
assert!(body.get("supported_content_types").is_none());
|
|
||||||
// custom_param 应该保留
|
|
||||||
assert_eq!(body["custom_param"], Value::String("value".to_string()));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -341,9 +341,7 @@ impl SkillSource {
|
|||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct SkillCatalog {
|
pub struct SkillCatalog {
|
||||||
skills: Vec<Skill>,
|
skills: Vec<Skill>,
|
||||||
#[allow(dead_code)]
|
|
||||||
max_index_chars: usize,
|
max_index_chars: usize,
|
||||||
#[allow(dead_code)]
|
|
||||||
max_listed_skills: usize,
|
max_listed_skills: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user