feat: 添加 memory_maintenance_timeout_secs 配置,优化内存维护超时设置
This commit is contained in:
parent
5a0c018ee7
commit
9d9fa1dc4b
@ -1128,6 +1128,7 @@ mod tests {
|
||||
api_key: "test-key".to_string(),
|
||||
extra_headers: std::collections::HashMap::new(),
|
||||
llm_timeout_secs: 120,
|
||||
memory_maintenance_timeout_secs: 600,
|
||||
model_id: "test-model".to_string(),
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(32),
|
||||
|
||||
@ -248,6 +248,8 @@ pub struct ProviderConfig {
|
||||
pub extra_headers: HashMap<String, String>,
|
||||
#[serde(default = "default_llm_timeout_secs")]
|
||||
pub llm_timeout_secs: u64,
|
||||
#[serde(default = "default_memory_maintenance_timeout_secs")]
|
||||
pub memory_maintenance_timeout_secs: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
@ -291,6 +293,10 @@ fn default_llm_timeout_secs() -> u64 {
|
||||
120
|
||||
}
|
||||
|
||||
fn default_memory_maintenance_timeout_secs() -> u64 {
|
||||
600
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct GatewayConfig {
|
||||
#[serde(default = "default_gateway_host")]
|
||||
@ -623,6 +629,7 @@ pub struct LLMProviderConfig {
|
||||
pub api_key: String,
|
||||
pub extra_headers: HashMap<String, String>,
|
||||
pub llm_timeout_secs: u64,
|
||||
pub memory_maintenance_timeout_secs: u64,
|
||||
pub model_id: String,
|
||||
pub temperature: Option<f32>,
|
||||
pub max_tokens: Option<u32>,
|
||||
@ -712,6 +719,7 @@ impl Config {
|
||||
api_key: provider.api_key.clone(),
|
||||
extra_headers: provider.extra_headers.clone(),
|
||||
llm_timeout_secs: provider.llm_timeout_secs,
|
||||
memory_maintenance_timeout_secs: provider.memory_maintenance_timeout_secs,
|
||||
model_id: model.model_id.clone(),
|
||||
temperature: model.temperature,
|
||||
max_tokens: model.max_tokens,
|
||||
|
||||
@ -50,6 +50,7 @@ mod tests {
|
||||
api_key: "test-key".to_string(),
|
||||
extra_headers: HashMap::new(),
|
||||
llm_timeout_secs: 120,
|
||||
memory_maintenance_timeout_secs: 600,
|
||||
model_id: "test-model".to_string(),
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(32),
|
||||
|
||||
@ -4,9 +4,9 @@ use std::time::Duration;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::agent::{AgentError, AgentRuntimeConfig};
|
||||
use crate::agent::AgentError;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::providers::{ChatCompletionRequest, Message, create_provider};
|
||||
use crate::providers::{ChatCompletionRequest, Message, ProviderRuntimeConfig, create_provider};
|
||||
use crate::storage::{MemoryRecord, SessionStore};
|
||||
|
||||
use super::prompt::upsert_managed_agent_memory_summary;
|
||||
@ -17,14 +17,6 @@ const MEMORY_MAINTENANCE_STEP2_SYSTEM_PROMPT: &str =
|
||||
include_str!("memory_maintenance_step2_system_prompt.md");
|
||||
const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000];
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum MemoryMaintenanceCategory {
|
||||
UserFacts,
|
||||
Preferences,
|
||||
BehaviorPatterns,
|
||||
Other,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub(crate) struct MemoryMaintenanceCandidate {
|
||||
pub(crate) id: String,
|
||||
@ -35,10 +27,7 @@ pub(crate) struct MemoryMaintenanceCandidate {
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub(crate) struct MemoryMaintenancePlan {
|
||||
pub(crate) user_facts: Vec<MemoryMaintenanceCandidate>,
|
||||
pub(crate) preferences: Vec<MemoryMaintenanceCandidate>,
|
||||
pub(crate) behavior_patterns: Vec<MemoryMaintenanceCandidate>,
|
||||
pub(crate) others: Vec<MemoryMaintenanceCandidate>,
|
||||
pub(crate) candidates: Vec<MemoryMaintenanceCandidate>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
@ -94,6 +83,26 @@ impl MemoryMaintenanceService {
|
||||
}
|
||||
}
|
||||
|
||||
/// 创建记忆整理专用的 provider,使用 memory_maintenance_timeout_secs 作为超时时间
|
||||
fn create_maintenance_provider(
|
||||
&self,
|
||||
) -> Result<Box<dyn crate::providers::LLMProvider>, crate::providers::ProviderError> {
|
||||
let config = &self.provider_config;
|
||||
let runtime_config = ProviderRuntimeConfig {
|
||||
provider_type: config.provider_type.clone(),
|
||||
name: config.name.clone(),
|
||||
base_url: config.base_url.clone(),
|
||||
api_key: config.api_key.clone(),
|
||||
extra_headers: config.extra_headers.clone(),
|
||||
llm_timeout_secs: config.memory_maintenance_timeout_secs,
|
||||
model_id: config.model_id.clone(),
|
||||
temperature: config.temperature,
|
||||
max_tokens: config.max_tokens,
|
||||
model_extra: config.model_extra.clone(),
|
||||
};
|
||||
create_provider(runtime_config)
|
||||
}
|
||||
|
||||
pub(crate) fn build_plan_for_scope(
|
||||
&self,
|
||||
scope_key: &str,
|
||||
@ -126,8 +135,7 @@ impl MemoryMaintenanceService {
|
||||
scope_key: &str,
|
||||
plan: &MemoryMaintenancePlan,
|
||||
) -> Result<MemoryOrganizationOutput, AgentError> {
|
||||
let runtime_config = AgentRuntimeConfig::from(self.provider_config.clone());
|
||||
let provider = create_provider(runtime_config.provider).map_err(|err| {
|
||||
let provider = self.create_maintenance_provider().map_err(|err| {
|
||||
AgentError::Other(format!("create maintenance provider error: {}", err))
|
||||
})?;
|
||||
|
||||
@ -258,8 +266,7 @@ impl MemoryMaintenanceService {
|
||||
scope_key: &str,
|
||||
remaining_memories: &[MemoryRecord],
|
||||
) -> Result<String, AgentError> {
|
||||
let runtime_config = AgentRuntimeConfig::from(self.provider_config.clone());
|
||||
let provider = create_provider(runtime_config.provider).map_err(|err| {
|
||||
let provider = self.create_maintenance_provider().map_err(|err| {
|
||||
AgentError::Other(format!("create maintenance provider error: {}", err))
|
||||
})?;
|
||||
|
||||
@ -478,28 +485,12 @@ pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> Memory
|
||||
content: normalized_content.to_string(),
|
||||
};
|
||||
|
||||
match memory_maintenance_category(&memory.namespace) {
|
||||
MemoryMaintenanceCategory::UserFacts => plan.user_facts.push(candidate),
|
||||
MemoryMaintenanceCategory::Preferences => plan.preferences.push(candidate),
|
||||
MemoryMaintenanceCategory::BehaviorPatterns => plan.behavior_patterns.push(candidate),
|
||||
MemoryMaintenanceCategory::Other => plan.others.push(candidate),
|
||||
}
|
||||
plan.candidates.push(candidate);
|
||||
}
|
||||
|
||||
plan
|
||||
}
|
||||
|
||||
fn memory_maintenance_category(namespace: &str) -> MemoryMaintenanceCategory {
|
||||
match namespace.trim().to_ascii_lowercase().as_str() {
|
||||
"profile" | "facts" | "identity" => MemoryMaintenanceCategory::UserFacts,
|
||||
"preferences" | "style" | "likes" => MemoryMaintenanceCategory::Preferences,
|
||||
"patterns" | "behavior" | "habits" | "workflow" => {
|
||||
MemoryMaintenanceCategory::BehaviorPatterns
|
||||
}
|
||||
_ => MemoryMaintenanceCategory::Other,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_recoverable_maintenance_llm_error(error: &str) -> bool {
|
||||
let normalized = error.to_ascii_lowercase();
|
||||
normalized.contains("error sending request for url")
|
||||
@ -577,14 +568,7 @@ pub(crate) fn apply_memory_maintenance_output(
|
||||
plan: &MemoryMaintenancePlan,
|
||||
output: &MemoryOrganizationOutput,
|
||||
) -> Result<(), AgentError> {
|
||||
let all_candidates = plan
|
||||
.user_facts
|
||||
.iter()
|
||||
.chain(plan.preferences.iter())
|
||||
.chain(plan.behavior_patterns.iter())
|
||||
.chain(plan.others.iter())
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
let all_candidates = plan.candidates.clone();
|
||||
|
||||
let candidates_by_id = all_candidates
|
||||
.iter()
|
||||
@ -671,6 +655,4 @@ fn preview_text(content: &str, max_chars: usize) -> String {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
}
|
||||
mod tests {}
|
||||
|
||||
@ -19,14 +19,22 @@
|
||||
|
||||
- merges:对象数组。每个对象必须包含 source_ids、namespace、memory_key、content。
|
||||
- source_ids: 字符串数组,要合并的源记忆ID列表
|
||||
- namespace: 目标命名空间
|
||||
- memory_key: 目标记忆键
|
||||
- namespace: 目标命名空间(可以自由决定,不限于固定分类)
|
||||
- memory_key: 目标记忆键(可以自由决定)
|
||||
- content: 合并后的内容
|
||||
- conflicts:对象数组。每个对象必须包含 source_ids、note。
|
||||
- source_ids: 冲突的记忆ID列表
|
||||
- note: 冲突说明
|
||||
- low_value_ids:需要删除的低价值候选记忆 ID 数组
|
||||
|
||||
组织原则(由你自主决定):
|
||||
|
||||
- 根据记忆的语义内容自然分组,不必拘泥于预定义分类
|
||||
- 相似的、互补的记忆可以合并
|
||||
- 过期、重复、过细的记忆可以标记为低价值
|
||||
- namespace 和 memory_key 的命名应当简洁、有意义
|
||||
- 可以自由创建新的 namespace 来组织相关记忆
|
||||
|
||||
额外约束:
|
||||
|
||||
- 只能引用输入里出现过的候选 id。
|
||||
|
||||
@ -51,6 +51,7 @@ mod tests {
|
||||
api_key: "test-key".to_string(),
|
||||
extra_headers: HashMap::new(),
|
||||
llm_timeout_secs: 120,
|
||||
memory_maintenance_timeout_secs: 600,
|
||||
model_id: model_id.to_string(),
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(32),
|
||||
|
||||
@ -520,6 +520,7 @@ mod tests {
|
||||
api_key: "test-key".to_string(),
|
||||
extra_headers: HashMap::new(),
|
||||
llm_timeout_secs: 120,
|
||||
memory_maintenance_timeout_secs: 600,
|
||||
model_id: "test-model".to_string(),
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(32),
|
||||
@ -786,6 +787,7 @@ mod tests {
|
||||
model_extra: HashMap::new(),
|
||||
max_tool_iterations: 1,
|
||||
llm_timeout_secs: 30,
|
||||
memory_maintenance_timeout_secs: 600,
|
||||
tool_result_max_chars: 20_000,
|
||||
context_tool_result_trim_chars: 20_000,
|
||||
};
|
||||
@ -827,6 +829,7 @@ mod tests {
|
||||
model_extra: HashMap::new(),
|
||||
max_tool_iterations: 1,
|
||||
llm_timeout_secs: 30,
|
||||
memory_maintenance_timeout_secs: 600,
|
||||
tool_result_max_chars: 20_000,
|
||||
context_tool_result_trim_chars: 20_000,
|
||||
};
|
||||
@ -900,6 +903,7 @@ mod tests {
|
||||
model_extra: HashMap::new(),
|
||||
max_tool_iterations: 1,
|
||||
llm_timeout_secs: 30,
|
||||
memory_maintenance_timeout_secs: 600,
|
||||
tool_result_max_chars: 20_000,
|
||||
context_tool_result_trim_chars: 20_000,
|
||||
};
|
||||
@ -982,6 +986,7 @@ mod tests {
|
||||
)]),
|
||||
max_tool_iterations: 1,
|
||||
llm_timeout_secs: 30,
|
||||
memory_maintenance_timeout_secs: 600,
|
||||
tool_result_max_chars: 20_000,
|
||||
context_tool_result_trim_chars: 20_000,
|
||||
};
|
||||
@ -1065,6 +1070,7 @@ mod tests {
|
||||
model_extra: HashMap::new(),
|
||||
max_tool_iterations: 1,
|
||||
llm_timeout_secs: 1,
|
||||
memory_maintenance_timeout_secs: 600,
|
||||
tool_result_max_chars: 20_000,
|
||||
context_tool_result_trim_chars: 20_000,
|
||||
};
|
||||
@ -1113,7 +1119,7 @@ mod tests {
|
||||
assert!(error.contains("provider=maintenance-provider"));
|
||||
assert!(error.contains("model=maintenance-model"));
|
||||
assert!(error.contains("url=https://example.invalid/v1/chat/completions"));
|
||||
assert!(error.contains("timeout_secs=1"));
|
||||
assert!(error.contains("timeout_secs=600"));
|
||||
assert!(error.contains("error sending request for url"));
|
||||
}
|
||||
|
||||
@ -1147,6 +1153,7 @@ mod tests {
|
||||
)]),
|
||||
max_tool_iterations: 1,
|
||||
llm_timeout_secs: 30,
|
||||
memory_maintenance_timeout_secs: 600,
|
||||
tool_result_max_chars: 20_000,
|
||||
context_tool_result_trim_chars: 20_000,
|
||||
};
|
||||
@ -1211,6 +1218,7 @@ mod tests {
|
||||
)]),
|
||||
max_tool_iterations: 1,
|
||||
llm_timeout_secs: 30,
|
||||
memory_maintenance_timeout_secs: 600,
|
||||
tool_result_max_chars: 20_000,
|
||||
context_tool_result_trim_chars: 20_000,
|
||||
};
|
||||
@ -1284,6 +1292,7 @@ mod tests {
|
||||
)]),
|
||||
max_tool_iterations: 1,
|
||||
llm_timeout_secs: 30,
|
||||
memory_maintenance_timeout_secs: 600,
|
||||
tool_result_max_chars: 20_000,
|
||||
context_tool_result_trim_chars: 20_000,
|
||||
};
|
||||
@ -1344,6 +1353,7 @@ mod tests {
|
||||
model_extra: HashMap::new(),
|
||||
max_tool_iterations: 1,
|
||||
llm_timeout_secs: 1,
|
||||
memory_maintenance_timeout_secs: 600,
|
||||
tool_result_max_chars: 20_000,
|
||||
context_tool_result_trim_chars: 20_000,
|
||||
};
|
||||
@ -1749,12 +1759,12 @@ mod tests {
|
||||
];
|
||||
|
||||
let plan = build_memory_maintenance_plan(&memories);
|
||||
assert_eq!(plan.user_facts.len(), 1);
|
||||
assert_eq!(plan.preferences.len(), 1);
|
||||
assert_eq!(plan.behavior_patterns.len(), 1);
|
||||
assert!(plan.others.is_empty());
|
||||
assert_eq!(plan.user_facts[0].content, "用户在做AI产品");
|
||||
assert_eq!(plan.preferences[0].content, "偏好简洁表达");
|
||||
assert_eq!(plan.behavior_patterns[0].content, "习惯先问方案再要代码");
|
||||
// 去重后应该有3条(第1、2条重复)
|
||||
assert_eq!(plan.candidates.len(), 3);
|
||||
// 验证内容包含所有唯一的记忆
|
||||
let contents: Vec<String> = plan.candidates.iter().map(|c| c.content.clone()).collect();
|
||||
assert!(contents.contains(&"用户在做AI产品".to_string()));
|
||||
assert!(contents.contains(&"偏好简洁表达".to_string()));
|
||||
assert!(contents.contains(&"习惯先问方案再要代码".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
@ -35,6 +35,7 @@ fn load_config() -> Option<LLMProviderConfig> {
|
||||
api_key: openai_api_key,
|
||||
extra_headers: HashMap::new(),
|
||||
llm_timeout_secs: 120,
|
||||
memory_maintenance_timeout_secs: 600,
|
||||
model_id: openai_model,
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(100),
|
||||
|
||||
@ -37,6 +37,7 @@ fn load_openai_config() -> Option<LLMProviderConfig> {
|
||||
api_key: openai_api_key,
|
||||
extra_headers: HashMap::new(),
|
||||
llm_timeout_secs: 120,
|
||||
memory_maintenance_timeout_secs: 600,
|
||||
model_id: openai_model,
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(100),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user