feat(chat-manager): enhance message retrieval with pagination and time range filtering
This commit is contained in:
parent
8c0c76a232
commit
11a8e93b77
@ -261,7 +261,14 @@ impl PromptSection for CrossChannelSection {
|
|||||||
管理会话和查看消息。参数:
|
管理会话和查看消息。参数:
|
||||||
- action = "list_sessions" — 列出最近活跃的会话
|
- action = "list_sessions" — 列出最近活跃的会话
|
||||||
- action = "list_channels" — 列出所有可用渠道
|
- action = "list_channels" — 列出所有可用渠道
|
||||||
- action = "list_messages" — 查看指定 session 的最近消息,需提供 session_id 和 count"#,
|
- action = "list_messages" — 查看指定 session 的历史消息,支持以下参数:
|
||||||
|
- session_id (必填): 会话 ID
|
||||||
|
- count (可选): 返回数量,默认 20,最大 100
|
||||||
|
- offset (可选): 跳过前 N 条,用于翻页查看更早历史,默认 0
|
||||||
|
- before_time (可选): Unix 时间戳(秒),只返回该时间之前的消息
|
||||||
|
- after_time (可选): Unix 时间戳(秒),只返回该时间之后的消息
|
||||||
|
|
||||||
|
当用户要求回顾历史、查找之前的消息、或你记不清之前的对话内容时,可以使用此工具的 list_messages 动作,通过调整 offset 或指定时间范围来查询具体的历史消息。"#,
|
||||||
session_line
|
session_line
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -692,6 +692,77 @@ impl Storage {
|
|||||||
Ok(messages)
|
Ok(messages)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn query_messages_range(
|
||||||
|
&self,
|
||||||
|
session_id: &str,
|
||||||
|
before_time: Option<i64>,
|
||||||
|
after_time: Option<i64>,
|
||||||
|
offset: i64,
|
||||||
|
limit: i64,
|
||||||
|
) -> Result<(Vec<crate::storage::message::MessageMeta>, i64), StorageError> {
|
||||||
|
let mut where_extra = String::new();
|
||||||
|
if before_time.is_some() {
|
||||||
|
where_extra.push_str(" AND created_at < ?");
|
||||||
|
}
|
||||||
|
if after_time.is_some() {
|
||||||
|
where_extra.push_str(" AND created_at > ?");
|
||||||
|
}
|
||||||
|
|
||||||
|
let count_sql = format!("SELECT COUNT(*) as total FROM messages WHERE session_id = ?{}", where_extra);
|
||||||
|
let select_sql = format!(
|
||||||
|
r#"
|
||||||
|
SELECT id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at
|
||||||
|
FROM messages
|
||||||
|
WHERE session_id = ?{}
|
||||||
|
ORDER BY seq ASC
|
||||||
|
LIMIT ? OFFSET ?
|
||||||
|
"#,
|
||||||
|
where_extra
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut count_query = sqlx::query(&count_sql).bind(session_id);
|
||||||
|
if let Some(bt) = before_time {
|
||||||
|
count_query = count_query.bind(bt);
|
||||||
|
}
|
||||||
|
if let Some(at) = after_time {
|
||||||
|
count_query = count_query.bind(at);
|
||||||
|
}
|
||||||
|
let count_row = count_query.fetch_one(self.pool()).await?;
|
||||||
|
let total: i64 = count_row.get("total");
|
||||||
|
|
||||||
|
let mut select_query = sqlx::query(&select_sql).bind(session_id);
|
||||||
|
if let Some(bt) = before_time {
|
||||||
|
select_query = select_query.bind(bt);
|
||||||
|
}
|
||||||
|
if let Some(at) = after_time {
|
||||||
|
select_query = select_query.bind(at);
|
||||||
|
}
|
||||||
|
let rows = select_query
|
||||||
|
.bind(limit)
|
||||||
|
.bind(offset)
|
||||||
|
.fetch_all(self.pool())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let messages: Vec<_> = rows
|
||||||
|
.into_iter()
|
||||||
|
.map(|row| crate::storage::message::MessageMeta {
|
||||||
|
id: row.get("id"),
|
||||||
|
session_id: row.get("session_id"),
|
||||||
|
seq: row.get("seq"),
|
||||||
|
role: row.get("role"),
|
||||||
|
content: row.get("content"),
|
||||||
|
media_refs: row.get("media_refs"),
|
||||||
|
tool_call_id: row.get("tool_call_id"),
|
||||||
|
tool_name: row.get("tool_name"),
|
||||||
|
tool_calls: row.get("tool_calls"),
|
||||||
|
source: row.get("source"),
|
||||||
|
created_at: row.get("created_at"),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok((messages, total))
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> {
|
pub async fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> {
|
||||||
sqlx::query(r#"DELETE FROM messages WHERE session_id = ?"#)
|
sqlx::query(r#"DELETE FROM messages WHERE session_id = ?"#)
|
||||||
.bind(session_id)
|
.bind(session_id)
|
||||||
|
|||||||
@ -27,8 +27,8 @@ impl Tool for ChatManagerTool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn description(&self) -> &str {
|
fn description(&self) -> &str {
|
||||||
"聊天管理工具。可以列出当前活跃的 session、可用的 channel、以及查看指定 session 的最近消息内容。\
|
"聊天管理工具。可以列出当前活跃的 session、可用的 channel,以及查看指定 session 的消息内容,支持时间范围筛选和分页翻页。\
|
||||||
action 可选值: list_sessions (列出最近活跃会话), list_channels (列出可用渠道), list_messages (查看最近消息)"
|
action 可选值: list_sessions (列出最近活跃会话), list_channels (列出可用渠道), list_messages (查看消息)"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parameters_schema(&self) -> serde_json::Value {
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
@ -38,7 +38,7 @@ action 可选值: list_sessions (列出最近活跃会话), list_channels (列
|
|||||||
"action": {
|
"action": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": ["list_sessions", "list_channels", "list_messages"],
|
"enum": ["list_sessions", "list_channels", "list_messages"],
|
||||||
"description": "操作类型: list_sessions 列出最近活跃会话, list_channels 列出可用渠道, list_messages 查看指定会话的最近消息"
|
"description": "操作类型: list_sessions 列出最近活跃会话, list_channels 列出可用渠道, list_messages 查看指定会话的消息"
|
||||||
},
|
},
|
||||||
"session_id": {
|
"session_id": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@ -46,7 +46,19 @@ action 可选值: list_sessions (列出最近活跃会话), list_channels (列
|
|||||||
},
|
},
|
||||||
"count": {
|
"count": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"description": "获取最近消息的数量,仅在 action 为 list_messages 时有效,默认 20"
|
"description": "获取消息的数量,仅在 action 为 list_messages 时有效,默认 20,最大 100"
|
||||||
|
},
|
||||||
|
"offset": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "跳过前 N 条消息(用于翻页),仅在 action 为 list_messages 时有效,默认 0"
|
||||||
|
},
|
||||||
|
"before_time": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Unix 时间戳(秒),仅返回此时间之前的消息,仅在 action 为 list_messages 时有效"
|
||||||
|
},
|
||||||
|
"after_time": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Unix 时间戳(秒),仅返回此时间之后的消息,仅在 action 为 list_messages 时有效"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": ["action"]
|
"required": ["action"]
|
||||||
@ -131,6 +143,10 @@ impl ChatManagerTool {
|
|||||||
.ok_or_else(|| anyhow::anyhow!("missing required parameter: session_id"))?;
|
.ok_or_else(|| anyhow::anyhow!("missing required parameter: session_id"))?;
|
||||||
|
|
||||||
let count = args["count"].as_i64().unwrap_or(20).clamp(1, 100);
|
let count = args["count"].as_i64().unwrap_or(20).clamp(1, 100);
|
||||||
|
let offset = args["offset"].as_i64().unwrap_or(0).max(0);
|
||||||
|
|
||||||
|
let before_time = args["before_time"].as_i64().map(|t| t * 1000);
|
||||||
|
let after_time = args["after_time"].as_i64().map(|t| t * 1000);
|
||||||
|
|
||||||
let session = self
|
let session = self
|
||||||
.storage
|
.storage
|
||||||
@ -138,15 +154,31 @@ impl ChatManagerTool {
|
|||||||
.await
|
.await
|
||||||
.map_err(|e| anyhow::anyhow!("Session not found: {}", e))?;
|
.map_err(|e| anyhow::anyhow!("Session not found: {}", e))?;
|
||||||
|
|
||||||
let messages = self
|
let (messages, total) = self
|
||||||
.storage
|
.storage
|
||||||
.list_recent_messages(session_id, count)
|
.query_messages_range(session_id, before_time, after_time, offset, count)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| anyhow::anyhow!("Failed to load messages: {}", e))?;
|
.map_err(|e| anyhow::anyhow!("Failed to load messages: {}", e))?;
|
||||||
|
|
||||||
|
let start_num = offset + 1;
|
||||||
|
let end_num = offset + messages.len() as i64;
|
||||||
|
|
||||||
|
let range_desc = if messages.is_empty() {
|
||||||
|
"无消息".to_string()
|
||||||
|
} else {
|
||||||
|
format!("第 {}-{} 条", start_num, end_num)
|
||||||
|
};
|
||||||
|
|
||||||
|
let filter_desc = match (before_time, after_time) {
|
||||||
|
(Some(_), Some(_)) => "(已按时间范围筛选)",
|
||||||
|
(Some(_), None) => "(已按截止时间筛选)",
|
||||||
|
(None, Some(_)) => "(已按起始时间筛选)",
|
||||||
|
(None, None) => "",
|
||||||
|
};
|
||||||
|
|
||||||
let mut output = format!(
|
let mut output = format!(
|
||||||
"会话: {} ({})\n--- 最近 {} 条消息 (共 {} 条) ---\n",
|
"会话: {} ({})\n--- 消息 {} / 共 {} 条 {} ---\n",
|
||||||
session_id, session.title, messages.len(), session.message_count
|
session_id, session.title, range_desc, total, filter_desc
|
||||||
);
|
);
|
||||||
|
|
||||||
if messages.is_empty() {
|
if messages.is_empty() {
|
||||||
@ -282,7 +314,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_list_messages() {
|
async fn test_list_messages_default() {
|
||||||
let (storage, _dir) = create_test_storage().await;
|
let (storage, _dir) = create_test_storage().await;
|
||||||
|
|
||||||
let now = chrono::Utc::now().timestamp_millis();
|
let now = chrono::Utc::now().timestamp_millis();
|
||||||
@ -330,6 +362,121 @@ mod tests {
|
|||||||
assert!(result.output.contains("消息内容 0"));
|
assert!(result.output.contains("消息内容 0"));
|
||||||
assert!(result.output.contains("消息内容 2"));
|
assert!(result.output.contains("消息内容 2"));
|
||||||
assert!(result.output.contains("测试会话"));
|
assert!(result.output.contains("测试会话"));
|
||||||
|
assert!(result.output.contains("共 3 条"));
|
||||||
|
// Verify ascending order: seq 1 before seq 3
|
||||||
|
let pos_0 = result.output.find("消息内容 0").unwrap();
|
||||||
|
let pos_2 = result.output.find("消息内容 2").unwrap();
|
||||||
|
assert!(pos_0 < pos_2, "Messages should be in ascending order");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_list_messages_with_pagination() {
|
||||||
|
let (storage, _dir) = create_test_storage().await;
|
||||||
|
|
||||||
|
let now = chrono::Utc::now().timestamp_millis();
|
||||||
|
let session_id = "cli_chat:sid0:dialog0";
|
||||||
|
let meta = crate::storage::session::SessionMeta {
|
||||||
|
id: session_id.to_string(),
|
||||||
|
channel: "cli_chat".to_string(),
|
||||||
|
chat_id: "sid0".to_string(),
|
||||||
|
dialog_id: "dialog0".to_string(),
|
||||||
|
title: "分页测试".to_string(),
|
||||||
|
created_at: now,
|
||||||
|
last_active_at: now,
|
||||||
|
message_count: 5,
|
||||||
|
routing_info: None,
|
||||||
|
deleted_at: None,
|
||||||
|
last_consolidated_at: None,
|
||||||
|
last_compressed_message_at: None,
|
||||||
|
};
|
||||||
|
storage.upsert_session(&meta).await.unwrap();
|
||||||
|
|
||||||
|
for i in 0..5 {
|
||||||
|
let msg = crate::storage::message::MessageMeta {
|
||||||
|
id: format!("msg{}", i),
|
||||||
|
session_id: session_id.to_string(),
|
||||||
|
seq: i as i64 + 1,
|
||||||
|
role: if i % 2 == 0 { "user".to_string() } else { "assistant".to_string() },
|
||||||
|
content: format!("消息内容 {}", i),
|
||||||
|
media_refs: None,
|
||||||
|
tool_call_id: None,
|
||||||
|
tool_name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
source: None,
|
||||||
|
created_at: now + i * 1000,
|
||||||
|
};
|
||||||
|
storage.append_message(session_id, &msg).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let tool = ChatManagerTool::new(storage, vec![]);
|
||||||
|
|
||||||
|
// offset=2, count=2 => should return messages 2,3 (seq 3,4)
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({ "action": "list_messages", "session_id": session_id, "offset": 2, "count": 2 }))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(result.success);
|
||||||
|
assert!(result.output.contains("第 3-4 条"));
|
||||||
|
assert!(result.output.contains("消息内容 2"));
|
||||||
|
assert!(result.output.contains("消息内容 3"));
|
||||||
|
assert!(!result.output.contains("消息内容 0"));
|
||||||
|
assert!(result.output.contains("共 5 条"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_list_messages_with_time_range() {
|
||||||
|
let (storage, _dir) = create_test_storage().await;
|
||||||
|
|
||||||
|
let now = chrono::Utc::now().timestamp_millis();
|
||||||
|
let session_id = "cli_chat:sid0:dialog0";
|
||||||
|
let meta = crate::storage::session::SessionMeta {
|
||||||
|
id: session_id.to_string(),
|
||||||
|
channel: "cli_chat".to_string(),
|
||||||
|
chat_id: "sid0".to_string(),
|
||||||
|
dialog_id: "dialog0".to_string(),
|
||||||
|
title: "时间范围测试".to_string(),
|
||||||
|
created_at: now,
|
||||||
|
last_active_at: now,
|
||||||
|
message_count: 5,
|
||||||
|
routing_info: None,
|
||||||
|
deleted_at: None,
|
||||||
|
last_consolidated_at: None,
|
||||||
|
last_compressed_message_at: None,
|
||||||
|
};
|
||||||
|
storage.upsert_session(&meta).await.unwrap();
|
||||||
|
|
||||||
|
for i in 0..5 {
|
||||||
|
let msg = crate::storage::message::MessageMeta {
|
||||||
|
id: format!("msg{}", i),
|
||||||
|
session_id: session_id.to_string(),
|
||||||
|
seq: i as i64 + 1,
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: format!("消息内容 {}", i),
|
||||||
|
media_refs: None,
|
||||||
|
tool_call_id: None,
|
||||||
|
tool_name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
source: None,
|
||||||
|
created_at: now + i * 1000,
|
||||||
|
};
|
||||||
|
storage.append_message(session_id, &msg).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let tool = ChatManagerTool::new(storage, vec![]);
|
||||||
|
|
||||||
|
// after_time: filter to messages after msg2's timestamp
|
||||||
|
let after_ts = (now + 1500) / 1000;
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({ "action": "list_messages", "session_id": session_id, "after_time": after_ts }))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(result.success);
|
||||||
|
assert!(result.output.contains("已按起始时间筛选"));
|
||||||
|
assert!(result.output.contains("消息内容 2"));
|
||||||
|
assert!(result.output.contains("消息内容 3"));
|
||||||
|
assert!(result.output.contains("消息内容 4"));
|
||||||
|
assert!(!result.output.contains("消息内容 0"));
|
||||||
|
assert!(!result.output.contains("消息内容 1"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user