diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index f3241eb..c5584d2 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -21,7 +21,7 @@ use std::time::Instant; const MAX_TOOL_RESULT_CHARS: usize = 16_000; /// Minimum characters to keep when truncating const TRUNCATION_SUFFIX_LEN: usize = 200; -const MEMORY_TOOL_USAGE_SYSTEM_PROMPT: &str = "你可以在处理任务过程中使用长期记忆工具。读取记忆时,优先使用 memory_search:当你需要用户长期偏好、稳定事实、历史决策、持续任务上下文时,先 search;已知 namespace/key 时可用 get;需要浏览最近记忆时可用 list。写入或修改记忆时,再使用 memory_manage。仅在遇到高价值且未来仍有用的信息时写入记忆:用户长期偏好、稳定事实、用户对你的纠正、持续任务/项目上下文、明确决策。不要保存一次性工具结果、临时列表、敏感凭证或不确定推测。写入时优先使用规范 namespace:preferences、profile、tasks、decisions,并优先调用 memory_manage(action='put');同一 namespace/key 可直接覆盖更新。检索时尽量同时提供中文关键词、英文别名,以及可能的 snake_case memory_key 词,例如 email / 邮件 / email_folder_preference。"; +const MEMORY_TOOL_USAGE_SYSTEM_PROMPT: &str = "你可以在处理任务过程中使用长期记忆工具。读取记忆时,优先使用 memory_search:当你需要用户长期偏好、稳定事实、历史决策、持续任务上下文时,先 search;已知 namespace/key 时可用 get;需要浏览最近记忆时可用 list。写入或修改记忆时,再使用 memory_manage。仅在遇到高价值且未来仍有用的信息时写入记忆:用户长期偏好、稳定事实、用户对你的纠正、持续任务/项目上下文、明确决策。不要保存一次性工具结果、临时列表、敏感凭证或不确定推测。写入时优先使用规范 namespace:preferences、profile、tasks、decisions,并优先调用 memory_manage(action='put');同一 namespace/key 可直接覆盖更新。检索时应提供 queries 数组,尽量同时放入中文关键词、英文别名,以及可能的 snake_case memory_key 词,例如 queries=['email', '邮件', 'email_folder_preference']。"; const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__"; const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str = "工具已经启动并进入等待用户操作的状态。请先完成外部操作,完成后直接告诉我继续。"; diff --git a/src/config/mod.rs b/src/config/mod.rs index 58104b0..74484a2 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -134,6 +134,8 @@ pub struct GatewayConfig { pub host: String, #[serde(default = "default_gateway_port")] pub port: u16, + #[serde(default)] + pub show_tool_results: bool, #[serde(default, rename = "session_ttl_hours")] pub session_ttl_hours: Option, #[serde(default = "default_agent_prompt_reinject_every", rename = "agent_prompt_reinject_every")] @@ -167,6 +169,7 @@ impl Default for GatewayConfig { Self { host: default_gateway_host(), port: default_gateway_port(), + show_tool_results: false, session_ttl_hours: None, agent_prompt_reinject_every: default_agent_prompt_reinject_every(), } @@ -395,6 +398,7 @@ mod tests { let config = Config::load(file.path().to_str().unwrap()).unwrap(); assert_eq!(config.gateway.host, "0.0.0.0"); assert_eq!(config.gateway.port, 19876); + assert!(!config.gateway.show_tool_results); assert_eq!(config.gateway.agent_prompt_reinject_every, 120); } @@ -428,6 +432,43 @@ mod tests { .unwrap(); let config = Config::load(file.path().to_str().unwrap()).unwrap(); + assert!(!config.gateway.show_tool_results); assert_eq!(config.gateway.agent_prompt_reinject_every, 100); } + + #[test] + fn test_gateway_config_can_enable_tool_results() { + let file = tempfile::NamedTempFile::new().unwrap(); + std::fs::write( + file.path(), + r#"{ + "providers": { + "aliyun": { + "type": "openai", + "base_url": "https://example.invalid/v1", + "api_key": "test-key", + "extra_headers": {} + } + }, + "models": { + "qwen-plus": { + "model_id": "qwen-plus" + } + }, + "agents": { + "default": { + "provider": "aliyun", + "model": "qwen-plus" + } + }, + "gateway": { + "show_tool_results": true + } +}"#, + ) + .unwrap(); + + let config = Config::load(file.path().to_str().unwrap()).unwrap(); + assert!(config.gateway.show_tool_results); + } } diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 14b9ba3..76e3f77 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -30,12 +30,14 @@ impl GatewayState { // Session TTL from config (default 4 hours) let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4); let agent_prompt_reinject_every = config.gateway.agent_prompt_reinject_every; + let show_tool_results = config.gateway.show_tool_results; let skills = Arc::new(SkillRuntime::from_config(config.skills.clone())); let session_manager = SessionManager::new( session_ttl_hours, agent_prompt_reinject_every, + show_tool_results, provider_config, skills, )?; diff --git a/src/gateway/session.rs b/src/gateway/session.rs index 44c6102..28e61ca 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -357,6 +357,7 @@ pub struct SessionManager { skills: Arc, store: Arc, agent_prompt_reinject_every: u64, + show_tool_results: bool, } struct SessionManagerInner { @@ -416,6 +417,7 @@ impl SessionManager { pub fn new( session_ttl_hours: u64, agent_prompt_reinject_every: u64, + show_tool_results: bool, provider_config: LLMProviderConfig, skills: Arc, ) -> Result { @@ -439,6 +441,7 @@ impl SessionManager { skills, store, agent_prompt_reinject_every, + show_tool_results, }) } @@ -644,7 +647,10 @@ impl SessionManager { result .emitted_messages .iter() - .filter(|message| !message.is_assistant_tool_call_message() || live_emitter.is_none()) + .filter(|message| { + (!message.is_assistant_tool_call_message() || live_emitter.is_none()) + && should_display_message_to_user(self.show_tool_results, message) + }) .flat_map(|message| { OutboundMessage::from_chat_message( channel_name, @@ -678,6 +684,18 @@ impl SessionManager { } } +fn should_display_message_to_user(show_tool_results: bool, message: &ChatMessage) -> bool { + if message.role != "tool" { + return true; + } + + show_tool_results + || matches!( + message.tool_state.as_ref().unwrap_or(&crate::bus::message::ToolMessageState::Completed), + crate::bus::message::ToolMessageState::PendingUserAction + ) +} + #[cfg(test)] mod tests { use super::*; @@ -700,6 +718,21 @@ mod tests { } } + #[test] + fn test_should_display_message_to_user_hides_completed_tool_results_by_default() { + let completed = ChatMessage::tool("call-1", "calculator", "2"); + let pending = ChatMessage::tool_with_state( + "call-2", + "bash", + "waiting", + crate::bus::message::ToolMessageState::PendingUserAction, + ); + + assert!(!should_display_message_to_user(false, &completed)); + assert!(should_display_message_to_user(false, &pending)); + assert!(should_display_message_to_user(true, &completed)); + } + #[test] fn test_parse_in_chat_command_aliases() { assert_eq!(parse_in_chat_command("/new"), Some(InChatCommand::FreshConversation)); diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index 63b4e19..b7da7e9 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -202,6 +202,18 @@ fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec { } } +fn should_display_message_to_user(show_tool_results: bool, message: &ChatMessage) -> bool { + if message.role != "tool" { + return true; + } + + show_tool_results + || matches!( + message.tool_state.as_ref().unwrap_or(&ToolMessageState::Completed), + ToolMessageState::PendingUserAction + ) +} + async fn handle_inbound( state: &Arc, session: &Arc>, @@ -260,7 +272,10 @@ async fn handle_inbound( for outbound in result .emitted_messages .iter() - .filter(|message| !message.is_assistant_tool_call_message()) + .filter(|message| { + !message.is_assistant_tool_call_message() + && should_display_message_to_user(state.config.gateway.show_tool_results, message) + }) .flat_map(ws_outbound_from_chat_message) { let _ = session_guard.send(outbound).await; @@ -405,7 +420,7 @@ async fn handle_inbound( #[cfg(test)] mod tests { - use super::ws_outbound_from_chat_message; + use super::{should_display_message_to_user, ws_outbound_from_chat_message}; use crate::bus::ChatMessage; use crate::bus::message::ToolMessageState; use crate::providers::ToolCall; @@ -461,4 +476,19 @@ mod tests { assert_eq!(outbound.len(), 1); assert!(matches!(outbound[0], WsOutbound::ToolPending { .. })); } + + #[test] + fn test_should_display_message_to_user_hides_completed_tool_results_by_default() { + let completed = ChatMessage::tool("call-1", "calculator", "2"); + let pending = ChatMessage::tool_with_state( + "call-2", + "bash", + "waiting", + ToolMessageState::PendingUserAction, + ); + + assert!(!should_display_message_to_user(false, &completed)); + assert!(should_display_message_to_user(false, &pending)); + assert!(should_display_message_to_user(true, &completed)); + } } diff --git a/src/tools/memory_manage.rs b/src/tools/memory_manage.rs index 2e97d3d..223af36 100644 --- a/src/tools/memory_manage.rs +++ b/src/tools/memory_manage.rs @@ -23,7 +23,7 @@ impl Tool for MemoryManageTool { } fn description(&self) -> &str { - "Create, update, or delete long-term user memories stored in SQLite. Supports actions: list, search, get, put, update, delete. Prefer memory_search when you only need to retrieve memory. Use memory_manage mainly when you need to write or modify memory records. Memories are scoped to the current channel and sender, and record the originating session/message when available." + "Create, update, or delete long-term user memories stored in SQLite. Supports actions: put, update, delete. Use memory_search for all retrieval, including search, get, and list. Memories are scoped to the current channel and sender, and record the originating session/message when available." } fn parameters_schema(&self) -> serde_json::Value { @@ -32,17 +32,13 @@ impl Tool for MemoryManageTool { "properties": { "action": { "type": "string", - "enum": ["list", "search", "get", "put", "update", "delete"], - "description": "Management action to perform. Prefer memory_search for retrieval-only access. Use 'put' to create or overwrite, 'update' to modify an existing record, and 'delete' to remove one." + "enum": ["put", "update", "delete"], + "description": "Management action to perform. Use 'put' to create or overwrite, 'update' to modify an existing record, and 'delete' to remove one. Use memory_search for retrieval." }, "namespace": { "type": "string", "description": "Optional memory namespace filter, such as profile, preferences, or tasks" }, - "query": { - "type": "string", - "description": "Keyword query for full-text memory search across namespace, memory_key, and content. Prefer concise bilingual keywords when possible, for example Chinese plus English aliases and likely snake_case key terms." - }, "key": { "type": "string", "description": "Exact memory key within the namespace" @@ -50,12 +46,6 @@ impl Tool for MemoryManageTool { "content": { "type": "string", "description": "Memory content for put/update" - }, - "limit": { - "type": "integer", - "description": "Maximum number of memories to return", - "minimum": 1, - "default": 20 } }, "required": ["action"] @@ -82,56 +72,9 @@ impl Tool for MemoryManageTool { }; let namespace = args.get("namespace").and_then(|value| value.as_str()); - let query = args.get("query").and_then(|value| value.as_str()); let key = args.get("key").and_then(|value| value.as_str()); let payload = match action { - "list" => { - let limit = args - .get("limit") - .and_then(|value| value.as_u64()) - .unwrap_or(20) as usize; - let memories = self - .store - .list_memories("user", &scope_key, namespace, limit)?; - json!({ - "count": memories.len(), - "memories": memories.into_iter().map(memory_to_json).collect::>() - }) - } - "search" => { - let query = match query { - Some(query) if !query.trim().is_empty() => query, - _ => return Ok(error_result("Missing required parameter: query")), - }; - let limit = args - .get("limit") - .and_then(|value| value.as_u64()) - .unwrap_or(20) as usize; - let memories = self - .store - .search_memories("user", &scope_key, query, namespace, limit)?; - json!({ - "query": query, - "count": memories.len(), - "memories": memories.into_iter().map(memory_to_json).collect::>() - }) - } - "get" => { - let namespace = match namespace { - Some(namespace) => namespace, - None => return Ok(error_result("Missing required parameter: namespace")), - }; - let key = match key { - Some(key) => key, - None => return Ok(error_result("Missing required parameter: key")), - }; - - match self.store.get_memory("user", &scope_key, namespace, key)? { - Some(memory) => memory_to_json(memory), - None => return Ok(error_result(&format!("memory '{}.{}' not found", namespace, key))), - } - } "put" => { let input = match build_memory_upsert(context, &scope_key, &args, true) { Ok(input) => input, @@ -273,7 +216,7 @@ mod tests { use super::*; #[tokio::test] - async fn test_memory_manage_put_and_get() { + async fn test_memory_manage_put_returns_saved_memory() { let store = Arc::new(SessionStore::in_memory().unwrap()); let tool = MemoryManageTool::new(store); let context = ToolContext { @@ -298,64 +241,8 @@ mod tests { .await .unwrap(); assert!(put.success); - - let get = tool - .execute_with_context( - &context, - json!({ - "action": "get", - "namespace": "profile", - "key": "language" - }), - ) - .await - .unwrap(); - assert!(get.success); - assert!(get.output.contains("Rust")); - assert!(get.output.contains("msg-1")); - } - - #[tokio::test] - async fn test_memory_manage_search() { - let store = Arc::new(SessionStore::in_memory().unwrap()); - let tool = MemoryManageTool::new(store); - let context = ToolContext { - channel_name: Some("feishu".to_string()), - sender_id: Some("user-1".to_string()), - chat_id: Some("chat-1".to_string()), - session_id: Some("feishu:chat-1".to_string()), - message_id: Some("msg-1".to_string()), - message_seq: Some(1), - }; - - let put = tool - .execute_with_context( - &context, - json!({ - "action": "put", - "namespace": "profile", - "key": "editor", - "content": "Prefers rust-analyzer over clippy hints" - }), - ) - .await - .unwrap(); - assert!(put.success); - - let search = tool - .execute_with_context( - &context, - json!({ - "action": "search", - "query": "rust-analyzer", - "limit": 5 - }), - ) - .await - .unwrap(); - assert!(search.success); - assert!(search.output.contains("rust-analyzer")); - assert!(search.output.contains("editor")); + assert!(put.output.contains("Rust")); + assert!(put.output.contains("msg-1")); } #[tokio::test] @@ -376,4 +263,30 @@ mod tests { assert!(!result.success); assert!(result.error.unwrap().contains("channel_name")); } + + #[tokio::test] + async fn test_memory_manage_rejects_read_actions() { + let store = Arc::new(SessionStore::in_memory().unwrap()); + let tool = MemoryManageTool::new(store); + let context = ToolContext { + channel_name: Some("feishu".to_string()), + sender_id: Some("user-1".to_string()), + ..ToolContext::default() + }; + + let result = tool + .execute_with_context( + &context, + json!({ + "action": "get", + "namespace": "profile", + "key": "language" + }), + ) + .await + .unwrap(); + + assert!(!result.success); + assert!(result.error.unwrap().contains("Unsupported action")); + } } \ No newline at end of file diff --git a/src/tools/memory_search.rs b/src/tools/memory_search.rs index a549c14..4f500df 100644 --- a/src/tools/memory_search.rs +++ b/src/tools/memory_search.rs @@ -23,7 +23,7 @@ impl Tool for MemorySearchTool { } fn description(&self) -> &str { - "Search and read long-term user memories stored in SQLite. Use this tool when you need prior preferences, stable facts, historical decisions, or ongoing task context. This tool is read-only and supports three actions: search for keyword lookup, get for exact namespace/key lookup, and list for browsing recent memories. Prefer this tool over memory_manage when you only need to retrieve memory." + "Search and read long-term user memories stored in SQLite. Use this tool when you need prior preferences, stable facts, historical decisions, or ongoing task context. This tool is read-only and supports three actions: search for multi-keyword recall, get for exact namespace/key lookup, and list for browsing recent memories. Prefer this tool over memory_manage when you only need to retrieve memory." } fn parameters_schema(&self) -> serde_json::Value { @@ -33,15 +33,19 @@ impl Tool for MemorySearchTool { "action": { "type": "string", "enum": ["search", "get", "list"], - "description": "Retrieval action. Use 'search' for keyword recall, 'get' for an exact namespace/key read, and 'list' to browse recent memories." + "description": "Retrieval action. Use 'search' for multi-keyword recall, 'get' for an exact namespace/key read, and 'list' to browse recent memories." }, "namespace": { "type": "string", "description": "Optional namespace filter, such as profile, preferences, tasks, or decisions. Required for get." }, - "query": { - "type": "string", - "description": "Keyword query for memory search. Prefer concise bilingual keywords, English aliases, and likely snake_case memory_key terms when known. Required for search." + "queries": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Keyword queries for memory search. Provide multiple concise bilingual keywords, English aliases, and likely snake_case memory_key terms when known. Search matches any of the provided entries. Required for search.", + "minItems": 1 }, "key": { "type": "string", @@ -78,7 +82,6 @@ impl Tool for MemorySearchTool { }; let namespace = args.get("namespace").and_then(|value| value.as_str()); - let query = args.get("query").and_then(|value| value.as_str()); let key = args.get("key").and_then(|value| value.as_str()); let payload = match action { @@ -94,19 +97,28 @@ impl Tool for MemorySearchTool { }) } "search" => { - let query = match query { - Some(query) if !query.trim().is_empty() => query, - _ => return Ok(error_result("Missing required parameter: query")), + let queries = match args.get("queries").and_then(|value| value.as_array()) { + Some(queries) => queries + .iter() + .filter_map(|value| value.as_str()) + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(ToOwned::to_owned) + .collect::>(), + None => return Ok(error_result("Missing required parameter: queries")), }; + if queries.is_empty() { + return Ok(error_result("Missing required parameter: queries")); + } let limit = args .get("limit") .and_then(|value| value.as_u64()) .unwrap_or(10) as usize; let memories = self .store - .search_memories("user", &scope_key, query, namespace, limit)?; + .search_memories_any("user", &scope_key, &queries, namespace, limit)?; json!({ - "query": query, + "queries": queries, "count": memories.len(), "memories": memories.into_iter().map(memory_to_json).collect::>() }) @@ -218,7 +230,7 @@ mod tests { &context, json!({ "action": "search", - "query": "Chinese language", + "queries": ["Chinese", "language"], "limit": 5 }), ) @@ -256,4 +268,22 @@ mod tests { assert!(!result.success); assert!(result.error.unwrap().contains("channel_name")); } + + #[tokio::test] + async fn test_memory_search_search_requires_queries() { + let store = Arc::new(SessionStore::in_memory().unwrap()); + let tool = MemorySearchTool::new(store); + let context = ToolContext { + channel_name: Some("feishu".to_string()), + sender_id: Some("user-1".to_string()), + ..ToolContext::default() + }; + + let result = tool + .execute_with_context(&context, json!({ "action": "search", "queries": [] })) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("queries")); + } } \ No newline at end of file