feat: 更新会话配置,重命名 session_ttl_hours 为 chat_history_ttl_hours,并调整相关逻辑以支持聊天历史过期管理
This commit is contained in:
parent
0ea98c6e8e
commit
daec690f59
@ -299,8 +299,8 @@ pub struct GatewayConfig {
|
|||||||
pub port: u16,
|
pub port: u16,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub show_tool_results: bool,
|
pub show_tool_results: bool,
|
||||||
#[serde(default, rename = "session_ttl_hours")]
|
#[serde(default, rename = "chat_history_ttl_hours")]
|
||||||
pub session_ttl_hours: Option<u64>,
|
pub chat_history_ttl_hours: Option<u64>,
|
||||||
#[serde(
|
#[serde(
|
||||||
default = "default_agent_prompt_reinject_every",
|
default = "default_agent_prompt_reinject_every",
|
||||||
rename = "agent_prompt_reinject_every"
|
rename = "agent_prompt_reinject_every"
|
||||||
@ -589,7 +589,7 @@ impl Default for GatewayConfig {
|
|||||||
host: default_gateway_host(),
|
host: default_gateway_host(),
|
||||||
port: default_gateway_port(),
|
port: default_gateway_port(),
|
||||||
show_tool_results: false,
|
show_tool_results: false,
|
||||||
session_ttl_hours: None,
|
chat_history_ttl_hours: Some(4),
|
||||||
agent_prompt_reinject_every: default_agent_prompt_reinject_every(),
|
agent_prompt_reinject_every: default_agent_prompt_reinject_every(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -89,6 +89,7 @@ mod tests {
|
|||||||
skills,
|
skills,
|
||||||
store.clone(),
|
store.clone(),
|
||||||
100,
|
100,
|
||||||
|
Some(4),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -116,12 +117,12 @@ mod tests {
|
|||||||
.load_all_messages(&session.persistent_session_id("chat-1"))
|
.load_all_messages(&session.persistent_session_id("chat-1"))
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.len(),
|
.len(),
|
||||||
2,
|
3,
|
||||||
);
|
);
|
||||||
|
|
||||||
session.ensure_chat_loaded("chat-1").unwrap();
|
session.ensure_chat_loaded("chat-1").unwrap();
|
||||||
let history = session.get_history("chat-1").unwrap();
|
let history = session.get_history("chat-1").unwrap();
|
||||||
assert_eq!(history.len(), 1);
|
assert_eq!(history.len(), 2);
|
||||||
assert_eq!(history[0].role, "system");
|
assert_eq!(history[0].role, "system");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -139,6 +140,7 @@ mod tests {
|
|||||||
skills,
|
skills,
|
||||||
store,
|
store,
|
||||||
100,
|
100,
|
||||||
|
Some(4),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -155,7 +157,7 @@ mod tests {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let history = session.get_history("chat-1").unwrap();
|
let history = session.get_history("chat-1").unwrap();
|
||||||
assert_eq!(history.len(), 1);
|
assert_eq!(history.len(), 2);
|
||||||
assert_eq!(history[0].role, "system");
|
assert_eq!(history[0].role, "system");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -60,8 +60,8 @@ impl GatewayState {
|
|||||||
provider_configs.insert(agent_name.clone(), config.get_provider_config(agent_name)?);
|
provider_configs.insert(agent_name.clone(), config.get_provider_config(agent_name)?);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Session TTL from config (default 4 hours)
|
// Chat history TTL from config (default 4 hours)
|
||||||
let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4);
|
let chat_history_ttl_hours = config.gateway.chat_history_ttl_hours;
|
||||||
let agent_prompt_reinject_every = config.gateway.agent_prompt_reinject_every;
|
let agent_prompt_reinject_every = config.gateway.agent_prompt_reinject_every;
|
||||||
let show_tool_results = config.gateway.show_tool_results;
|
let show_tool_results = config.gateway.show_tool_results;
|
||||||
|
|
||||||
@ -70,7 +70,6 @@ impl GatewayState {
|
|||||||
let bus = channel_manager.bus();
|
let bus = channel_manager.bus();
|
||||||
|
|
||||||
let session_manager = build_session_manager_with_sender(
|
let session_manager = build_session_manager_with_sender(
|
||||||
session_ttl_hours,
|
|
||||||
agent_prompt_reinject_every,
|
agent_prompt_reinject_every,
|
||||||
show_tool_results,
|
show_tool_results,
|
||||||
config.time.timezone.clone(),
|
config.time.timezone.clone(),
|
||||||
@ -79,6 +78,7 @@ impl GatewayState {
|
|||||||
skills,
|
skills,
|
||||||
Arc::new(BusSessionMessageSender::new(bus.clone())),
|
Arc::new(BusSessionMessageSender::new(bus.clone())),
|
||||||
std::collections::HashSet::new(),
|
std::collections::HashSet::new(),
|
||||||
|
chat_history_ttl_hours,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
|||||||
@ -23,7 +23,6 @@ use super::session_lifecycle::SessionLifecycleService;
|
|||||||
use super::session_message_service::SessionMessageService;
|
use super::session_message_service::SessionMessageService;
|
||||||
|
|
||||||
pub(crate) fn build_session_manager(
|
pub(crate) fn build_session_manager(
|
||||||
session_ttl_hours: u64,
|
|
||||||
agent_prompt_reinject_every: u64,
|
agent_prompt_reinject_every: u64,
|
||||||
show_tool_results: bool,
|
show_tool_results: bool,
|
||||||
default_timezone: String,
|
default_timezone: String,
|
||||||
@ -31,9 +30,9 @@ pub(crate) fn build_session_manager(
|
|||||||
provider_configs: HashMap<String, LLMProviderConfig>,
|
provider_configs: HashMap<String, LLMProviderConfig>,
|
||||||
skills: Arc<SkillRuntime>,
|
skills: Arc<SkillRuntime>,
|
||||||
disabled_tools: HashSet<String>,
|
disabled_tools: HashSet<String>,
|
||||||
|
chat_history_ttl_hours: Option<u64>,
|
||||||
) -> Result<SessionManager, AgentError> {
|
) -> Result<SessionManager, AgentError> {
|
||||||
build_session_manager_with_sender(
|
build_session_manager_with_sender(
|
||||||
session_ttl_hours,
|
|
||||||
agent_prompt_reinject_every,
|
agent_prompt_reinject_every,
|
||||||
show_tool_results,
|
show_tool_results,
|
||||||
default_timezone,
|
default_timezone,
|
||||||
@ -42,11 +41,11 @@ pub(crate) fn build_session_manager(
|
|||||||
skills,
|
skills,
|
||||||
Arc::new(NoopSessionMessageSender),
|
Arc::new(NoopSessionMessageSender),
|
||||||
disabled_tools,
|
disabled_tools,
|
||||||
|
chat_history_ttl_hours,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn build_session_manager_with_sender(
|
pub(crate) fn build_session_manager_with_sender(
|
||||||
session_ttl_hours: u64,
|
|
||||||
agent_prompt_reinject_every: u64,
|
agent_prompt_reinject_every: u64,
|
||||||
show_tool_results: bool,
|
show_tool_results: bool,
|
||||||
default_timezone: String,
|
default_timezone: String,
|
||||||
@ -55,6 +54,7 @@ pub(crate) fn build_session_manager_with_sender(
|
|||||||
skills: Arc<SkillRuntime>,
|
skills: Arc<SkillRuntime>,
|
||||||
session_message_sender: Arc<dyn SessionMessageSender>,
|
session_message_sender: Arc<dyn SessionMessageSender>,
|
||||||
disabled_tools: HashSet<String>,
|
disabled_tools: HashSet<String>,
|
||||||
|
chat_history_ttl_hours: Option<u64>,
|
||||||
) -> Result<SessionManager, AgentError> {
|
) -> Result<SessionManager, AgentError> {
|
||||||
let store = Arc::new(
|
let store = Arc::new(
|
||||||
SessionStore::new()
|
SessionStore::new()
|
||||||
@ -97,8 +97,9 @@ pub(crate) fn build_session_manager_with_sender(
|
|||||||
prompt_injector,
|
prompt_injector,
|
||||||
conversations,
|
conversations,
|
||||||
skill_events,
|
skill_events,
|
||||||
|
chat_history_ttl_hours,
|
||||||
);
|
);
|
||||||
let lifecycle = SessionLifecycleService::new(session_ttl_hours, session_factory);
|
let lifecycle = SessionLifecycleService::new(session_factory);
|
||||||
let cli_sessions = CliSessionService::new(store.clone());
|
let cli_sessions = CliSessionService::new(store.clone());
|
||||||
let messages = SessionMessageService::new(lifecycle.clone(), show_tool_results);
|
let messages = SessionMessageService::new(lifecycle.clone(), show_tool_results);
|
||||||
let scheduled_tasks = ScheduledAgentTaskService::new(
|
let scheduled_tasks = ScheduledAgentTaskService::new(
|
||||||
|
|||||||
@ -100,6 +100,7 @@ impl Session {
|
|||||||
skills: Arc<SkillRuntime>,
|
skills: Arc<SkillRuntime>,
|
||||||
store: Arc<SessionStore>,
|
store: Arc<SessionStore>,
|
||||||
agent_prompt_reinject_every: u64,
|
agent_prompt_reinject_every: u64,
|
||||||
|
chat_history_ttl_hours: Option<u64>,
|
||||||
) -> Result<Self, AgentError> {
|
) -> Result<Self, AgentError> {
|
||||||
let agent_factory = AgentFactory::new(tools, skills.clone());
|
let agent_factory = AgentFactory::new(tools, skills.clone());
|
||||||
let conversations: Arc<dyn ConversationRepository> = store.clone();
|
let conversations: Arc<dyn ConversationRepository> = store.clone();
|
||||||
@ -114,6 +115,7 @@ impl Session {
|
|||||||
prompt_injector,
|
prompt_injector,
|
||||||
conversations,
|
conversations,
|
||||||
skill_events,
|
skill_events,
|
||||||
|
chat_history_ttl_hours,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
@ -127,6 +129,7 @@ impl Session {
|
|||||||
prompt_injector: PromptInjector,
|
prompt_injector: PromptInjector,
|
||||||
conversations: Arc<dyn ConversationRepository>,
|
conversations: Arc<dyn ConversationRepository>,
|
||||||
skill_events: Arc<dyn SkillEventRepository>,
|
skill_events: Arc<dyn SkillEventRepository>,
|
||||||
|
chat_history_ttl_hours: Option<u64>,
|
||||||
) -> Result<Self, AgentError> {
|
) -> Result<Self, AgentError> {
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
id: Uuid::new_v4(),
|
id: Uuid::new_v4(),
|
||||||
@ -142,6 +145,7 @@ impl Session {
|
|||||||
conversations,
|
conversations,
|
||||||
skill_events,
|
skill_events,
|
||||||
provider_config,
|
provider_config,
|
||||||
|
chat_history_ttl_hours,
|
||||||
),
|
),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -372,7 +376,6 @@ impl SessionManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn new(
|
pub fn new(
|
||||||
session_ttl_hours: u64,
|
|
||||||
agent_prompt_reinject_every: u64,
|
agent_prompt_reinject_every: u64,
|
||||||
show_tool_results: bool,
|
show_tool_results: bool,
|
||||||
default_timezone: String,
|
default_timezone: String,
|
||||||
@ -380,9 +383,9 @@ impl SessionManager {
|
|||||||
provider_configs: HashMap<String, LLMProviderConfig>,
|
provider_configs: HashMap<String, LLMProviderConfig>,
|
||||||
skills: Arc<SkillRuntime>,
|
skills: Arc<SkillRuntime>,
|
||||||
disabled_tools: std::collections::HashSet<String>,
|
disabled_tools: std::collections::HashSet<String>,
|
||||||
|
chat_history_ttl_hours: Option<u64>,
|
||||||
) -> Result<Self, AgentError> {
|
) -> Result<Self, AgentError> {
|
||||||
super::runtime::build_session_manager(
|
super::runtime::build_session_manager(
|
||||||
session_ttl_hours,
|
|
||||||
agent_prompt_reinject_every,
|
agent_prompt_reinject_every,
|
||||||
show_tool_results,
|
show_tool_results,
|
||||||
default_timezone,
|
default_timezone,
|
||||||
@ -390,6 +393,7 @@ impl SessionManager {
|
|||||||
provider_configs,
|
provider_configs,
|
||||||
skills,
|
skills,
|
||||||
disabled_tools,
|
disabled_tools,
|
||||||
|
chat_history_ttl_hours,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -553,6 +557,7 @@ mod tests {
|
|||||||
skills,
|
skills,
|
||||||
store,
|
store,
|
||||||
100,
|
100,
|
||||||
|
Some(4),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -599,6 +604,7 @@ mod tests {
|
|||||||
skills,
|
skills,
|
||||||
store.clone(),
|
store.clone(),
|
||||||
100,
|
100,
|
||||||
|
Some(4),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -785,7 +791,6 @@ mod tests {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let session_manager = SessionManager::new(
|
let session_manager = SessionManager::new(
|
||||||
4,
|
|
||||||
100,
|
100,
|
||||||
false,
|
false,
|
||||||
"Asia/Shanghai".to_string(),
|
"Asia/Shanghai".to_string(),
|
||||||
@ -793,6 +798,7 @@ mod tests {
|
|||||||
HashMap::from([("default".to_string(), provider_config)]),
|
HashMap::from([("default".to_string(), provider_config)]),
|
||||||
Arc::new(SkillRuntime::default()),
|
Arc::new(SkillRuntime::default()),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
|
Some(4),
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@ -831,7 +837,6 @@ mod tests {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let session_manager = SessionManager::new(
|
let session_manager = SessionManager::new(
|
||||||
4,
|
|
||||||
100,
|
100,
|
||||||
false,
|
false,
|
||||||
"Asia/Shanghai".to_string(),
|
"Asia/Shanghai".to_string(),
|
||||||
@ -842,6 +847,7 @@ mod tests {
|
|||||||
]),
|
]),
|
||||||
Arc::new(SkillRuntime::default()),
|
Arc::new(SkillRuntime::default()),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
|
Some(4),
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@ -899,7 +905,6 @@ mod tests {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let session_manager = SessionManager::new(
|
let session_manager = SessionManager::new(
|
||||||
4,
|
|
||||||
100,
|
100,
|
||||||
false,
|
false,
|
||||||
"Asia/Shanghai".to_string(),
|
"Asia/Shanghai".to_string(),
|
||||||
@ -907,6 +912,7 @@ mod tests {
|
|||||||
HashMap::from([("default".to_string(), provider_config)]),
|
HashMap::from([("default".to_string(), provider_config)]),
|
||||||
Arc::new(SkillRuntime::default()),
|
Arc::new(SkillRuntime::default()),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
|
Some(4),
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@ -981,7 +987,6 @@ mod tests {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let session_manager = SessionManager::new(
|
let session_manager = SessionManager::new(
|
||||||
4,
|
|
||||||
100,
|
100,
|
||||||
false,
|
false,
|
||||||
"Asia/Shanghai".to_string(),
|
"Asia/Shanghai".to_string(),
|
||||||
@ -989,6 +994,7 @@ mod tests {
|
|||||||
HashMap::from([("default".to_string(), provider_config)]),
|
HashMap::from([("default".to_string(), provider_config)]),
|
||||||
Arc::new(SkillRuntime::default()),
|
Arc::new(SkillRuntime::default()),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
|
Some(4),
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@ -1068,7 +1074,6 @@ mod tests {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let session_manager = SessionManager::new(
|
let session_manager = SessionManager::new(
|
||||||
4,
|
|
||||||
100,
|
100,
|
||||||
false,
|
false,
|
||||||
"Asia/Shanghai".to_string(),
|
"Asia/Shanghai".to_string(),
|
||||||
@ -1076,6 +1081,7 @@ mod tests {
|
|||||||
HashMap::from([("default".to_string(), provider_config)]),
|
HashMap::from([("default".to_string(), provider_config)]),
|
||||||
Arc::new(SkillRuntime::default()),
|
Arc::new(SkillRuntime::default()),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
|
Some(4),
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@ -1145,7 +1151,6 @@ mod tests {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let session_manager = SessionManager::new(
|
let session_manager = SessionManager::new(
|
||||||
4,
|
|
||||||
100,
|
100,
|
||||||
false,
|
false,
|
||||||
"Asia/Shanghai".to_string(),
|
"Asia/Shanghai".to_string(),
|
||||||
@ -1153,6 +1158,7 @@ mod tests {
|
|||||||
HashMap::from([("default".to_string(), provider_config)]),
|
HashMap::from([("default".to_string(), provider_config)]),
|
||||||
Arc::new(SkillRuntime::default()),
|
Arc::new(SkillRuntime::default()),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
|
Some(4),
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@ -1209,7 +1215,6 @@ mod tests {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let session_manager = SessionManager::new(
|
let session_manager = SessionManager::new(
|
||||||
4,
|
|
||||||
100,
|
100,
|
||||||
false,
|
false,
|
||||||
"Asia/Shanghai".to_string(),
|
"Asia/Shanghai".to_string(),
|
||||||
@ -1217,6 +1222,7 @@ mod tests {
|
|||||||
HashMap::from([("default".to_string(), provider_config)]),
|
HashMap::from([("default".to_string(), provider_config)]),
|
||||||
Arc::new(SkillRuntime::default()),
|
Arc::new(SkillRuntime::default()),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
|
Some(4),
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@ -1283,7 +1289,6 @@ mod tests {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let session_manager = SessionManager::new(
|
let session_manager = SessionManager::new(
|
||||||
4,
|
|
||||||
100,
|
100,
|
||||||
false,
|
false,
|
||||||
"Asia/Shanghai".to_string(),
|
"Asia/Shanghai".to_string(),
|
||||||
@ -1291,6 +1296,7 @@ mod tests {
|
|||||||
HashMap::from([("default".to_string(), provider_config)]),
|
HashMap::from([("default".to_string(), provider_config)]),
|
||||||
Arc::new(SkillRuntime::default()),
|
Arc::new(SkillRuntime::default()),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
|
Some(4),
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@ -1341,7 +1347,6 @@ mod tests {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let session_manager = SessionManager::new(
|
let session_manager = SessionManager::new(
|
||||||
4,
|
|
||||||
100,
|
100,
|
||||||
false,
|
false,
|
||||||
"Asia/Shanghai".to_string(),
|
"Asia/Shanghai".to_string(),
|
||||||
@ -1349,6 +1354,7 @@ mod tests {
|
|||||||
HashMap::from([("default".to_string(), provider_config)]),
|
HashMap::from([("default".to_string(), provider_config)]),
|
||||||
Arc::new(SkillRuntime::default()),
|
Arc::new(SkillRuntime::default()),
|
||||||
HashSet::new(),
|
HashSet::new(),
|
||||||
|
Some(4),
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@ -1536,6 +1542,7 @@ mod tests {
|
|||||||
skills,
|
skills,
|
||||||
store.clone(),
|
store.clone(),
|
||||||
100,
|
100,
|
||||||
|
Some(4),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -1544,7 +1551,7 @@ mod tests {
|
|||||||
session.ensure_chat_loaded("chat-1").unwrap();
|
session.ensure_chat_loaded("chat-1").unwrap();
|
||||||
|
|
||||||
let history = session.get_history("chat-1").unwrap();
|
let history = session.get_history("chat-1").unwrap();
|
||||||
assert_eq!(history.len(), 1);
|
assert_eq!(history.len(), 2);
|
||||||
assert_eq!(history[0].role, "system");
|
assert_eq!(history[0].role, "system");
|
||||||
assert!(history[0].content.contains("PicoBot 代理配置"));
|
assert!(history[0].content.contains("PicoBot 代理配置"));
|
||||||
}
|
}
|
||||||
@ -1575,6 +1582,7 @@ mod tests {
|
|||||||
skills,
|
skills,
|
||||||
store.clone(),
|
store.clone(),
|
||||||
100,
|
100,
|
||||||
|
Some(4),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -1597,7 +1605,7 @@ mod tests {
|
|||||||
.iter()
|
.iter()
|
||||||
.filter(|message| message.role == "system")
|
.filter(|message| message.role == "system")
|
||||||
.count();
|
.count();
|
||||||
assert_eq!(system_messages, 2);
|
assert_eq!(system_messages, 3);
|
||||||
|
|
||||||
let stored = store
|
let stored = store
|
||||||
.get_session(&session.persistent_session_id("chat-1"))
|
.get_session(&session.persistent_session_id("chat-1"))
|
||||||
@ -1613,7 +1621,7 @@ mod tests {
|
|||||||
.iter()
|
.iter()
|
||||||
.filter(|message| message.role == "system")
|
.filter(|message| message.role == "system")
|
||||||
.count();
|
.count();
|
||||||
assert_eq!(system_messages, 2);
|
assert_eq!(system_messages, 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@ -1642,6 +1650,7 @@ mod tests {
|
|||||||
skills,
|
skills,
|
||||||
store.clone(),
|
store.clone(),
|
||||||
0,
|
0,
|
||||||
|
Some(4),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -1664,7 +1673,7 @@ mod tests {
|
|||||||
.iter()
|
.iter()
|
||||||
.filter(|message| message.role == "system")
|
.filter(|message| message.role == "system")
|
||||||
.count();
|
.count();
|
||||||
assert_eq!(system_messages, 1);
|
assert_eq!(system_messages, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@ -20,6 +20,7 @@ pub(crate) struct SessionFactory {
|
|||||||
prompt_injector: PromptInjector,
|
prompt_injector: PromptInjector,
|
||||||
conversations: Arc<dyn ConversationRepository>,
|
conversations: Arc<dyn ConversationRepository>,
|
||||||
skill_events: Arc<dyn SkillEventRepository>,
|
skill_events: Arc<dyn SkillEventRepository>,
|
||||||
|
chat_history_ttl_hours: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SessionFactory {
|
impl SessionFactory {
|
||||||
@ -30,6 +31,7 @@ impl SessionFactory {
|
|||||||
prompt_injector: PromptInjector,
|
prompt_injector: PromptInjector,
|
||||||
conversations: Arc<dyn ConversationRepository>,
|
conversations: Arc<dyn ConversationRepository>,
|
||||||
skill_events: Arc<dyn SkillEventRepository>,
|
skill_events: Arc<dyn SkillEventRepository>,
|
||||||
|
chat_history_ttl_hours: Option<u64>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
provider_config,
|
provider_config,
|
||||||
@ -38,6 +40,7 @@ impl SessionFactory {
|
|||||||
prompt_injector,
|
prompt_injector,
|
||||||
conversations,
|
conversations,
|
||||||
skill_events,
|
skill_events,
|
||||||
|
chat_history_ttl_hours,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -55,6 +58,7 @@ impl SessionFactory {
|
|||||||
self.prompt_injector.clone(),
|
self.prompt_injector.clone(),
|
||||||
self.conversations.clone(),
|
self.conversations.clone(),
|
||||||
self.skill_events.clone(),
|
self.skill_events.clone(),
|
||||||
|
self.chat_history_ttl_hours,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|||||||
@ -19,6 +19,13 @@ fn preview_text(content: &str, max_chars: usize) -> String {
|
|||||||
preview.replace('\n', "\\n")
|
preview.replace('\n', "\\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn current_timestamp() -> i64 {
|
||||||
|
std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_secs() as i64
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) struct SessionHistory {
|
pub(crate) struct SessionHistory {
|
||||||
channel_name: String,
|
channel_name: String,
|
||||||
chat_histories: HashMap<String, Vec<ChatMessage>>,
|
chat_histories: HashMap<String, Vec<ChatMessage>>,
|
||||||
@ -27,6 +34,7 @@ pub(crate) struct SessionHistory {
|
|||||||
conversations: Arc<dyn ConversationRepository>,
|
conversations: Arc<dyn ConversationRepository>,
|
||||||
skill_events: Arc<dyn SkillEventRepository>,
|
skill_events: Arc<dyn SkillEventRepository>,
|
||||||
provider_config: LLMProviderConfig,
|
provider_config: LLMProviderConfig,
|
||||||
|
chat_history_ttl_hours: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SessionHistory {
|
impl SessionHistory {
|
||||||
@ -36,6 +44,7 @@ impl SessionHistory {
|
|||||||
conversations: Arc<dyn ConversationRepository>,
|
conversations: Arc<dyn ConversationRepository>,
|
||||||
skill_events: Arc<dyn SkillEventRepository>,
|
skill_events: Arc<dyn SkillEventRepository>,
|
||||||
provider_config: LLMProviderConfig,
|
provider_config: LLMProviderConfig,
|
||||||
|
chat_history_ttl_hours: Option<u64>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
channel_name: channel_name.into(),
|
channel_name: channel_name.into(),
|
||||||
@ -45,6 +54,7 @@ impl SessionHistory {
|
|||||||
conversations,
|
conversations,
|
||||||
skill_events,
|
skill_events,
|
||||||
provider_config,
|
provider_config,
|
||||||
|
chat_history_ttl_hours,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -62,6 +72,29 @@ impl SessionHistory {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn ensure_chat_loaded(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
pub(crate) fn ensure_chat_loaded(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||||
|
// 获取 session 记录(用于检查最后活跃时间)
|
||||||
|
let session_record = self.ensure_persistent_session(chat_id)?;
|
||||||
|
|
||||||
|
// 检查是否超时
|
||||||
|
if let Some(ttl_hours) = self.chat_history_ttl_hours {
|
||||||
|
if ttl_hours > 0 {
|
||||||
|
let now = current_timestamp();
|
||||||
|
let elapsed_hours = (now - session_record.last_active_at) / 3600;
|
||||||
|
if elapsed_hours >= ttl_hours as i64 {
|
||||||
|
tracing::info!(
|
||||||
|
channel = %self.channel_name,
|
||||||
|
chat_id = %chat_id,
|
||||||
|
elapsed_hours = elapsed_hours,
|
||||||
|
ttl_hours = ttl_hours,
|
||||||
|
"Chat history expired, resetting context"
|
||||||
|
);
|
||||||
|
// 重置会话上下文(清空内存历史,但保留数据库记录)
|
||||||
|
self.reset_chat_context(chat_id)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 原有逻辑
|
||||||
if self.chat_histories.contains_key(chat_id) {
|
if self.chat_histories.contains_key(chat_id) {
|
||||||
return self.ensure_initial_agent_prompt(chat_id);
|
return self.ensure_initial_agent_prompt(chat_id);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -14,9 +14,9 @@ pub(crate) struct SessionLifecycleService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl SessionLifecycleService {
|
impl SessionLifecycleService {
|
||||||
pub(crate) fn new(session_ttl_hours: u64, session_factory: SessionFactory) -> Self {
|
pub(crate) fn new(session_factory: SessionFactory) -> Self {
|
||||||
Self {
|
Self {
|
||||||
session_pool: SessionPool::new(session_ttl_hours, session_factory),
|
session_pool: SessionPool::new(session_factory),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::Instant;
|
||||||
|
|
||||||
use tokio::sync::{Mutex, mpsc};
|
use tokio::sync::{Mutex, mpsc};
|
||||||
|
|
||||||
@ -19,16 +19,14 @@ pub(crate) struct SessionPool {
|
|||||||
struct SessionPoolInner {
|
struct SessionPoolInner {
|
||||||
sessions: HashMap<String, Arc<Mutex<Session>>>,
|
sessions: HashMap<String, Arc<Mutex<Session>>>,
|
||||||
session_timestamps: HashMap<String, Instant>,
|
session_timestamps: HashMap<String, Instant>,
|
||||||
session_ttl: Duration,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SessionPool {
|
impl SessionPool {
|
||||||
pub(crate) fn new(session_ttl_hours: u64, session_factory: SessionFactory) -> Self {
|
pub(crate) fn new(session_factory: SessionFactory) -> Self {
|
||||||
Self {
|
Self {
|
||||||
inner: Arc::new(Mutex::new(SessionPoolInner {
|
inner: Arc::new(Mutex::new(SessionPoolInner {
|
||||||
sessions: HashMap::new(),
|
sessions: HashMap::new(),
|
||||||
session_timestamps: HashMap::new(),
|
session_timestamps: HashMap::new(),
|
||||||
session_ttl: Duration::from_secs(session_ttl_hours * 3600),
|
|
||||||
})),
|
})),
|
||||||
session_factory,
|
session_factory,
|
||||||
}
|
}
|
||||||
@ -37,24 +35,12 @@ impl SessionPool {
|
|||||||
pub(crate) async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> {
|
pub(crate) async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> {
|
||||||
let mut inner = self.inner.lock().await;
|
let mut inner = self.inner.lock().await;
|
||||||
|
|
||||||
let should_recreate = if let Some(last_active) = inner.session_timestamps.get(channel_name)
|
// 简化:只检查 session 是否存在,不做超时判断
|
||||||
{
|
if inner.sessions.contains_key(channel_name) {
|
||||||
let elapsed = last_active.elapsed();
|
return Ok(());
|
||||||
if elapsed > inner.session_ttl {
|
|
||||||
tracing::info!(channel = %channel_name, elapsed_hours = elapsed.as_secs() / 3600, "Session expired, recreating");
|
|
||||||
true
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
tracing::debug!(channel = %channel_name, "Creating new session");
|
|
||||||
true
|
|
||||||
};
|
|
||||||
|
|
||||||
if should_recreate {
|
|
||||||
inner.sessions.remove(channel_name);
|
|
||||||
|
|
||||||
|
// Session 不存在则创建
|
||||||
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
|
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
|
||||||
let session = self
|
let session = self
|
||||||
.session_factory
|
.session_factory
|
||||||
@ -67,7 +53,6 @@ impl SessionPool {
|
|||||||
inner
|
inner
|
||||||
.session_timestamps
|
.session_timestamps
|
||||||
.insert(channel_name.to_string(), Instant::now());
|
.insert(channel_name.to_string(), Instant::now());
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -85,25 +70,7 @@ impl SessionPool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn cleanup_expired_sessions(&self) -> usize {
|
pub(crate) async fn cleanup_expired_sessions(&self) -> usize {
|
||||||
let mut inner = self.inner.lock().await;
|
// Session 级别不再自动清理,返回 0
|
||||||
let now = Instant::now();
|
0
|
||||||
let expired_channels: Vec<String> = inner
|
|
||||||
.session_timestamps
|
|
||||||
.iter()
|
|
||||||
.filter_map(|(channel_name, last_active)| {
|
|
||||||
if now.duration_since(*last_active) > inner.session_ttl {
|
|
||||||
Some(channel_name.clone())
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
for channel_name in &expired_channels {
|
|
||||||
inner.sessions.remove(channel_name);
|
|
||||||
inner.session_timestamps.remove(channel_name);
|
|
||||||
}
|
|
||||||
|
|
||||||
expired_channels.len()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user