use std::sync::Arc; use async_trait::async_trait; use serde_json::json; use crate::storage::{MemoryRecord, MemoryRepository}; use crate::tools::traits::{Tool, ToolContext, ToolResult}; use crate::tools::extract_u64; pub struct MemorySearchTool { memories: Arc, } impl MemorySearchTool { pub fn new(memories: Arc) -> Self { Self { memories } } } #[async_trait] impl Tool for MemorySearchTool { fn name(&self) -> &str { "memory_search" } fn description(&self) -> &str { "Search and read long-term user memories from the configured memory repository. This is the default entry point for memory retrieval and should usually be the first memory tool you call at the start of a request, unless the request is clearly a simple greeting, a one-off calculation, or a direct fact question that does not depend on user history. Use it to recall prior preferences, stable facts, historical decisions, and ongoing task context. If the request also needs other independent read-only tools, you may call memory_search in the same round alongside them. This tool is read-only and supports three actions: search for multi-keyword recall, get for exact namespace/key lookup, and list for browsing recent memories. Prefer this tool over memory_manage whenever you only need to retrieve memory." } fn parameters_schema(&self) -> serde_json::Value { json!({ "type": "object", "properties": { "action": { "type": "string", "enum": ["search", "get", "list"], "description": "Retrieval action. Use 'search' for multi-keyword recall, 'get' for an exact namespace/key read, and 'list' to browse recent memories." }, "namespace": { "type": "string", "description": "Optional namespace filter, such as profile, preferences, tasks, or decisions. Required for get." }, "queries": { "type": "array", "items": { "type": "string" }, "description": "Keyword queries for memory search. Provide multiple concise bilingual keywords, English aliases, and likely snake_case memory_key terms when known. Search matches any of the provided entries. Required for search.", "minItems": 1 }, "key": { "type": "string", "description": "Exact memory key within the namespace. Required for get." }, "limit": { "type": "integer", "description": "Maximum number of memories to return", "minimum": 1, "default": 10 } }, "required": ["action"] }) } async fn execute(&self, _args: serde_json::Value) -> anyhow::Result { Ok(error_result("memory_search requires tool context")) } async fn execute_with_context( &self, context: &ToolContext, args: serde_json::Value, ) -> anyhow::Result { let action = match args.get("action").and_then(|value| value.as_str()) { Some(action) => action, None => return Ok(error_result("Missing required parameter: action")), }; let scope_key = match scope_key_from_context(context) { Ok(scope_key) => scope_key, Err(result) => return Ok(result), }; let namespace = args.get("namespace").and_then(|value| value.as_str()); let key = args.get("key").and_then(|value| value.as_str()); let payload = match action { "list" => { let limit = extract_u64(&args, "limit").unwrap_or(10) as usize; let memories = self .memories .list_memories("user", &scope_key, namespace, limit)?; json!({ "count": memories.len(), "memories": memories.into_iter().map(memory_to_json).collect::>() }) } "search" => { let queries = match args.get("queries") { Some(value) => { // 支持两种格式:实际数组 或 字符串化的数组 if let Some(arr) = value.as_array() { arr .iter() .filter_map(|v| v.as_str()) .map(str::trim) .filter(|v| !v.is_empty()) .map(ToOwned::to_owned) .collect::>() } else if let Some(s) = value.as_str() { // 尝试解析字符串化的 JSON 数组 match serde_json::from_str::>(s) { Ok(arr) => arr .iter() .filter_map(|v| v.as_str()) .map(str::trim) .filter(|v| !v.is_empty()) .map(ToOwned::to_owned) .collect::>(), Err(_) => { // 如果不是 JSON 数组,尝试按逗号分割 s.split(',') .map(str::trim) .filter(|v| !v.is_empty()) .map(ToOwned::to_owned) .collect::>() } } } else { vec![] } } None => vec![] }; if queries.is_empty() { return Ok(error_result("Missing required parameter: queries")); } let limit = extract_u64(&args, "limit").unwrap_or(10) as usize; let memories = self .memories .search_memories_any("user", &scope_key, &queries, namespace, limit)?; json!({ "queries": queries, "count": memories.len(), "memories": memories.into_iter().map(memory_to_json).collect::>() }) } "get" => { let namespace = match namespace { Some(namespace) => namespace, None => return Ok(error_result("Missing required parameter: namespace")), }; let key = match key { Some(key) => key, None => return Ok(error_result("Missing required parameter: key")), }; match self .memories .get_memory("user", &scope_key, namespace, key)? { Some(memory) => memory_to_json(memory), None => { return Ok(error_result(&format!( "memory '{}.{}' not found", namespace, key ))); } } } _ => return Ok(error_result("Unsupported action")), }; Ok(ToolResult { success: true, output: serde_json::to_string_pretty(&payload)?, error: None, }) } fn read_only(&self) -> bool { true } } fn scope_key_from_context(context: &ToolContext) -> Result { let channel_name = context .channel_name .as_deref() .ok_or_else(|| error_result("memory_search requires channel_name in tool context"))?; Ok(channel_name.to_string()) } fn memory_to_json(memory: MemoryRecord) -> serde_json::Value { json!({ "id": memory.id, "scope_kind": memory.scope_kind, "scope_key": memory.scope_key, "namespace": memory.namespace, "key": memory.memory_key, "content": memory.content, "source_type": memory.source_type, "source_session_id": memory.source_session_id, "source_message_id": memory.source_message_id, "source_message_seq": memory.source_message_seq, "source_channel_name": memory.source_channel_name, "source_chat_id": memory.source_chat_id, "created_at": memory.created_at, "updated_at": memory.updated_at, }) } fn error_result(message: &str) -> ToolResult { ToolResult { success: false, output: String::new(), error: Some(message.to_string()), } } #[cfg(test)] mod tests { use super::*; use crate::storage::SessionStore; const TEST_CHANNEL: &str = "test-channel"; #[tokio::test] async fn test_memory_search_search_and_get() { let store = Arc::new(SessionStore::in_memory().unwrap()); store .put_memory(&crate::storage::MemoryUpsert { scope_kind: "user".to_string(), scope_key: TEST_CHANNEL.to_string(), namespace: "preferences".to_string(), memory_key: "language".to_string(), content: "User prefers Chinese responses".to_string(), source_type: "message".to_string(), source_session_id: Some(format!("{}:chat-1", TEST_CHANNEL)), source_message_id: Some("msg-1".to_string()), source_message_seq: Some(1), source_channel_name: Some(TEST_CHANNEL.to_string()), source_chat_id: Some("chat-1".to_string()), }) .unwrap(); let tool = MemorySearchTool::new(store); let context = ToolContext { channel_name: Some(TEST_CHANNEL.to_string()), chat_id: Some("chat-1".to_string()), session_id: Some(format!("{}:chat-1", TEST_CHANNEL)), message_id: Some("msg-2".to_string()), message_seq: Some(2), ..ToolContext::default() }; let search = tool .execute_with_context( &context, json!({ "action": "search", "queries": ["Chinese", "language"], "limit": 5 }), ) .await .unwrap(); assert!(search.success); assert!(search.output.contains("language")); let get = tool .execute_with_context( &context, json!({ "action": "get", "namespace": "preferences", "key": "language" }), ) .await .unwrap(); assert!(get.success); assert!(get.output.contains("Chinese")); } #[tokio::test] async fn test_memory_search_is_read_only_and_requires_context() { let store = Arc::new(SessionStore::in_memory().unwrap()); let tool = MemorySearchTool::new(store); assert!(tool.read_only()); let result = tool .execute_with_context(&ToolContext::default(), json!({ "action": "list" })) .await .unwrap(); assert!(!result.success); assert!(result.error.unwrap().contains("channel_name")); } #[tokio::test] async fn test_memory_search_search_requires_queries() { let store = Arc::new(SessionStore::in_memory().unwrap()); let tool = MemorySearchTool::new(store); let context = ToolContext { channel_name: Some(TEST_CHANNEL.to_string()), ..ToolContext::default() }; let result = tool .execute_with_context(&context, json!({ "action": "search", "queries": [] })) .await .unwrap(); assert!(!result.success); assert!(result.error.unwrap().contains("queries")); } }