feat(gateway): 添加 show_tool_results 配置以控制工具结果显示
feat(memory): 更新 MemoryManageTool 和 MemorySearchTool 描述,优化参数处理
This commit is contained in:
parent
0331774466
commit
65abf017a1
@ -21,7 +21,7 @@ use std::time::Instant;
|
|||||||
const MAX_TOOL_RESULT_CHARS: usize = 16_000;
|
const MAX_TOOL_RESULT_CHARS: usize = 16_000;
|
||||||
/// Minimum characters to keep when truncating
|
/// Minimum characters to keep when truncating
|
||||||
const TRUNCATION_SUFFIX_LEN: usize = 200;
|
const TRUNCATION_SUFFIX_LEN: usize = 200;
|
||||||
const MEMORY_TOOL_USAGE_SYSTEM_PROMPT: &str = "你可以在处理任务过程中使用长期记忆工具。读取记忆时,优先使用 memory_search:当你需要用户长期偏好、稳定事实、历史决策、持续任务上下文时,先 search;已知 namespace/key 时可用 get;需要浏览最近记忆时可用 list。写入或修改记忆时,再使用 memory_manage。仅在遇到高价值且未来仍有用的信息时写入记忆:用户长期偏好、稳定事实、用户对你的纠正、持续任务/项目上下文、明确决策。不要保存一次性工具结果、临时列表、敏感凭证或不确定推测。写入时优先使用规范 namespace:preferences、profile、tasks、decisions,并优先调用 memory_manage(action='put');同一 namespace/key 可直接覆盖更新。检索时尽量同时提供中文关键词、英文别名,以及可能的 snake_case memory_key 词,例如 email / 邮件 / email_folder_preference。";
|
const MEMORY_TOOL_USAGE_SYSTEM_PROMPT: &str = "你可以在处理任务过程中使用长期记忆工具。读取记忆时,优先使用 memory_search:当你需要用户长期偏好、稳定事实、历史决策、持续任务上下文时,先 search;已知 namespace/key 时可用 get;需要浏览最近记忆时可用 list。写入或修改记忆时,再使用 memory_manage。仅在遇到高价值且未来仍有用的信息时写入记忆:用户长期偏好、稳定事实、用户对你的纠正、持续任务/项目上下文、明确决策。不要保存一次性工具结果、临时列表、敏感凭证或不确定推测。写入时优先使用规范 namespace:preferences、profile、tasks、decisions,并优先调用 memory_manage(action='put');同一 namespace/key 可直接覆盖更新。检索时应提供 queries 数组,尽量同时放入中文关键词、英文别名,以及可能的 snake_case memory_key 词,例如 queries=['email', '邮件', 'email_folder_preference']。";
|
||||||
const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
|
const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
|
||||||
const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str = "工具已经启动并进入等待用户操作的状态。请先完成外部操作,完成后直接告诉我继续。";
|
const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str = "工具已经启动并进入等待用户操作的状态。请先完成外部操作,完成后直接告诉我继续。";
|
||||||
|
|
||||||
|
|||||||
@ -134,6 +134,8 @@ pub struct GatewayConfig {
|
|||||||
pub host: String,
|
pub host: String,
|
||||||
#[serde(default = "default_gateway_port")]
|
#[serde(default = "default_gateway_port")]
|
||||||
pub port: u16,
|
pub port: u16,
|
||||||
|
#[serde(default)]
|
||||||
|
pub show_tool_results: bool,
|
||||||
#[serde(default, rename = "session_ttl_hours")]
|
#[serde(default, rename = "session_ttl_hours")]
|
||||||
pub session_ttl_hours: Option<u64>,
|
pub session_ttl_hours: Option<u64>,
|
||||||
#[serde(default = "default_agent_prompt_reinject_every", rename = "agent_prompt_reinject_every")]
|
#[serde(default = "default_agent_prompt_reinject_every", rename = "agent_prompt_reinject_every")]
|
||||||
@ -167,6 +169,7 @@ impl Default for GatewayConfig {
|
|||||||
Self {
|
Self {
|
||||||
host: default_gateway_host(),
|
host: default_gateway_host(),
|
||||||
port: default_gateway_port(),
|
port: default_gateway_port(),
|
||||||
|
show_tool_results: false,
|
||||||
session_ttl_hours: None,
|
session_ttl_hours: None,
|
||||||
agent_prompt_reinject_every: default_agent_prompt_reinject_every(),
|
agent_prompt_reinject_every: default_agent_prompt_reinject_every(),
|
||||||
}
|
}
|
||||||
@ -395,6 +398,7 @@ mod tests {
|
|||||||
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||||
assert_eq!(config.gateway.host, "0.0.0.0");
|
assert_eq!(config.gateway.host, "0.0.0.0");
|
||||||
assert_eq!(config.gateway.port, 19876);
|
assert_eq!(config.gateway.port, 19876);
|
||||||
|
assert!(!config.gateway.show_tool_results);
|
||||||
assert_eq!(config.gateway.agent_prompt_reinject_every, 120);
|
assert_eq!(config.gateway.agent_prompt_reinject_every, 120);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -428,6 +432,43 @@ mod tests {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||||
|
assert!(!config.gateway.show_tool_results);
|
||||||
assert_eq!(config.gateway.agent_prompt_reinject_every, 100);
|
assert_eq!(config.gateway.agent_prompt_reinject_every, 100);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_gateway_config_can_enable_tool_results() {
|
||||||
|
let file = tempfile::NamedTempFile::new().unwrap();
|
||||||
|
std::fs::write(
|
||||||
|
file.path(),
|
||||||
|
r#"{
|
||||||
|
"providers": {
|
||||||
|
"aliyun": {
|
||||||
|
"type": "openai",
|
||||||
|
"base_url": "https://example.invalid/v1",
|
||||||
|
"api_key": "test-key",
|
||||||
|
"extra_headers": {}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"models": {
|
||||||
|
"qwen-plus": {
|
||||||
|
"model_id": "qwen-plus"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"agents": {
|
||||||
|
"default": {
|
||||||
|
"provider": "aliyun",
|
||||||
|
"model": "qwen-plus"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"gateway": {
|
||||||
|
"show_tool_results": true
|
||||||
|
}
|
||||||
|
}"#,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||||
|
assert!(config.gateway.show_tool_results);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -30,12 +30,14 @@ impl GatewayState {
|
|||||||
// Session TTL from config (default 4 hours)
|
// Session TTL from config (default 4 hours)
|
||||||
let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4);
|
let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4);
|
||||||
let agent_prompt_reinject_every = config.gateway.agent_prompt_reinject_every;
|
let agent_prompt_reinject_every = config.gateway.agent_prompt_reinject_every;
|
||||||
|
let show_tool_results = config.gateway.show_tool_results;
|
||||||
|
|
||||||
let skills = Arc::new(SkillRuntime::from_config(config.skills.clone()));
|
let skills = Arc::new(SkillRuntime::from_config(config.skills.clone()));
|
||||||
|
|
||||||
let session_manager = SessionManager::new(
|
let session_manager = SessionManager::new(
|
||||||
session_ttl_hours,
|
session_ttl_hours,
|
||||||
agent_prompt_reinject_every,
|
agent_prompt_reinject_every,
|
||||||
|
show_tool_results,
|
||||||
provider_config,
|
provider_config,
|
||||||
skills,
|
skills,
|
||||||
)?;
|
)?;
|
||||||
|
|||||||
@ -357,6 +357,7 @@ pub struct SessionManager {
|
|||||||
skills: Arc<SkillRuntime>,
|
skills: Arc<SkillRuntime>,
|
||||||
store: Arc<SessionStore>,
|
store: Arc<SessionStore>,
|
||||||
agent_prompt_reinject_every: u64,
|
agent_prompt_reinject_every: u64,
|
||||||
|
show_tool_results: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct SessionManagerInner {
|
struct SessionManagerInner {
|
||||||
@ -416,6 +417,7 @@ impl SessionManager {
|
|||||||
pub fn new(
|
pub fn new(
|
||||||
session_ttl_hours: u64,
|
session_ttl_hours: u64,
|
||||||
agent_prompt_reinject_every: u64,
|
agent_prompt_reinject_every: u64,
|
||||||
|
show_tool_results: bool,
|
||||||
provider_config: LLMProviderConfig,
|
provider_config: LLMProviderConfig,
|
||||||
skills: Arc<SkillRuntime>,
|
skills: Arc<SkillRuntime>,
|
||||||
) -> Result<Self, AgentError> {
|
) -> Result<Self, AgentError> {
|
||||||
@ -439,6 +441,7 @@ impl SessionManager {
|
|||||||
skills,
|
skills,
|
||||||
store,
|
store,
|
||||||
agent_prompt_reinject_every,
|
agent_prompt_reinject_every,
|
||||||
|
show_tool_results,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -644,7 +647,10 @@ impl SessionManager {
|
|||||||
result
|
result
|
||||||
.emitted_messages
|
.emitted_messages
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|message| !message.is_assistant_tool_call_message() || live_emitter.is_none())
|
.filter(|message| {
|
||||||
|
(!message.is_assistant_tool_call_message() || live_emitter.is_none())
|
||||||
|
&& should_display_message_to_user(self.show_tool_results, message)
|
||||||
|
})
|
||||||
.flat_map(|message| {
|
.flat_map(|message| {
|
||||||
OutboundMessage::from_chat_message(
|
OutboundMessage::from_chat_message(
|
||||||
channel_name,
|
channel_name,
|
||||||
@ -678,6 +684,18 @@ impl SessionManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn should_display_message_to_user(show_tool_results: bool, message: &ChatMessage) -> bool {
|
||||||
|
if message.role != "tool" {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
show_tool_results
|
||||||
|
|| matches!(
|
||||||
|
message.tool_state.as_ref().unwrap_or(&crate::bus::message::ToolMessageState::Completed),
|
||||||
|
crate::bus::message::ToolMessageState::PendingUserAction
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@ -700,6 +718,21 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_should_display_message_to_user_hides_completed_tool_results_by_default() {
|
||||||
|
let completed = ChatMessage::tool("call-1", "calculator", "2");
|
||||||
|
let pending = ChatMessage::tool_with_state(
|
||||||
|
"call-2",
|
||||||
|
"bash",
|
||||||
|
"waiting",
|
||||||
|
crate::bus::message::ToolMessageState::PendingUserAction,
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(!should_display_message_to_user(false, &completed));
|
||||||
|
assert!(should_display_message_to_user(false, &pending));
|
||||||
|
assert!(should_display_message_to_user(true, &completed));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_parse_in_chat_command_aliases() {
|
fn test_parse_in_chat_command_aliases() {
|
||||||
assert_eq!(parse_in_chat_command("/new"), Some(InChatCommand::FreshConversation));
|
assert_eq!(parse_in_chat_command("/new"), Some(InChatCommand::FreshConversation));
|
||||||
|
|||||||
@ -202,6 +202,18 @@ fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn should_display_message_to_user(show_tool_results: bool, message: &ChatMessage) -> bool {
|
||||||
|
if message.role != "tool" {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
show_tool_results
|
||||||
|
|| matches!(
|
||||||
|
message.tool_state.as_ref().unwrap_or(&ToolMessageState::Completed),
|
||||||
|
ToolMessageState::PendingUserAction
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
async fn handle_inbound(
|
async fn handle_inbound(
|
||||||
state: &Arc<GatewayState>,
|
state: &Arc<GatewayState>,
|
||||||
session: &Arc<Mutex<Session>>,
|
session: &Arc<Mutex<Session>>,
|
||||||
@ -260,7 +272,10 @@ async fn handle_inbound(
|
|||||||
for outbound in result
|
for outbound in result
|
||||||
.emitted_messages
|
.emitted_messages
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|message| !message.is_assistant_tool_call_message())
|
.filter(|message| {
|
||||||
|
!message.is_assistant_tool_call_message()
|
||||||
|
&& should_display_message_to_user(state.config.gateway.show_tool_results, message)
|
||||||
|
})
|
||||||
.flat_map(ws_outbound_from_chat_message)
|
.flat_map(ws_outbound_from_chat_message)
|
||||||
{
|
{
|
||||||
let _ = session_guard.send(outbound).await;
|
let _ = session_guard.send(outbound).await;
|
||||||
@ -405,7 +420,7 @@ async fn handle_inbound(
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::ws_outbound_from_chat_message;
|
use super::{should_display_message_to_user, ws_outbound_from_chat_message};
|
||||||
use crate::bus::ChatMessage;
|
use crate::bus::ChatMessage;
|
||||||
use crate::bus::message::ToolMessageState;
|
use crate::bus::message::ToolMessageState;
|
||||||
use crate::providers::ToolCall;
|
use crate::providers::ToolCall;
|
||||||
@ -461,4 +476,19 @@ mod tests {
|
|||||||
assert_eq!(outbound.len(), 1);
|
assert_eq!(outbound.len(), 1);
|
||||||
assert!(matches!(outbound[0], WsOutbound::ToolPending { .. }));
|
assert!(matches!(outbound[0], WsOutbound::ToolPending { .. }));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_should_display_message_to_user_hides_completed_tool_results_by_default() {
|
||||||
|
let completed = ChatMessage::tool("call-1", "calculator", "2");
|
||||||
|
let pending = ChatMessage::tool_with_state(
|
||||||
|
"call-2",
|
||||||
|
"bash",
|
||||||
|
"waiting",
|
||||||
|
ToolMessageState::PendingUserAction,
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(!should_display_message_to_user(false, &completed));
|
||||||
|
assert!(should_display_message_to_user(false, &pending));
|
||||||
|
assert!(should_display_message_to_user(true, &completed));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -23,7 +23,7 @@ impl Tool for MemoryManageTool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn description(&self) -> &str {
|
fn description(&self) -> &str {
|
||||||
"Create, update, or delete long-term user memories stored in SQLite. Supports actions: list, search, get, put, update, delete. Prefer memory_search when you only need to retrieve memory. Use memory_manage mainly when you need to write or modify memory records. Memories are scoped to the current channel and sender, and record the originating session/message when available."
|
"Create, update, or delete long-term user memories stored in SQLite. Supports actions: put, update, delete. Use memory_search for all retrieval, including search, get, and list. Memories are scoped to the current channel and sender, and record the originating session/message when available."
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parameters_schema(&self) -> serde_json::Value {
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
@ -32,17 +32,13 @@ impl Tool for MemoryManageTool {
|
|||||||
"properties": {
|
"properties": {
|
||||||
"action": {
|
"action": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": ["list", "search", "get", "put", "update", "delete"],
|
"enum": ["put", "update", "delete"],
|
||||||
"description": "Management action to perform. Prefer memory_search for retrieval-only access. Use 'put' to create or overwrite, 'update' to modify an existing record, and 'delete' to remove one."
|
"description": "Management action to perform. Use 'put' to create or overwrite, 'update' to modify an existing record, and 'delete' to remove one. Use memory_search for retrieval."
|
||||||
},
|
},
|
||||||
"namespace": {
|
"namespace": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Optional memory namespace filter, such as profile, preferences, or tasks"
|
"description": "Optional memory namespace filter, such as profile, preferences, or tasks"
|
||||||
},
|
},
|
||||||
"query": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Keyword query for full-text memory search across namespace, memory_key, and content. Prefer concise bilingual keywords when possible, for example Chinese plus English aliases and likely snake_case key terms."
|
|
||||||
},
|
|
||||||
"key": {
|
"key": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Exact memory key within the namespace"
|
"description": "Exact memory key within the namespace"
|
||||||
@ -50,12 +46,6 @@ impl Tool for MemoryManageTool {
|
|||||||
"content": {
|
"content": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Memory content for put/update"
|
"description": "Memory content for put/update"
|
||||||
},
|
|
||||||
"limit": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "Maximum number of memories to return",
|
|
||||||
"minimum": 1,
|
|
||||||
"default": 20
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": ["action"]
|
"required": ["action"]
|
||||||
@ -82,56 +72,9 @@ impl Tool for MemoryManageTool {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let namespace = args.get("namespace").and_then(|value| value.as_str());
|
let namespace = args.get("namespace").and_then(|value| value.as_str());
|
||||||
let query = args.get("query").and_then(|value| value.as_str());
|
|
||||||
let key = args.get("key").and_then(|value| value.as_str());
|
let key = args.get("key").and_then(|value| value.as_str());
|
||||||
|
|
||||||
let payload = match action {
|
let payload = match action {
|
||||||
"list" => {
|
|
||||||
let limit = args
|
|
||||||
.get("limit")
|
|
||||||
.and_then(|value| value.as_u64())
|
|
||||||
.unwrap_or(20) as usize;
|
|
||||||
let memories = self
|
|
||||||
.store
|
|
||||||
.list_memories("user", &scope_key, namespace, limit)?;
|
|
||||||
json!({
|
|
||||||
"count": memories.len(),
|
|
||||||
"memories": memories.into_iter().map(memory_to_json).collect::<Vec<_>>()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
"search" => {
|
|
||||||
let query = match query {
|
|
||||||
Some(query) if !query.trim().is_empty() => query,
|
|
||||||
_ => return Ok(error_result("Missing required parameter: query")),
|
|
||||||
};
|
|
||||||
let limit = args
|
|
||||||
.get("limit")
|
|
||||||
.and_then(|value| value.as_u64())
|
|
||||||
.unwrap_or(20) as usize;
|
|
||||||
let memories = self
|
|
||||||
.store
|
|
||||||
.search_memories("user", &scope_key, query, namespace, limit)?;
|
|
||||||
json!({
|
|
||||||
"query": query,
|
|
||||||
"count": memories.len(),
|
|
||||||
"memories": memories.into_iter().map(memory_to_json).collect::<Vec<_>>()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
"get" => {
|
|
||||||
let namespace = match namespace {
|
|
||||||
Some(namespace) => namespace,
|
|
||||||
None => return Ok(error_result("Missing required parameter: namespace")),
|
|
||||||
};
|
|
||||||
let key = match key {
|
|
||||||
Some(key) => key,
|
|
||||||
None => return Ok(error_result("Missing required parameter: key")),
|
|
||||||
};
|
|
||||||
|
|
||||||
match self.store.get_memory("user", &scope_key, namespace, key)? {
|
|
||||||
Some(memory) => memory_to_json(memory),
|
|
||||||
None => return Ok(error_result(&format!("memory '{}.{}' not found", namespace, key))),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"put" => {
|
"put" => {
|
||||||
let input = match build_memory_upsert(context, &scope_key, &args, true) {
|
let input = match build_memory_upsert(context, &scope_key, &args, true) {
|
||||||
Ok(input) => input,
|
Ok(input) => input,
|
||||||
@ -273,7 +216,7 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_memory_manage_put_and_get() {
|
async fn test_memory_manage_put_returns_saved_memory() {
|
||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
let tool = MemoryManageTool::new(store);
|
let tool = MemoryManageTool::new(store);
|
||||||
let context = ToolContext {
|
let context = ToolContext {
|
||||||
@ -298,64 +241,8 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert!(put.success);
|
assert!(put.success);
|
||||||
|
assert!(put.output.contains("Rust"));
|
||||||
let get = tool
|
assert!(put.output.contains("msg-1"));
|
||||||
.execute_with_context(
|
|
||||||
&context,
|
|
||||||
json!({
|
|
||||||
"action": "get",
|
|
||||||
"namespace": "profile",
|
|
||||||
"key": "language"
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(get.success);
|
|
||||||
assert!(get.output.contains("Rust"));
|
|
||||||
assert!(get.output.contains("msg-1"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_memory_manage_search() {
|
|
||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
|
||||||
let tool = MemoryManageTool::new(store);
|
|
||||||
let context = ToolContext {
|
|
||||||
channel_name: Some("feishu".to_string()),
|
|
||||||
sender_id: Some("user-1".to_string()),
|
|
||||||
chat_id: Some("chat-1".to_string()),
|
|
||||||
session_id: Some("feishu:chat-1".to_string()),
|
|
||||||
message_id: Some("msg-1".to_string()),
|
|
||||||
message_seq: Some(1),
|
|
||||||
};
|
|
||||||
|
|
||||||
let put = tool
|
|
||||||
.execute_with_context(
|
|
||||||
&context,
|
|
||||||
json!({
|
|
||||||
"action": "put",
|
|
||||||
"namespace": "profile",
|
|
||||||
"key": "editor",
|
|
||||||
"content": "Prefers rust-analyzer over clippy hints"
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(put.success);
|
|
||||||
|
|
||||||
let search = tool
|
|
||||||
.execute_with_context(
|
|
||||||
&context,
|
|
||||||
json!({
|
|
||||||
"action": "search",
|
|
||||||
"query": "rust-analyzer",
|
|
||||||
"limit": 5
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(search.success);
|
|
||||||
assert!(search.output.contains("rust-analyzer"));
|
|
||||||
assert!(search.output.contains("editor"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@ -376,4 +263,30 @@ mod tests {
|
|||||||
assert!(!result.success);
|
assert!(!result.success);
|
||||||
assert!(result.error.unwrap().contains("channel_name"));
|
assert!(result.error.unwrap().contains("channel_name"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_memory_manage_rejects_read_actions() {
|
||||||
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
|
let tool = MemoryManageTool::new(store);
|
||||||
|
let context = ToolContext {
|
||||||
|
channel_name: Some("feishu".to_string()),
|
||||||
|
sender_id: Some("user-1".to_string()),
|
||||||
|
..ToolContext::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = tool
|
||||||
|
.execute_with_context(
|
||||||
|
&context,
|
||||||
|
json!({
|
||||||
|
"action": "get",
|
||||||
|
"namespace": "profile",
|
||||||
|
"key": "language"
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(!result.success);
|
||||||
|
assert!(result.error.unwrap().contains("Unsupported action"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@ -23,7 +23,7 @@ impl Tool for MemorySearchTool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn description(&self) -> &str {
|
fn description(&self) -> &str {
|
||||||
"Search and read long-term user memories stored in SQLite. Use this tool when you need prior preferences, stable facts, historical decisions, or ongoing task context. This tool is read-only and supports three actions: search for keyword lookup, get for exact namespace/key lookup, and list for browsing recent memories. Prefer this tool over memory_manage when you only need to retrieve memory."
|
"Search and read long-term user memories stored in SQLite. Use this tool when you need prior preferences, stable facts, historical decisions, or ongoing task context. This tool is read-only and supports three actions: search for multi-keyword recall, get for exact namespace/key lookup, and list for browsing recent memories. Prefer this tool over memory_manage when you only need to retrieve memory."
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parameters_schema(&self) -> serde_json::Value {
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
@ -33,15 +33,19 @@ impl Tool for MemorySearchTool {
|
|||||||
"action": {
|
"action": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": ["search", "get", "list"],
|
"enum": ["search", "get", "list"],
|
||||||
"description": "Retrieval action. Use 'search' for keyword recall, 'get' for an exact namespace/key read, and 'list' to browse recent memories."
|
"description": "Retrieval action. Use 'search' for multi-keyword recall, 'get' for an exact namespace/key read, and 'list' to browse recent memories."
|
||||||
},
|
},
|
||||||
"namespace": {
|
"namespace": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Optional namespace filter, such as profile, preferences, tasks, or decisions. Required for get."
|
"description": "Optional namespace filter, such as profile, preferences, tasks, or decisions. Required for get."
|
||||||
},
|
},
|
||||||
"query": {
|
"queries": {
|
||||||
"type": "string",
|
"type": "array",
|
||||||
"description": "Keyword query for memory search. Prefer concise bilingual keywords, English aliases, and likely snake_case memory_key terms when known. Required for search."
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"description": "Keyword queries for memory search. Provide multiple concise bilingual keywords, English aliases, and likely snake_case memory_key terms when known. Search matches any of the provided entries. Required for search.",
|
||||||
|
"minItems": 1
|
||||||
},
|
},
|
||||||
"key": {
|
"key": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@ -78,7 +82,6 @@ impl Tool for MemorySearchTool {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let namespace = args.get("namespace").and_then(|value| value.as_str());
|
let namespace = args.get("namespace").and_then(|value| value.as_str());
|
||||||
let query = args.get("query").and_then(|value| value.as_str());
|
|
||||||
let key = args.get("key").and_then(|value| value.as_str());
|
let key = args.get("key").and_then(|value| value.as_str());
|
||||||
|
|
||||||
let payload = match action {
|
let payload = match action {
|
||||||
@ -94,19 +97,28 @@ impl Tool for MemorySearchTool {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
"search" => {
|
"search" => {
|
||||||
let query = match query {
|
let queries = match args.get("queries").and_then(|value| value.as_array()) {
|
||||||
Some(query) if !query.trim().is_empty() => query,
|
Some(queries) => queries
|
||||||
_ => return Ok(error_result("Missing required parameter: query")),
|
.iter()
|
||||||
|
.filter_map(|value| value.as_str())
|
||||||
|
.map(str::trim)
|
||||||
|
.filter(|value| !value.is_empty())
|
||||||
|
.map(ToOwned::to_owned)
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
None => return Ok(error_result("Missing required parameter: queries")),
|
||||||
};
|
};
|
||||||
|
if queries.is_empty() {
|
||||||
|
return Ok(error_result("Missing required parameter: queries"));
|
||||||
|
}
|
||||||
let limit = args
|
let limit = args
|
||||||
.get("limit")
|
.get("limit")
|
||||||
.and_then(|value| value.as_u64())
|
.and_then(|value| value.as_u64())
|
||||||
.unwrap_or(10) as usize;
|
.unwrap_or(10) as usize;
|
||||||
let memories = self
|
let memories = self
|
||||||
.store
|
.store
|
||||||
.search_memories("user", &scope_key, query, namespace, limit)?;
|
.search_memories_any("user", &scope_key, &queries, namespace, limit)?;
|
||||||
json!({
|
json!({
|
||||||
"query": query,
|
"queries": queries,
|
||||||
"count": memories.len(),
|
"count": memories.len(),
|
||||||
"memories": memories.into_iter().map(memory_to_json).collect::<Vec<_>>()
|
"memories": memories.into_iter().map(memory_to_json).collect::<Vec<_>>()
|
||||||
})
|
})
|
||||||
@ -218,7 +230,7 @@ mod tests {
|
|||||||
&context,
|
&context,
|
||||||
json!({
|
json!({
|
||||||
"action": "search",
|
"action": "search",
|
||||||
"query": "Chinese language",
|
"queries": ["Chinese", "language"],
|
||||||
"limit": 5
|
"limit": 5
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
@ -256,4 +268,22 @@ mod tests {
|
|||||||
assert!(!result.success);
|
assert!(!result.success);
|
||||||
assert!(result.error.unwrap().contains("channel_name"));
|
assert!(result.error.unwrap().contains("channel_name"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_memory_search_search_requires_queries() {
|
||||||
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
|
let tool = MemorySearchTool::new(store);
|
||||||
|
let context = ToolContext {
|
||||||
|
channel_name: Some("feishu".to_string()),
|
||||||
|
sender_id: Some("user-1".to_string()),
|
||||||
|
..ToolContext::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = tool
|
||||||
|
.execute_with_context(&context, json!({ "action": "search", "queries": [] }))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(!result.success);
|
||||||
|
assert!(result.error.unwrap().contains("queries"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user