From 11a8e93b77c4f9c3d10f16d7e44958f3617cea49 Mon Sep 17 00:00:00 2001 From: xiaoxixi Date: Sun, 10 May 2026 15:12:30 +0800 Subject: [PATCH] feat(chat-manager): enhance message retrieval with pagination and time range filtering --- src/agent/system_prompt.rs | 9 +- src/storage/mod.rs | 71 ++++++++++++++++ src/tools/chat_manager.rs | 165 +++++++++++++++++++++++++++++++++++-- 3 files changed, 235 insertions(+), 10 deletions(-) diff --git a/src/agent/system_prompt.rs b/src/agent/system_prompt.rs index a23611e..116850a 100644 --- a/src/agent/system_prompt.rs +++ b/src/agent/system_prompt.rs @@ -261,7 +261,14 @@ impl PromptSection for CrossChannelSection { 管理会话和查看消息。参数: - action = "list_sessions" — 列出最近活跃的会话 - action = "list_channels" — 列出所有可用渠道 -- action = "list_messages" — 查看指定 session 的最近消息,需提供 session_id 和 count"#, +- action = "list_messages" — 查看指定 session 的历史消息,支持以下参数: + - session_id (必填): 会话 ID + - count (可选): 返回数量,默认 20,最大 100 + - offset (可选): 跳过前 N 条,用于翻页查看更早历史,默认 0 + - before_time (可选): Unix 时间戳(秒),只返回该时间之前的消息 + - after_time (可选): Unix 时间戳(秒),只返回该时间之后的消息 + +当用户要求回顾历史、查找之前的消息、或你记不清之前的对话内容时,可以使用此工具的 list_messages 动作,通过调整 offset 或指定时间范围来查询具体的历史消息。"#, session_line ) } diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 276ebb9..17d7fd8 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -692,6 +692,77 @@ impl Storage { Ok(messages) } + pub async fn query_messages_range( + &self, + session_id: &str, + before_time: Option, + after_time: Option, + offset: i64, + limit: i64, + ) -> Result<(Vec, i64), StorageError> { + let mut where_extra = String::new(); + if before_time.is_some() { + where_extra.push_str(" AND created_at < ?"); + } + if after_time.is_some() { + where_extra.push_str(" AND created_at > ?"); + } + + let count_sql = format!("SELECT COUNT(*) as total FROM messages WHERE session_id = ?{}", where_extra); + let select_sql = format!( + r#" + SELECT id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at + FROM messages + WHERE session_id = ?{} + ORDER BY seq ASC + LIMIT ? OFFSET ? + "#, + where_extra + ); + + let mut count_query = sqlx::query(&count_sql).bind(session_id); + if let Some(bt) = before_time { + count_query = count_query.bind(bt); + } + if let Some(at) = after_time { + count_query = count_query.bind(at); + } + let count_row = count_query.fetch_one(self.pool()).await?; + let total: i64 = count_row.get("total"); + + let mut select_query = sqlx::query(&select_sql).bind(session_id); + if let Some(bt) = before_time { + select_query = select_query.bind(bt); + } + if let Some(at) = after_time { + select_query = select_query.bind(at); + } + let rows = select_query + .bind(limit) + .bind(offset) + .fetch_all(self.pool()) + .await?; + + let messages: Vec<_> = rows + .into_iter() + .map(|row| crate::storage::message::MessageMeta { + id: row.get("id"), + session_id: row.get("session_id"), + seq: row.get("seq"), + role: row.get("role"), + content: row.get("content"), + media_refs: row.get("media_refs"), + tool_call_id: row.get("tool_call_id"), + tool_name: row.get("tool_name"), + tool_calls: row.get("tool_calls"), + source: row.get("source"), + created_at: row.get("created_at"), + }) + .collect(); + + Ok((messages, total)) + } + pub async fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> { sqlx::query(r#"DELETE FROM messages WHERE session_id = ?"#) .bind(session_id) diff --git a/src/tools/chat_manager.rs b/src/tools/chat_manager.rs index 51cdcfb..50a4508 100644 --- a/src/tools/chat_manager.rs +++ b/src/tools/chat_manager.rs @@ -27,8 +27,8 @@ impl Tool for ChatManagerTool { } fn description(&self) -> &str { - "聊天管理工具。可以列出当前活跃的 session、可用的 channel、以及查看指定 session 的最近消息内容。\ -action 可选值: list_sessions (列出最近活跃会话), list_channels (列出可用渠道), list_messages (查看最近消息)" + "聊天管理工具。可以列出当前活跃的 session、可用的 channel,以及查看指定 session 的消息内容,支持时间范围筛选和分页翻页。\ +action 可选值: list_sessions (列出最近活跃会话), list_channels (列出可用渠道), list_messages (查看消息)" } fn parameters_schema(&self) -> serde_json::Value { @@ -38,7 +38,7 @@ action 可选值: list_sessions (列出最近活跃会话), list_channels (列 "action": { "type": "string", "enum": ["list_sessions", "list_channels", "list_messages"], - "description": "操作类型: list_sessions 列出最近活跃会话, list_channels 列出可用渠道, list_messages 查看指定会话的最近消息" + "description": "操作类型: list_sessions 列出最近活跃会话, list_channels 列出可用渠道, list_messages 查看指定会话的消息" }, "session_id": { "type": "string", @@ -46,7 +46,19 @@ action 可选值: list_sessions (列出最近活跃会话), list_channels (列 }, "count": { "type": "integer", - "description": "获取最近消息的数量,仅在 action 为 list_messages 时有效,默认 20" + "description": "获取消息的数量,仅在 action 为 list_messages 时有效,默认 20,最大 100" + }, + "offset": { + "type": "integer", + "description": "跳过前 N 条消息(用于翻页),仅在 action 为 list_messages 时有效,默认 0" + }, + "before_time": { + "type": "integer", + "description": "Unix 时间戳(秒),仅返回此时间之前的消息,仅在 action 为 list_messages 时有效" + }, + "after_time": { + "type": "integer", + "description": "Unix 时间戳(秒),仅返回此时间之后的消息,仅在 action 为 list_messages 时有效" } }, "required": ["action"] @@ -131,6 +143,10 @@ impl ChatManagerTool { .ok_or_else(|| anyhow::anyhow!("missing required parameter: session_id"))?; let count = args["count"].as_i64().unwrap_or(20).clamp(1, 100); + let offset = args["offset"].as_i64().unwrap_or(0).max(0); + + let before_time = args["before_time"].as_i64().map(|t| t * 1000); + let after_time = args["after_time"].as_i64().map(|t| t * 1000); let session = self .storage @@ -138,15 +154,31 @@ impl ChatManagerTool { .await .map_err(|e| anyhow::anyhow!("Session not found: {}", e))?; - let messages = self + let (messages, total) = self .storage - .list_recent_messages(session_id, count) + .query_messages_range(session_id, before_time, after_time, offset, count) .await .map_err(|e| anyhow::anyhow!("Failed to load messages: {}", e))?; + let start_num = offset + 1; + let end_num = offset + messages.len() as i64; + + let range_desc = if messages.is_empty() { + "无消息".to_string() + } else { + format!("第 {}-{} 条", start_num, end_num) + }; + + let filter_desc = match (before_time, after_time) { + (Some(_), Some(_)) => "(已按时间范围筛选)", + (Some(_), None) => "(已按截止时间筛选)", + (None, Some(_)) => "(已按起始时间筛选)", + (None, None) => "", + }; + let mut output = format!( - "会话: {} ({})\n--- 最近 {} 条消息 (共 {} 条) ---\n", - session_id, session.title, messages.len(), session.message_count + "会话: {} ({})\n--- 消息 {} / 共 {} 条 {} ---\n", + session_id, session.title, range_desc, total, filter_desc ); if messages.is_empty() { @@ -282,7 +314,7 @@ mod tests { } #[tokio::test] - async fn test_list_messages() { + async fn test_list_messages_default() { let (storage, _dir) = create_test_storage().await; let now = chrono::Utc::now().timestamp_millis(); @@ -330,6 +362,121 @@ mod tests { assert!(result.output.contains("消息内容 0")); assert!(result.output.contains("消息内容 2")); assert!(result.output.contains("测试会话")); + assert!(result.output.contains("共 3 条")); + // Verify ascending order: seq 1 before seq 3 + let pos_0 = result.output.find("消息内容 0").unwrap(); + let pos_2 = result.output.find("消息内容 2").unwrap(); + assert!(pos_0 < pos_2, "Messages should be in ascending order"); + } + + #[tokio::test] + async fn test_list_messages_with_pagination() { + let (storage, _dir) = create_test_storage().await; + + let now = chrono::Utc::now().timestamp_millis(); + let session_id = "cli_chat:sid0:dialog0"; + let meta = crate::storage::session::SessionMeta { + id: session_id.to_string(), + channel: "cli_chat".to_string(), + chat_id: "sid0".to_string(), + dialog_id: "dialog0".to_string(), + title: "分页测试".to_string(), + created_at: now, + last_active_at: now, + message_count: 5, + routing_info: None, + deleted_at: None, + last_consolidated_at: None, + last_compressed_message_at: None, + }; + storage.upsert_session(&meta).await.unwrap(); + + for i in 0..5 { + let msg = crate::storage::message::MessageMeta { + id: format!("msg{}", i), + session_id: session_id.to_string(), + seq: i as i64 + 1, + role: if i % 2 == 0 { "user".to_string() } else { "assistant".to_string() }, + content: format!("消息内容 {}", i), + media_refs: None, + tool_call_id: None, + tool_name: None, + tool_calls: None, + source: None, + created_at: now + i * 1000, + }; + storage.append_message(session_id, &msg).await.unwrap(); + } + + let tool = ChatManagerTool::new(storage, vec![]); + + // offset=2, count=2 => should return messages 2,3 (seq 3,4) + let result = tool + .execute(json!({ "action": "list_messages", "session_id": session_id, "offset": 2, "count": 2 })) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("第 3-4 条")); + assert!(result.output.contains("消息内容 2")); + assert!(result.output.contains("消息内容 3")); + assert!(!result.output.contains("消息内容 0")); + assert!(result.output.contains("共 5 条")); + } + + #[tokio::test] + async fn test_list_messages_with_time_range() { + let (storage, _dir) = create_test_storage().await; + + let now = chrono::Utc::now().timestamp_millis(); + let session_id = "cli_chat:sid0:dialog0"; + let meta = crate::storage::session::SessionMeta { + id: session_id.to_string(), + channel: "cli_chat".to_string(), + chat_id: "sid0".to_string(), + dialog_id: "dialog0".to_string(), + title: "时间范围测试".to_string(), + created_at: now, + last_active_at: now, + message_count: 5, + routing_info: None, + deleted_at: None, + last_consolidated_at: None, + last_compressed_message_at: None, + }; + storage.upsert_session(&meta).await.unwrap(); + + for i in 0..5 { + let msg = crate::storage::message::MessageMeta { + id: format!("msg{}", i), + session_id: session_id.to_string(), + seq: i as i64 + 1, + role: "user".to_string(), + content: format!("消息内容 {}", i), + media_refs: None, + tool_call_id: None, + tool_name: None, + tool_calls: None, + source: None, + created_at: now + i * 1000, + }; + storage.append_message(session_id, &msg).await.unwrap(); + } + + let tool = ChatManagerTool::new(storage, vec![]); + + // after_time: filter to messages after msg2's timestamp + let after_ts = (now + 1500) / 1000; + let result = tool + .execute(json!({ "action": "list_messages", "session_id": session_id, "after_time": after_ts })) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("已按起始时间筛选")); + assert!(result.output.contains("消息内容 2")); + assert!(result.output.contains("消息内容 3")); + assert!(result.output.contains("消息内容 4")); + assert!(!result.output.contains("消息内容 0")); + assert!(!result.output.contains("消息内容 1")); } #[tokio::test]