From d35e89a44c17c1ae21488ad30733ddb6df101569 Mon Sep 17 00:00:00 2001 From: ooodc <549496103@qq.com> Date: Wed, 22 Apr 2026 06:57:22 +0800 Subject: [PATCH] # --- src/agent/agent_loop.rs | 13 +- src/bus/message.rs | 213 +++++++++++++++++ src/client/mod.rs | 10 + src/gateway/mod.rs | 17 +- src/gateway/session.rs | 68 ++++-- src/gateway/ws.rs | 64 ++++- src/protocol.rs | 17 ++ src/storage/mod.rs | 443 +++++++++++++++++++++++++++++++++++ src/tools/memory_manage.rs | 313 +++++++++++++++++++++++++ src/tools/mod.rs | 4 +- src/tools/traits.rs | 18 ++ tests/test_request_format.rs | 52 ++++ 12 files changed, 1195 insertions(+), 37 deletions(-) create mode 100644 src/tools/memory_manage.rs diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index b249714..c790c5d 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -7,7 +7,7 @@ use crate::observability::{ use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall}; use crate::skills::SkillRuntime; use crate::storage::SessionStore; -use crate::tools::ToolRegistry; +use crate::tools::{ToolContext, ToolRegistry}; use std::collections::VecDeque; use std::hash::{Hash, Hasher}; use std::io::Read; @@ -225,6 +225,7 @@ pub struct AgentLoop { skills: Arc, skill_event_store: Option>, skill_event_session_id: Option, + tool_context: ToolContext, observer: Option>, max_iterations: usize, } @@ -247,6 +248,7 @@ impl AgentLoop { skills: Arc::new(SkillRuntime::default()), skill_event_store: None, skill_event_session_id: None, + tool_context: ToolContext::default(), observer: None, max_iterations, }) @@ -263,6 +265,7 @@ impl AgentLoop { skills: Arc::new(SkillRuntime::default()), skill_event_store: None, skill_event_session_id: None, + tool_context: ToolContext::default(), observer: None, max_iterations, }) @@ -283,6 +286,7 @@ impl AgentLoop { skills, skill_event_store: None, skill_event_session_id: None, + tool_context: ToolContext::default(), observer: None, max_iterations, }) @@ -294,6 +298,11 @@ impl AgentLoop { self } + pub fn with_tool_context(mut self, context: ToolContext) -> Self { + self.tool_context = context; + self + } + /// Set an observer for tracking events. pub fn with_observer(mut self, observer: Arc) -> Self { self.observer = Some(observer); @@ -622,7 +631,7 @@ impl AgentLoop { } }; - match tool.execute(tool_call.arguments.clone()).await { + match tool.execute_with_context(&self.tool_context, tool_call.arguments.clone()).await { Ok(result) => { if result.success { ToolExecutionOutcome::success(result.output) diff --git a/src/bus/message.rs b/src/bus/message.rs index 1b2386c..959e384 100644 --- a/src/bus/message.rs +++ b/src/bus/message.rs @@ -191,12 +191,165 @@ pub struct OutboundMessage { pub reply_to: Option, pub media: Vec, pub metadata: HashMap, + pub event_kind: OutboundEventKind, + pub role: String, + pub tool_call_id: Option, + pub tool_name: Option, + pub tool_arguments: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum OutboundEventKind { + AssistantResponse, + ToolCall, + ToolResult, } impl OutboundMessage { pub fn is_stream_delta(&self) -> bool { self.metadata.get("_stream_delta").is_some() } + + pub fn assistant( + channel: impl Into, + chat_id: impl Into, + content: impl Into, + reply_to: Option, + metadata: HashMap, + ) -> Self { + Self { + channel: channel.into(), + chat_id: chat_id.into(), + content: content.into(), + reply_to, + media: Vec::new(), + metadata, + event_kind: OutboundEventKind::AssistantResponse, + role: "assistant".to_string(), + tool_call_id: None, + tool_name: None, + tool_arguments: None, + } + } + + pub fn tool_call( + channel: impl Into, + chat_id: impl Into, + message_id: impl Into, + tool_name: impl Into, + tool_arguments: serde_json::Value, + reply_to: Option, + metadata: HashMap, + ) -> Self { + let tool_name = tool_name.into(); + let content = format_tool_call_content(&tool_name, &tool_arguments); + Self { + channel: channel.into(), + chat_id: chat_id.into(), + content, + reply_to, + media: Vec::new(), + metadata, + event_kind: OutboundEventKind::ToolCall, + role: "assistant".to_string(), + tool_call_id: Some(message_id.into()), + tool_name: Some(tool_name), + tool_arguments: Some(tool_arguments), + } + } + + pub fn tool_result( + channel: impl Into, + chat_id: impl Into, + tool_call_id: impl Into, + tool_name: impl Into, + content: impl Into, + reply_to: Option, + metadata: HashMap, + ) -> Self { + let tool_name = tool_name.into(); + let raw_content = content.into(); + let content = format_tool_result_content(&tool_name, &raw_content); + Self { + channel: channel.into(), + chat_id: chat_id.into(), + content, + reply_to, + media: Vec::new(), + metadata, + event_kind: OutboundEventKind::ToolResult, + role: "tool".to_string(), + tool_call_id: Some(tool_call_id.into()), + tool_name: Some(tool_name), + tool_arguments: None, + } + } + + pub fn from_chat_message( + channel: &str, + chat_id: &str, + reply_to: Option, + metadata: &HashMap, + message: &ChatMessage, + ) -> Vec { + match message.role.as_str() { + "assistant" => { + if let Some(tool_calls) = &message.tool_calls { + tool_calls + .iter() + .map(|tool_call| { + Self::tool_call( + channel.to_string(), + chat_id.to_string(), + tool_call.id.clone(), + tool_call.name.clone(), + tool_call.arguments.clone(), + reply_to.clone(), + metadata.clone(), + ) + }) + .collect() + } else { + vec![Self::assistant( + channel.to_string(), + chat_id.to_string(), + message.content.clone(), + reply_to, + metadata.clone(), + )] + } + } + "tool" => vec![Self::tool_result( + channel.to_string(), + chat_id.to_string(), + message.tool_call_id.clone().unwrap_or_else(|| message.id.clone()), + message.tool_name.clone().unwrap_or_else(|| "tool".to_string()), + message.content.clone(), + reply_to, + metadata.clone(), + )], + _ => Vec::new(), + } + } +} + +fn format_tool_call_content(tool_name: &str, tool_arguments: &serde_json::Value) -> String { + format!( + "调用工具: {}\n\n输入参数:\n{}", + tool_name, + format_json_value(tool_arguments), + ) +} + +fn format_tool_result_content(tool_name: &str, content: &str) -> String { + format!("工具结果: {}\n\n{}", tool_name, content) +} + +fn format_json_value(value: &serde_json::Value) -> String { + match value { + serde_json::Value::Object(map) if map.is_empty() => "{}".to_string(), + other => serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()), + } } // ============================================================================ @@ -209,3 +362,63 @@ fn current_timestamp() -> i64 { .unwrap() .as_millis() as i64 } + +#[cfg(test)] +mod tests { + use super::{ChatMessage, OutboundEventKind, OutboundMessage}; + use crate::providers::ToolCall; + use serde_json::json; + use std::collections::HashMap; + + #[test] + fn test_from_chat_message_expands_tool_calls() { + let message = ChatMessage::assistant_with_tool_calls( + "", + vec![ + ToolCall { + id: "call-1".to_string(), + name: "calculator".to_string(), + arguments: json!({"expression": "1 + 1"}), + }, + ToolCall { + id: "call-2".to_string(), + name: "file_read".to_string(), + arguments: json!({"path": "README.md"}), + }, + ], + ); + + let outbound = OutboundMessage::from_chat_message( + "feishu", + "chat-1", + None, + &HashMap::new(), + &message, + ); + + assert_eq!(outbound.len(), 2); + assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolCall); + assert_eq!(outbound[0].tool_name.as_deref(), Some("calculator")); + assert_eq!(outbound[0].tool_arguments.as_ref().unwrap()["expression"], "1 + 1"); + assert_eq!(outbound[1].tool_name.as_deref(), Some("file_read")); + } + + #[test] + fn test_from_chat_message_maps_tool_result() { + let message = ChatMessage::tool("call-9", "calculator", "2"); + + let outbound = OutboundMessage::from_chat_message( + "feishu", + "chat-1", + None, + &HashMap::new(), + &message, + ); + + assert_eq!(outbound.len(), 1); + assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolResult); + assert_eq!(outbound[0].tool_call_id.as_deref(), Some("call-9")); + assert_eq!(outbound[0].tool_name.as_deref(), Some("calculator")); + assert!(outbound[0].content.contains("工具结果: calculator")); + } +} diff --git a/src/client/mod.rs b/src/client/mod.rs index a3178f9..78c9ea8 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -40,6 +40,10 @@ fn parse_message(raw: &str) -> Result { serde_json::from_str(raw) } +fn format_json(value: &serde_json::Value) -> String { + serde_json::to_string_pretty(value).unwrap_or_else(|_| value.to_string()) +} + pub async fn run(gateway_url: &str) -> Result<(), Box> { let (ws_stream, _) = connect_async(gateway_url).await?; tracing::info!(url = %gateway_url, "Connected to gateway"); @@ -63,6 +67,12 @@ pub async fn run(gateway_url: &str) -> Result<(), Box> { WsOutbound::AssistantResponse { content, .. } => { input.write_response(&content).await?; } + WsOutbound::ToolCall { tool_name, arguments, .. } => { + input.write_output(&format!("Tool call: {}\n{}\n", tool_name, format_json(&arguments))).await?; + } + WsOutbound::ToolResult { tool_name, content, .. } => { + input.write_output(&format!("Tool result: {}\n{}\n", tool_name, content)).await?; + } WsOutbound::Error { message, .. } => { input.write_output(&format!("Error: {}", message)).await?; } diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 6a2a83f..5ab681d 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -81,20 +81,15 @@ impl GatewayState { &inbound.content, inbound.media, ).await { - Ok(response_content) => { + Ok(outbound_messages) => { // Forward channel-specific metadata from inbound to outbound. // This allows channels to propagate context (e.g. feishu message_id for reaction cleanup) // without gateway needing channel-specific code. - let outbound = crate::bus::OutboundMessage { - channel: inbound.channel.clone(), - chat_id: inbound.chat_id.clone(), - content: response_content, - reply_to: None, - media: vec![], - metadata: inbound.forwarded_metadata, - }; - if let Err(e) = bus_for_inbound.publish_outbound(outbound).await { - tracing::error!(error = %e, "Failed to publish outbound"); + for mut outbound in outbound_messages { + outbound.metadata.extend(inbound.forwarded_metadata.clone()); + if let Err(e) = bus_for_inbound.publish_outbound(outbound).await { + tracing::error!(error = %e, "Failed to publish outbound"); + } } } Err(e) => { diff --git a/src/gateway/session.rs b/src/gateway/session.rs index 0053e83..c4ad5b1 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::{Mutex, mpsc}; use uuid::Uuid; -use crate::bus::ChatMessage; +use crate::bus::{ChatMessage, OutboundMessage}; use crate::config::LLMProviderConfig; use crate::agent::{AgentLoop, AgentError, ContextCompressor}; use crate::protocol::WsOutbound; @@ -11,7 +11,8 @@ use crate::skills::SkillRuntime; use crate::storage::{SessionRecord, SessionStore, persistent_session_id}; use crate::tools::{ BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, - HttpRequestTool, SkillListTool, SkillManageTool, ToolRegistry, WebFetchTool, + HttpRequestTool, MemoryManageTool, SkillListTool, SkillManageTool, ToolContext, ToolRegistry, + WebFetchTool, }; /// Session 按 channel 隔离,每个 channel 一个 Session @@ -197,13 +198,30 @@ impl Session { } /// 创建一个临时的 AgentLoop 实例来处理消息 - pub fn create_agent(&self, chat_id: &str) -> Result { + pub fn create_agent( + &self, + chat_id: &str, + sender_id: Option<&str>, + message_id: Option<&str>, + ) -> Result { + let session_id = self.persistent_session_id(chat_id); AgentLoop::with_tools_and_skills( self.provider_config.clone(), self.tools.clone(), self.skills.clone(), ) - .map(|agent| agent.with_skill_event_store(self.store.clone(), self.persistent_session_id(chat_id))) + .map(|agent| { + agent + .with_skill_event_store(self.store.clone(), session_id.clone()) + .with_tool_context(ToolContext { + channel_name: Some(self.channel_name.clone()), + sender_id: sender_id.map(str::to_string), + chat_id: Some(chat_id.to_string()), + session_id: Some(session_id), + message_id: message_id.map(str::to_string), + message_seq: None, + }) + }) } } @@ -223,12 +241,13 @@ struct SessionManagerInner { session_ttl: Duration, } -fn default_tools(skills: Arc) -> ToolRegistry { +fn default_tools(skills: Arc, store: Arc) -> ToolRegistry { let mut registry = ToolRegistry::new(); registry.register(CalculatorTool::new()); registry.register(FileReadTool::new()); registry.register(FileWriteTool::new()); registry.register(FileEditTool::new()); + registry.register(MemoryManageTool::new(store)); registry.register(SkillListTool::new(skills.clone())); registry.register(SkillManageTool::new(skills)); registry.register(BashTool::new()); @@ -290,7 +309,7 @@ impl SessionManager { session_ttl: Duration::from_secs(session_ttl_hours * 3600), })), provider_config, - tools: Arc::new(default_tools(skills.clone())), + tools: Arc::new(default_tools(skills.clone(), store.clone())), skills, store, }) @@ -414,11 +433,11 @@ impl SessionManager { pub async fn handle_message( &self, channel_name: &str, - _sender_id: &str, + sender_id: &str, chat_id: &str, content: &str, media: Vec, - ) -> Result { + ) -> Result, AgentError> { #[cfg(debug_assertions)] { tracing::debug!( @@ -453,7 +472,13 @@ impl SessionManager { session_guard.ensure_chat_loaded(chat_id)?; if let Some(command_response) = handle_in_chat_command(&mut session_guard, chat_id, content)? { - return Ok(command_response); + return Ok(vec![OutboundMessage::assistant( + channel_name.to_string(), + chat_id.to_string(), + command_response, + None, + HashMap::new(), + )]); } // 添加用户消息到历史 @@ -463,6 +488,7 @@ impl SessionManager { tracing::debug!(media_count = %media.len(), media_refs = ?media_refs, "Adding user message with media"); } let user_message = session_guard.create_user_message(content, media_refs); + let user_message_id = user_message.id.clone(); session_guard.append_persisted_message(chat_id, user_message)?; // 获取完整历史 @@ -476,24 +502,36 @@ impl SessionManager { session_guard.record_skill_offer(chat_id)?; // 创建 agent 并处理 - let agent = session_guard.create_agent(chat_id)?; + let agent = session_guard.create_agent(chat_id, Some(sender_id), Some(&user_message_id))?; let result = agent.process(history).await?; // 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复 session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?; - result.final_response + result + .emitted_messages + .iter() + .flat_map(|message| { + OutboundMessage::from_chat_message( + channel_name, + chat_id, + None, + &HashMap::new(), + message, + ) + }) + .collect::>() }; #[cfg(debug_assertions)] tracing::debug!( channel = %channel_name, chat_id = %chat_id, - response_len = response.content.len(), - "Agent response received" + outbound_count = response.len(), + "Agent response sequence received" ); - Ok(response.content) + Ok(response) } /// 清除指定 session 的所有历史 @@ -541,7 +579,7 @@ mod tests { let store = Arc::new(SessionStore::in_memory().unwrap()); let (user_tx, _user_rx) = mpsc::channel(4); let skills = Arc::new(SkillRuntime::default()); - let tools = Arc::new(default_tools(skills.clone())); + let tools = Arc::new(default_tools(skills.clone(), store.clone())); let mut session = Session::new( "feishu".to_string(), test_provider_config(), diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index 41702d8..9d37879 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -4,6 +4,7 @@ use axum::extract::State; use axum::response::Response; use futures_util::{SinkExt, StreamExt}; use tokio::sync::{mpsc, Mutex}; +use crate::bus::ChatMessage; use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound}; use super::{GatewayState, session::{Session, handle_in_chat_command}}; @@ -140,6 +141,52 @@ fn to_session_summary(record: crate::storage::SessionRecord) -> SessionSummary { } } +fn format_tool_arguments(arguments: &serde_json::Value) -> String { + serde_json::to_string_pretty(arguments).unwrap_or_else(|_| arguments.to_string()) +} + +fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec { + match message.role.as_str() { + "assistant" => { + if let Some(tool_calls) = &message.tool_calls { + tool_calls + .iter() + .map(|tool_call| WsOutbound::ToolCall { + id: message.id.clone(), + tool_call_id: tool_call.id.clone(), + tool_name: tool_call.name.clone(), + arguments: tool_call.arguments.clone(), + content: format!( + "调用工具: {}\n\n输入参数:\n{}", + tool_call.name, + format_tool_arguments(&tool_call.arguments), + ), + role: message.role.clone(), + }) + .collect() + } else { + vec![WsOutbound::AssistantResponse { + id: message.id.clone(), + content: message.content.clone(), + role: message.role.clone(), + }] + } + } + "tool" => vec![WsOutbound::ToolResult { + id: message.id.clone(), + tool_call_id: message.tool_call_id.clone().unwrap_or_else(|| message.id.clone()), + tool_name: message.tool_name.clone().unwrap_or_else(|| "tool".to_string()), + content: format!( + "工具结果: {}\n\n{}", + message.tool_name.clone().unwrap_or_else(|| "tool".to_string()), + message.content, + ), + role: message.role.clone(), + }], + _ => Vec::new(), + } +} + async fn handle_inbound( state: &Arc, session: &Arc>, @@ -166,6 +213,7 @@ async fn handle_inbound( } let user_message = session_guard.create_user_message(&content, Vec::new()); + let user_message_id = user_message.id.clone(); session_guard.append_persisted_message(&chat_id, user_message)?; let raw_history = session_guard.get_or_create_history(&chat_id).clone(); @@ -183,17 +231,17 @@ async fn handle_inbound( session_guard.record_skill_offer(&chat_id)?; - let agent = session_guard.create_agent(&chat_id)?; + let agent = session_guard.create_agent(&chat_id, None, Some(&user_message_id))?; match agent.process(history).await { Ok(result) => { session_guard.append_persisted_messages(&chat_id, result.emitted_messages.clone())?; - let _ = session_guard - .send(WsOutbound::AssistantResponse { - id: result.final_response.id, - content: result.final_response.content, - role: result.final_response.role, - }) - .await; + for outbound in result + .emitted_messages + .iter() + .flat_map(ws_outbound_from_chat_message) + { + let _ = session_guard.send(outbound).await; + } } Err(error) => { tracing::error!(chat_id = %chat_id, error = %error, "Agent process error"); diff --git a/src/protocol.rs b/src/protocol.rs index b301c0e..17e2bf9 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -71,6 +71,23 @@ pub enum WsInbound { pub enum WsOutbound { #[serde(rename = "assistant_response")] AssistantResponse { id: String, content: String, role: String }, + #[serde(rename = "tool_call")] + ToolCall { + id: String, + tool_call_id: String, + tool_name: String, + arguments: serde_json::Value, + content: String, + role: String, + }, + #[serde(rename = "tool_result")] + ToolResult { + id: String, + tool_call_id: String, + tool_name: String, + content: String, + role: String, + }, #[serde(rename = "error")] Error { code: String, message: String }, #[serde(rename = "session_established")] diff --git a/src/storage/mod.rs b/src/storage/mod.rs index d5ec478..32b9802 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -42,6 +42,39 @@ pub struct SessionRecord { pub reset_cutoff_seq: i64, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MemoryRecord { + pub id: String, + pub scope_kind: String, + pub scope_key: String, + pub namespace: String, + pub memory_key: String, + pub content: String, + pub source_type: String, + pub source_session_id: Option, + pub source_message_id: Option, + pub source_message_seq: Option, + pub source_channel_name: Option, + pub source_chat_id: Option, + pub created_at: i64, + pub updated_at: i64, +} + +#[derive(Debug, Clone)] +pub struct MemoryUpsert { + pub scope_kind: String, + pub scope_key: String, + pub namespace: String, + pub memory_key: String, + pub content: String, + pub source_type: String, + pub source_session_id: Option, + pub source_message_id: Option, + pub source_message_seq: Option, + pub source_channel_name: Option, + pub source_chat_id: Option, +} + #[derive(Clone)] pub struct SessionStore { conn: Arc>, @@ -122,6 +155,56 @@ impl SessionStore { ON skill_events(session_id, created_at DESC); CREATE INDEX IF NOT EXISTS idx_skill_events_type_created ON skill_events(event_type, created_at DESC); + + CREATE TABLE IF NOT EXISTS memories ( + id TEXT PRIMARY KEY, + scope_kind TEXT NOT NULL, + scope_key TEXT NOT NULL, + namespace TEXT NOT NULL, + memory_key TEXT NOT NULL, + content TEXT NOT NULL, + source_type TEXT NOT NULL, + source_session_id TEXT, + source_message_id TEXT, + source_message_seq INTEGER, + source_channel_name TEXT, + source_chat_id TEXT, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL, + UNIQUE(scope_kind, scope_key, namespace, memory_key) + ); + + CREATE INDEX IF NOT EXISTS idx_memories_scope_updated + ON memories(scope_kind, scope_key, updated_at DESC); + CREATE INDEX IF NOT EXISTS idx_memories_scope_namespace_updated + ON memories(scope_kind, scope_key, namespace, updated_at DESC); + CREATE INDEX IF NOT EXISTS idx_memories_source_session + ON memories(source_session_id, updated_at DESC); + + CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5( + namespace, + memory_key, + content, + content='memories', + content_rowid='rowid' + ); + + CREATE TRIGGER IF NOT EXISTS memories_ai AFTER INSERT ON memories BEGIN + INSERT INTO memories_fts(rowid, namespace, memory_key, content) + VALUES (new.rowid, new.namespace, new.memory_key, new.content); + END; + + CREATE TRIGGER IF NOT EXISTS memories_ad AFTER DELETE ON memories BEGIN + INSERT INTO memories_fts(memories_fts, rowid, namespace, memory_key, content) + VALUES ('delete', old.rowid, old.namespace, old.memory_key, old.content); + END; + + CREATE TRIGGER IF NOT EXISTS memories_au AFTER UPDATE ON memories BEGIN + INSERT INTO memories_fts(memories_fts, rowid, namespace, memory_key, content) + VALUES ('delete', old.rowid, old.namespace, old.memory_key, old.content); + INSERT INTO memories_fts(rowid, namespace, memory_key, content) + VALUES (new.rowid, new.namespace, new.memory_key, new.content); + END; ", )?; @@ -417,6 +500,246 @@ impl SessionStore { Ok(events) } + pub fn put_memory(&self, input: &MemoryUpsert) -> Result { + let now = current_timestamp(); + let conn = self.conn.lock().expect("session db mutex poisoned"); + let tx = conn.unchecked_transaction()?; + + let existing: Option<(String, i64)> = tx + .query_row( + " + SELECT id, created_at + FROM memories + WHERE scope_kind = ?1 AND scope_key = ?2 AND namespace = ?3 AND memory_key = ?4 + ", + params![ + input.scope_kind, + input.scope_key, + input.namespace, + input.memory_key, + ], + |row| Ok((row.get(0)?, row.get(1)?)), + ) + .optional()?; + + let (id, created_at) = existing + .unwrap_or_else(|| (uuid::Uuid::new_v4().to_string(), now)); + + tx.execute( + " + INSERT INTO memories ( + id, scope_kind, scope_key, namespace, memory_key, content, + source_type, source_session_id, source_message_id, source_message_seq, + source_channel_name, source_chat_id, created_at, updated_at + ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14) + ON CONFLICT(scope_kind, scope_key, namespace, memory_key) DO UPDATE SET + content = excluded.content, + source_type = excluded.source_type, + source_session_id = excluded.source_session_id, + source_message_id = excluded.source_message_id, + source_message_seq = excluded.source_message_seq, + source_channel_name = excluded.source_channel_name, + source_chat_id = excluded.source_chat_id, + updated_at = excluded.updated_at + ", + params![ + id, + input.scope_kind, + input.scope_key, + input.namespace, + input.memory_key, + input.content, + input.source_type, + input.source_session_id, + input.source_message_id, + input.source_message_seq, + input.source_channel_name, + input.source_chat_id, + created_at, + now, + ], + )?; + + tx.commit()?; + drop(conn); + + self.get_memory( + &input.scope_kind, + &input.scope_key, + &input.namespace, + &input.memory_key, + )? + .ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into()) + } + + pub fn get_memory( + &self, + scope_kind: &str, + scope_key: &str, + namespace: &str, + memory_key: &str, + ) -> Result, StorageError> { + let conn = self.conn.lock().expect("session db mutex poisoned"); + let mut stmt = conn.prepare( + " + SELECT id, scope_kind, scope_key, namespace, memory_key, content, + source_type, source_session_id, source_message_id, source_message_seq, + source_channel_name, source_chat_id, created_at, updated_at + FROM memories + WHERE scope_kind = ?1 AND scope_key = ?2 AND namespace = ?3 AND memory_key = ?4 + ", + )?; + + stmt.query_row( + params![scope_kind, scope_key, namespace, memory_key], + map_memory_record, + ) + .optional() + .map_err(StorageError::from) + } + + pub fn list_memories( + &self, + scope_kind: &str, + scope_key: &str, + namespace: Option<&str>, + limit: usize, + ) -> Result, StorageError> { + let conn = self.conn.lock().expect("session db mutex poisoned"); + let limit = limit.max(1) as i64; + let mut memories = Vec::new(); + + if let Some(namespace) = namespace { + let mut stmt = conn.prepare( + " + SELECT id, scope_kind, scope_key, namespace, memory_key, content, + source_type, source_session_id, source_message_id, source_message_seq, + source_channel_name, source_chat_id, created_at, updated_at + FROM memories + WHERE scope_kind = ?1 AND scope_key = ?2 AND namespace = ?3 + ORDER BY updated_at DESC, created_at DESC + LIMIT ?4 + ", + )?; + let rows = stmt.query_map(params![scope_kind, scope_key, namespace, limit], map_memory_record)?; + for row in rows { + memories.push(row?); + } + } else { + let mut stmt = conn.prepare( + " + SELECT id, scope_kind, scope_key, namespace, memory_key, content, + source_type, source_session_id, source_message_id, source_message_seq, + source_channel_name, source_chat_id, created_at, updated_at + FROM memories + WHERE scope_kind = ?1 AND scope_key = ?2 + ORDER BY updated_at DESC, created_at DESC + LIMIT ?3 + ", + )?; + let rows = stmt.query_map(params![scope_kind, scope_key, limit], map_memory_record)?; + for row in rows { + memories.push(row?); + } + } + + Ok(memories) + } + + pub fn update_memory( + &self, + input: &MemoryUpsert, + ) -> Result, StorageError> { + if self + .get_memory( + &input.scope_kind, + &input.scope_key, + &input.namespace, + &input.memory_key, + )? + .is_none() + { + return Ok(None); + } + + self.put_memory(input).map(Some) + } + + pub fn delete_memory( + &self, + scope_kind: &str, + scope_key: &str, + namespace: &str, + memory_key: &str, + ) -> Result { + let conn = self.conn.lock().expect("session db mutex poisoned"); + let changed = conn.execute( + " + DELETE FROM memories + WHERE scope_kind = ?1 AND scope_key = ?2 AND namespace = ?3 AND memory_key = ?4 + ", + params![scope_kind, scope_key, namespace, memory_key], + )?; + Ok(changed > 0) + } + + pub fn search_memories( + &self, + scope_kind: &str, + scope_key: &str, + query: &str, + namespace: Option<&str>, + limit: usize, + ) -> Result, StorageError> { + let conn = self.conn.lock().expect("session db mutex poisoned"); + let limit = limit.max(1) as i64; + let query = quote_fts_query(query); + let mut memories = Vec::new(); + + if let Some(namespace) = namespace { + let mut stmt = conn.prepare( + " + SELECT m.id, m.scope_kind, m.scope_key, m.namespace, m.memory_key, m.content, + m.source_type, m.source_session_id, m.source_message_id, m.source_message_seq, + m.source_channel_name, m.source_chat_id, m.created_at, m.updated_at + FROM memories_fts f + JOIN memories m ON m.rowid = f.rowid + WHERE memories_fts MATCH ?1 + AND m.scope_kind = ?2 + AND m.scope_key = ?3 + AND m.namespace = ?4 + ORDER BY bm25(memories_fts), m.updated_at DESC + LIMIT ?5 + ", + )?; + let rows = stmt.query_map(params![query, scope_kind, scope_key, namespace, limit], map_memory_record)?; + for row in rows { + memories.push(row?); + } + } else { + let mut stmt = conn.prepare( + " + SELECT m.id, m.scope_kind, m.scope_key, m.namespace, m.memory_key, m.content, + m.source_type, m.source_session_id, m.source_message_id, m.source_message_seq, + m.source_channel_name, m.source_chat_id, m.created_at, m.updated_at + FROM memories_fts f + JOIN memories m ON m.rowid = f.rowid + WHERE memories_fts MATCH ?1 + AND m.scope_kind = ?2 + AND m.scope_key = ?3 + ORDER BY bm25(memories_fts), m.updated_at DESC + LIMIT ?4 + ", + )?; + let rows = stmt.query_map(params![query, scope_kind, scope_key, limit], map_memory_record)?; + for row in rows { + memories.push(row?); + } + } + + Ok(memories) + } + pub fn load_messages(&self, session_id: &str) -> Result, StorageError> { let conn = self.conn.lock().expect("session db mutex poisoned"); let cutoff_seq = active_reset_cutoff(&conn, session_id)?; @@ -479,6 +802,25 @@ fn map_skill_event_record(row: &rusqlite::Row<'_>) -> rusqlite::Result) -> rusqlite::Result { + Ok(MemoryRecord { + id: row.get(0)?, + scope_kind: row.get(1)?, + scope_key: row.get(2)?, + namespace: row.get(3)?, + memory_key: row.get(4)?, + content: row.get(5)?, + source_type: row.get(6)?, + source_session_id: row.get(7)?, + source_message_id: row.get(8)?, + source_message_seq: row.get(9)?, + source_channel_name: row.get(10)?, + source_chat_id: row.get(11)?, + created_at: row.get(12)?, + updated_at: row.get(13)?, + }) +} + fn ensure_sessions_schema(conn: &Connection) -> Result<(), StorageError> { if !has_column(conn, "sessions", "reset_cutoff_seq")? { conn.execute( @@ -580,6 +922,10 @@ fn current_timestamp() -> i64 { .as_millis() as i64 } +fn quote_fts_query(query: &str) -> String { + format!("\"{}\"", query.replace('"', "\"\"")) +} + #[cfg(test)] mod tests { use super::*; @@ -797,4 +1143,101 @@ mod tests { assert_eq!(session_events[0].skill_name.as_deref(), Some("code-review")); assert_eq!(session_events[0].payload["source"], "project"); } + + #[test] + fn test_memory_roundtrip_with_source_fields() { + let store = SessionStore::in_memory().unwrap(); + + let saved = store + .put_memory(&MemoryUpsert { + scope_kind: "user".to_string(), + scope_key: "feishu:user-1".to_string(), + namespace: "profile".to_string(), + memory_key: "language".to_string(), + content: "Rust".to_string(), + source_type: "message".to_string(), + source_session_id: Some("feishu:chat-1".to_string()), + source_message_id: Some("msg-1".to_string()), + source_message_seq: Some(7), + source_channel_name: Some("feishu".to_string()), + source_chat_id: Some("chat-1".to_string()), + }) + .unwrap(); + + assert_eq!(saved.content, "Rust"); + assert_eq!(saved.source_type, "message"); + assert_eq!(saved.source_session_id.as_deref(), Some("feishu:chat-1")); + assert_eq!(saved.source_message_id.as_deref(), Some("msg-1")); + assert_eq!(saved.source_message_seq, Some(7)); + + let fetched = store + .get_memory("user", "feishu:user-1", "profile", "language") + .unwrap() + .unwrap(); + assert_eq!(fetched.id, saved.id); + assert_eq!(fetched.source_chat_id.as_deref(), Some("chat-1")); + } + + #[test] + fn test_memory_fts_tracks_upsert_and_delete() { + let store = SessionStore::in_memory().unwrap(); + + store + .put_memory(&MemoryUpsert { + scope_kind: "user".to_string(), + scope_key: "feishu:user-1".to_string(), + namespace: "preferences".to_string(), + memory_key: "editor".to_string(), + content: "Prefers rust-analyzer and cargo test output".to_string(), + source_type: "message".to_string(), + source_session_id: Some("feishu:chat-2".to_string()), + source_message_id: Some("msg-2".to_string()), + source_message_seq: Some(3), + source_channel_name: Some("feishu".to_string()), + source_chat_id: Some("chat-2".to_string()), + }) + .unwrap(); + + let hits = store + .search_memories("user", "feishu:user-1", "rust-analyzer", None, 10) + .unwrap(); + assert_eq!(hits.len(), 1); + assert_eq!(hits[0].memory_key, "editor"); + + store + .put_memory(&MemoryUpsert { + scope_kind: "user".to_string(), + scope_key: "feishu:user-1".to_string(), + namespace: "preferences".to_string(), + memory_key: "editor".to_string(), + content: "Prefers clippy diagnostics".to_string(), + source_type: "message".to_string(), + source_session_id: Some("feishu:chat-3".to_string()), + source_message_id: Some("msg-3".to_string()), + source_message_seq: Some(4), + source_channel_name: Some("feishu".to_string()), + source_chat_id: Some("chat-3".to_string()), + }) + .unwrap(); + + let old_hits = store + .search_memories("user", "feishu:user-1", "rust-analyzer", None, 10) + .unwrap(); + assert!(old_hits.is_empty()); + + let new_hits = store + .search_memories("user", "feishu:user-1", "clippy", None, 10) + .unwrap(); + assert_eq!(new_hits.len(), 1); + + let deleted = store + .delete_memory("user", "feishu:user-1", "preferences", "editor") + .unwrap(); + assert!(deleted); + + let hits_after_delete = store + .search_memories("user", "feishu:user-1", "clippy", None, 10) + .unwrap(); + assert!(hits_after_delete.is_empty()); + } } \ No newline at end of file diff --git a/src/tools/memory_manage.rs b/src/tools/memory_manage.rs new file mode 100644 index 0000000..87aac50 --- /dev/null +++ b/src/tools/memory_manage.rs @@ -0,0 +1,313 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use serde_json::json; + +use crate::storage::{MemoryRecord, MemoryUpsert, SessionStore}; +use crate::tools::traits::{Tool, ToolContext, ToolResult}; + +pub struct MemoryManageTool { + store: Arc, +} + +impl MemoryManageTool { + pub fn new(store: Arc) -> Self { + Self { store } + } +} + +#[async_trait] +impl Tool for MemoryManageTool { + fn name(&self) -> &str { + "memory_manage" + } + + fn description(&self) -> &str { + "Manage user memories stored in SQLite. Supports actions: list, get, put, update, delete. Memories are scoped to the current channel and sender, and record the originating session/message when available." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": ["list", "get", "put", "update", "delete"], + "description": "Management action to perform" + }, + "namespace": { + "type": "string", + "description": "Memory namespace, such as profile, preferences, or tasks" + }, + "key": { + "type": "string", + "description": "Memory key within the namespace" + }, + "content": { + "type": "string", + "description": "Memory content for put/update" + }, + "limit": { + "type": "integer", + "description": "Maximum number of memories to list", + "minimum": 1, + "default": 20 + } + }, + "required": ["action"] + }) + } + + async fn execute(&self, _args: serde_json::Value) -> anyhow::Result { + Ok(error_result("memory_manage requires tool context")) + } + + async fn execute_with_context( + &self, + context: &ToolContext, + args: serde_json::Value, + ) -> anyhow::Result { + let action = match args.get("action").and_then(|value| value.as_str()) { + Some(action) => action, + None => return Ok(error_result("Missing required parameter: action")), + }; + + let scope_key = match scope_key_from_context(context) { + Ok(scope_key) => scope_key, + Err(result) => return Ok(result), + }; + + let namespace = args.get("namespace").and_then(|value| value.as_str()); + let key = args.get("key").and_then(|value| value.as_str()); + + let payload = match action { + "list" => { + let limit = args + .get("limit") + .and_then(|value| value.as_u64()) + .unwrap_or(20) as usize; + let memories = self + .store + .list_memories("user", &scope_key, namespace, limit)?; + json!({ + "count": memories.len(), + "memories": memories.into_iter().map(memory_to_json).collect::>() + }) + } + "get" => { + let namespace = match namespace { + Some(namespace) => namespace, + None => return Ok(error_result("Missing required parameter: namespace")), + }; + let key = match key { + Some(key) => key, + None => return Ok(error_result("Missing required parameter: key")), + }; + + match self.store.get_memory("user", &scope_key, namespace, key)? { + Some(memory) => memory_to_json(memory), + None => return Ok(error_result(&format!("memory '{}.{}' not found", namespace, key))), + } + } + "put" => { + let input = match build_memory_upsert(context, &scope_key, &args, true) { + Ok(input) => input, + Err(result) => return Ok(result), + }; + memory_to_json(self.store.put_memory(&input)?) + } + "update" => { + let input = match build_memory_upsert(context, &scope_key, &args, false) { + Ok(input) => input, + Err(result) => return Ok(result), + }; + + match self.store.update_memory(&input)? { + Some(memory) => memory_to_json(memory), + None => { + return Ok(error_result(&format!( + "memory '{}.{}' not found", + input.namespace, input.memory_key + ))) + } + } + } + "delete" => { + let namespace = match namespace { + Some(namespace) => namespace, + None => return Ok(error_result("Missing required parameter: namespace")), + }; + let key = match key { + Some(key) => key, + None => return Ok(error_result("Missing required parameter: key")), + }; + + let deleted = self.store.delete_memory("user", &scope_key, namespace, key)?; + if !deleted { + return Ok(error_result(&format!("memory '{}.{}' not found", namespace, key))); + } + + json!({ + "status": "deleted", + "namespace": namespace, + "key": key, + }) + } + _ => return Ok(error_result("Unsupported action")), + }; + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&payload)?, + error: None, + }) + } +} + +fn build_memory_upsert( + context: &ToolContext, + scope_key: &str, + args: &serde_json::Value, + allow_put: bool, +) -> Result { + let namespace = match args.get("namespace").and_then(|value| value.as_str()) { + Some(namespace) => namespace, + None => return Err(error_result("Missing required parameter: namespace")), + }; + let key = match args.get("key").and_then(|value| value.as_str()) { + Some(key) => key, + None => return Err(error_result("Missing required parameter: key")), + }; + let content = match args.get("content").and_then(|value| value.as_str()) { + Some(content) => content, + None => return Err(error_result("Missing required parameter: content")), + }; + + let source_type = if context.message_id.is_some() { + "message" + } else if allow_put { + "manual" + } else { + "session" + }; + + Ok(MemoryUpsert { + scope_kind: "user".to_string(), + scope_key: scope_key.to_string(), + namespace: namespace.to_string(), + memory_key: key.to_string(), + content: content.to_string(), + source_type: source_type.to_string(), + source_session_id: context.session_id.clone(), + source_message_id: context.message_id.clone(), + source_message_seq: context.message_seq, + source_channel_name: context.channel_name.clone(), + source_chat_id: context.chat_id.clone(), + }) +} + +fn scope_key_from_context(context: &ToolContext) -> Result { + let channel_name = context + .channel_name + .as_deref() + .ok_or_else(|| error_result("memory_manage requires channel_name in tool context"))?; + let sender_id = context + .sender_id + .as_deref() + .ok_or_else(|| error_result("memory_manage requires sender_id in tool context"))?; + Ok(format!("{}:{}", channel_name, sender_id)) +} + +fn memory_to_json(memory: MemoryRecord) -> serde_json::Value { + json!({ + "id": memory.id, + "scope_kind": memory.scope_kind, + "scope_key": memory.scope_key, + "namespace": memory.namespace, + "key": memory.memory_key, + "content": memory.content, + "source_type": memory.source_type, + "source_session_id": memory.source_session_id, + "source_message_id": memory.source_message_id, + "source_message_seq": memory.source_message_seq, + "source_channel_name": memory.source_channel_name, + "source_chat_id": memory.source_chat_id, + "created_at": memory.created_at, + "updated_at": memory.updated_at, + }) +} + +fn error_result(message: &str) -> ToolResult { + ToolResult { + success: false, + output: String::new(), + error: Some(message.to_string()), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_memory_manage_put_and_get() { + let store = Arc::new(SessionStore::in_memory().unwrap()); + let tool = MemoryManageTool::new(store); + let context = ToolContext { + channel_name: Some("feishu".to_string()), + sender_id: Some("user-1".to_string()), + chat_id: Some("chat-1".to_string()), + session_id: Some("feishu:chat-1".to_string()), + message_id: Some("msg-1".to_string()), + message_seq: Some(1), + }; + + let put = tool + .execute_with_context( + &context, + json!({ + "action": "put", + "namespace": "profile", + "key": "language", + "content": "Rust" + }), + ) + .await + .unwrap(); + assert!(put.success); + + let get = tool + .execute_with_context( + &context, + json!({ + "action": "get", + "namespace": "profile", + "key": "language" + }), + ) + .await + .unwrap(); + assert!(get.success); + assert!(get.output.contains("Rust")); + assert!(get.output.contains("msg-1")); + } + + #[tokio::test] + async fn test_memory_manage_requires_context() { + let store = Arc::new(SessionStore::in_memory().unwrap()); + let tool = MemoryManageTool::new(store); + + let result = tool + .execute_with_context( + &ToolContext::default(), + json!({ + "action": "list" + }), + ) + .await + .unwrap(); + + assert!(!result.success); + assert!(result.error.unwrap().contains("channel_name")); + } +} \ No newline at end of file diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 55deb52..624da9d 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -4,6 +4,7 @@ pub mod file_edit; pub mod file_read; pub mod file_write; pub mod http_request; +pub mod memory_manage; pub mod registry; pub mod schema; pub mod skill_manage; @@ -16,8 +17,9 @@ pub use file_edit::FileEditTool; pub use file_read::FileReadTool; pub use file_write::FileWriteTool; pub use http_request::HttpRequestTool; +pub use memory_manage::MemoryManageTool; pub use registry::ToolRegistry; pub use schema::{CleaningStrategy, SchemaCleanr}; pub use skill_manage::{SkillListTool, SkillManageTool}; -pub use traits::{Tool, ToolResult}; +pub use traits::{Tool, ToolContext, ToolResult}; pub use web_fetch::WebFetchTool; diff --git a/src/tools/traits.rs b/src/tools/traits.rs index f3ffdc4..b00ca14 100644 --- a/src/tools/traits.rs +++ b/src/tools/traits.rs @@ -7,6 +7,16 @@ pub struct ToolResult { pub error: Option, } +#[derive(Debug, Clone, Default)] +pub struct ToolContext { + pub channel_name: Option, + pub sender_id: Option, + pub chat_id: Option, + pub session_id: Option, + pub message_id: Option, + pub message_seq: Option, +} + #[async_trait] pub trait Tool: Send + Sync + 'static { fn name(&self) -> &str; @@ -14,6 +24,14 @@ pub trait Tool: Send + Sync + 'static { fn parameters_schema(&self) -> serde_json::Value; async fn execute(&self, args: serde_json::Value) -> anyhow::Result; + async fn execute_with_context( + &self, + _context: &ToolContext, + args: serde_json::Value, + ) -> anyhow::Result { + self.execute(args).await + } + /// Whether this tool is side-effect free and safe to parallelize. fn read_only(&self) -> bool { false diff --git a/tests/test_request_format.rs b/tests/test_request_format.rs index d73ce37..eb6ea16 100644 --- a/tests/test_request_format.rs +++ b/tests/test_request_format.rs @@ -117,3 +117,55 @@ fn test_clear_history_with_session_id_serialization() { assert!(json.contains(r#""type":"clear_history""#)); assert!(json.contains(r#""session_id":"session-1""#)); } + +#[test] +fn test_tool_call_outbound_serialization() { + let msg = WsOutbound::ToolCall { + id: "msg-1".to_string(), + tool_call_id: "call-1".to_string(), + tool_name: "calculator".to_string(), + arguments: serde_json::json!({"expression": "1 + 1"}), + content: "调用工具: calculator".to_string(), + role: "assistant".to_string(), + }; + + let json = serde_json::to_string(&msg).unwrap(); + assert!(json.contains(r#""type":"tool_call""#)); + assert!(json.contains(r#""tool_name":"calculator""#)); + assert!(json.contains(r#""expression":"1 + 1""#)); + + let decoded: WsOutbound = serde_json::from_str(&json).unwrap(); + match decoded { + WsOutbound::ToolCall { tool_call_id, tool_name, arguments, .. } => { + assert_eq!(tool_call_id, "call-1"); + assert_eq!(tool_name, "calculator"); + assert_eq!(arguments["expression"], "1 + 1"); + } + other => panic!("unexpected decoded variant: {:?}", other), + } +} + +#[test] +fn test_tool_result_outbound_serialization() { + let msg = WsOutbound::ToolResult { + id: "msg-2".to_string(), + tool_call_id: "call-1".to_string(), + tool_name: "calculator".to_string(), + content: "工具结果: calculator\n\n2".to_string(), + role: "tool".to_string(), + }; + + let json = serde_json::to_string(&msg).unwrap(); + assert!(json.contains(r#""type":"tool_result""#)); + assert!(json.contains(r#""tool_name":"calculator""#)); + + let decoded: WsOutbound = serde_json::from_str(&json).unwrap(); + match decoded { + WsOutbound::ToolResult { tool_call_id, tool_name, content, .. } => { + assert_eq!(tool_call_id, "call-1"); + assert_eq!(tool_name, "calculator"); + assert!(content.contains('2')); + } + other => panic!("unexpected decoded variant: {:?}", other), + } +}