Compare commits
No commits in common. "abb2d596f410be9fcbe353fb5a4e60d1c0783acc" and "b5a1635a05949dc3dcf5f7bbedec5ac3c052e41a" have entirely different histories.
abb2d596f4
...
b5a1635a05
@ -2,10 +2,10 @@ use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{
|
||||
Arc,
|
||||
Arc, Mutex,
|
||||
atomic::{AtomicBool, Ordering},
|
||||
};
|
||||
use std::time::UNIX_EPOCH;
|
||||
use std::time::{Duration, UNIX_EPOCH};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures_util::FutureExt;
|
||||
@ -18,6 +18,44 @@ use crate::bus::message::OutboundEventKind;
|
||||
use crate::channels::base::{Channel, ChannelError};
|
||||
use crate::config::{LLMProviderConfig, WechatChannelConfig};
|
||||
|
||||
/// Rate limiter: ensures minimum interval between messages to the same chat.
|
||||
static LAST_SEND: std::sync::LazyLock<Mutex<HashMap<String, tokio::time::Instant>>> =
|
||||
std::sync::LazyLock::new(|| Mutex::new(HashMap::new()));
|
||||
|
||||
/// Minimum interval between consecutive messages to the same WeChat user.
|
||||
const MIN_MSG_INTERVAL_MS: u64 = 1500;
|
||||
|
||||
/// Wait if needed to respect the rate limit, then record the send time.
|
||||
async fn throttle(chat_id: &str) {
|
||||
let now = tokio::time::Instant::now();
|
||||
let sleep_ms = {
|
||||
let mut map = LAST_SEND.lock().unwrap();
|
||||
let delay = map.get(chat_id).and_then(|last| {
|
||||
let elapsed = now.duration_since(*last);
|
||||
if elapsed < Duration::from_millis(MIN_MSG_INTERVAL_MS) {
|
||||
Some(Duration::from_millis(MIN_MSG_INTERVAL_MS) - elapsed)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
// Record the expected send time now (before sleep), so concurrent
|
||||
// callers also see the updated time and don't race.
|
||||
map.insert(chat_id.to_string(), now);
|
||||
delay
|
||||
};
|
||||
if let Some(ms) = sleep_ms {
|
||||
tokio::time::sleep(ms).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Update rate-limit timestamp (without waiting) — used after each chunk.
|
||||
fn touch_rate_limit(chat_id: &str) {
|
||||
LAST_SEND
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(chat_id.to_string(), tokio::time::Instant::now());
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WechatChannel {
|
||||
name: String,
|
||||
@ -303,17 +341,8 @@ impl Channel for WechatChannel {
|
||||
}
|
||||
|
||||
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
|
||||
// WeChat iLink Bot has a ~10-message burst limit per context_token.
|
||||
// Filter non-essential message types to conserve budget:
|
||||
// - ToolCall: internal tool invocation details, not useful to WeChat users
|
||||
// - ToolResult / ToolPending: raw tool output, not user-facing
|
||||
// - Subagent events: internal agent orchestration
|
||||
if matches!(
|
||||
msg.event_kind,
|
||||
OutboundEventKind::ToolResult
|
||||
| OutboundEventKind::ToolPending
|
||||
| OutboundEventKind::ToolCall
|
||||
) || msg.metadata.get("is_subagent_event").map(|v| v == "true").unwrap_or(false)
|
||||
if matches!(msg.event_kind, OutboundEventKind::ToolResult | OutboundEventKind::ToolPending)
|
||||
|| msg.metadata.get("is_subagent_event").map(|v| v == "true").unwrap_or(false)
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
@ -322,15 +351,42 @@ impl Channel for WechatChannel {
|
||||
let mut text_sent = false;
|
||||
|
||||
if !text.is_empty() {
|
||||
self.bot.send(&msg.chat_id, &text).await.map_err(|error| {
|
||||
ChannelError::SendError(format!("WeChat text send failed: {}", error))
|
||||
})?;
|
||||
// Rate limit: ensure minimum interval between messages to the same user
|
||||
throttle(&msg.chat_id).await;
|
||||
|
||||
let chunks = split_text(&text, MAX_WECHAT_CHUNK_CHARS);
|
||||
if chunks.len() > 1 {
|
||||
tracing::info!(
|
||||
channel = %self.name,
|
||||
chat_id = %msg.chat_id,
|
||||
content_len = text.len(),
|
||||
total_chars = text.len(),
|
||||
chunk_count = chunks.len(),
|
||||
"WeChat: splitting long message into chunks"
|
||||
);
|
||||
}
|
||||
for (i, chunk) in chunks.iter().enumerate() {
|
||||
if i > 0 {
|
||||
tokio::time::sleep(Duration::from_millis(CHUNK_SEND_INTERVAL_MS)).await;
|
||||
}
|
||||
self.bot.send(&msg.chat_id, chunk).await.map_err(|error| {
|
||||
ChannelError::SendError(format!(
|
||||
"WeChat text send failed (chunk {}/{}): {}",
|
||||
i + 1,
|
||||
chunks.len(),
|
||||
error
|
||||
))
|
||||
})?;
|
||||
// Update rate-limit timestamp so the next message or chunk respects the gap
|
||||
touch_rate_limit(&msg.chat_id);
|
||||
tracing::info!(
|
||||
channel = %self.name,
|
||||
chat_id = %msg.chat_id,
|
||||
chunk = i + 1,
|
||||
total_chunks = chunks.len(),
|
||||
content_len = chunk.len(),
|
||||
"WeChat text message sent"
|
||||
);
|
||||
}
|
||||
text_sent = true;
|
||||
}
|
||||
|
||||
@ -365,6 +421,141 @@ impl Channel for WechatChannel {
|
||||
}
|
||||
}
|
||||
|
||||
/// Split text into chunks suitable for WeChat delivery.
|
||||
/// - Prefers splitting at paragraph breaks, then newlines, then sentence boundaries.
|
||||
/// - Avoids splitting in the middle of markdown tables and code blocks.
|
||||
const MAX_WECHAT_CHUNK_CHARS: usize = 2000;
|
||||
const CHUNK_SEND_INTERVAL_MS: u64 = 500;
|
||||
|
||||
fn split_text(text: &str, limit: usize) -> Vec<String> {
|
||||
if text.len() <= limit {
|
||||
return vec![text.to_string()];
|
||||
}
|
||||
let mut chunks = Vec::new();
|
||||
let mut remaining = text;
|
||||
while !remaining.is_empty() {
|
||||
if remaining.len() <= limit {
|
||||
chunks.push(remaining.to_string());
|
||||
break;
|
||||
}
|
||||
let end = remaining.floor_char_boundary(limit);
|
||||
let window = &remaining[..end];
|
||||
|
||||
// Find a safe split point, avoiding table/code-block interiors
|
||||
let cut = find_split_point(window, limit);
|
||||
chunks.push(remaining[..cut].to_string());
|
||||
remaining = remaining[cut..].trim_start();
|
||||
}
|
||||
if chunks.is_empty() {
|
||||
vec![String::new()]
|
||||
} else {
|
||||
chunks
|
||||
}
|
||||
}
|
||||
|
||||
/// Find the best split point in `window`, avoiding markdown table rows and code fences.
|
||||
fn find_split_point(window: &str, _limit: usize) -> usize {
|
||||
let end = window.len();
|
||||
|
||||
// Build a set of line-start indices that are "unsafe" to split before
|
||||
// because they're inside a markdown table or code block.
|
||||
let unsafe_starts = find_unsafe_line_starts(window);
|
||||
|
||||
// Try split points from best to worst, skipping unsafe ones
|
||||
for (delim, len) in &[
|
||||
("\n\n", 2), // paragraph break (best)
|
||||
("\n", 1), // newline
|
||||
("。", 3), // Chinese period
|
||||
("\n", 1), // any newline (retry with relaxed threshold)
|
||||
] {
|
||||
let min_pos = if *delim == "\n" && *len == 1 && end > 0 {
|
||||
// For the relaxed newline pass, accept any position
|
||||
0
|
||||
} else {
|
||||
end * 3 / 10
|
||||
};
|
||||
|
||||
match window.rfind(delim) {
|
||||
Some(pos) if pos >= min_pos => {
|
||||
let after = pos + len;
|
||||
// Check that the line starting at `after` is not inside a protected block
|
||||
if !unsafe_starts.contains(&after) {
|
||||
return after;
|
||||
}
|
||||
// If this split point is inside a protected block, keep looking earlier
|
||||
if let Some(prev) = window[..pos].rfind(delim) {
|
||||
let prev_after = prev + len;
|
||||
if prev_after >= min_pos && !unsafe_starts.contains(&prev_after) {
|
||||
return prev_after;
|
||||
}
|
||||
}
|
||||
// If still inside protected block, try earlier .find
|
||||
if let Some(earlier) = window[..pos].rfind("\n\n") {
|
||||
let earlier_after = earlier + 2;
|
||||
if !unsafe_starts.contains(&earlier_after) {
|
||||
return earlier_after;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Last resort: just cut at the character boundary (may break a table, but better than nothing)
|
||||
end
|
||||
}
|
||||
|
||||
/// Returns byte offsets of line starts that are "unsafe" to split before,
|
||||
/// because they fall inside a markdown table or code block.
|
||||
fn find_unsafe_line_starts(window: &str) -> Vec<usize> {
|
||||
let mut unsafe_starts = Vec::new();
|
||||
let mut in_code_block = false;
|
||||
let mut in_table = false;
|
||||
let mut pos = 0;
|
||||
|
||||
for line in window.split_inclusive('\n') {
|
||||
let trimmed = line.trim();
|
||||
let is_empty = trimmed.is_empty();
|
||||
|
||||
// Track code blocks
|
||||
if trimmed.starts_with("```") {
|
||||
if in_code_block {
|
||||
in_code_block = false;
|
||||
// The closing fence itself is safe after
|
||||
pos += line.len();
|
||||
continue;
|
||||
} else {
|
||||
in_code_block = true;
|
||||
pos += line.len();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if in_code_block {
|
||||
unsafe_starts.push(pos);
|
||||
pos += line.len();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Track markdown table rows
|
||||
let is_table_row = trimmed.starts_with('|') && trimmed.ends_with('|');
|
||||
|
||||
if is_table_row {
|
||||
in_table = true;
|
||||
unsafe_starts.push(pos);
|
||||
} else if in_table && !is_empty {
|
||||
// Non-empty non-table line after table: table ended on previous line
|
||||
in_table = false;
|
||||
} else if is_empty {
|
||||
in_table = false;
|
||||
}
|
||||
|
||||
pos += line.len();
|
||||
}
|
||||
|
||||
unsafe_starts
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@ -17,9 +17,8 @@ pub use ports::{
|
||||
};
|
||||
pub use records::{
|
||||
allowed_namespace_names, get_namespace_description, is_valid_namespace,
|
||||
ALLOWED_MEMORY_NAMESPACES, GLOBAL_SCOPE_KEY, MemoryRecord, MemoryUpsert, SchedulerJobRecord,
|
||||
SchedulerJobState, SchedulerJobStatus, SchedulerJobUpsert, SessionRecord, SkillEventRecord,
|
||||
TopicRecord,
|
||||
ALLOWED_MEMORY_NAMESPACES, MemoryRecord, MemoryUpsert, SchedulerJobRecord, SchedulerJobState,
|
||||
SchedulerJobStatus, SchedulerJobUpsert, SessionRecord, SkillEventRecord, TopicRecord,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -211,7 +210,6 @@ impl SessionStore {
|
||||
ensure_sessions_schema(&conn)?;
|
||||
ensure_messages_schema(&conn)?;
|
||||
ensure_scheduler_schema(&conn)?;
|
||||
ensure_memory_scope_key_migration(&conn)?;
|
||||
|
||||
Ok(Self {
|
||||
conn: Arc::new(Mutex::new(conn)),
|
||||
@ -1728,14 +1726,6 @@ fn ensure_scheduler_schema(conn: &Connection) -> Result<(), StorageError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ensure_memory_scope_key_migration(conn: &Connection) -> Result<(), StorageError> {
|
||||
conn.execute(
|
||||
"UPDATE memories SET scope_key = 'default' WHERE scope_key != 'default'",
|
||||
[],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn has_column(
|
||||
conn: &Connection,
|
||||
table_name: &str,
|
||||
|
||||
@ -1,8 +1,5 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// 全局统一的记忆 scope_key,所有渠道共享同一份记忆空间
|
||||
pub const GLOBAL_SCOPE_KEY: &str = "default";
|
||||
|
||||
/// 允许的记忆命名空间列表
|
||||
///
|
||||
/// 每个命名空间代表一类记忆内容,用于分类管理和检索。
|
||||
|
||||
@ -187,8 +187,12 @@ fn build_memory_upsert(
|
||||
})
|
||||
}
|
||||
|
||||
fn scope_key_from_context(_context: &ToolContext) -> Result<String, ToolResult> {
|
||||
Ok(crate::storage::GLOBAL_SCOPE_KEY.to_string())
|
||||
fn scope_key_from_context(context: &ToolContext) -> Result<String, ToolResult> {
|
||||
let channel_name = context
|
||||
.channel_name
|
||||
.as_deref()
|
||||
.ok_or_else(|| error_result("memory_manage requires channel_name in tool context"))?;
|
||||
Ok(channel_name.to_string())
|
||||
}
|
||||
|
||||
fn memory_to_json(memory: MemoryRecord) -> serde_json::Value {
|
||||
@ -256,26 +260,22 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_memory_manage_works_with_default_context() {
|
||||
async fn test_memory_manage_requires_context() {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let tool = MemoryManageTool::new(store);
|
||||
|
||||
// scope_key 已全局统一为 "default",不再依赖 channel_name
|
||||
let result = tool
|
||||
.execute_with_context(
|
||||
&ToolContext::default(),
|
||||
json!({
|
||||
"action": "put",
|
||||
"namespace": "user",
|
||||
"key": "language",
|
||||
"content": "Rust"
|
||||
"action": "list"
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Rust"));
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("channel_name"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@ -186,8 +186,12 @@ impl Tool for MemorySearchTool {
|
||||
}
|
||||
}
|
||||
|
||||
fn scope_key_from_context(_context: &ToolContext) -> Result<String, ToolResult> {
|
||||
Ok(crate::storage::GLOBAL_SCOPE_KEY.to_string())
|
||||
fn scope_key_from_context(context: &ToolContext) -> Result<String, ToolResult> {
|
||||
let channel_name = context
|
||||
.channel_name
|
||||
.as_deref()
|
||||
.ok_or_else(|| error_result("memory_search requires channel_name in tool context"))?;
|
||||
Ok(channel_name.to_string())
|
||||
}
|
||||
|
||||
fn memory_to_json(memory: MemoryRecord) -> serde_json::Value {
|
||||
@ -230,7 +234,7 @@ mod tests {
|
||||
store
|
||||
.put_memory(&crate::storage::MemoryUpsert {
|
||||
scope_kind: "user".to_string(),
|
||||
scope_key: crate::storage::GLOBAL_SCOPE_KEY.to_string(),
|
||||
scope_key: TEST_CHANNEL.to_string(),
|
||||
namespace: "user".to_string(),
|
||||
memory_key: "language".to_string(),
|
||||
content: "User prefers Chinese responses".to_string(),
|
||||
@ -246,6 +250,10 @@ mod tests {
|
||||
let tool = MemorySearchTool::new(store);
|
||||
let context = ToolContext {
|
||||
channel_name: Some(TEST_CHANNEL.to_string()),
|
||||
chat_id: Some("chat-1".to_string()),
|
||||
session_id: Some(format!("{}:chat-1", TEST_CHANNEL)),
|
||||
message_id: Some("msg-2".to_string()),
|
||||
message_seq: Some(2),
|
||||
..ToolContext::default()
|
||||
};
|
||||
|
||||
@ -279,18 +287,18 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_memory_search_is_read_only_and_works_with_default_context() {
|
||||
async fn test_memory_search_is_read_only_and_requires_context() {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let tool = MemorySearchTool::new(store);
|
||||
|
||||
assert!(tool.read_only());
|
||||
|
||||
// scope_key 已全局统一为 "default",不再依赖 channel_name
|
||||
let result = tool
|
||||
.execute_with_context(&ToolContext::default(), json!({ "action": "list" }))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("channel_name"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user