From 71a8033d15d6f2f3a8f0601f31e5d5d44c7ed28d Mon Sep 17 00:00:00 2001 From: ooodc <549496103@qq.com> Date: Wed, 22 Apr 2026 14:52:16 +0800 Subject: [PATCH] =?UTF-8?q?feat(memory):=20=E6=B7=BB=E5=8A=A0=20MemorySear?= =?UTF-8?q?chTool=20=E4=BB=A5=E6=94=AF=E6=8C=81=E9=95=BF=E6=9C=9F=E7=94=A8?= =?UTF-8?q?=E6=88=B7=E8=AE=B0=E5=BF=86=E7=9A=84=E6=90=9C=E7=B4=A2=E5=92=8C?= =?UTF-8?q?=E8=AF=BB=E5=8F=96=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/agent/agent_loop.rs | 2 +- src/gateway/session.rs | 3 +- src/tools/memory_manage.rs | 4 +- src/tools/memory_search.rs | 259 +++++++++++++++++++++++++++++++++++++ src/tools/mod.rs | 2 + 5 files changed, 266 insertions(+), 4 deletions(-) create mode 100644 src/tools/memory_search.rs diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index 1c443ea..8399ffb 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -21,7 +21,7 @@ use std::time::Instant; const MAX_TOOL_RESULT_CHARS: usize = 16_000; /// Minimum characters to keep when truncating const TRUNCATION_SUFFIX_LEN: usize = 200; -const MEMORY_AUTOSAVE_SYSTEM_PROMPT: &str = "你可以在处理任务过程中使用 memory_manage 工具维护长期记忆。仅在遇到高价值且未来仍有用的信息时写入记忆:用户长期偏好、稳定事实、用户对你的纠正、持续任务/项目上下文、明确决策。不要保存一次性工具结果、临时列表、敏感凭证或不确定推测。写入时优先使用规范 namespace:preferences、profile、tasks、decisions。若需要写入,优先调用 memory_manage(action='put');同一 namespace/key 可直接覆盖更新。搜索记忆时,优先使用 memory_manage(action='search'),并尽量同时提供中文关键词、英文别名,以及可能的 snake_case memory_key 词,例如 email / 邮件 / email_folder_preference。"; +const MEMORY_AUTOSAVE_SYSTEM_PROMPT: &str = "你可以在处理任务过程中使用长期记忆工具。读取记忆时,优先使用 memory_search:当你需要用户长期偏好、稳定事实、历史决策、持续任务上下文时,先 search;已知 namespace/key 时可用 get;需要浏览最近记忆时可用 list。写入或修改记忆时,再使用 memory_manage。仅在遇到高价值且未来仍有用的信息时写入记忆:用户长期偏好、稳定事实、用户对你的纠正、持续任务/项目上下文、明确决策。不要保存一次性工具结果、临时列表、敏感凭证或不确定推测。写入时优先使用规范 namespace:preferences、profile、tasks、decisions,并优先调用 memory_manage(action='put');同一 namespace/key 可直接覆盖更新。检索时尽量同时提供中文关键词、英文别名,以及可能的 snake_case memory_key 词,例如 email / 邮件 / email_folder_preference。"; const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__"; const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str = "工具已经启动并进入等待用户操作的状态。请先完成外部操作,完成后直接告诉我继续。"; diff --git a/src/gateway/session.rs b/src/gateway/session.rs index bc25218..79a4c4d 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -14,7 +14,7 @@ use crate::skills::SkillRuntime; use crate::storage::{MemoryRecord, SessionRecord, SessionStore, persistent_session_id}; use crate::tools::{ BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, - HttpRequestTool, MemoryManageTool, SkillListTool, SkillManageTool, ToolContext, ToolRegistry, + HttpRequestTool, MemoryManageTool, MemorySearchTool, SkillListTool, SkillManageTool, ToolContext, ToolRegistry, WebFetchTool, }; @@ -376,6 +376,7 @@ fn default_tools(skills: Arc, store: Arc) -> ToolReg registry.register(FileReadTool::new()); registry.register(FileWriteTool::new()); registry.register(FileEditTool::new()); + registry.register(MemorySearchTool::new(store.clone())); registry.register(MemoryManageTool::new(store)); registry.register(SkillListTool::new(skills.clone())); registry.register(SkillManageTool::new(skills)); diff --git a/src/tools/memory_manage.rs b/src/tools/memory_manage.rs index a436061..2e97d3d 100644 --- a/src/tools/memory_manage.rs +++ b/src/tools/memory_manage.rs @@ -23,7 +23,7 @@ impl Tool for MemoryManageTool { } fn description(&self) -> &str { - "Manage user memories stored in SQLite. Supports actions: list, search, get, put, update, delete. Use search first when looking for user preferences, historical facts, prior decisions, or previously stored information. Search matches namespace, memory_key, and content. When searching, prefer bilingual queries that include both Chinese and English aliases, and include likely snake_case key terms when known. Memories are scoped to the current channel and sender, and record the originating session/message when available." + "Create, update, or delete long-term user memories stored in SQLite. Supports actions: list, search, get, put, update, delete. Prefer memory_search when you only need to retrieve memory. Use memory_manage mainly when you need to write or modify memory records. Memories are scoped to the current channel and sender, and record the originating session/message when available." } fn parameters_schema(&self) -> serde_json::Value { @@ -33,7 +33,7 @@ impl Tool for MemoryManageTool { "action": { "type": "string", "enum": ["list", "search", "get", "put", "update", "delete"], - "description": "Management action to perform. Prefer 'search' for keyword lookup across stored memories, 'get' for an exact namespace/key lookup, and 'list' for browsing recent memories." + "description": "Management action to perform. Prefer memory_search for retrieval-only access. Use 'put' to create or overwrite, 'update' to modify an existing record, and 'delete' to remove one." }, "namespace": { "type": "string", diff --git a/src/tools/memory_search.rs b/src/tools/memory_search.rs new file mode 100644 index 0000000..a549c14 --- /dev/null +++ b/src/tools/memory_search.rs @@ -0,0 +1,259 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use serde_json::json; + +use crate::storage::{MemoryRecord, SessionStore}; +use crate::tools::traits::{Tool, ToolContext, ToolResult}; + +pub struct MemorySearchTool { + store: Arc, +} + +impl MemorySearchTool { + pub fn new(store: Arc) -> Self { + Self { store } + } +} + +#[async_trait] +impl Tool for MemorySearchTool { + fn name(&self) -> &str { + "memory_search" + } + + fn description(&self) -> &str { + "Search and read long-term user memories stored in SQLite. Use this tool when you need prior preferences, stable facts, historical decisions, or ongoing task context. This tool is read-only and supports three actions: search for keyword lookup, get for exact namespace/key lookup, and list for browsing recent memories. Prefer this tool over memory_manage when you only need to retrieve memory." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": ["search", "get", "list"], + "description": "Retrieval action. Use 'search' for keyword recall, 'get' for an exact namespace/key read, and 'list' to browse recent memories." + }, + "namespace": { + "type": "string", + "description": "Optional namespace filter, such as profile, preferences, tasks, or decisions. Required for get." + }, + "query": { + "type": "string", + "description": "Keyword query for memory search. Prefer concise bilingual keywords, English aliases, and likely snake_case memory_key terms when known. Required for search." + }, + "key": { + "type": "string", + "description": "Exact memory key within the namespace. Required for get." + }, + "limit": { + "type": "integer", + "description": "Maximum number of memories to return", + "minimum": 1, + "default": 10 + } + }, + "required": ["action"] + }) + } + + async fn execute(&self, _args: serde_json::Value) -> anyhow::Result { + Ok(error_result("memory_search requires tool context")) + } + + async fn execute_with_context( + &self, + context: &ToolContext, + args: serde_json::Value, + ) -> anyhow::Result { + let action = match args.get("action").and_then(|value| value.as_str()) { + Some(action) => action, + None => return Ok(error_result("Missing required parameter: action")), + }; + + let scope_key = match scope_key_from_context(context) { + Ok(scope_key) => scope_key, + Err(result) => return Ok(result), + }; + + let namespace = args.get("namespace").and_then(|value| value.as_str()); + let query = args.get("query").and_then(|value| value.as_str()); + let key = args.get("key").and_then(|value| value.as_str()); + + let payload = match action { + "list" => { + let limit = args + .get("limit") + .and_then(|value| value.as_u64()) + .unwrap_or(10) as usize; + let memories = self.store.list_memories("user", &scope_key, namespace, limit)?; + json!({ + "count": memories.len(), + "memories": memories.into_iter().map(memory_to_json).collect::>() + }) + } + "search" => { + let query = match query { + Some(query) if !query.trim().is_empty() => query, + _ => return Ok(error_result("Missing required parameter: query")), + }; + let limit = args + .get("limit") + .and_then(|value| value.as_u64()) + .unwrap_or(10) as usize; + let memories = self + .store + .search_memories("user", &scope_key, query, namespace, limit)?; + json!({ + "query": query, + "count": memories.len(), + "memories": memories.into_iter().map(memory_to_json).collect::>() + }) + } + "get" => { + let namespace = match namespace { + Some(namespace) => namespace, + None => return Ok(error_result("Missing required parameter: namespace")), + }; + let key = match key { + Some(key) => key, + None => return Ok(error_result("Missing required parameter: key")), + }; + + match self.store.get_memory("user", &scope_key, namespace, key)? { + Some(memory) => memory_to_json(memory), + None => return Ok(error_result(&format!("memory '{}.{}' not found", namespace, key))), + } + } + _ => return Ok(error_result("Unsupported action")), + }; + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&payload)?, + error: None, + }) + } + + fn read_only(&self) -> bool { + true + } +} + +fn scope_key_from_context(context: &ToolContext) -> Result { + let channel_name = context + .channel_name + .as_deref() + .ok_or_else(|| error_result("memory_search requires channel_name in tool context"))?; + let sender_id = context + .sender_id + .as_deref() + .ok_or_else(|| error_result("memory_search requires sender_id in tool context"))?; + Ok(format!("{}:{}", channel_name, sender_id)) +} + +fn memory_to_json(memory: MemoryRecord) -> serde_json::Value { + json!({ + "id": memory.id, + "scope_kind": memory.scope_kind, + "scope_key": memory.scope_key, + "namespace": memory.namespace, + "key": memory.memory_key, + "content": memory.content, + "source_type": memory.source_type, + "source_session_id": memory.source_session_id, + "source_message_id": memory.source_message_id, + "source_message_seq": memory.source_message_seq, + "source_channel_name": memory.source_channel_name, + "source_chat_id": memory.source_chat_id, + "created_at": memory.created_at, + "updated_at": memory.updated_at, + }) +} + +fn error_result(message: &str) -> ToolResult { + ToolResult { + success: false, + output: String::new(), + error: Some(message.to_string()), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_memory_search_search_and_get() { + let store = Arc::new(SessionStore::in_memory().unwrap()); + store + .put_memory(&crate::storage::MemoryUpsert { + scope_kind: "user".to_string(), + scope_key: "feishu:user-1".to_string(), + namespace: "preferences".to_string(), + memory_key: "language".to_string(), + content: "User prefers Chinese responses".to_string(), + source_type: "message".to_string(), + source_session_id: Some("feishu:chat-1".to_string()), + source_message_id: Some("msg-1".to_string()), + source_message_seq: Some(1), + source_channel_name: Some("feishu".to_string()), + source_chat_id: Some("chat-1".to_string()), + }) + .unwrap(); + + let tool = MemorySearchTool::new(store); + let context = ToolContext { + channel_name: Some("feishu".to_string()), + sender_id: Some("user-1".to_string()), + chat_id: Some("chat-1".to_string()), + session_id: Some("feishu:chat-1".to_string()), + message_id: Some("msg-2".to_string()), + message_seq: Some(2), + }; + + let search = tool + .execute_with_context( + &context, + json!({ + "action": "search", + "query": "Chinese language", + "limit": 5 + }), + ) + .await + .unwrap(); + assert!(search.success); + assert!(search.output.contains("language")); + + let get = tool + .execute_with_context( + &context, + json!({ + "action": "get", + "namespace": "preferences", + "key": "language" + }), + ) + .await + .unwrap(); + assert!(get.success); + assert!(get.output.contains("Chinese")); + } + + #[tokio::test] + async fn test_memory_search_is_read_only_and_requires_context() { + let store = Arc::new(SessionStore::in_memory().unwrap()); + let tool = MemorySearchTool::new(store); + + assert!(tool.read_only()); + + let result = tool + .execute_with_context(&ToolContext::default(), json!({ "action": "list" })) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("channel_name")); + } +} \ No newline at end of file diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 624da9d..384f447 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -5,6 +5,7 @@ pub mod file_read; pub mod file_write; pub mod http_request; pub mod memory_manage; +pub mod memory_search; pub mod registry; pub mod schema; pub mod skill_manage; @@ -18,6 +19,7 @@ pub use file_read::FileReadTool; pub use file_write::FileWriteTool; pub use http_request::HttpRequestTool; pub use memory_manage::MemoryManageTool; +pub use memory_search::MemorySearchTool; pub use registry::ToolRegistry; pub use schema::{CleaningStrategy, SchemaCleanr}; pub use skill_manage::{SkillListTool, SkillManageTool};