diff --git a/src/config/mod.rs b/src/config/mod.rs index d2ecf7e..f08a129 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -299,8 +299,8 @@ pub struct GatewayConfig { pub port: u16, #[serde(default)] pub show_tool_results: bool, - #[serde(default, rename = "session_ttl_hours")] - pub session_ttl_hours: Option, + #[serde(default, rename = "chat_history_ttl_hours")] + pub chat_history_ttl_hours: Option, #[serde( default = "default_agent_prompt_reinject_every", rename = "agent_prompt_reinject_every" @@ -589,7 +589,7 @@ impl Default for GatewayConfig { host: default_gateway_host(), port: default_gateway_port(), show_tool_results: false, - session_ttl_hours: None, + chat_history_ttl_hours: Some(4), agent_prompt_reinject_every: default_agent_prompt_reinject_every(), } } diff --git a/src/gateway/command.rs b/src/gateway/command.rs index 49e226b..9fad8e9 100644 --- a/src/gateway/command.rs +++ b/src/gateway/command.rs @@ -89,6 +89,7 @@ mod tests { skills, store.clone(), 100, + Some(4), ) .await .unwrap(); @@ -116,12 +117,12 @@ mod tests { .load_all_messages(&session.persistent_session_id("chat-1")) .unwrap() .len(), - 2, + 3, ); session.ensure_chat_loaded("chat-1").unwrap(); let history = session.get_history("chat-1").unwrap(); - assert_eq!(history.len(), 1); + assert_eq!(history.len(), 2); assert_eq!(history[0].role, "system"); } @@ -139,6 +140,7 @@ mod tests { skills, store, 100, + Some(4), ) .await .unwrap(); @@ -155,7 +157,7 @@ mod tests { .unwrap(); let history = session.get_history("chat-1").unwrap(); - assert_eq!(history.len(), 1); + assert_eq!(history.len(), 2); assert_eq!(history[0].role, "system"); } } diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 951251f..aace5cb 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -60,8 +60,8 @@ impl GatewayState { provider_configs.insert(agent_name.clone(), config.get_provider_config(agent_name)?); } - // Session TTL from config (default 4 hours) - let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4); + // Chat history TTL from config (default 4 hours) + let chat_history_ttl_hours = config.gateway.chat_history_ttl_hours; let agent_prompt_reinject_every = config.gateway.agent_prompt_reinject_every; let show_tool_results = config.gateway.show_tool_results; @@ -70,7 +70,6 @@ impl GatewayState { let bus = channel_manager.bus(); let session_manager = build_session_manager_with_sender( - session_ttl_hours, agent_prompt_reinject_every, show_tool_results, config.time.timezone.clone(), @@ -79,6 +78,7 @@ impl GatewayState { skills, Arc::new(BusSessionMessageSender::new(bus.clone())), std::collections::HashSet::new(), + chat_history_ttl_hours, )?; Ok(Self { diff --git a/src/gateway/runtime.rs b/src/gateway/runtime.rs index 6bae7c4..3652ba3 100644 --- a/src/gateway/runtime.rs +++ b/src/gateway/runtime.rs @@ -23,7 +23,6 @@ use super::session_lifecycle::SessionLifecycleService; use super::session_message_service::SessionMessageService; pub(crate) fn build_session_manager( - session_ttl_hours: u64, agent_prompt_reinject_every: u64, show_tool_results: bool, default_timezone: String, @@ -31,9 +30,9 @@ pub(crate) fn build_session_manager( provider_configs: HashMap, skills: Arc, disabled_tools: HashSet, + chat_history_ttl_hours: Option, ) -> Result { build_session_manager_with_sender( - session_ttl_hours, agent_prompt_reinject_every, show_tool_results, default_timezone, @@ -42,11 +41,11 @@ pub(crate) fn build_session_manager( skills, Arc::new(NoopSessionMessageSender), disabled_tools, + chat_history_ttl_hours, ) } pub(crate) fn build_session_manager_with_sender( - session_ttl_hours: u64, agent_prompt_reinject_every: u64, show_tool_results: bool, default_timezone: String, @@ -55,6 +54,7 @@ pub(crate) fn build_session_manager_with_sender( skills: Arc, session_message_sender: Arc, disabled_tools: HashSet, + chat_history_ttl_hours: Option, ) -> Result { let store = Arc::new( SessionStore::new() @@ -97,8 +97,9 @@ pub(crate) fn build_session_manager_with_sender( prompt_injector, conversations, skill_events, + chat_history_ttl_hours, ); - let lifecycle = SessionLifecycleService::new(session_ttl_hours, session_factory); + let lifecycle = SessionLifecycleService::new(session_factory); let cli_sessions = CliSessionService::new(store.clone()); let messages = SessionMessageService::new(lifecycle.clone(), show_tool_results); let scheduled_tasks = ScheduledAgentTaskService::new( diff --git a/src/gateway/session.rs b/src/gateway/session.rs index 7b78178..ac3bb0e 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -100,6 +100,7 @@ impl Session { skills: Arc, store: Arc, agent_prompt_reinject_every: u64, + chat_history_ttl_hours: Option, ) -> Result { let agent_factory = AgentFactory::new(tools, skills.clone()); let conversations: Arc = store.clone(); @@ -114,6 +115,7 @@ impl Session { prompt_injector, conversations, skill_events, + chat_history_ttl_hours, ) .await } @@ -127,6 +129,7 @@ impl Session { prompt_injector: PromptInjector, conversations: Arc, skill_events: Arc, + chat_history_ttl_hours: Option, ) -> Result { Ok(Self { id: Uuid::new_v4(), @@ -142,6 +145,7 @@ impl Session { conversations, skill_events, provider_config, + chat_history_ttl_hours, ), }) } @@ -372,7 +376,6 @@ impl SessionManager { } pub fn new( - session_ttl_hours: u64, agent_prompt_reinject_every: u64, show_tool_results: bool, default_timezone: String, @@ -380,9 +383,9 @@ impl SessionManager { provider_configs: HashMap, skills: Arc, disabled_tools: std::collections::HashSet, + chat_history_ttl_hours: Option, ) -> Result { super::runtime::build_session_manager( - session_ttl_hours, agent_prompt_reinject_every, show_tool_results, default_timezone, @@ -390,6 +393,7 @@ impl SessionManager { provider_configs, skills, disabled_tools, + chat_history_ttl_hours, ) } @@ -553,6 +557,7 @@ mod tests { skills, store, 100, + Some(4), ) .await .unwrap(); @@ -599,6 +604,7 @@ mod tests { skills, store.clone(), 100, + Some(4), ) .await .unwrap(); @@ -785,7 +791,6 @@ mod tests { }; let session_manager = SessionManager::new( - 4, 100, false, "Asia/Shanghai".to_string(), @@ -793,6 +798,7 @@ mod tests { HashMap::from([("default".to_string(), provider_config)]), Arc::new(SkillRuntime::default()), HashSet::new(), + Some(4), ) .unwrap(); @@ -831,7 +837,6 @@ mod tests { }; let session_manager = SessionManager::new( - 4, 100, false, "Asia/Shanghai".to_string(), @@ -842,6 +847,7 @@ mod tests { ]), Arc::new(SkillRuntime::default()), HashSet::new(), + Some(4), ) .unwrap(); @@ -899,7 +905,6 @@ mod tests { }; let session_manager = SessionManager::new( - 4, 100, false, "Asia/Shanghai".to_string(), @@ -907,6 +912,7 @@ mod tests { HashMap::from([("default".to_string(), provider_config)]), Arc::new(SkillRuntime::default()), HashSet::new(), + Some(4), ) .unwrap(); @@ -981,7 +987,6 @@ mod tests { }; let session_manager = SessionManager::new( - 4, 100, false, "Asia/Shanghai".to_string(), @@ -989,6 +994,7 @@ mod tests { HashMap::from([("default".to_string(), provider_config)]), Arc::new(SkillRuntime::default()), HashSet::new(), + Some(4), ) .unwrap(); @@ -1068,7 +1074,6 @@ mod tests { }; let session_manager = SessionManager::new( - 4, 100, false, "Asia/Shanghai".to_string(), @@ -1076,6 +1081,7 @@ mod tests { HashMap::from([("default".to_string(), provider_config)]), Arc::new(SkillRuntime::default()), HashSet::new(), + Some(4), ) .unwrap(); @@ -1145,7 +1151,6 @@ mod tests { }; let session_manager = SessionManager::new( - 4, 100, false, "Asia/Shanghai".to_string(), @@ -1153,6 +1158,7 @@ mod tests { HashMap::from([("default".to_string(), provider_config)]), Arc::new(SkillRuntime::default()), HashSet::new(), + Some(4), ) .unwrap(); @@ -1209,7 +1215,6 @@ mod tests { }; let session_manager = SessionManager::new( - 4, 100, false, "Asia/Shanghai".to_string(), @@ -1217,6 +1222,7 @@ mod tests { HashMap::from([("default".to_string(), provider_config)]), Arc::new(SkillRuntime::default()), HashSet::new(), + Some(4), ) .unwrap(); @@ -1283,7 +1289,6 @@ mod tests { }; let session_manager = SessionManager::new( - 4, 100, false, "Asia/Shanghai".to_string(), @@ -1291,6 +1296,7 @@ mod tests { HashMap::from([("default".to_string(), provider_config)]), Arc::new(SkillRuntime::default()), HashSet::new(), + Some(4), ) .unwrap(); @@ -1341,7 +1347,6 @@ mod tests { }; let session_manager = SessionManager::new( - 4, 100, false, "Asia/Shanghai".to_string(), @@ -1349,6 +1354,7 @@ mod tests { HashMap::from([("default".to_string(), provider_config)]), Arc::new(SkillRuntime::default()), HashSet::new(), + Some(4), ) .unwrap(); @@ -1536,6 +1542,7 @@ mod tests { skills, store.clone(), 100, + Some(4), ) .await .unwrap(); @@ -1544,7 +1551,7 @@ mod tests { session.ensure_chat_loaded("chat-1").unwrap(); let history = session.get_history("chat-1").unwrap(); - assert_eq!(history.len(), 1); + assert_eq!(history.len(), 2); assert_eq!(history[0].role, "system"); assert!(history[0].content.contains("PicoBot 代理配置")); } @@ -1575,6 +1582,7 @@ mod tests { skills, store.clone(), 100, + Some(4), ) .await .unwrap(); @@ -1597,7 +1605,7 @@ mod tests { .iter() .filter(|message| message.role == "system") .count(); - assert_eq!(system_messages, 2); + assert_eq!(system_messages, 3); let stored = store .get_session(&session.persistent_session_id("chat-1")) @@ -1613,7 +1621,7 @@ mod tests { .iter() .filter(|message| message.role == "system") .count(); - assert_eq!(system_messages, 2); + assert_eq!(system_messages, 3); } #[tokio::test] @@ -1642,6 +1650,7 @@ mod tests { skills, store.clone(), 0, + Some(4), ) .await .unwrap(); @@ -1664,7 +1673,7 @@ mod tests { .iter() .filter(|message| message.role == "system") .count(); - assert_eq!(system_messages, 1); + assert_eq!(system_messages, 2); } #[test] diff --git a/src/gateway/session_factory.rs b/src/gateway/session_factory.rs index 8484c4f..741e98d 100644 --- a/src/gateway/session_factory.rs +++ b/src/gateway/session_factory.rs @@ -20,6 +20,7 @@ pub(crate) struct SessionFactory { prompt_injector: PromptInjector, conversations: Arc, skill_events: Arc, + chat_history_ttl_hours: Option, } impl SessionFactory { @@ -30,6 +31,7 @@ impl SessionFactory { prompt_injector: PromptInjector, conversations: Arc, skill_events: Arc, + chat_history_ttl_hours: Option, ) -> Self { Self { provider_config, @@ -38,6 +40,7 @@ impl SessionFactory { prompt_injector, conversations, skill_events, + chat_history_ttl_hours, } } @@ -55,6 +58,7 @@ impl SessionFactory { self.prompt_injector.clone(), self.conversations.clone(), self.skill_events.clone(), + self.chat_history_ttl_hours, ) .await } diff --git a/src/gateway/session_history.rs b/src/gateway/session_history.rs index a19923f..90247c8 100644 --- a/src/gateway/session_history.rs +++ b/src/gateway/session_history.rs @@ -19,6 +19,13 @@ fn preview_text(content: &str, max_chars: usize) -> String { preview.replace('\n', "\\n") } +fn current_timestamp() -> i64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() as i64 +} + pub(crate) struct SessionHistory { channel_name: String, chat_histories: HashMap>, @@ -27,6 +34,7 @@ pub(crate) struct SessionHistory { conversations: Arc, skill_events: Arc, provider_config: LLMProviderConfig, + chat_history_ttl_hours: Option, } impl SessionHistory { @@ -36,6 +44,7 @@ impl SessionHistory { conversations: Arc, skill_events: Arc, provider_config: LLMProviderConfig, + chat_history_ttl_hours: Option, ) -> Self { Self { channel_name: channel_name.into(), @@ -45,6 +54,7 @@ impl SessionHistory { conversations, skill_events, provider_config, + chat_history_ttl_hours, } } @@ -62,6 +72,29 @@ impl SessionHistory { } pub(crate) fn ensure_chat_loaded(&mut self, chat_id: &str) -> Result<(), AgentError> { + // 获取 session 记录(用于检查最后活跃时间) + let session_record = self.ensure_persistent_session(chat_id)?; + + // 检查是否超时 + if let Some(ttl_hours) = self.chat_history_ttl_hours { + if ttl_hours > 0 { + let now = current_timestamp(); + let elapsed_hours = (now - session_record.last_active_at) / 3600; + if elapsed_hours >= ttl_hours as i64 { + tracing::info!( + channel = %self.channel_name, + chat_id = %chat_id, + elapsed_hours = elapsed_hours, + ttl_hours = ttl_hours, + "Chat history expired, resetting context" + ); + // 重置会话上下文(清空内存历史,但保留数据库记录) + self.reset_chat_context(chat_id)?; + } + } + } + + // 原有逻辑 if self.chat_histories.contains_key(chat_id) { return self.ensure_initial_agent_prompt(chat_id); } diff --git a/src/gateway/session_lifecycle.rs b/src/gateway/session_lifecycle.rs index e81af98..e433269 100644 --- a/src/gateway/session_lifecycle.rs +++ b/src/gateway/session_lifecycle.rs @@ -14,9 +14,9 @@ pub(crate) struct SessionLifecycleService { } impl SessionLifecycleService { - pub(crate) fn new(session_ttl_hours: u64, session_factory: SessionFactory) -> Self { + pub(crate) fn new(session_factory: SessionFactory) -> Self { Self { - session_pool: SessionPool::new(session_ttl_hours, session_factory), + session_pool: SessionPool::new(session_factory), } } diff --git a/src/gateway/session_pool.rs b/src/gateway/session_pool.rs index 887ea95..9aa156c 100644 --- a/src/gateway/session_pool.rs +++ b/src/gateway/session_pool.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; use std::sync::Arc; -use std::time::{Duration, Instant}; +use std::time::Instant; use tokio::sync::{Mutex, mpsc}; @@ -19,16 +19,14 @@ pub(crate) struct SessionPool { struct SessionPoolInner { sessions: HashMap>>, session_timestamps: HashMap, - session_ttl: Duration, } impl SessionPool { - pub(crate) fn new(session_ttl_hours: u64, session_factory: SessionFactory) -> Self { + pub(crate) fn new(session_factory: SessionFactory) -> Self { Self { inner: Arc::new(Mutex::new(SessionPoolInner { sessions: HashMap::new(), session_timestamps: HashMap::new(), - session_ttl: Duration::from_secs(session_ttl_hours * 3600), })), session_factory, } @@ -37,38 +35,25 @@ impl SessionPool { pub(crate) async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> { let mut inner = self.inner.lock().await; - let should_recreate = if let Some(last_active) = inner.session_timestamps.get(channel_name) - { - let elapsed = last_active.elapsed(); - if elapsed > inner.session_ttl { - tracing::info!(channel = %channel_name, elapsed_hours = elapsed.as_secs() / 3600, "Session expired, recreating"); - true - } else { - false - } - } else { - #[cfg(debug_assertions)] - tracing::debug!(channel = %channel_name, "Creating new session"); - true - }; - - if should_recreate { - inner.sessions.remove(channel_name); - - let (user_tx, _rx) = mpsc::channel::(100); - let session = self - .session_factory - .create(channel_name.to_string(), user_tx) - .await?; - - inner - .sessions - .insert(channel_name.to_string(), Arc::new(Mutex::new(session))); - inner - .session_timestamps - .insert(channel_name.to_string(), Instant::now()); + // 简化:只检查 session 是否存在,不做超时判断 + if inner.sessions.contains_key(channel_name) { + return Ok(()); } + // Session 不存在则创建 + let (user_tx, _rx) = mpsc::channel::(100); + let session = self + .session_factory + .create(channel_name.to_string(), user_tx) + .await?; + + inner + .sessions + .insert(channel_name.to_string(), Arc::new(Mutex::new(session))); + inner + .session_timestamps + .insert(channel_name.to_string(), Instant::now()); + Ok(()) } @@ -85,25 +70,7 @@ impl SessionPool { } pub(crate) async fn cleanup_expired_sessions(&self) -> usize { - let mut inner = self.inner.lock().await; - let now = Instant::now(); - let expired_channels: Vec = inner - .session_timestamps - .iter() - .filter_map(|(channel_name, last_active)| { - if now.duration_since(*last_active) > inner.session_ttl { - Some(channel_name.clone()) - } else { - None - } - }) - .collect(); - - for channel_name in &expired_channels { - inner.sessions.remove(channel_name); - inner.session_timestamps.remove(channel_name); - } - - expired_channels.len() + // Session 级别不再自动清理,返回 0 + 0 } }