diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 956a9f9..1326709 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -17,8 +17,9 @@ pub use ports::{ }; pub use records::{ allowed_namespace_names, get_namespace_description, is_valid_namespace, - ALLOWED_MEMORY_NAMESPACES, MemoryRecord, MemoryUpsert, SchedulerJobRecord, SchedulerJobState, - SchedulerJobStatus, SchedulerJobUpsert, SessionRecord, SkillEventRecord, TopicRecord, + ALLOWED_MEMORY_NAMESPACES, GLOBAL_SCOPE_KEY, MemoryRecord, MemoryUpsert, SchedulerJobRecord, + SchedulerJobState, SchedulerJobStatus, SchedulerJobUpsert, SessionRecord, SkillEventRecord, + TopicRecord, }; #[derive(Clone)] @@ -210,6 +211,7 @@ impl SessionStore { ensure_sessions_schema(&conn)?; ensure_messages_schema(&conn)?; ensure_scheduler_schema(&conn)?; + ensure_memory_scope_key_migration(&conn)?; Ok(Self { conn: Arc::new(Mutex::new(conn)), @@ -1726,6 +1728,14 @@ fn ensure_scheduler_schema(conn: &Connection) -> Result<(), StorageError> { Ok(()) } +fn ensure_memory_scope_key_migration(conn: &Connection) -> Result<(), StorageError> { + conn.execute( + "UPDATE memories SET scope_key = 'default' WHERE scope_key != 'default'", + [], + )?; + Ok(()) +} + fn has_column( conn: &Connection, table_name: &str, diff --git a/src/storage/records.rs b/src/storage/records.rs index 115c991..7bc1c0e 100644 --- a/src/storage/records.rs +++ b/src/storage/records.rs @@ -1,5 +1,8 @@ use serde::{Deserialize, Serialize}; +/// 全局统一的记忆 scope_key,所有渠道共享同一份记忆空间 +pub const GLOBAL_SCOPE_KEY: &str = "default"; + /// 允许的记忆命名空间列表 /// /// 每个命名空间代表一类记忆内容,用于分类管理和检索。 diff --git a/src/tools/memory_manage.rs b/src/tools/memory_manage.rs index 8cae0f1..a66e818 100644 --- a/src/tools/memory_manage.rs +++ b/src/tools/memory_manage.rs @@ -187,12 +187,8 @@ fn build_memory_upsert( }) } -fn scope_key_from_context(context: &ToolContext) -> Result { - let channel_name = context - .channel_name - .as_deref() - .ok_or_else(|| error_result("memory_manage requires channel_name in tool context"))?; - Ok(channel_name.to_string()) +fn scope_key_from_context(_context: &ToolContext) -> Result { + Ok(crate::storage::GLOBAL_SCOPE_KEY.to_string()) } fn memory_to_json(memory: MemoryRecord) -> serde_json::Value { @@ -260,22 +256,26 @@ mod tests { } #[tokio::test] - async fn test_memory_manage_requires_context() { + async fn test_memory_manage_works_with_default_context() { let store = Arc::new(SessionStore::in_memory().unwrap()); let tool = MemoryManageTool::new(store); + // scope_key 已全局统一为 "default",不再依赖 channel_name let result = tool .execute_with_context( &ToolContext::default(), json!({ - "action": "list" + "action": "put", + "namespace": "user", + "key": "language", + "content": "Rust" }), ) .await .unwrap(); - assert!(!result.success); - assert!(result.error.unwrap().contains("channel_name")); + assert!(result.success); + assert!(result.output.contains("Rust")); } #[tokio::test] diff --git a/src/tools/memory_search.rs b/src/tools/memory_search.rs index f0e5863..ff166c0 100644 --- a/src/tools/memory_search.rs +++ b/src/tools/memory_search.rs @@ -186,12 +186,8 @@ impl Tool for MemorySearchTool { } } -fn scope_key_from_context(context: &ToolContext) -> Result { - let channel_name = context - .channel_name - .as_deref() - .ok_or_else(|| error_result("memory_search requires channel_name in tool context"))?; - Ok(channel_name.to_string()) +fn scope_key_from_context(_context: &ToolContext) -> Result { + Ok(crate::storage::GLOBAL_SCOPE_KEY.to_string()) } fn memory_to_json(memory: MemoryRecord) -> serde_json::Value { @@ -234,7 +230,7 @@ mod tests { store .put_memory(&crate::storage::MemoryUpsert { scope_kind: "user".to_string(), - scope_key: TEST_CHANNEL.to_string(), + scope_key: crate::storage::GLOBAL_SCOPE_KEY.to_string(), namespace: "user".to_string(), memory_key: "language".to_string(), content: "User prefers Chinese responses".to_string(), @@ -250,10 +246,6 @@ mod tests { let tool = MemorySearchTool::new(store); let context = ToolContext { channel_name: Some(TEST_CHANNEL.to_string()), - chat_id: Some("chat-1".to_string()), - session_id: Some(format!("{}:chat-1", TEST_CHANNEL)), - message_id: Some("msg-2".to_string()), - message_seq: Some(2), ..ToolContext::default() }; @@ -287,18 +279,18 @@ mod tests { } #[tokio::test] - async fn test_memory_search_is_read_only_and_requires_context() { + async fn test_memory_search_is_read_only_and_works_with_default_context() { let store = Arc::new(SessionStore::in_memory().unwrap()); let tool = MemorySearchTool::new(store); assert!(tool.read_only()); + // scope_key 已全局统一为 "default",不再依赖 channel_name let result = tool .execute_with_context(&ToolContext::default(), json!({ "action": "list" })) .await .unwrap(); - assert!(!result.success); - assert!(result.error.unwrap().contains("channel_name")); + assert!(result.success); } #[tokio::test]