Compare commits
2 Commits
b5a1635a05
...
abb2d596f4
| Author | SHA1 | Date | |
|---|---|---|---|
| abb2d596f4 | |||
| e36f66e23b |
@ -2,10 +2,10 @@ use std::collections::HashMap;
|
|||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::sync::{
|
use std::sync::{
|
||||||
Arc, Mutex,
|
Arc,
|
||||||
atomic::{AtomicBool, Ordering},
|
atomic::{AtomicBool, Ordering},
|
||||||
};
|
};
|
||||||
use std::time::{Duration, UNIX_EPOCH};
|
use std::time::UNIX_EPOCH;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures_util::FutureExt;
|
use futures_util::FutureExt;
|
||||||
@ -18,44 +18,6 @@ use crate::bus::message::OutboundEventKind;
|
|||||||
use crate::channels::base::{Channel, ChannelError};
|
use crate::channels::base::{Channel, ChannelError};
|
||||||
use crate::config::{LLMProviderConfig, WechatChannelConfig};
|
use crate::config::{LLMProviderConfig, WechatChannelConfig};
|
||||||
|
|
||||||
/// Rate limiter: ensures minimum interval between messages to the same chat.
|
|
||||||
static LAST_SEND: std::sync::LazyLock<Mutex<HashMap<String, tokio::time::Instant>>> =
|
|
||||||
std::sync::LazyLock::new(|| Mutex::new(HashMap::new()));
|
|
||||||
|
|
||||||
/// Minimum interval between consecutive messages to the same WeChat user.
|
|
||||||
const MIN_MSG_INTERVAL_MS: u64 = 1500;
|
|
||||||
|
|
||||||
/// Wait if needed to respect the rate limit, then record the send time.
|
|
||||||
async fn throttle(chat_id: &str) {
|
|
||||||
let now = tokio::time::Instant::now();
|
|
||||||
let sleep_ms = {
|
|
||||||
let mut map = LAST_SEND.lock().unwrap();
|
|
||||||
let delay = map.get(chat_id).and_then(|last| {
|
|
||||||
let elapsed = now.duration_since(*last);
|
|
||||||
if elapsed < Duration::from_millis(MIN_MSG_INTERVAL_MS) {
|
|
||||||
Some(Duration::from_millis(MIN_MSG_INTERVAL_MS) - elapsed)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
});
|
|
||||||
// Record the expected send time now (before sleep), so concurrent
|
|
||||||
// callers also see the updated time and don't race.
|
|
||||||
map.insert(chat_id.to_string(), now);
|
|
||||||
delay
|
|
||||||
};
|
|
||||||
if let Some(ms) = sleep_ms {
|
|
||||||
tokio::time::sleep(ms).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Update rate-limit timestamp (without waiting) — used after each chunk.
|
|
||||||
fn touch_rate_limit(chat_id: &str) {
|
|
||||||
LAST_SEND
|
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.insert(chat_id.to_string(), tokio::time::Instant::now());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct WechatChannel {
|
pub struct WechatChannel {
|
||||||
name: String,
|
name: String,
|
||||||
@ -341,8 +303,17 @@ impl Channel for WechatChannel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
|
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
|
||||||
if matches!(msg.event_kind, OutboundEventKind::ToolResult | OutboundEventKind::ToolPending)
|
// WeChat iLink Bot has a ~10-message burst limit per context_token.
|
||||||
|| msg.metadata.get("is_subagent_event").map(|v| v == "true").unwrap_or(false)
|
// Filter non-essential message types to conserve budget:
|
||||||
|
// - ToolCall: internal tool invocation details, not useful to WeChat users
|
||||||
|
// - ToolResult / ToolPending: raw tool output, not user-facing
|
||||||
|
// - Subagent events: internal agent orchestration
|
||||||
|
if matches!(
|
||||||
|
msg.event_kind,
|
||||||
|
OutboundEventKind::ToolResult
|
||||||
|
| OutboundEventKind::ToolPending
|
||||||
|
| OutboundEventKind::ToolCall
|
||||||
|
) || msg.metadata.get("is_subagent_event").map(|v| v == "true").unwrap_or(false)
|
||||||
{
|
{
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
@ -351,42 +322,15 @@ impl Channel for WechatChannel {
|
|||||||
let mut text_sent = false;
|
let mut text_sent = false;
|
||||||
|
|
||||||
if !text.is_empty() {
|
if !text.is_empty() {
|
||||||
// Rate limit: ensure minimum interval between messages to the same user
|
self.bot.send(&msg.chat_id, &text).await.map_err(|error| {
|
||||||
throttle(&msg.chat_id).await;
|
ChannelError::SendError(format!("WeChat text send failed: {}", error))
|
||||||
|
})?;
|
||||||
let chunks = split_text(&text, MAX_WECHAT_CHUNK_CHARS);
|
tracing::info!(
|
||||||
if chunks.len() > 1 {
|
channel = %self.name,
|
||||||
tracing::info!(
|
chat_id = %msg.chat_id,
|
||||||
channel = %self.name,
|
content_len = text.len(),
|
||||||
chat_id = %msg.chat_id,
|
"WeChat text message sent"
|
||||||
total_chars = text.len(),
|
);
|
||||||
chunk_count = chunks.len(),
|
|
||||||
"WeChat: splitting long message into chunks"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
for (i, chunk) in chunks.iter().enumerate() {
|
|
||||||
if i > 0 {
|
|
||||||
tokio::time::sleep(Duration::from_millis(CHUNK_SEND_INTERVAL_MS)).await;
|
|
||||||
}
|
|
||||||
self.bot.send(&msg.chat_id, chunk).await.map_err(|error| {
|
|
||||||
ChannelError::SendError(format!(
|
|
||||||
"WeChat text send failed (chunk {}/{}): {}",
|
|
||||||
i + 1,
|
|
||||||
chunks.len(),
|
|
||||||
error
|
|
||||||
))
|
|
||||||
})?;
|
|
||||||
// Update rate-limit timestamp so the next message or chunk respects the gap
|
|
||||||
touch_rate_limit(&msg.chat_id);
|
|
||||||
tracing::info!(
|
|
||||||
channel = %self.name,
|
|
||||||
chat_id = %msg.chat_id,
|
|
||||||
chunk = i + 1,
|
|
||||||
total_chunks = chunks.len(),
|
|
||||||
content_len = chunk.len(),
|
|
||||||
"WeChat text message sent"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
text_sent = true;
|
text_sent = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -421,141 +365,6 @@ impl Channel for WechatChannel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Split text into chunks suitable for WeChat delivery.
|
|
||||||
/// - Prefers splitting at paragraph breaks, then newlines, then sentence boundaries.
|
|
||||||
/// - Avoids splitting in the middle of markdown tables and code blocks.
|
|
||||||
const MAX_WECHAT_CHUNK_CHARS: usize = 2000;
|
|
||||||
const CHUNK_SEND_INTERVAL_MS: u64 = 500;
|
|
||||||
|
|
||||||
fn split_text(text: &str, limit: usize) -> Vec<String> {
|
|
||||||
if text.len() <= limit {
|
|
||||||
return vec![text.to_string()];
|
|
||||||
}
|
|
||||||
let mut chunks = Vec::new();
|
|
||||||
let mut remaining = text;
|
|
||||||
while !remaining.is_empty() {
|
|
||||||
if remaining.len() <= limit {
|
|
||||||
chunks.push(remaining.to_string());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
let end = remaining.floor_char_boundary(limit);
|
|
||||||
let window = &remaining[..end];
|
|
||||||
|
|
||||||
// Find a safe split point, avoiding table/code-block interiors
|
|
||||||
let cut = find_split_point(window, limit);
|
|
||||||
chunks.push(remaining[..cut].to_string());
|
|
||||||
remaining = remaining[cut..].trim_start();
|
|
||||||
}
|
|
||||||
if chunks.is_empty() {
|
|
||||||
vec![String::new()]
|
|
||||||
} else {
|
|
||||||
chunks
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Find the best split point in `window`, avoiding markdown table rows and code fences.
|
|
||||||
fn find_split_point(window: &str, _limit: usize) -> usize {
|
|
||||||
let end = window.len();
|
|
||||||
|
|
||||||
// Build a set of line-start indices that are "unsafe" to split before
|
|
||||||
// because they're inside a markdown table or code block.
|
|
||||||
let unsafe_starts = find_unsafe_line_starts(window);
|
|
||||||
|
|
||||||
// Try split points from best to worst, skipping unsafe ones
|
|
||||||
for (delim, len) in &[
|
|
||||||
("\n\n", 2), // paragraph break (best)
|
|
||||||
("\n", 1), // newline
|
|
||||||
("。", 3), // Chinese period
|
|
||||||
("\n", 1), // any newline (retry with relaxed threshold)
|
|
||||||
] {
|
|
||||||
let min_pos = if *delim == "\n" && *len == 1 && end > 0 {
|
|
||||||
// For the relaxed newline pass, accept any position
|
|
||||||
0
|
|
||||||
} else {
|
|
||||||
end * 3 / 10
|
|
||||||
};
|
|
||||||
|
|
||||||
match window.rfind(delim) {
|
|
||||||
Some(pos) if pos >= min_pos => {
|
|
||||||
let after = pos + len;
|
|
||||||
// Check that the line starting at `after` is not inside a protected block
|
|
||||||
if !unsafe_starts.contains(&after) {
|
|
||||||
return after;
|
|
||||||
}
|
|
||||||
// If this split point is inside a protected block, keep looking earlier
|
|
||||||
if let Some(prev) = window[..pos].rfind(delim) {
|
|
||||||
let prev_after = prev + len;
|
|
||||||
if prev_after >= min_pos && !unsafe_starts.contains(&prev_after) {
|
|
||||||
return prev_after;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// If still inside protected block, try earlier .find
|
|
||||||
if let Some(earlier) = window[..pos].rfind("\n\n") {
|
|
||||||
let earlier_after = earlier + 2;
|
|
||||||
if !unsafe_starts.contains(&earlier_after) {
|
|
||||||
return earlier_after;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Last resort: just cut at the character boundary (may break a table, but better than nothing)
|
|
||||||
end
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns byte offsets of line starts that are "unsafe" to split before,
|
|
||||||
/// because they fall inside a markdown table or code block.
|
|
||||||
fn find_unsafe_line_starts(window: &str) -> Vec<usize> {
|
|
||||||
let mut unsafe_starts = Vec::new();
|
|
||||||
let mut in_code_block = false;
|
|
||||||
let mut in_table = false;
|
|
||||||
let mut pos = 0;
|
|
||||||
|
|
||||||
for line in window.split_inclusive('\n') {
|
|
||||||
let trimmed = line.trim();
|
|
||||||
let is_empty = trimmed.is_empty();
|
|
||||||
|
|
||||||
// Track code blocks
|
|
||||||
if trimmed.starts_with("```") {
|
|
||||||
if in_code_block {
|
|
||||||
in_code_block = false;
|
|
||||||
// The closing fence itself is safe after
|
|
||||||
pos += line.len();
|
|
||||||
continue;
|
|
||||||
} else {
|
|
||||||
in_code_block = true;
|
|
||||||
pos += line.len();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if in_code_block {
|
|
||||||
unsafe_starts.push(pos);
|
|
||||||
pos += line.len();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Track markdown table rows
|
|
||||||
let is_table_row = trimmed.starts_with('|') && trimmed.ends_with('|');
|
|
||||||
|
|
||||||
if is_table_row {
|
|
||||||
in_table = true;
|
|
||||||
unsafe_starts.push(pos);
|
|
||||||
} else if in_table && !is_empty {
|
|
||||||
// Non-empty non-table line after table: table ended on previous line
|
|
||||||
in_table = false;
|
|
||||||
} else if is_empty {
|
|
||||||
in_table = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
pos += line.len();
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe_starts
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|||||||
@ -17,8 +17,9 @@ pub use ports::{
|
|||||||
};
|
};
|
||||||
pub use records::{
|
pub use records::{
|
||||||
allowed_namespace_names, get_namespace_description, is_valid_namespace,
|
allowed_namespace_names, get_namespace_description, is_valid_namespace,
|
||||||
ALLOWED_MEMORY_NAMESPACES, MemoryRecord, MemoryUpsert, SchedulerJobRecord, SchedulerJobState,
|
ALLOWED_MEMORY_NAMESPACES, GLOBAL_SCOPE_KEY, MemoryRecord, MemoryUpsert, SchedulerJobRecord,
|
||||||
SchedulerJobStatus, SchedulerJobUpsert, SessionRecord, SkillEventRecord, TopicRecord,
|
SchedulerJobState, SchedulerJobStatus, SchedulerJobUpsert, SessionRecord, SkillEventRecord,
|
||||||
|
TopicRecord,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -210,6 +211,7 @@ impl SessionStore {
|
|||||||
ensure_sessions_schema(&conn)?;
|
ensure_sessions_schema(&conn)?;
|
||||||
ensure_messages_schema(&conn)?;
|
ensure_messages_schema(&conn)?;
|
||||||
ensure_scheduler_schema(&conn)?;
|
ensure_scheduler_schema(&conn)?;
|
||||||
|
ensure_memory_scope_key_migration(&conn)?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
conn: Arc::new(Mutex::new(conn)),
|
conn: Arc::new(Mutex::new(conn)),
|
||||||
@ -1726,6 +1728,14 @@ fn ensure_scheduler_schema(conn: &Connection) -> Result<(), StorageError> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn ensure_memory_scope_key_migration(conn: &Connection) -> Result<(), StorageError> {
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE memories SET scope_key = 'default' WHERE scope_key != 'default'",
|
||||||
|
[],
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn has_column(
|
fn has_column(
|
||||||
conn: &Connection,
|
conn: &Connection,
|
||||||
table_name: &str,
|
table_name: &str,
|
||||||
|
|||||||
@ -1,5 +1,8 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
/// 全局统一的记忆 scope_key,所有渠道共享同一份记忆空间
|
||||||
|
pub const GLOBAL_SCOPE_KEY: &str = "default";
|
||||||
|
|
||||||
/// 允许的记忆命名空间列表
|
/// 允许的记忆命名空间列表
|
||||||
///
|
///
|
||||||
/// 每个命名空间代表一类记忆内容,用于分类管理和检索。
|
/// 每个命名空间代表一类记忆内容,用于分类管理和检索。
|
||||||
|
|||||||
@ -187,12 +187,8 @@ fn build_memory_upsert(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn scope_key_from_context(context: &ToolContext) -> Result<String, ToolResult> {
|
fn scope_key_from_context(_context: &ToolContext) -> Result<String, ToolResult> {
|
||||||
let channel_name = context
|
Ok(crate::storage::GLOBAL_SCOPE_KEY.to_string())
|
||||||
.channel_name
|
|
||||||
.as_deref()
|
|
||||||
.ok_or_else(|| error_result("memory_manage requires channel_name in tool context"))?;
|
|
||||||
Ok(channel_name.to_string())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn memory_to_json(memory: MemoryRecord) -> serde_json::Value {
|
fn memory_to_json(memory: MemoryRecord) -> serde_json::Value {
|
||||||
@ -260,22 +256,26 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_memory_manage_requires_context() {
|
async fn test_memory_manage_works_with_default_context() {
|
||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
let tool = MemoryManageTool::new(store);
|
let tool = MemoryManageTool::new(store);
|
||||||
|
|
||||||
|
// scope_key 已全局统一为 "default",不再依赖 channel_name
|
||||||
let result = tool
|
let result = tool
|
||||||
.execute_with_context(
|
.execute_with_context(
|
||||||
&ToolContext::default(),
|
&ToolContext::default(),
|
||||||
json!({
|
json!({
|
||||||
"action": "list"
|
"action": "put",
|
||||||
|
"namespace": "user",
|
||||||
|
"key": "language",
|
||||||
|
"content": "Rust"
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert!(!result.success);
|
assert!(result.success);
|
||||||
assert!(result.error.unwrap().contains("channel_name"));
|
assert!(result.output.contains("Rust"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|||||||
@ -186,12 +186,8 @@ impl Tool for MemorySearchTool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn scope_key_from_context(context: &ToolContext) -> Result<String, ToolResult> {
|
fn scope_key_from_context(_context: &ToolContext) -> Result<String, ToolResult> {
|
||||||
let channel_name = context
|
Ok(crate::storage::GLOBAL_SCOPE_KEY.to_string())
|
||||||
.channel_name
|
|
||||||
.as_deref()
|
|
||||||
.ok_or_else(|| error_result("memory_search requires channel_name in tool context"))?;
|
|
||||||
Ok(channel_name.to_string())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn memory_to_json(memory: MemoryRecord) -> serde_json::Value {
|
fn memory_to_json(memory: MemoryRecord) -> serde_json::Value {
|
||||||
@ -234,7 +230,7 @@ mod tests {
|
|||||||
store
|
store
|
||||||
.put_memory(&crate::storage::MemoryUpsert {
|
.put_memory(&crate::storage::MemoryUpsert {
|
||||||
scope_kind: "user".to_string(),
|
scope_kind: "user".to_string(),
|
||||||
scope_key: TEST_CHANNEL.to_string(),
|
scope_key: crate::storage::GLOBAL_SCOPE_KEY.to_string(),
|
||||||
namespace: "user".to_string(),
|
namespace: "user".to_string(),
|
||||||
memory_key: "language".to_string(),
|
memory_key: "language".to_string(),
|
||||||
content: "User prefers Chinese responses".to_string(),
|
content: "User prefers Chinese responses".to_string(),
|
||||||
@ -250,10 +246,6 @@ mod tests {
|
|||||||
let tool = MemorySearchTool::new(store);
|
let tool = MemorySearchTool::new(store);
|
||||||
let context = ToolContext {
|
let context = ToolContext {
|
||||||
channel_name: Some(TEST_CHANNEL.to_string()),
|
channel_name: Some(TEST_CHANNEL.to_string()),
|
||||||
chat_id: Some("chat-1".to_string()),
|
|
||||||
session_id: Some(format!("{}:chat-1", TEST_CHANNEL)),
|
|
||||||
message_id: Some("msg-2".to_string()),
|
|
||||||
message_seq: Some(2),
|
|
||||||
..ToolContext::default()
|
..ToolContext::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -287,18 +279,18 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_memory_search_is_read_only_and_requires_context() {
|
async fn test_memory_search_is_read_only_and_works_with_default_context() {
|
||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
let tool = MemorySearchTool::new(store);
|
let tool = MemorySearchTool::new(store);
|
||||||
|
|
||||||
assert!(tool.read_only());
|
assert!(tool.read_only());
|
||||||
|
|
||||||
|
// scope_key 已全局统一为 "default",不再依赖 channel_name
|
||||||
let result = tool
|
let result = tool
|
||||||
.execute_with_context(&ToolContext::default(), json!({ "action": "list" }))
|
.execute_with_context(&ToolContext::default(), json!({ "action": "list" }))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert!(!result.success);
|
assert!(result.success);
|
||||||
assert!(result.error.unwrap().contains("channel_name"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user