feat: 更新会话配置,重命名 session_ttl_hours 为 chat_history_ttl_hours,并调整相关逻辑以支持聊天历史过期管理

This commit is contained in:
ooodc 2026-05-10 19:29:55 +08:00
parent 0ea98c6e8e
commit daec690f59
9 changed files with 100 additions and 84 deletions

View File

@ -299,8 +299,8 @@ pub struct GatewayConfig {
pub port: u16,
#[serde(default)]
pub show_tool_results: bool,
#[serde(default, rename = "session_ttl_hours")]
pub session_ttl_hours: Option<u64>,
#[serde(default, rename = "chat_history_ttl_hours")]
pub chat_history_ttl_hours: Option<u64>,
#[serde(
default = "default_agent_prompt_reinject_every",
rename = "agent_prompt_reinject_every"
@ -589,7 +589,7 @@ impl Default for GatewayConfig {
host: default_gateway_host(),
port: default_gateway_port(),
show_tool_results: false,
session_ttl_hours: None,
chat_history_ttl_hours: Some(4),
agent_prompt_reinject_every: default_agent_prompt_reinject_every(),
}
}

View File

@ -89,6 +89,7 @@ mod tests {
skills,
store.clone(),
100,
Some(4),
)
.await
.unwrap();
@ -116,12 +117,12 @@ mod tests {
.load_all_messages(&session.persistent_session_id("chat-1"))
.unwrap()
.len(),
2,
3,
);
session.ensure_chat_loaded("chat-1").unwrap();
let history = session.get_history("chat-1").unwrap();
assert_eq!(history.len(), 1);
assert_eq!(history.len(), 2);
assert_eq!(history[0].role, "system");
}
@ -139,6 +140,7 @@ mod tests {
skills,
store,
100,
Some(4),
)
.await
.unwrap();
@ -155,7 +157,7 @@ mod tests {
.unwrap();
let history = session.get_history("chat-1").unwrap();
assert_eq!(history.len(), 1);
assert_eq!(history.len(), 2);
assert_eq!(history[0].role, "system");
}
}

View File

@ -60,8 +60,8 @@ impl GatewayState {
provider_configs.insert(agent_name.clone(), config.get_provider_config(agent_name)?);
}
// Session TTL from config (default 4 hours)
let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4);
// Chat history TTL from config (default 4 hours)
let chat_history_ttl_hours = config.gateway.chat_history_ttl_hours;
let agent_prompt_reinject_every = config.gateway.agent_prompt_reinject_every;
let show_tool_results = config.gateway.show_tool_results;
@ -70,7 +70,6 @@ impl GatewayState {
let bus = channel_manager.bus();
let session_manager = build_session_manager_with_sender(
session_ttl_hours,
agent_prompt_reinject_every,
show_tool_results,
config.time.timezone.clone(),
@ -79,6 +78,7 @@ impl GatewayState {
skills,
Arc::new(BusSessionMessageSender::new(bus.clone())),
std::collections::HashSet::new(),
chat_history_ttl_hours,
)?;
Ok(Self {

View File

@ -23,7 +23,6 @@ use super::session_lifecycle::SessionLifecycleService;
use super::session_message_service::SessionMessageService;
pub(crate) fn build_session_manager(
session_ttl_hours: u64,
agent_prompt_reinject_every: u64,
show_tool_results: bool,
default_timezone: String,
@ -31,9 +30,9 @@ pub(crate) fn build_session_manager(
provider_configs: HashMap<String, LLMProviderConfig>,
skills: Arc<SkillRuntime>,
disabled_tools: HashSet<String>,
chat_history_ttl_hours: Option<u64>,
) -> Result<SessionManager, AgentError> {
build_session_manager_with_sender(
session_ttl_hours,
agent_prompt_reinject_every,
show_tool_results,
default_timezone,
@ -42,11 +41,11 @@ pub(crate) fn build_session_manager(
skills,
Arc::new(NoopSessionMessageSender),
disabled_tools,
chat_history_ttl_hours,
)
}
pub(crate) fn build_session_manager_with_sender(
session_ttl_hours: u64,
agent_prompt_reinject_every: u64,
show_tool_results: bool,
default_timezone: String,
@ -55,6 +54,7 @@ pub(crate) fn build_session_manager_with_sender(
skills: Arc<SkillRuntime>,
session_message_sender: Arc<dyn SessionMessageSender>,
disabled_tools: HashSet<String>,
chat_history_ttl_hours: Option<u64>,
) -> Result<SessionManager, AgentError> {
let store = Arc::new(
SessionStore::new()
@ -97,8 +97,9 @@ pub(crate) fn build_session_manager_with_sender(
prompt_injector,
conversations,
skill_events,
chat_history_ttl_hours,
);
let lifecycle = SessionLifecycleService::new(session_ttl_hours, session_factory);
let lifecycle = SessionLifecycleService::new(session_factory);
let cli_sessions = CliSessionService::new(store.clone());
let messages = SessionMessageService::new(lifecycle.clone(), show_tool_results);
let scheduled_tasks = ScheduledAgentTaskService::new(

View File

@ -100,6 +100,7 @@ impl Session {
skills: Arc<SkillRuntime>,
store: Arc<SessionStore>,
agent_prompt_reinject_every: u64,
chat_history_ttl_hours: Option<u64>,
) -> Result<Self, AgentError> {
let agent_factory = AgentFactory::new(tools, skills.clone());
let conversations: Arc<dyn ConversationRepository> = store.clone();
@ -114,6 +115,7 @@ impl Session {
prompt_injector,
conversations,
skill_events,
chat_history_ttl_hours,
)
.await
}
@ -127,6 +129,7 @@ impl Session {
prompt_injector: PromptInjector,
conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>,
chat_history_ttl_hours: Option<u64>,
) -> Result<Self, AgentError> {
Ok(Self {
id: Uuid::new_v4(),
@ -142,6 +145,7 @@ impl Session {
conversations,
skill_events,
provider_config,
chat_history_ttl_hours,
),
})
}
@ -372,7 +376,6 @@ impl SessionManager {
}
pub fn new(
session_ttl_hours: u64,
agent_prompt_reinject_every: u64,
show_tool_results: bool,
default_timezone: String,
@ -380,9 +383,9 @@ impl SessionManager {
provider_configs: HashMap<String, LLMProviderConfig>,
skills: Arc<SkillRuntime>,
disabled_tools: std::collections::HashSet<String>,
chat_history_ttl_hours: Option<u64>,
) -> Result<Self, AgentError> {
super::runtime::build_session_manager(
session_ttl_hours,
agent_prompt_reinject_every,
show_tool_results,
default_timezone,
@ -390,6 +393,7 @@ impl SessionManager {
provider_configs,
skills,
disabled_tools,
chat_history_ttl_hours,
)
}
@ -553,6 +557,7 @@ mod tests {
skills,
store,
100,
Some(4),
)
.await
.unwrap();
@ -599,6 +604,7 @@ mod tests {
skills,
store.clone(),
100,
Some(4),
)
.await
.unwrap();
@ -785,7 +791,6 @@ mod tests {
};
let session_manager = SessionManager::new(
4,
100,
false,
"Asia/Shanghai".to_string(),
@ -793,6 +798,7 @@ mod tests {
HashMap::from([("default".to_string(), provider_config)]),
Arc::new(SkillRuntime::default()),
HashSet::new(),
Some(4),
)
.unwrap();
@ -831,7 +837,6 @@ mod tests {
};
let session_manager = SessionManager::new(
4,
100,
false,
"Asia/Shanghai".to_string(),
@ -842,6 +847,7 @@ mod tests {
]),
Arc::new(SkillRuntime::default()),
HashSet::new(),
Some(4),
)
.unwrap();
@ -899,7 +905,6 @@ mod tests {
};
let session_manager = SessionManager::new(
4,
100,
false,
"Asia/Shanghai".to_string(),
@ -907,6 +912,7 @@ mod tests {
HashMap::from([("default".to_string(), provider_config)]),
Arc::new(SkillRuntime::default()),
HashSet::new(),
Some(4),
)
.unwrap();
@ -981,7 +987,6 @@ mod tests {
};
let session_manager = SessionManager::new(
4,
100,
false,
"Asia/Shanghai".to_string(),
@ -989,6 +994,7 @@ mod tests {
HashMap::from([("default".to_string(), provider_config)]),
Arc::new(SkillRuntime::default()),
HashSet::new(),
Some(4),
)
.unwrap();
@ -1068,7 +1074,6 @@ mod tests {
};
let session_manager = SessionManager::new(
4,
100,
false,
"Asia/Shanghai".to_string(),
@ -1076,6 +1081,7 @@ mod tests {
HashMap::from([("default".to_string(), provider_config)]),
Arc::new(SkillRuntime::default()),
HashSet::new(),
Some(4),
)
.unwrap();
@ -1145,7 +1151,6 @@ mod tests {
};
let session_manager = SessionManager::new(
4,
100,
false,
"Asia/Shanghai".to_string(),
@ -1153,6 +1158,7 @@ mod tests {
HashMap::from([("default".to_string(), provider_config)]),
Arc::new(SkillRuntime::default()),
HashSet::new(),
Some(4),
)
.unwrap();
@ -1209,7 +1215,6 @@ mod tests {
};
let session_manager = SessionManager::new(
4,
100,
false,
"Asia/Shanghai".to_string(),
@ -1217,6 +1222,7 @@ mod tests {
HashMap::from([("default".to_string(), provider_config)]),
Arc::new(SkillRuntime::default()),
HashSet::new(),
Some(4),
)
.unwrap();
@ -1283,7 +1289,6 @@ mod tests {
};
let session_manager = SessionManager::new(
4,
100,
false,
"Asia/Shanghai".to_string(),
@ -1291,6 +1296,7 @@ mod tests {
HashMap::from([("default".to_string(), provider_config)]),
Arc::new(SkillRuntime::default()),
HashSet::new(),
Some(4),
)
.unwrap();
@ -1341,7 +1347,6 @@ mod tests {
};
let session_manager = SessionManager::new(
4,
100,
false,
"Asia/Shanghai".to_string(),
@ -1349,6 +1354,7 @@ mod tests {
HashMap::from([("default".to_string(), provider_config)]),
Arc::new(SkillRuntime::default()),
HashSet::new(),
Some(4),
)
.unwrap();
@ -1536,6 +1542,7 @@ mod tests {
skills,
store.clone(),
100,
Some(4),
)
.await
.unwrap();
@ -1544,7 +1551,7 @@ mod tests {
session.ensure_chat_loaded("chat-1").unwrap();
let history = session.get_history("chat-1").unwrap();
assert_eq!(history.len(), 1);
assert_eq!(history.len(), 2);
assert_eq!(history[0].role, "system");
assert!(history[0].content.contains("PicoBot 代理配置"));
}
@ -1575,6 +1582,7 @@ mod tests {
skills,
store.clone(),
100,
Some(4),
)
.await
.unwrap();
@ -1597,7 +1605,7 @@ mod tests {
.iter()
.filter(|message| message.role == "system")
.count();
assert_eq!(system_messages, 2);
assert_eq!(system_messages, 3);
let stored = store
.get_session(&session.persistent_session_id("chat-1"))
@ -1613,7 +1621,7 @@ mod tests {
.iter()
.filter(|message| message.role == "system")
.count();
assert_eq!(system_messages, 2);
assert_eq!(system_messages, 3);
}
#[tokio::test]
@ -1642,6 +1650,7 @@ mod tests {
skills,
store.clone(),
0,
Some(4),
)
.await
.unwrap();
@ -1664,7 +1673,7 @@ mod tests {
.iter()
.filter(|message| message.role == "system")
.count();
assert_eq!(system_messages, 1);
assert_eq!(system_messages, 2);
}
#[test]

View File

@ -20,6 +20,7 @@ pub(crate) struct SessionFactory {
prompt_injector: PromptInjector,
conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>,
chat_history_ttl_hours: Option<u64>,
}
impl SessionFactory {
@ -30,6 +31,7 @@ impl SessionFactory {
prompt_injector: PromptInjector,
conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>,
chat_history_ttl_hours: Option<u64>,
) -> Self {
Self {
provider_config,
@ -38,6 +40,7 @@ impl SessionFactory {
prompt_injector,
conversations,
skill_events,
chat_history_ttl_hours,
}
}
@ -55,6 +58,7 @@ impl SessionFactory {
self.prompt_injector.clone(),
self.conversations.clone(),
self.skill_events.clone(),
self.chat_history_ttl_hours,
)
.await
}

View File

@ -19,6 +19,13 @@ fn preview_text(content: &str, max_chars: usize) -> String {
preview.replace('\n', "\\n")
}
fn current_timestamp() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as i64
}
pub(crate) struct SessionHistory {
channel_name: String,
chat_histories: HashMap<String, Vec<ChatMessage>>,
@ -27,6 +34,7 @@ pub(crate) struct SessionHistory {
conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>,
provider_config: LLMProviderConfig,
chat_history_ttl_hours: Option<u64>,
}
impl SessionHistory {
@ -36,6 +44,7 @@ impl SessionHistory {
conversations: Arc<dyn ConversationRepository>,
skill_events: Arc<dyn SkillEventRepository>,
provider_config: LLMProviderConfig,
chat_history_ttl_hours: Option<u64>,
) -> Self {
Self {
channel_name: channel_name.into(),
@ -45,6 +54,7 @@ impl SessionHistory {
conversations,
skill_events,
provider_config,
chat_history_ttl_hours,
}
}
@ -62,6 +72,29 @@ impl SessionHistory {
}
pub(crate) fn ensure_chat_loaded(&mut self, chat_id: &str) -> Result<(), AgentError> {
// 获取 session 记录(用于检查最后活跃时间)
let session_record = self.ensure_persistent_session(chat_id)?;
// 检查是否超时
if let Some(ttl_hours) = self.chat_history_ttl_hours {
if ttl_hours > 0 {
let now = current_timestamp();
let elapsed_hours = (now - session_record.last_active_at) / 3600;
if elapsed_hours >= ttl_hours as i64 {
tracing::info!(
channel = %self.channel_name,
chat_id = %chat_id,
elapsed_hours = elapsed_hours,
ttl_hours = ttl_hours,
"Chat history expired, resetting context"
);
// 重置会话上下文(清空内存历史,但保留数据库记录)
self.reset_chat_context(chat_id)?;
}
}
}
// 原有逻辑
if self.chat_histories.contains_key(chat_id) {
return self.ensure_initial_agent_prompt(chat_id);
}

View File

@ -14,9 +14,9 @@ pub(crate) struct SessionLifecycleService {
}
impl SessionLifecycleService {
pub(crate) fn new(session_ttl_hours: u64, session_factory: SessionFactory) -> Self {
pub(crate) fn new(session_factory: SessionFactory) -> Self {
Self {
session_pool: SessionPool::new(session_ttl_hours, session_factory),
session_pool: SessionPool::new(session_factory),
}
}

View File

@ -1,6 +1,6 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use std::time::Instant;
use tokio::sync::{Mutex, mpsc};
@ -19,16 +19,14 @@ pub(crate) struct SessionPool {
struct SessionPoolInner {
sessions: HashMap<String, Arc<Mutex<Session>>>,
session_timestamps: HashMap<String, Instant>,
session_ttl: Duration,
}
impl SessionPool {
pub(crate) fn new(session_ttl_hours: u64, session_factory: SessionFactory) -> Self {
pub(crate) fn new(session_factory: SessionFactory) -> Self {
Self {
inner: Arc::new(Mutex::new(SessionPoolInner {
sessions: HashMap::new(),
session_timestamps: HashMap::new(),
session_ttl: Duration::from_secs(session_ttl_hours * 3600),
})),
session_factory,
}
@ -37,24 +35,12 @@ impl SessionPool {
pub(crate) async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> {
let mut inner = self.inner.lock().await;
let should_recreate = if let Some(last_active) = inner.session_timestamps.get(channel_name)
{
let elapsed = last_active.elapsed();
if elapsed > inner.session_ttl {
tracing::info!(channel = %channel_name, elapsed_hours = elapsed.as_secs() / 3600, "Session expired, recreating");
true
} else {
false
// 简化:只检查 session 是否存在,不做超时判断
if inner.sessions.contains_key(channel_name) {
return Ok(());
}
} else {
#[cfg(debug_assertions)]
tracing::debug!(channel = %channel_name, "Creating new session");
true
};
if should_recreate {
inner.sessions.remove(channel_name);
// Session 不存在则创建
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
let session = self
.session_factory
@ -67,7 +53,6 @@ impl SessionPool {
inner
.session_timestamps
.insert(channel_name.to_string(), Instant::now());
}
Ok(())
}
@ -85,25 +70,7 @@ impl SessionPool {
}
pub(crate) async fn cleanup_expired_sessions(&self) -> usize {
let mut inner = self.inner.lock().await;
let now = Instant::now();
let expired_channels: Vec<String> = inner
.session_timestamps
.iter()
.filter_map(|(channel_name, last_active)| {
if now.duration_since(*last_active) > inner.session_ttl {
Some(channel_name.clone())
} else {
None
}
})
.collect();
for channel_name in &expired_channels {
inner.sessions.remove(channel_name);
inner.session_timestamps.remove(channel_name);
}
expired_channels.len()
// Session 级别不再自动清理,返回 0
0
}
}