feat: 统一记忆 scope_key 为 "default",简化上下文依赖
This commit is contained in:
parent
e36f66e23b
commit
abb2d596f4
@ -17,8 +17,9 @@ pub use ports::{
|
||||
};
|
||||
pub use records::{
|
||||
allowed_namespace_names, get_namespace_description, is_valid_namespace,
|
||||
ALLOWED_MEMORY_NAMESPACES, MemoryRecord, MemoryUpsert, SchedulerJobRecord, SchedulerJobState,
|
||||
SchedulerJobStatus, SchedulerJobUpsert, SessionRecord, SkillEventRecord, TopicRecord,
|
||||
ALLOWED_MEMORY_NAMESPACES, GLOBAL_SCOPE_KEY, MemoryRecord, MemoryUpsert, SchedulerJobRecord,
|
||||
SchedulerJobState, SchedulerJobStatus, SchedulerJobUpsert, SessionRecord, SkillEventRecord,
|
||||
TopicRecord,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -210,6 +211,7 @@ impl SessionStore {
|
||||
ensure_sessions_schema(&conn)?;
|
||||
ensure_messages_schema(&conn)?;
|
||||
ensure_scheduler_schema(&conn)?;
|
||||
ensure_memory_scope_key_migration(&conn)?;
|
||||
|
||||
Ok(Self {
|
||||
conn: Arc::new(Mutex::new(conn)),
|
||||
@ -1726,6 +1728,14 @@ fn ensure_scheduler_schema(conn: &Connection) -> Result<(), StorageError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ensure_memory_scope_key_migration(conn: &Connection) -> Result<(), StorageError> {
|
||||
conn.execute(
|
||||
"UPDATE memories SET scope_key = 'default' WHERE scope_key != 'default'",
|
||||
[],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn has_column(
|
||||
conn: &Connection,
|
||||
table_name: &str,
|
||||
|
||||
@ -1,5 +1,8 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// 全局统一的记忆 scope_key,所有渠道共享同一份记忆空间
|
||||
pub const GLOBAL_SCOPE_KEY: &str = "default";
|
||||
|
||||
/// 允许的记忆命名空间列表
|
||||
///
|
||||
/// 每个命名空间代表一类记忆内容,用于分类管理和检索。
|
||||
|
||||
@ -187,12 +187,8 @@ fn build_memory_upsert(
|
||||
})
|
||||
}
|
||||
|
||||
fn scope_key_from_context(context: &ToolContext) -> Result<String, ToolResult> {
|
||||
let channel_name = context
|
||||
.channel_name
|
||||
.as_deref()
|
||||
.ok_or_else(|| error_result("memory_manage requires channel_name in tool context"))?;
|
||||
Ok(channel_name.to_string())
|
||||
fn scope_key_from_context(_context: &ToolContext) -> Result<String, ToolResult> {
|
||||
Ok(crate::storage::GLOBAL_SCOPE_KEY.to_string())
|
||||
}
|
||||
|
||||
fn memory_to_json(memory: MemoryRecord) -> serde_json::Value {
|
||||
@ -260,22 +256,26 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_memory_manage_requires_context() {
|
||||
async fn test_memory_manage_works_with_default_context() {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let tool = MemoryManageTool::new(store);
|
||||
|
||||
// scope_key 已全局统一为 "default",不再依赖 channel_name
|
||||
let result = tool
|
||||
.execute_with_context(
|
||||
&ToolContext::default(),
|
||||
json!({
|
||||
"action": "list"
|
||||
"action": "put",
|
||||
"namespace": "user",
|
||||
"key": "language",
|
||||
"content": "Rust"
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("channel_name"));
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Rust"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@ -186,12 +186,8 @@ impl Tool for MemorySearchTool {
|
||||
}
|
||||
}
|
||||
|
||||
fn scope_key_from_context(context: &ToolContext) -> Result<String, ToolResult> {
|
||||
let channel_name = context
|
||||
.channel_name
|
||||
.as_deref()
|
||||
.ok_or_else(|| error_result("memory_search requires channel_name in tool context"))?;
|
||||
Ok(channel_name.to_string())
|
||||
fn scope_key_from_context(_context: &ToolContext) -> Result<String, ToolResult> {
|
||||
Ok(crate::storage::GLOBAL_SCOPE_KEY.to_string())
|
||||
}
|
||||
|
||||
fn memory_to_json(memory: MemoryRecord) -> serde_json::Value {
|
||||
@ -234,7 +230,7 @@ mod tests {
|
||||
store
|
||||
.put_memory(&crate::storage::MemoryUpsert {
|
||||
scope_kind: "user".to_string(),
|
||||
scope_key: TEST_CHANNEL.to_string(),
|
||||
scope_key: crate::storage::GLOBAL_SCOPE_KEY.to_string(),
|
||||
namespace: "user".to_string(),
|
||||
memory_key: "language".to_string(),
|
||||
content: "User prefers Chinese responses".to_string(),
|
||||
@ -250,10 +246,6 @@ mod tests {
|
||||
let tool = MemorySearchTool::new(store);
|
||||
let context = ToolContext {
|
||||
channel_name: Some(TEST_CHANNEL.to_string()),
|
||||
chat_id: Some("chat-1".to_string()),
|
||||
session_id: Some(format!("{}:chat-1", TEST_CHANNEL)),
|
||||
message_id: Some("msg-2".to_string()),
|
||||
message_seq: Some(2),
|
||||
..ToolContext::default()
|
||||
};
|
||||
|
||||
@ -287,18 +279,18 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_memory_search_is_read_only_and_requires_context() {
|
||||
async fn test_memory_search_is_read_only_and_works_with_default_context() {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let tool = MemorySearchTool::new(store);
|
||||
|
||||
assert!(tool.read_only());
|
||||
|
||||
// scope_key 已全局统一为 "default",不再依赖 channel_name
|
||||
let result = tool
|
||||
.execute_with_context(&ToolContext::default(), json!({ "action": "list" }))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("channel_name"));
|
||||
assert!(result.success);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user