This commit is contained in:
ooodc 2026-04-22 06:57:22 +08:00
parent a0fe7c57bd
commit d35e89a44c
12 changed files with 1195 additions and 37 deletions

View File

@ -7,7 +7,7 @@ use crate::observability::{
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall};
use crate::skills::SkillRuntime;
use crate::storage::SessionStore;
use crate::tools::ToolRegistry;
use crate::tools::{ToolContext, ToolRegistry};
use std::collections::VecDeque;
use std::hash::{Hash, Hasher};
use std::io::Read;
@ -225,6 +225,7 @@ pub struct AgentLoop {
skills: Arc<SkillRuntime>,
skill_event_store: Option<Arc<SessionStore>>,
skill_event_session_id: Option<String>,
tool_context: ToolContext,
observer: Option<Arc<dyn Observer>>,
max_iterations: usize,
}
@ -247,6 +248,7 @@ impl AgentLoop {
skills: Arc::new(SkillRuntime::default()),
skill_event_store: None,
skill_event_session_id: None,
tool_context: ToolContext::default(),
observer: None,
max_iterations,
})
@ -263,6 +265,7 @@ impl AgentLoop {
skills: Arc::new(SkillRuntime::default()),
skill_event_store: None,
skill_event_session_id: None,
tool_context: ToolContext::default(),
observer: None,
max_iterations,
})
@ -283,6 +286,7 @@ impl AgentLoop {
skills,
skill_event_store: None,
skill_event_session_id: None,
tool_context: ToolContext::default(),
observer: None,
max_iterations,
})
@ -294,6 +298,11 @@ impl AgentLoop {
self
}
pub fn with_tool_context(mut self, context: ToolContext) -> Self {
self.tool_context = context;
self
}
/// Set an observer for tracking events.
pub fn with_observer(mut self, observer: Arc<dyn Observer>) -> Self {
self.observer = Some(observer);
@ -622,7 +631,7 @@ impl AgentLoop {
}
};
match tool.execute(tool_call.arguments.clone()).await {
match tool.execute_with_context(&self.tool_context, tool_call.arguments.clone()).await {
Ok(result) => {
if result.success {
ToolExecutionOutcome::success(result.output)

View File

@ -191,12 +191,165 @@ pub struct OutboundMessage {
pub reply_to: Option<String>,
pub media: Vec<MediaItem>,
pub metadata: HashMap<String, String>,
pub event_kind: OutboundEventKind,
pub role: String,
pub tool_call_id: Option<String>,
pub tool_name: Option<String>,
pub tool_arguments: Option<serde_json::Value>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum OutboundEventKind {
AssistantResponse,
ToolCall,
ToolResult,
}
impl OutboundMessage {
pub fn is_stream_delta(&self) -> bool {
self.metadata.get("_stream_delta").is_some()
}
pub fn assistant(
channel: impl Into<String>,
chat_id: impl Into<String>,
content: impl Into<String>,
reply_to: Option<String>,
metadata: HashMap<String, String>,
) -> Self {
Self {
channel: channel.into(),
chat_id: chat_id.into(),
content: content.into(),
reply_to,
media: Vec::new(),
metadata,
event_kind: OutboundEventKind::AssistantResponse,
role: "assistant".to_string(),
tool_call_id: None,
tool_name: None,
tool_arguments: None,
}
}
pub fn tool_call(
channel: impl Into<String>,
chat_id: impl Into<String>,
message_id: impl Into<String>,
tool_name: impl Into<String>,
tool_arguments: serde_json::Value,
reply_to: Option<String>,
metadata: HashMap<String, String>,
) -> Self {
let tool_name = tool_name.into();
let content = format_tool_call_content(&tool_name, &tool_arguments);
Self {
channel: channel.into(),
chat_id: chat_id.into(),
content,
reply_to,
media: Vec::new(),
metadata,
event_kind: OutboundEventKind::ToolCall,
role: "assistant".to_string(),
tool_call_id: Some(message_id.into()),
tool_name: Some(tool_name),
tool_arguments: Some(tool_arguments),
}
}
pub fn tool_result(
channel: impl Into<String>,
chat_id: impl Into<String>,
tool_call_id: impl Into<String>,
tool_name: impl Into<String>,
content: impl Into<String>,
reply_to: Option<String>,
metadata: HashMap<String, String>,
) -> Self {
let tool_name = tool_name.into();
let raw_content = content.into();
let content = format_tool_result_content(&tool_name, &raw_content);
Self {
channel: channel.into(),
chat_id: chat_id.into(),
content,
reply_to,
media: Vec::new(),
metadata,
event_kind: OutboundEventKind::ToolResult,
role: "tool".to_string(),
tool_call_id: Some(tool_call_id.into()),
tool_name: Some(tool_name),
tool_arguments: None,
}
}
pub fn from_chat_message(
channel: &str,
chat_id: &str,
reply_to: Option<String>,
metadata: &HashMap<String, String>,
message: &ChatMessage,
) -> Vec<Self> {
match message.role.as_str() {
"assistant" => {
if let Some(tool_calls) = &message.tool_calls {
tool_calls
.iter()
.map(|tool_call| {
Self::tool_call(
channel.to_string(),
chat_id.to_string(),
tool_call.id.clone(),
tool_call.name.clone(),
tool_call.arguments.clone(),
reply_to.clone(),
metadata.clone(),
)
})
.collect()
} else {
vec![Self::assistant(
channel.to_string(),
chat_id.to_string(),
message.content.clone(),
reply_to,
metadata.clone(),
)]
}
}
"tool" => vec![Self::tool_result(
channel.to_string(),
chat_id.to_string(),
message.tool_call_id.clone().unwrap_or_else(|| message.id.clone()),
message.tool_name.clone().unwrap_or_else(|| "tool".to_string()),
message.content.clone(),
reply_to,
metadata.clone(),
)],
_ => Vec::new(),
}
}
}
fn format_tool_call_content(tool_name: &str, tool_arguments: &serde_json::Value) -> String {
format!(
"调用工具: {}\n\n输入参数:\n{}",
tool_name,
format_json_value(tool_arguments),
)
}
fn format_tool_result_content(tool_name: &str, content: &str) -> String {
format!("工具结果: {}\n\n{}", tool_name, content)
}
fn format_json_value(value: &serde_json::Value) -> String {
match value {
serde_json::Value::Object(map) if map.is_empty() => "{}".to_string(),
other => serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()),
}
}
// ============================================================================
@ -209,3 +362,63 @@ fn current_timestamp() -> i64 {
.unwrap()
.as_millis() as i64
}
#[cfg(test)]
mod tests {
use super::{ChatMessage, OutboundEventKind, OutboundMessage};
use crate::providers::ToolCall;
use serde_json::json;
use std::collections::HashMap;
#[test]
fn test_from_chat_message_expands_tool_calls() {
let message = ChatMessage::assistant_with_tool_calls(
"",
vec![
ToolCall {
id: "call-1".to_string(),
name: "calculator".to_string(),
arguments: json!({"expression": "1 + 1"}),
},
ToolCall {
id: "call-2".to_string(),
name: "file_read".to_string(),
arguments: json!({"path": "README.md"}),
},
],
);
let outbound = OutboundMessage::from_chat_message(
"feishu",
"chat-1",
None,
&HashMap::new(),
&message,
);
assert_eq!(outbound.len(), 2);
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolCall);
assert_eq!(outbound[0].tool_name.as_deref(), Some("calculator"));
assert_eq!(outbound[0].tool_arguments.as_ref().unwrap()["expression"], "1 + 1");
assert_eq!(outbound[1].tool_name.as_deref(), Some("file_read"));
}
#[test]
fn test_from_chat_message_maps_tool_result() {
let message = ChatMessage::tool("call-9", "calculator", "2");
let outbound = OutboundMessage::from_chat_message(
"feishu",
"chat-1",
None,
&HashMap::new(),
&message,
);
assert_eq!(outbound.len(), 1);
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolResult);
assert_eq!(outbound[0].tool_call_id.as_deref(), Some("call-9"));
assert_eq!(outbound[0].tool_name.as_deref(), Some("calculator"));
assert!(outbound[0].content.contains("工具结果: calculator"));
}
}

View File

@ -40,6 +40,10 @@ fn parse_message(raw: &str) -> Result<WsOutbound, serde_json::Error> {
serde_json::from_str(raw)
}
fn format_json(value: &serde_json::Value) -> String {
serde_json::to_string_pretty(value).unwrap_or_else(|_| value.to_string())
}
pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
let (ws_stream, _) = connect_async(gateway_url).await?;
tracing::info!(url = %gateway_url, "Connected to gateway");
@ -63,6 +67,12 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
WsOutbound::AssistantResponse { content, .. } => {
input.write_response(&content).await?;
}
WsOutbound::ToolCall { tool_name, arguments, .. } => {
input.write_output(&format!("Tool call: {}\n{}\n", tool_name, format_json(&arguments))).await?;
}
WsOutbound::ToolResult { tool_name, content, .. } => {
input.write_output(&format!("Tool result: {}\n{}\n", tool_name, content)).await?;
}
WsOutbound::Error { message, .. } => {
input.write_output(&format!("Error: {}", message)).await?;
}

View File

@ -81,20 +81,15 @@ impl GatewayState {
&inbound.content,
inbound.media,
).await {
Ok(response_content) => {
Ok(outbound_messages) => {
// Forward channel-specific metadata from inbound to outbound.
// This allows channels to propagate context (e.g. feishu message_id for reaction cleanup)
// without gateway needing channel-specific code.
let outbound = crate::bus::OutboundMessage {
channel: inbound.channel.clone(),
chat_id: inbound.chat_id.clone(),
content: response_content,
reply_to: None,
media: vec![],
metadata: inbound.forwarded_metadata,
};
if let Err(e) = bus_for_inbound.publish_outbound(outbound).await {
tracing::error!(error = %e, "Failed to publish outbound");
for mut outbound in outbound_messages {
outbound.metadata.extend(inbound.forwarded_metadata.clone());
if let Err(e) = bus_for_inbound.publish_outbound(outbound).await {
tracing::error!(error = %e, "Failed to publish outbound");
}
}
}
Err(e) => {

View File

@ -3,7 +3,7 @@ use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, mpsc};
use uuid::Uuid;
use crate::bus::ChatMessage;
use crate::bus::{ChatMessage, OutboundMessage};
use crate::config::LLMProviderConfig;
use crate::agent::{AgentLoop, AgentError, ContextCompressor};
use crate::protocol::WsOutbound;
@ -11,7 +11,8 @@ use crate::skills::SkillRuntime;
use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
use crate::tools::{
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
HttpRequestTool, SkillListTool, SkillManageTool, ToolRegistry, WebFetchTool,
HttpRequestTool, MemoryManageTool, SkillListTool, SkillManageTool, ToolContext, ToolRegistry,
WebFetchTool,
};
/// Session 按 channel 隔离,每个 channel 一个 Session
@ -197,13 +198,30 @@ impl Session {
}
/// 创建一个临时的 AgentLoop 实例来处理消息
pub fn create_agent(&self, chat_id: &str) -> Result<AgentLoop, AgentError> {
pub fn create_agent(
&self,
chat_id: &str,
sender_id: Option<&str>,
message_id: Option<&str>,
) -> Result<AgentLoop, AgentError> {
let session_id = self.persistent_session_id(chat_id);
AgentLoop::with_tools_and_skills(
self.provider_config.clone(),
self.tools.clone(),
self.skills.clone(),
)
.map(|agent| agent.with_skill_event_store(self.store.clone(), self.persistent_session_id(chat_id)))
.map(|agent| {
agent
.with_skill_event_store(self.store.clone(), session_id.clone())
.with_tool_context(ToolContext {
channel_name: Some(self.channel_name.clone()),
sender_id: sender_id.map(str::to_string),
chat_id: Some(chat_id.to_string()),
session_id: Some(session_id),
message_id: message_id.map(str::to_string),
message_seq: None,
})
})
}
}
@ -223,12 +241,13 @@ struct SessionManagerInner {
session_ttl: Duration,
}
fn default_tools(skills: Arc<SkillRuntime>) -> ToolRegistry {
fn default_tools(skills: Arc<SkillRuntime>, store: Arc<SessionStore>) -> ToolRegistry {
let mut registry = ToolRegistry::new();
registry.register(CalculatorTool::new());
registry.register(FileReadTool::new());
registry.register(FileWriteTool::new());
registry.register(FileEditTool::new());
registry.register(MemoryManageTool::new(store));
registry.register(SkillListTool::new(skills.clone()));
registry.register(SkillManageTool::new(skills));
registry.register(BashTool::new());
@ -290,7 +309,7 @@ impl SessionManager {
session_ttl: Duration::from_secs(session_ttl_hours * 3600),
})),
provider_config,
tools: Arc::new(default_tools(skills.clone())),
tools: Arc::new(default_tools(skills.clone(), store.clone())),
skills,
store,
})
@ -414,11 +433,11 @@ impl SessionManager {
pub async fn handle_message(
&self,
channel_name: &str,
_sender_id: &str,
sender_id: &str,
chat_id: &str,
content: &str,
media: Vec<crate::bus::MediaItem>,
) -> Result<String, AgentError> {
) -> Result<Vec<OutboundMessage>, AgentError> {
#[cfg(debug_assertions)]
{
tracing::debug!(
@ -453,7 +472,13 @@ impl SessionManager {
session_guard.ensure_chat_loaded(chat_id)?;
if let Some(command_response) = handle_in_chat_command(&mut session_guard, chat_id, content)? {
return Ok(command_response);
return Ok(vec![OutboundMessage::assistant(
channel_name.to_string(),
chat_id.to_string(),
command_response,
None,
HashMap::new(),
)]);
}
// 添加用户消息到历史
@ -463,6 +488,7 @@ impl SessionManager {
tracing::debug!(media_count = %media.len(), media_refs = ?media_refs, "Adding user message with media");
}
let user_message = session_guard.create_user_message(content, media_refs);
let user_message_id = user_message.id.clone();
session_guard.append_persisted_message(chat_id, user_message)?;
// 获取完整历史
@ -476,24 +502,36 @@ impl SessionManager {
session_guard.record_skill_offer(chat_id)?;
// 创建 agent 并处理
let agent = session_guard.create_agent(chat_id)?;
let agent = session_guard.create_agent(chat_id, Some(sender_id), Some(&user_message_id))?;
let result = agent.process(history).await?;
// 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复
session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?;
result.final_response
result
.emitted_messages
.iter()
.flat_map(|message| {
OutboundMessage::from_chat_message(
channel_name,
chat_id,
None,
&HashMap::new(),
message,
)
})
.collect::<Vec<_>>()
};
#[cfg(debug_assertions)]
tracing::debug!(
channel = %channel_name,
chat_id = %chat_id,
response_len = response.content.len(),
"Agent response received"
outbound_count = response.len(),
"Agent response sequence received"
);
Ok(response.content)
Ok(response)
}
/// 清除指定 session 的所有历史
@ -541,7 +579,7 @@ mod tests {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(default_tools(skills.clone()));
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
let mut session = Session::new(
"feishu".to_string(),
test_provider_config(),

View File

@ -4,6 +4,7 @@ use axum::extract::State;
use axum::response::Response;
use futures_util::{SinkExt, StreamExt};
use tokio::sync::{mpsc, Mutex};
use crate::bus::ChatMessage;
use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound};
use super::{GatewayState, session::{Session, handle_in_chat_command}};
@ -140,6 +141,52 @@ fn to_session_summary(record: crate::storage::SessionRecord) -> SessionSummary {
}
}
fn format_tool_arguments(arguments: &serde_json::Value) -> String {
serde_json::to_string_pretty(arguments).unwrap_or_else(|_| arguments.to_string())
}
fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
match message.role.as_str() {
"assistant" => {
if let Some(tool_calls) = &message.tool_calls {
tool_calls
.iter()
.map(|tool_call| WsOutbound::ToolCall {
id: message.id.clone(),
tool_call_id: tool_call.id.clone(),
tool_name: tool_call.name.clone(),
arguments: tool_call.arguments.clone(),
content: format!(
"调用工具: {}\n\n输入参数:\n{}",
tool_call.name,
format_tool_arguments(&tool_call.arguments),
),
role: message.role.clone(),
})
.collect()
} else {
vec![WsOutbound::AssistantResponse {
id: message.id.clone(),
content: message.content.clone(),
role: message.role.clone(),
}]
}
}
"tool" => vec![WsOutbound::ToolResult {
id: message.id.clone(),
tool_call_id: message.tool_call_id.clone().unwrap_or_else(|| message.id.clone()),
tool_name: message.tool_name.clone().unwrap_or_else(|| "tool".to_string()),
content: format!(
"工具结果: {}\n\n{}",
message.tool_name.clone().unwrap_or_else(|| "tool".to_string()),
message.content,
),
role: message.role.clone(),
}],
_ => Vec::new(),
}
}
async fn handle_inbound(
state: &Arc<GatewayState>,
session: &Arc<Mutex<Session>>,
@ -166,6 +213,7 @@ async fn handle_inbound(
}
let user_message = session_guard.create_user_message(&content, Vec::new());
let user_message_id = user_message.id.clone();
session_guard.append_persisted_message(&chat_id, user_message)?;
let raw_history = session_guard.get_or_create_history(&chat_id).clone();
@ -183,17 +231,17 @@ async fn handle_inbound(
session_guard.record_skill_offer(&chat_id)?;
let agent = session_guard.create_agent(&chat_id)?;
let agent = session_guard.create_agent(&chat_id, None, Some(&user_message_id))?;
match agent.process(history).await {
Ok(result) => {
session_guard.append_persisted_messages(&chat_id, result.emitted_messages.clone())?;
let _ = session_guard
.send(WsOutbound::AssistantResponse {
id: result.final_response.id,
content: result.final_response.content,
role: result.final_response.role,
})
.await;
for outbound in result
.emitted_messages
.iter()
.flat_map(ws_outbound_from_chat_message)
{
let _ = session_guard.send(outbound).await;
}
}
Err(error) => {
tracing::error!(chat_id = %chat_id, error = %error, "Agent process error");

View File

@ -71,6 +71,23 @@ pub enum WsInbound {
pub enum WsOutbound {
#[serde(rename = "assistant_response")]
AssistantResponse { id: String, content: String, role: String },
#[serde(rename = "tool_call")]
ToolCall {
id: String,
tool_call_id: String,
tool_name: String,
arguments: serde_json::Value,
content: String,
role: String,
},
#[serde(rename = "tool_result")]
ToolResult {
id: String,
tool_call_id: String,
tool_name: String,
content: String,
role: String,
},
#[serde(rename = "error")]
Error { code: String, message: String },
#[serde(rename = "session_established")]

View File

@ -42,6 +42,39 @@ pub struct SessionRecord {
pub reset_cutoff_seq: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryRecord {
pub id: String,
pub scope_kind: String,
pub scope_key: String,
pub namespace: String,
pub memory_key: String,
pub content: String,
pub source_type: String,
pub source_session_id: Option<String>,
pub source_message_id: Option<String>,
pub source_message_seq: Option<i64>,
pub source_channel_name: Option<String>,
pub source_chat_id: Option<String>,
pub created_at: i64,
pub updated_at: i64,
}
#[derive(Debug, Clone)]
pub struct MemoryUpsert {
pub scope_kind: String,
pub scope_key: String,
pub namespace: String,
pub memory_key: String,
pub content: String,
pub source_type: String,
pub source_session_id: Option<String>,
pub source_message_id: Option<String>,
pub source_message_seq: Option<i64>,
pub source_channel_name: Option<String>,
pub source_chat_id: Option<String>,
}
#[derive(Clone)]
pub struct SessionStore {
conn: Arc<Mutex<Connection>>,
@ -122,6 +155,56 @@ impl SessionStore {
ON skill_events(session_id, created_at DESC);
CREATE INDEX IF NOT EXISTS idx_skill_events_type_created
ON skill_events(event_type, created_at DESC);
CREATE TABLE IF NOT EXISTS memories (
id TEXT PRIMARY KEY,
scope_kind TEXT NOT NULL,
scope_key TEXT NOT NULL,
namespace TEXT NOT NULL,
memory_key TEXT NOT NULL,
content TEXT NOT NULL,
source_type TEXT NOT NULL,
source_session_id TEXT,
source_message_id TEXT,
source_message_seq INTEGER,
source_channel_name TEXT,
source_chat_id TEXT,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL,
UNIQUE(scope_kind, scope_key, namespace, memory_key)
);
CREATE INDEX IF NOT EXISTS idx_memories_scope_updated
ON memories(scope_kind, scope_key, updated_at DESC);
CREATE INDEX IF NOT EXISTS idx_memories_scope_namespace_updated
ON memories(scope_kind, scope_key, namespace, updated_at DESC);
CREATE INDEX IF NOT EXISTS idx_memories_source_session
ON memories(source_session_id, updated_at DESC);
CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5(
namespace,
memory_key,
content,
content='memories',
content_rowid='rowid'
);
CREATE TRIGGER IF NOT EXISTS memories_ai AFTER INSERT ON memories BEGIN
INSERT INTO memories_fts(rowid, namespace, memory_key, content)
VALUES (new.rowid, new.namespace, new.memory_key, new.content);
END;
CREATE TRIGGER IF NOT EXISTS memories_ad AFTER DELETE ON memories BEGIN
INSERT INTO memories_fts(memories_fts, rowid, namespace, memory_key, content)
VALUES ('delete', old.rowid, old.namespace, old.memory_key, old.content);
END;
CREATE TRIGGER IF NOT EXISTS memories_au AFTER UPDATE ON memories BEGIN
INSERT INTO memories_fts(memories_fts, rowid, namespace, memory_key, content)
VALUES ('delete', old.rowid, old.namespace, old.memory_key, old.content);
INSERT INTO memories_fts(rowid, namespace, memory_key, content)
VALUES (new.rowid, new.namespace, new.memory_key, new.content);
END;
",
)?;
@ -417,6 +500,246 @@ impl SessionStore {
Ok(events)
}
pub fn put_memory(&self, input: &MemoryUpsert) -> Result<MemoryRecord, StorageError> {
let now = current_timestamp();
let conn = self.conn.lock().expect("session db mutex poisoned");
let tx = conn.unchecked_transaction()?;
let existing: Option<(String, i64)> = tx
.query_row(
"
SELECT id, created_at
FROM memories
WHERE scope_kind = ?1 AND scope_key = ?2 AND namespace = ?3 AND memory_key = ?4
",
params![
input.scope_kind,
input.scope_key,
input.namespace,
input.memory_key,
],
|row| Ok((row.get(0)?, row.get(1)?)),
)
.optional()?;
let (id, created_at) = existing
.unwrap_or_else(|| (uuid::Uuid::new_v4().to_string(), now));
tx.execute(
"
INSERT INTO memories (
id, scope_kind, scope_key, namespace, memory_key, content,
source_type, source_session_id, source_message_id, source_message_seq,
source_channel_name, source_chat_id, created_at, updated_at
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14)
ON CONFLICT(scope_kind, scope_key, namespace, memory_key) DO UPDATE SET
content = excluded.content,
source_type = excluded.source_type,
source_session_id = excluded.source_session_id,
source_message_id = excluded.source_message_id,
source_message_seq = excluded.source_message_seq,
source_channel_name = excluded.source_channel_name,
source_chat_id = excluded.source_chat_id,
updated_at = excluded.updated_at
",
params![
id,
input.scope_kind,
input.scope_key,
input.namespace,
input.memory_key,
input.content,
input.source_type,
input.source_session_id,
input.source_message_id,
input.source_message_seq,
input.source_channel_name,
input.source_chat_id,
created_at,
now,
],
)?;
tx.commit()?;
drop(conn);
self.get_memory(
&input.scope_kind,
&input.scope_key,
&input.namespace,
&input.memory_key,
)?
.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
}
pub fn get_memory(
&self,
scope_kind: &str,
scope_key: &str,
namespace: &str,
memory_key: &str,
) -> Result<Option<MemoryRecord>, StorageError> {
let conn = self.conn.lock().expect("session db mutex poisoned");
let mut stmt = conn.prepare(
"
SELECT id, scope_kind, scope_key, namespace, memory_key, content,
source_type, source_session_id, source_message_id, source_message_seq,
source_channel_name, source_chat_id, created_at, updated_at
FROM memories
WHERE scope_kind = ?1 AND scope_key = ?2 AND namespace = ?3 AND memory_key = ?4
",
)?;
stmt.query_row(
params![scope_kind, scope_key, namespace, memory_key],
map_memory_record,
)
.optional()
.map_err(StorageError::from)
}
pub fn list_memories(
&self,
scope_kind: &str,
scope_key: &str,
namespace: Option<&str>,
limit: usize,
) -> Result<Vec<MemoryRecord>, StorageError> {
let conn = self.conn.lock().expect("session db mutex poisoned");
let limit = limit.max(1) as i64;
let mut memories = Vec::new();
if let Some(namespace) = namespace {
let mut stmt = conn.prepare(
"
SELECT id, scope_kind, scope_key, namespace, memory_key, content,
source_type, source_session_id, source_message_id, source_message_seq,
source_channel_name, source_chat_id, created_at, updated_at
FROM memories
WHERE scope_kind = ?1 AND scope_key = ?2 AND namespace = ?3
ORDER BY updated_at DESC, created_at DESC
LIMIT ?4
",
)?;
let rows = stmt.query_map(params![scope_kind, scope_key, namespace, limit], map_memory_record)?;
for row in rows {
memories.push(row?);
}
} else {
let mut stmt = conn.prepare(
"
SELECT id, scope_kind, scope_key, namespace, memory_key, content,
source_type, source_session_id, source_message_id, source_message_seq,
source_channel_name, source_chat_id, created_at, updated_at
FROM memories
WHERE scope_kind = ?1 AND scope_key = ?2
ORDER BY updated_at DESC, created_at DESC
LIMIT ?3
",
)?;
let rows = stmt.query_map(params![scope_kind, scope_key, limit], map_memory_record)?;
for row in rows {
memories.push(row?);
}
}
Ok(memories)
}
pub fn update_memory(
&self,
input: &MemoryUpsert,
) -> Result<Option<MemoryRecord>, StorageError> {
if self
.get_memory(
&input.scope_kind,
&input.scope_key,
&input.namespace,
&input.memory_key,
)?
.is_none()
{
return Ok(None);
}
self.put_memory(input).map(Some)
}
pub fn delete_memory(
&self,
scope_kind: &str,
scope_key: &str,
namespace: &str,
memory_key: &str,
) -> Result<bool, StorageError> {
let conn = self.conn.lock().expect("session db mutex poisoned");
let changed = conn.execute(
"
DELETE FROM memories
WHERE scope_kind = ?1 AND scope_key = ?2 AND namespace = ?3 AND memory_key = ?4
",
params![scope_kind, scope_key, namespace, memory_key],
)?;
Ok(changed > 0)
}
pub fn search_memories(
&self,
scope_kind: &str,
scope_key: &str,
query: &str,
namespace: Option<&str>,
limit: usize,
) -> Result<Vec<MemoryRecord>, StorageError> {
let conn = self.conn.lock().expect("session db mutex poisoned");
let limit = limit.max(1) as i64;
let query = quote_fts_query(query);
let mut memories = Vec::new();
if let Some(namespace) = namespace {
let mut stmt = conn.prepare(
"
SELECT m.id, m.scope_kind, m.scope_key, m.namespace, m.memory_key, m.content,
m.source_type, m.source_session_id, m.source_message_id, m.source_message_seq,
m.source_channel_name, m.source_chat_id, m.created_at, m.updated_at
FROM memories_fts f
JOIN memories m ON m.rowid = f.rowid
WHERE memories_fts MATCH ?1
AND m.scope_kind = ?2
AND m.scope_key = ?3
AND m.namespace = ?4
ORDER BY bm25(memories_fts), m.updated_at DESC
LIMIT ?5
",
)?;
let rows = stmt.query_map(params![query, scope_kind, scope_key, namespace, limit], map_memory_record)?;
for row in rows {
memories.push(row?);
}
} else {
let mut stmt = conn.prepare(
"
SELECT m.id, m.scope_kind, m.scope_key, m.namespace, m.memory_key, m.content,
m.source_type, m.source_session_id, m.source_message_id, m.source_message_seq,
m.source_channel_name, m.source_chat_id, m.created_at, m.updated_at
FROM memories_fts f
JOIN memories m ON m.rowid = f.rowid
WHERE memories_fts MATCH ?1
AND m.scope_kind = ?2
AND m.scope_key = ?3
ORDER BY bm25(memories_fts), m.updated_at DESC
LIMIT ?4
",
)?;
let rows = stmt.query_map(params![query, scope_kind, scope_key, limit], map_memory_record)?;
for row in rows {
memories.push(row?);
}
}
Ok(memories)
}
pub fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError> {
let conn = self.conn.lock().expect("session db mutex poisoned");
let cutoff_seq = active_reset_cutoff(&conn, session_id)?;
@ -479,6 +802,25 @@ fn map_skill_event_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<SkillEven
})
}
fn map_memory_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<MemoryRecord> {
Ok(MemoryRecord {
id: row.get(0)?,
scope_kind: row.get(1)?,
scope_key: row.get(2)?,
namespace: row.get(3)?,
memory_key: row.get(4)?,
content: row.get(5)?,
source_type: row.get(6)?,
source_session_id: row.get(7)?,
source_message_id: row.get(8)?,
source_message_seq: row.get(9)?,
source_channel_name: row.get(10)?,
source_chat_id: row.get(11)?,
created_at: row.get(12)?,
updated_at: row.get(13)?,
})
}
fn ensure_sessions_schema(conn: &Connection) -> Result<(), StorageError> {
if !has_column(conn, "sessions", "reset_cutoff_seq")? {
conn.execute(
@ -580,6 +922,10 @@ fn current_timestamp() -> i64 {
.as_millis() as i64
}
fn quote_fts_query(query: &str) -> String {
format!("\"{}\"", query.replace('"', "\"\""))
}
#[cfg(test)]
mod tests {
use super::*;
@ -797,4 +1143,101 @@ mod tests {
assert_eq!(session_events[0].skill_name.as_deref(), Some("code-review"));
assert_eq!(session_events[0].payload["source"], "project");
}
#[test]
fn test_memory_roundtrip_with_source_fields() {
let store = SessionStore::in_memory().unwrap();
let saved = store
.put_memory(&MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: "feishu:user-1".to_string(),
namespace: "profile".to_string(),
memory_key: "language".to_string(),
content: "Rust".to_string(),
source_type: "message".to_string(),
source_session_id: Some("feishu:chat-1".to_string()),
source_message_id: Some("msg-1".to_string()),
source_message_seq: Some(7),
source_channel_name: Some("feishu".to_string()),
source_chat_id: Some("chat-1".to_string()),
})
.unwrap();
assert_eq!(saved.content, "Rust");
assert_eq!(saved.source_type, "message");
assert_eq!(saved.source_session_id.as_deref(), Some("feishu:chat-1"));
assert_eq!(saved.source_message_id.as_deref(), Some("msg-1"));
assert_eq!(saved.source_message_seq, Some(7));
let fetched = store
.get_memory("user", "feishu:user-1", "profile", "language")
.unwrap()
.unwrap();
assert_eq!(fetched.id, saved.id);
assert_eq!(fetched.source_chat_id.as_deref(), Some("chat-1"));
}
#[test]
fn test_memory_fts_tracks_upsert_and_delete() {
let store = SessionStore::in_memory().unwrap();
store
.put_memory(&MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: "feishu:user-1".to_string(),
namespace: "preferences".to_string(),
memory_key: "editor".to_string(),
content: "Prefers rust-analyzer and cargo test output".to_string(),
source_type: "message".to_string(),
source_session_id: Some("feishu:chat-2".to_string()),
source_message_id: Some("msg-2".to_string()),
source_message_seq: Some(3),
source_channel_name: Some("feishu".to_string()),
source_chat_id: Some("chat-2".to_string()),
})
.unwrap();
let hits = store
.search_memories("user", "feishu:user-1", "rust-analyzer", None, 10)
.unwrap();
assert_eq!(hits.len(), 1);
assert_eq!(hits[0].memory_key, "editor");
store
.put_memory(&MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: "feishu:user-1".to_string(),
namespace: "preferences".to_string(),
memory_key: "editor".to_string(),
content: "Prefers clippy diagnostics".to_string(),
source_type: "message".to_string(),
source_session_id: Some("feishu:chat-3".to_string()),
source_message_id: Some("msg-3".to_string()),
source_message_seq: Some(4),
source_channel_name: Some("feishu".to_string()),
source_chat_id: Some("chat-3".to_string()),
})
.unwrap();
let old_hits = store
.search_memories("user", "feishu:user-1", "rust-analyzer", None, 10)
.unwrap();
assert!(old_hits.is_empty());
let new_hits = store
.search_memories("user", "feishu:user-1", "clippy", None, 10)
.unwrap();
assert_eq!(new_hits.len(), 1);
let deleted = store
.delete_memory("user", "feishu:user-1", "preferences", "editor")
.unwrap();
assert!(deleted);
let hits_after_delete = store
.search_memories("user", "feishu:user-1", "clippy", None, 10)
.unwrap();
assert!(hits_after_delete.is_empty());
}
}

313
src/tools/memory_manage.rs Normal file
View File

@ -0,0 +1,313 @@
use std::sync::Arc;
use async_trait::async_trait;
use serde_json::json;
use crate::storage::{MemoryRecord, MemoryUpsert, SessionStore};
use crate::tools::traits::{Tool, ToolContext, ToolResult};
pub struct MemoryManageTool {
store: Arc<SessionStore>,
}
impl MemoryManageTool {
pub fn new(store: Arc<SessionStore>) -> Self {
Self { store }
}
}
#[async_trait]
impl Tool for MemoryManageTool {
fn name(&self) -> &str {
"memory_manage"
}
fn description(&self) -> &str {
"Manage user memories stored in SQLite. Supports actions: list, get, put, update, delete. Memories are scoped to the current channel and sender, and record the originating session/message when available."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": ["list", "get", "put", "update", "delete"],
"description": "Management action to perform"
},
"namespace": {
"type": "string",
"description": "Memory namespace, such as profile, preferences, or tasks"
},
"key": {
"type": "string",
"description": "Memory key within the namespace"
},
"content": {
"type": "string",
"description": "Memory content for put/update"
},
"limit": {
"type": "integer",
"description": "Maximum number of memories to list",
"minimum": 1,
"default": 20
}
},
"required": ["action"]
})
}
async fn execute(&self, _args: serde_json::Value) -> anyhow::Result<ToolResult> {
Ok(error_result("memory_manage requires tool context"))
}
async fn execute_with_context(
&self,
context: &ToolContext,
args: serde_json::Value,
) -> anyhow::Result<ToolResult> {
let action = match args.get("action").and_then(|value| value.as_str()) {
Some(action) => action,
None => return Ok(error_result("Missing required parameter: action")),
};
let scope_key = match scope_key_from_context(context) {
Ok(scope_key) => scope_key,
Err(result) => return Ok(result),
};
let namespace = args.get("namespace").and_then(|value| value.as_str());
let key = args.get("key").and_then(|value| value.as_str());
let payload = match action {
"list" => {
let limit = args
.get("limit")
.and_then(|value| value.as_u64())
.unwrap_or(20) as usize;
let memories = self
.store
.list_memories("user", &scope_key, namespace, limit)?;
json!({
"count": memories.len(),
"memories": memories.into_iter().map(memory_to_json).collect::<Vec<_>>()
})
}
"get" => {
let namespace = match namespace {
Some(namespace) => namespace,
None => return Ok(error_result("Missing required parameter: namespace")),
};
let key = match key {
Some(key) => key,
None => return Ok(error_result("Missing required parameter: key")),
};
match self.store.get_memory("user", &scope_key, namespace, key)? {
Some(memory) => memory_to_json(memory),
None => return Ok(error_result(&format!("memory '{}.{}' not found", namespace, key))),
}
}
"put" => {
let input = match build_memory_upsert(context, &scope_key, &args, true) {
Ok(input) => input,
Err(result) => return Ok(result),
};
memory_to_json(self.store.put_memory(&input)?)
}
"update" => {
let input = match build_memory_upsert(context, &scope_key, &args, false) {
Ok(input) => input,
Err(result) => return Ok(result),
};
match self.store.update_memory(&input)? {
Some(memory) => memory_to_json(memory),
None => {
return Ok(error_result(&format!(
"memory '{}.{}' not found",
input.namespace, input.memory_key
)))
}
}
}
"delete" => {
let namespace = match namespace {
Some(namespace) => namespace,
None => return Ok(error_result("Missing required parameter: namespace")),
};
let key = match key {
Some(key) => key,
None => return Ok(error_result("Missing required parameter: key")),
};
let deleted = self.store.delete_memory("user", &scope_key, namespace, key)?;
if !deleted {
return Ok(error_result(&format!("memory '{}.{}' not found", namespace, key)));
}
json!({
"status": "deleted",
"namespace": namespace,
"key": key,
})
}
_ => return Ok(error_result("Unsupported action")),
};
Ok(ToolResult {
success: true,
output: serde_json::to_string_pretty(&payload)?,
error: None,
})
}
}
fn build_memory_upsert(
context: &ToolContext,
scope_key: &str,
args: &serde_json::Value,
allow_put: bool,
) -> Result<MemoryUpsert, ToolResult> {
let namespace = match args.get("namespace").and_then(|value| value.as_str()) {
Some(namespace) => namespace,
None => return Err(error_result("Missing required parameter: namespace")),
};
let key = match args.get("key").and_then(|value| value.as_str()) {
Some(key) => key,
None => return Err(error_result("Missing required parameter: key")),
};
let content = match args.get("content").and_then(|value| value.as_str()) {
Some(content) => content,
None => return Err(error_result("Missing required parameter: content")),
};
let source_type = if context.message_id.is_some() {
"message"
} else if allow_put {
"manual"
} else {
"session"
};
Ok(MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: scope_key.to_string(),
namespace: namespace.to_string(),
memory_key: key.to_string(),
content: content.to_string(),
source_type: source_type.to_string(),
source_session_id: context.session_id.clone(),
source_message_id: context.message_id.clone(),
source_message_seq: context.message_seq,
source_channel_name: context.channel_name.clone(),
source_chat_id: context.chat_id.clone(),
})
}
fn scope_key_from_context(context: &ToolContext) -> Result<String, ToolResult> {
let channel_name = context
.channel_name
.as_deref()
.ok_or_else(|| error_result("memory_manage requires channel_name in tool context"))?;
let sender_id = context
.sender_id
.as_deref()
.ok_or_else(|| error_result("memory_manage requires sender_id in tool context"))?;
Ok(format!("{}:{}", channel_name, sender_id))
}
fn memory_to_json(memory: MemoryRecord) -> serde_json::Value {
json!({
"id": memory.id,
"scope_kind": memory.scope_kind,
"scope_key": memory.scope_key,
"namespace": memory.namespace,
"key": memory.memory_key,
"content": memory.content,
"source_type": memory.source_type,
"source_session_id": memory.source_session_id,
"source_message_id": memory.source_message_id,
"source_message_seq": memory.source_message_seq,
"source_channel_name": memory.source_channel_name,
"source_chat_id": memory.source_chat_id,
"created_at": memory.created_at,
"updated_at": memory.updated_at,
})
}
fn error_result(message: &str) -> ToolResult {
ToolResult {
success: false,
output: String::new(),
error: Some(message.to_string()),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_memory_manage_put_and_get() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let tool = MemoryManageTool::new(store);
let context = ToolContext {
channel_name: Some("feishu".to_string()),
sender_id: Some("user-1".to_string()),
chat_id: Some("chat-1".to_string()),
session_id: Some("feishu:chat-1".to_string()),
message_id: Some("msg-1".to_string()),
message_seq: Some(1),
};
let put = tool
.execute_with_context(
&context,
json!({
"action": "put",
"namespace": "profile",
"key": "language",
"content": "Rust"
}),
)
.await
.unwrap();
assert!(put.success);
let get = tool
.execute_with_context(
&context,
json!({
"action": "get",
"namespace": "profile",
"key": "language"
}),
)
.await
.unwrap();
assert!(get.success);
assert!(get.output.contains("Rust"));
assert!(get.output.contains("msg-1"));
}
#[tokio::test]
async fn test_memory_manage_requires_context() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let tool = MemoryManageTool::new(store);
let result = tool
.execute_with_context(
&ToolContext::default(),
json!({
"action": "list"
}),
)
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("channel_name"));
}
}

View File

@ -4,6 +4,7 @@ pub mod file_edit;
pub mod file_read;
pub mod file_write;
pub mod http_request;
pub mod memory_manage;
pub mod registry;
pub mod schema;
pub mod skill_manage;
@ -16,8 +17,9 @@ pub use file_edit::FileEditTool;
pub use file_read::FileReadTool;
pub use file_write::FileWriteTool;
pub use http_request::HttpRequestTool;
pub use memory_manage::MemoryManageTool;
pub use registry::ToolRegistry;
pub use schema::{CleaningStrategy, SchemaCleanr};
pub use skill_manage::{SkillListTool, SkillManageTool};
pub use traits::{Tool, ToolResult};
pub use traits::{Tool, ToolContext, ToolResult};
pub use web_fetch::WebFetchTool;

View File

@ -7,6 +7,16 @@ pub struct ToolResult {
pub error: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct ToolContext {
pub channel_name: Option<String>,
pub sender_id: Option<String>,
pub chat_id: Option<String>,
pub session_id: Option<String>,
pub message_id: Option<String>,
pub message_seq: Option<i64>,
}
#[async_trait]
pub trait Tool: Send + Sync + 'static {
fn name(&self) -> &str;
@ -14,6 +24,14 @@ pub trait Tool: Send + Sync + 'static {
fn parameters_schema(&self) -> serde_json::Value;
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult>;
async fn execute_with_context(
&self,
_context: &ToolContext,
args: serde_json::Value,
) -> anyhow::Result<ToolResult> {
self.execute(args).await
}
/// Whether this tool is side-effect free and safe to parallelize.
fn read_only(&self) -> bool {
false

View File

@ -117,3 +117,55 @@ fn test_clear_history_with_session_id_serialization() {
assert!(json.contains(r#""type":"clear_history""#));
assert!(json.contains(r#""session_id":"session-1""#));
}
#[test]
fn test_tool_call_outbound_serialization() {
let msg = WsOutbound::ToolCall {
id: "msg-1".to_string(),
tool_call_id: "call-1".to_string(),
tool_name: "calculator".to_string(),
arguments: serde_json::json!({"expression": "1 + 1"}),
content: "调用工具: calculator".to_string(),
role: "assistant".to_string(),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains(r#""type":"tool_call""#));
assert!(json.contains(r#""tool_name":"calculator""#));
assert!(json.contains(r#""expression":"1 + 1""#));
let decoded: WsOutbound = serde_json::from_str(&json).unwrap();
match decoded {
WsOutbound::ToolCall { tool_call_id, tool_name, arguments, .. } => {
assert_eq!(tool_call_id, "call-1");
assert_eq!(tool_name, "calculator");
assert_eq!(arguments["expression"], "1 + 1");
}
other => panic!("unexpected decoded variant: {:?}", other),
}
}
#[test]
fn test_tool_result_outbound_serialization() {
let msg = WsOutbound::ToolResult {
id: "msg-2".to_string(),
tool_call_id: "call-1".to_string(),
tool_name: "calculator".to_string(),
content: "工具结果: calculator\n\n2".to_string(),
role: "tool".to_string(),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains(r#""type":"tool_result""#));
assert!(json.contains(r#""tool_name":"calculator""#));
let decoded: WsOutbound = serde_json::from_str(&json).unwrap();
match decoded {
WsOutbound::ToolResult { tool_call_id, tool_name, content, .. } => {
assert_eq!(tool_call_id, "call-1");
assert_eq!(tool_name, "calculator");
assert!(content.contains('2'));
}
other => panic!("unexpected decoded variant: {:?}", other),
}
}