#
This commit is contained in:
parent
a0fe7c57bd
commit
d35e89a44c
@ -7,7 +7,7 @@ use crate::observability::{
|
||||
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall};
|
||||
use crate::skills::SkillRuntime;
|
||||
use crate::storage::SessionStore;
|
||||
use crate::tools::ToolRegistry;
|
||||
use crate::tools::{ToolContext, ToolRegistry};
|
||||
use std::collections::VecDeque;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::io::Read;
|
||||
@ -225,6 +225,7 @@ pub struct AgentLoop {
|
||||
skills: Arc<SkillRuntime>,
|
||||
skill_event_store: Option<Arc<SessionStore>>,
|
||||
skill_event_session_id: Option<String>,
|
||||
tool_context: ToolContext,
|
||||
observer: Option<Arc<dyn Observer>>,
|
||||
max_iterations: usize,
|
||||
}
|
||||
@ -247,6 +248,7 @@ impl AgentLoop {
|
||||
skills: Arc::new(SkillRuntime::default()),
|
||||
skill_event_store: None,
|
||||
skill_event_session_id: None,
|
||||
tool_context: ToolContext::default(),
|
||||
observer: None,
|
||||
max_iterations,
|
||||
})
|
||||
@ -263,6 +265,7 @@ impl AgentLoop {
|
||||
skills: Arc::new(SkillRuntime::default()),
|
||||
skill_event_store: None,
|
||||
skill_event_session_id: None,
|
||||
tool_context: ToolContext::default(),
|
||||
observer: None,
|
||||
max_iterations,
|
||||
})
|
||||
@ -283,6 +286,7 @@ impl AgentLoop {
|
||||
skills,
|
||||
skill_event_store: None,
|
||||
skill_event_session_id: None,
|
||||
tool_context: ToolContext::default(),
|
||||
observer: None,
|
||||
max_iterations,
|
||||
})
|
||||
@ -294,6 +298,11 @@ impl AgentLoop {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_tool_context(mut self, context: ToolContext) -> Self {
|
||||
self.tool_context = context;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set an observer for tracking events.
|
||||
pub fn with_observer(mut self, observer: Arc<dyn Observer>) -> Self {
|
||||
self.observer = Some(observer);
|
||||
@ -622,7 +631,7 @@ impl AgentLoop {
|
||||
}
|
||||
};
|
||||
|
||||
match tool.execute(tool_call.arguments.clone()).await {
|
||||
match tool.execute_with_context(&self.tool_context, tool_call.arguments.clone()).await {
|
||||
Ok(result) => {
|
||||
if result.success {
|
||||
ToolExecutionOutcome::success(result.output)
|
||||
|
||||
@ -191,12 +191,165 @@ pub struct OutboundMessage {
|
||||
pub reply_to: Option<String>,
|
||||
pub media: Vec<MediaItem>,
|
||||
pub metadata: HashMap<String, String>,
|
||||
pub event_kind: OutboundEventKind,
|
||||
pub role: String,
|
||||
pub tool_call_id: Option<String>,
|
||||
pub tool_name: Option<String>,
|
||||
pub tool_arguments: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum OutboundEventKind {
|
||||
AssistantResponse,
|
||||
ToolCall,
|
||||
ToolResult,
|
||||
}
|
||||
|
||||
impl OutboundMessage {
|
||||
pub fn is_stream_delta(&self) -> bool {
|
||||
self.metadata.get("_stream_delta").is_some()
|
||||
}
|
||||
|
||||
pub fn assistant(
|
||||
channel: impl Into<String>,
|
||||
chat_id: impl Into<String>,
|
||||
content: impl Into<String>,
|
||||
reply_to: Option<String>,
|
||||
metadata: HashMap<String, String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
channel: channel.into(),
|
||||
chat_id: chat_id.into(),
|
||||
content: content.into(),
|
||||
reply_to,
|
||||
media: Vec::new(),
|
||||
metadata,
|
||||
event_kind: OutboundEventKind::AssistantResponse,
|
||||
role: "assistant".to_string(),
|
||||
tool_call_id: None,
|
||||
tool_name: None,
|
||||
tool_arguments: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tool_call(
|
||||
channel: impl Into<String>,
|
||||
chat_id: impl Into<String>,
|
||||
message_id: impl Into<String>,
|
||||
tool_name: impl Into<String>,
|
||||
tool_arguments: serde_json::Value,
|
||||
reply_to: Option<String>,
|
||||
metadata: HashMap<String, String>,
|
||||
) -> Self {
|
||||
let tool_name = tool_name.into();
|
||||
let content = format_tool_call_content(&tool_name, &tool_arguments);
|
||||
Self {
|
||||
channel: channel.into(),
|
||||
chat_id: chat_id.into(),
|
||||
content,
|
||||
reply_to,
|
||||
media: Vec::new(),
|
||||
metadata,
|
||||
event_kind: OutboundEventKind::ToolCall,
|
||||
role: "assistant".to_string(),
|
||||
tool_call_id: Some(message_id.into()),
|
||||
tool_name: Some(tool_name),
|
||||
tool_arguments: Some(tool_arguments),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tool_result(
|
||||
channel: impl Into<String>,
|
||||
chat_id: impl Into<String>,
|
||||
tool_call_id: impl Into<String>,
|
||||
tool_name: impl Into<String>,
|
||||
content: impl Into<String>,
|
||||
reply_to: Option<String>,
|
||||
metadata: HashMap<String, String>,
|
||||
) -> Self {
|
||||
let tool_name = tool_name.into();
|
||||
let raw_content = content.into();
|
||||
let content = format_tool_result_content(&tool_name, &raw_content);
|
||||
Self {
|
||||
channel: channel.into(),
|
||||
chat_id: chat_id.into(),
|
||||
content,
|
||||
reply_to,
|
||||
media: Vec::new(),
|
||||
metadata,
|
||||
event_kind: OutboundEventKind::ToolResult,
|
||||
role: "tool".to_string(),
|
||||
tool_call_id: Some(tool_call_id.into()),
|
||||
tool_name: Some(tool_name),
|
||||
tool_arguments: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_chat_message(
|
||||
channel: &str,
|
||||
chat_id: &str,
|
||||
reply_to: Option<String>,
|
||||
metadata: &HashMap<String, String>,
|
||||
message: &ChatMessage,
|
||||
) -> Vec<Self> {
|
||||
match message.role.as_str() {
|
||||
"assistant" => {
|
||||
if let Some(tool_calls) = &message.tool_calls {
|
||||
tool_calls
|
||||
.iter()
|
||||
.map(|tool_call| {
|
||||
Self::tool_call(
|
||||
channel.to_string(),
|
||||
chat_id.to_string(),
|
||||
tool_call.id.clone(),
|
||||
tool_call.name.clone(),
|
||||
tool_call.arguments.clone(),
|
||||
reply_to.clone(),
|
||||
metadata.clone(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
vec![Self::assistant(
|
||||
channel.to_string(),
|
||||
chat_id.to_string(),
|
||||
message.content.clone(),
|
||||
reply_to,
|
||||
metadata.clone(),
|
||||
)]
|
||||
}
|
||||
}
|
||||
"tool" => vec![Self::tool_result(
|
||||
channel.to_string(),
|
||||
chat_id.to_string(),
|
||||
message.tool_call_id.clone().unwrap_or_else(|| message.id.clone()),
|
||||
message.tool_name.clone().unwrap_or_else(|| "tool".to_string()),
|
||||
message.content.clone(),
|
||||
reply_to,
|
||||
metadata.clone(),
|
||||
)],
|
||||
_ => Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn format_tool_call_content(tool_name: &str, tool_arguments: &serde_json::Value) -> String {
|
||||
format!(
|
||||
"调用工具: {}\n\n输入参数:\n{}",
|
||||
tool_name,
|
||||
format_json_value(tool_arguments),
|
||||
)
|
||||
}
|
||||
|
||||
fn format_tool_result_content(tool_name: &str, content: &str) -> String {
|
||||
format!("工具结果: {}\n\n{}", tool_name, content)
|
||||
}
|
||||
|
||||
fn format_json_value(value: &serde_json::Value) -> String {
|
||||
match value {
|
||||
serde_json::Value::Object(map) if map.is_empty() => "{}".to_string(),
|
||||
other => serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
@ -209,3 +362,63 @@ fn current_timestamp() -> i64 {
|
||||
.unwrap()
|
||||
.as_millis() as i64
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{ChatMessage, OutboundEventKind, OutboundMessage};
|
||||
use crate::providers::ToolCall;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[test]
|
||||
fn test_from_chat_message_expands_tool_calls() {
|
||||
let message = ChatMessage::assistant_with_tool_calls(
|
||||
"",
|
||||
vec![
|
||||
ToolCall {
|
||||
id: "call-1".to_string(),
|
||||
name: "calculator".to_string(),
|
||||
arguments: json!({"expression": "1 + 1"}),
|
||||
},
|
||||
ToolCall {
|
||||
id: "call-2".to_string(),
|
||||
name: "file_read".to_string(),
|
||||
arguments: json!({"path": "README.md"}),
|
||||
},
|
||||
],
|
||||
);
|
||||
|
||||
let outbound = OutboundMessage::from_chat_message(
|
||||
"feishu",
|
||||
"chat-1",
|
||||
None,
|
||||
&HashMap::new(),
|
||||
&message,
|
||||
);
|
||||
|
||||
assert_eq!(outbound.len(), 2);
|
||||
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolCall);
|
||||
assert_eq!(outbound[0].tool_name.as_deref(), Some("calculator"));
|
||||
assert_eq!(outbound[0].tool_arguments.as_ref().unwrap()["expression"], "1 + 1");
|
||||
assert_eq!(outbound[1].tool_name.as_deref(), Some("file_read"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_chat_message_maps_tool_result() {
|
||||
let message = ChatMessage::tool("call-9", "calculator", "2");
|
||||
|
||||
let outbound = OutboundMessage::from_chat_message(
|
||||
"feishu",
|
||||
"chat-1",
|
||||
None,
|
||||
&HashMap::new(),
|
||||
&message,
|
||||
);
|
||||
|
||||
assert_eq!(outbound.len(), 1);
|
||||
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolResult);
|
||||
assert_eq!(outbound[0].tool_call_id.as_deref(), Some("call-9"));
|
||||
assert_eq!(outbound[0].tool_name.as_deref(), Some("calculator"));
|
||||
assert!(outbound[0].content.contains("工具结果: calculator"));
|
||||
}
|
||||
}
|
||||
|
||||
@ -40,6 +40,10 @@ fn parse_message(raw: &str) -> Result<WsOutbound, serde_json::Error> {
|
||||
serde_json::from_str(raw)
|
||||
}
|
||||
|
||||
fn format_json(value: &serde_json::Value) -> String {
|
||||
serde_json::to_string_pretty(value).unwrap_or_else(|_| value.to_string())
|
||||
}
|
||||
|
||||
pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (ws_stream, _) = connect_async(gateway_url).await?;
|
||||
tracing::info!(url = %gateway_url, "Connected to gateway");
|
||||
@ -63,6 +67,12 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
WsOutbound::AssistantResponse { content, .. } => {
|
||||
input.write_response(&content).await?;
|
||||
}
|
||||
WsOutbound::ToolCall { tool_name, arguments, .. } => {
|
||||
input.write_output(&format!("Tool call: {}\n{}\n", tool_name, format_json(&arguments))).await?;
|
||||
}
|
||||
WsOutbound::ToolResult { tool_name, content, .. } => {
|
||||
input.write_output(&format!("Tool result: {}\n{}\n", tool_name, content)).await?;
|
||||
}
|
||||
WsOutbound::Error { message, .. } => {
|
||||
input.write_output(&format!("Error: {}", message)).await?;
|
||||
}
|
||||
|
||||
@ -81,22 +81,17 @@ impl GatewayState {
|
||||
&inbound.content,
|
||||
inbound.media,
|
||||
).await {
|
||||
Ok(response_content) => {
|
||||
Ok(outbound_messages) => {
|
||||
// Forward channel-specific metadata from inbound to outbound.
|
||||
// This allows channels to propagate context (e.g. feishu message_id for reaction cleanup)
|
||||
// without gateway needing channel-specific code.
|
||||
let outbound = crate::bus::OutboundMessage {
|
||||
channel: inbound.channel.clone(),
|
||||
chat_id: inbound.chat_id.clone(),
|
||||
content: response_content,
|
||||
reply_to: None,
|
||||
media: vec![],
|
||||
metadata: inbound.forwarded_metadata,
|
||||
};
|
||||
for mut outbound in outbound_messages {
|
||||
outbound.metadata.extend(inbound.forwarded_metadata.clone());
|
||||
if let Err(e) = bus_for_inbound.publish_outbound(outbound).await {
|
||||
tracing::error!(error = %e, "Failed to publish outbound");
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to handle message");
|
||||
}
|
||||
|
||||
@ -3,7 +3,7 @@ use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
use uuid::Uuid;
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::bus::{ChatMessage, OutboundMessage};
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::agent::{AgentLoop, AgentError, ContextCompressor};
|
||||
use crate::protocol::WsOutbound;
|
||||
@ -11,7 +11,8 @@ use crate::skills::SkillRuntime;
|
||||
use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
|
||||
use crate::tools::{
|
||||
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
|
||||
HttpRequestTool, SkillListTool, SkillManageTool, ToolRegistry, WebFetchTool,
|
||||
HttpRequestTool, MemoryManageTool, SkillListTool, SkillManageTool, ToolContext, ToolRegistry,
|
||||
WebFetchTool,
|
||||
};
|
||||
|
||||
/// Session 按 channel 隔离,每个 channel 一个 Session
|
||||
@ -197,13 +198,30 @@ impl Session {
|
||||
}
|
||||
|
||||
/// 创建一个临时的 AgentLoop 实例来处理消息
|
||||
pub fn create_agent(&self, chat_id: &str) -> Result<AgentLoop, AgentError> {
|
||||
pub fn create_agent(
|
||||
&self,
|
||||
chat_id: &str,
|
||||
sender_id: Option<&str>,
|
||||
message_id: Option<&str>,
|
||||
) -> Result<AgentLoop, AgentError> {
|
||||
let session_id = self.persistent_session_id(chat_id);
|
||||
AgentLoop::with_tools_and_skills(
|
||||
self.provider_config.clone(),
|
||||
self.tools.clone(),
|
||||
self.skills.clone(),
|
||||
)
|
||||
.map(|agent| agent.with_skill_event_store(self.store.clone(), self.persistent_session_id(chat_id)))
|
||||
.map(|agent| {
|
||||
agent
|
||||
.with_skill_event_store(self.store.clone(), session_id.clone())
|
||||
.with_tool_context(ToolContext {
|
||||
channel_name: Some(self.channel_name.clone()),
|
||||
sender_id: sender_id.map(str::to_string),
|
||||
chat_id: Some(chat_id.to_string()),
|
||||
session_id: Some(session_id),
|
||||
message_id: message_id.map(str::to_string),
|
||||
message_seq: None,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -223,12 +241,13 @@ struct SessionManagerInner {
|
||||
session_ttl: Duration,
|
||||
}
|
||||
|
||||
fn default_tools(skills: Arc<SkillRuntime>) -> ToolRegistry {
|
||||
fn default_tools(skills: Arc<SkillRuntime>, store: Arc<SessionStore>) -> ToolRegistry {
|
||||
let mut registry = ToolRegistry::new();
|
||||
registry.register(CalculatorTool::new());
|
||||
registry.register(FileReadTool::new());
|
||||
registry.register(FileWriteTool::new());
|
||||
registry.register(FileEditTool::new());
|
||||
registry.register(MemoryManageTool::new(store));
|
||||
registry.register(SkillListTool::new(skills.clone()));
|
||||
registry.register(SkillManageTool::new(skills));
|
||||
registry.register(BashTool::new());
|
||||
@ -290,7 +309,7 @@ impl SessionManager {
|
||||
session_ttl: Duration::from_secs(session_ttl_hours * 3600),
|
||||
})),
|
||||
provider_config,
|
||||
tools: Arc::new(default_tools(skills.clone())),
|
||||
tools: Arc::new(default_tools(skills.clone(), store.clone())),
|
||||
skills,
|
||||
store,
|
||||
})
|
||||
@ -414,11 +433,11 @@ impl SessionManager {
|
||||
pub async fn handle_message(
|
||||
&self,
|
||||
channel_name: &str,
|
||||
_sender_id: &str,
|
||||
sender_id: &str,
|
||||
chat_id: &str,
|
||||
content: &str,
|
||||
media: Vec<crate::bus::MediaItem>,
|
||||
) -> Result<String, AgentError> {
|
||||
) -> Result<Vec<OutboundMessage>, AgentError> {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
tracing::debug!(
|
||||
@ -453,7 +472,13 @@ impl SessionManager {
|
||||
session_guard.ensure_chat_loaded(chat_id)?;
|
||||
|
||||
if let Some(command_response) = handle_in_chat_command(&mut session_guard, chat_id, content)? {
|
||||
return Ok(command_response);
|
||||
return Ok(vec![OutboundMessage::assistant(
|
||||
channel_name.to_string(),
|
||||
chat_id.to_string(),
|
||||
command_response,
|
||||
None,
|
||||
HashMap::new(),
|
||||
)]);
|
||||
}
|
||||
|
||||
// 添加用户消息到历史
|
||||
@ -463,6 +488,7 @@ impl SessionManager {
|
||||
tracing::debug!(media_count = %media.len(), media_refs = ?media_refs, "Adding user message with media");
|
||||
}
|
||||
let user_message = session_guard.create_user_message(content, media_refs);
|
||||
let user_message_id = user_message.id.clone();
|
||||
session_guard.append_persisted_message(chat_id, user_message)?;
|
||||
|
||||
// 获取完整历史
|
||||
@ -476,24 +502,36 @@ impl SessionManager {
|
||||
session_guard.record_skill_offer(chat_id)?;
|
||||
|
||||
// 创建 agent 并处理
|
||||
let agent = session_guard.create_agent(chat_id)?;
|
||||
let agent = session_guard.create_agent(chat_id, Some(sender_id), Some(&user_message_id))?;
|
||||
let result = agent.process(history).await?;
|
||||
|
||||
// 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复
|
||||
session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?;
|
||||
|
||||
result.final_response
|
||||
result
|
||||
.emitted_messages
|
||||
.iter()
|
||||
.flat_map(|message| {
|
||||
OutboundMessage::from_chat_message(
|
||||
channel_name,
|
||||
chat_id,
|
||||
None,
|
||||
&HashMap::new(),
|
||||
message,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(
|
||||
channel = %channel_name,
|
||||
chat_id = %chat_id,
|
||||
response_len = response.content.len(),
|
||||
"Agent response received"
|
||||
outbound_count = response.len(),
|
||||
"Agent response sequence received"
|
||||
);
|
||||
|
||||
Ok(response.content)
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// 清除指定 session 的所有历史
|
||||
@ -541,7 +579,7 @@ mod tests {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||
let skills = Arc::new(SkillRuntime::default());
|
||||
let tools = Arc::new(default_tools(skills.clone()));
|
||||
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
|
||||
let mut session = Session::new(
|
||||
"feishu".to_string(),
|
||||
test_provider_config(),
|
||||
|
||||
@ -4,6 +4,7 @@ use axum::extract::State;
|
||||
use axum::response::Response;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound};
|
||||
use super::{GatewayState, session::{Session, handle_in_chat_command}};
|
||||
|
||||
@ -140,6 +141,52 @@ fn to_session_summary(record: crate::storage::SessionRecord) -> SessionSummary {
|
||||
}
|
||||
}
|
||||
|
||||
fn format_tool_arguments(arguments: &serde_json::Value) -> String {
|
||||
serde_json::to_string_pretty(arguments).unwrap_or_else(|_| arguments.to_string())
|
||||
}
|
||||
|
||||
fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
|
||||
match message.role.as_str() {
|
||||
"assistant" => {
|
||||
if let Some(tool_calls) = &message.tool_calls {
|
||||
tool_calls
|
||||
.iter()
|
||||
.map(|tool_call| WsOutbound::ToolCall {
|
||||
id: message.id.clone(),
|
||||
tool_call_id: tool_call.id.clone(),
|
||||
tool_name: tool_call.name.clone(),
|
||||
arguments: tool_call.arguments.clone(),
|
||||
content: format!(
|
||||
"调用工具: {}\n\n输入参数:\n{}",
|
||||
tool_call.name,
|
||||
format_tool_arguments(&tool_call.arguments),
|
||||
),
|
||||
role: message.role.clone(),
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
vec![WsOutbound::AssistantResponse {
|
||||
id: message.id.clone(),
|
||||
content: message.content.clone(),
|
||||
role: message.role.clone(),
|
||||
}]
|
||||
}
|
||||
}
|
||||
"tool" => vec![WsOutbound::ToolResult {
|
||||
id: message.id.clone(),
|
||||
tool_call_id: message.tool_call_id.clone().unwrap_or_else(|| message.id.clone()),
|
||||
tool_name: message.tool_name.clone().unwrap_or_else(|| "tool".to_string()),
|
||||
content: format!(
|
||||
"工具结果: {}\n\n{}",
|
||||
message.tool_name.clone().unwrap_or_else(|| "tool".to_string()),
|
||||
message.content,
|
||||
),
|
||||
role: message.role.clone(),
|
||||
}],
|
||||
_ => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_inbound(
|
||||
state: &Arc<GatewayState>,
|
||||
session: &Arc<Mutex<Session>>,
|
||||
@ -166,6 +213,7 @@ async fn handle_inbound(
|
||||
}
|
||||
|
||||
let user_message = session_guard.create_user_message(&content, Vec::new());
|
||||
let user_message_id = user_message.id.clone();
|
||||
session_guard.append_persisted_message(&chat_id, user_message)?;
|
||||
|
||||
let raw_history = session_guard.get_or_create_history(&chat_id).clone();
|
||||
@ -183,17 +231,17 @@ async fn handle_inbound(
|
||||
|
||||
session_guard.record_skill_offer(&chat_id)?;
|
||||
|
||||
let agent = session_guard.create_agent(&chat_id)?;
|
||||
let agent = session_guard.create_agent(&chat_id, None, Some(&user_message_id))?;
|
||||
match agent.process(history).await {
|
||||
Ok(result) => {
|
||||
session_guard.append_persisted_messages(&chat_id, result.emitted_messages.clone())?;
|
||||
let _ = session_guard
|
||||
.send(WsOutbound::AssistantResponse {
|
||||
id: result.final_response.id,
|
||||
content: result.final_response.content,
|
||||
role: result.final_response.role,
|
||||
})
|
||||
.await;
|
||||
for outbound in result
|
||||
.emitted_messages
|
||||
.iter()
|
||||
.flat_map(ws_outbound_from_chat_message)
|
||||
{
|
||||
let _ = session_guard.send(outbound).await;
|
||||
}
|
||||
}
|
||||
Err(error) => {
|
||||
tracing::error!(chat_id = %chat_id, error = %error, "Agent process error");
|
||||
|
||||
@ -71,6 +71,23 @@ pub enum WsInbound {
|
||||
pub enum WsOutbound {
|
||||
#[serde(rename = "assistant_response")]
|
||||
AssistantResponse { id: String, content: String, role: String },
|
||||
#[serde(rename = "tool_call")]
|
||||
ToolCall {
|
||||
id: String,
|
||||
tool_call_id: String,
|
||||
tool_name: String,
|
||||
arguments: serde_json::Value,
|
||||
content: String,
|
||||
role: String,
|
||||
},
|
||||
#[serde(rename = "tool_result")]
|
||||
ToolResult {
|
||||
id: String,
|
||||
tool_call_id: String,
|
||||
tool_name: String,
|
||||
content: String,
|
||||
role: String,
|
||||
},
|
||||
#[serde(rename = "error")]
|
||||
Error { code: String, message: String },
|
||||
#[serde(rename = "session_established")]
|
||||
|
||||
@ -42,6 +42,39 @@ pub struct SessionRecord {
|
||||
pub reset_cutoff_seq: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MemoryRecord {
|
||||
pub id: String,
|
||||
pub scope_kind: String,
|
||||
pub scope_key: String,
|
||||
pub namespace: String,
|
||||
pub memory_key: String,
|
||||
pub content: String,
|
||||
pub source_type: String,
|
||||
pub source_session_id: Option<String>,
|
||||
pub source_message_id: Option<String>,
|
||||
pub source_message_seq: Option<i64>,
|
||||
pub source_channel_name: Option<String>,
|
||||
pub source_chat_id: Option<String>,
|
||||
pub created_at: i64,
|
||||
pub updated_at: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MemoryUpsert {
|
||||
pub scope_kind: String,
|
||||
pub scope_key: String,
|
||||
pub namespace: String,
|
||||
pub memory_key: String,
|
||||
pub content: String,
|
||||
pub source_type: String,
|
||||
pub source_session_id: Option<String>,
|
||||
pub source_message_id: Option<String>,
|
||||
pub source_message_seq: Option<i64>,
|
||||
pub source_channel_name: Option<String>,
|
||||
pub source_chat_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SessionStore {
|
||||
conn: Arc<Mutex<Connection>>,
|
||||
@ -122,6 +155,56 @@ impl SessionStore {
|
||||
ON skill_events(session_id, created_at DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_skill_events_type_created
|
||||
ON skill_events(event_type, created_at DESC);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS memories (
|
||||
id TEXT PRIMARY KEY,
|
||||
scope_kind TEXT NOT NULL,
|
||||
scope_key TEXT NOT NULL,
|
||||
namespace TEXT NOT NULL,
|
||||
memory_key TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
source_type TEXT NOT NULL,
|
||||
source_session_id TEXT,
|
||||
source_message_id TEXT,
|
||||
source_message_seq INTEGER,
|
||||
source_channel_name TEXT,
|
||||
source_chat_id TEXT,
|
||||
created_at INTEGER NOT NULL,
|
||||
updated_at INTEGER NOT NULL,
|
||||
UNIQUE(scope_kind, scope_key, namespace, memory_key)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_scope_updated
|
||||
ON memories(scope_kind, scope_key, updated_at DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_scope_namespace_updated
|
||||
ON memories(scope_kind, scope_key, namespace, updated_at DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_source_session
|
||||
ON memories(source_session_id, updated_at DESC);
|
||||
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5(
|
||||
namespace,
|
||||
memory_key,
|
||||
content,
|
||||
content='memories',
|
||||
content_rowid='rowid'
|
||||
);
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS memories_ai AFTER INSERT ON memories BEGIN
|
||||
INSERT INTO memories_fts(rowid, namespace, memory_key, content)
|
||||
VALUES (new.rowid, new.namespace, new.memory_key, new.content);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS memories_ad AFTER DELETE ON memories BEGIN
|
||||
INSERT INTO memories_fts(memories_fts, rowid, namespace, memory_key, content)
|
||||
VALUES ('delete', old.rowid, old.namespace, old.memory_key, old.content);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS memories_au AFTER UPDATE ON memories BEGIN
|
||||
INSERT INTO memories_fts(memories_fts, rowid, namespace, memory_key, content)
|
||||
VALUES ('delete', old.rowid, old.namespace, old.memory_key, old.content);
|
||||
INSERT INTO memories_fts(rowid, namespace, memory_key, content)
|
||||
VALUES (new.rowid, new.namespace, new.memory_key, new.content);
|
||||
END;
|
||||
",
|
||||
)?;
|
||||
|
||||
@ -417,6 +500,246 @@ impl SessionStore {
|
||||
Ok(events)
|
||||
}
|
||||
|
||||
pub fn put_memory(&self, input: &MemoryUpsert) -> Result<MemoryRecord, StorageError> {
|
||||
let now = current_timestamp();
|
||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||
let tx = conn.unchecked_transaction()?;
|
||||
|
||||
let existing: Option<(String, i64)> = tx
|
||||
.query_row(
|
||||
"
|
||||
SELECT id, created_at
|
||||
FROM memories
|
||||
WHERE scope_kind = ?1 AND scope_key = ?2 AND namespace = ?3 AND memory_key = ?4
|
||||
",
|
||||
params![
|
||||
input.scope_kind,
|
||||
input.scope_key,
|
||||
input.namespace,
|
||||
input.memory_key,
|
||||
],
|
||||
|row| Ok((row.get(0)?, row.get(1)?)),
|
||||
)
|
||||
.optional()?;
|
||||
|
||||
let (id, created_at) = existing
|
||||
.unwrap_or_else(|| (uuid::Uuid::new_v4().to_string(), now));
|
||||
|
||||
tx.execute(
|
||||
"
|
||||
INSERT INTO memories (
|
||||
id, scope_kind, scope_key, namespace, memory_key, content,
|
||||
source_type, source_session_id, source_message_id, source_message_seq,
|
||||
source_channel_name, source_chat_id, created_at, updated_at
|
||||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14)
|
||||
ON CONFLICT(scope_kind, scope_key, namespace, memory_key) DO UPDATE SET
|
||||
content = excluded.content,
|
||||
source_type = excluded.source_type,
|
||||
source_session_id = excluded.source_session_id,
|
||||
source_message_id = excluded.source_message_id,
|
||||
source_message_seq = excluded.source_message_seq,
|
||||
source_channel_name = excluded.source_channel_name,
|
||||
source_chat_id = excluded.source_chat_id,
|
||||
updated_at = excluded.updated_at
|
||||
",
|
||||
params![
|
||||
id,
|
||||
input.scope_kind,
|
||||
input.scope_key,
|
||||
input.namespace,
|
||||
input.memory_key,
|
||||
input.content,
|
||||
input.source_type,
|
||||
input.source_session_id,
|
||||
input.source_message_id,
|
||||
input.source_message_seq,
|
||||
input.source_channel_name,
|
||||
input.source_chat_id,
|
||||
created_at,
|
||||
now,
|
||||
],
|
||||
)?;
|
||||
|
||||
tx.commit()?;
|
||||
drop(conn);
|
||||
|
||||
self.get_memory(
|
||||
&input.scope_kind,
|
||||
&input.scope_key,
|
||||
&input.namespace,
|
||||
&input.memory_key,
|
||||
)?
|
||||
.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
|
||||
}
|
||||
|
||||
pub fn get_memory(
|
||||
&self,
|
||||
scope_kind: &str,
|
||||
scope_key: &str,
|
||||
namespace: &str,
|
||||
memory_key: &str,
|
||||
) -> Result<Option<MemoryRecord>, StorageError> {
|
||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||
let mut stmt = conn.prepare(
|
||||
"
|
||||
SELECT id, scope_kind, scope_key, namespace, memory_key, content,
|
||||
source_type, source_session_id, source_message_id, source_message_seq,
|
||||
source_channel_name, source_chat_id, created_at, updated_at
|
||||
FROM memories
|
||||
WHERE scope_kind = ?1 AND scope_key = ?2 AND namespace = ?3 AND memory_key = ?4
|
||||
",
|
||||
)?;
|
||||
|
||||
stmt.query_row(
|
||||
params![scope_kind, scope_key, namespace, memory_key],
|
||||
map_memory_record,
|
||||
)
|
||||
.optional()
|
||||
.map_err(StorageError::from)
|
||||
}
|
||||
|
||||
pub fn list_memories(
|
||||
&self,
|
||||
scope_kind: &str,
|
||||
scope_key: &str,
|
||||
namespace: Option<&str>,
|
||||
limit: usize,
|
||||
) -> Result<Vec<MemoryRecord>, StorageError> {
|
||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||
let limit = limit.max(1) as i64;
|
||||
let mut memories = Vec::new();
|
||||
|
||||
if let Some(namespace) = namespace {
|
||||
let mut stmt = conn.prepare(
|
||||
"
|
||||
SELECT id, scope_kind, scope_key, namespace, memory_key, content,
|
||||
source_type, source_session_id, source_message_id, source_message_seq,
|
||||
source_channel_name, source_chat_id, created_at, updated_at
|
||||
FROM memories
|
||||
WHERE scope_kind = ?1 AND scope_key = ?2 AND namespace = ?3
|
||||
ORDER BY updated_at DESC, created_at DESC
|
||||
LIMIT ?4
|
||||
",
|
||||
)?;
|
||||
let rows = stmt.query_map(params![scope_kind, scope_key, namespace, limit], map_memory_record)?;
|
||||
for row in rows {
|
||||
memories.push(row?);
|
||||
}
|
||||
} else {
|
||||
let mut stmt = conn.prepare(
|
||||
"
|
||||
SELECT id, scope_kind, scope_key, namespace, memory_key, content,
|
||||
source_type, source_session_id, source_message_id, source_message_seq,
|
||||
source_channel_name, source_chat_id, created_at, updated_at
|
||||
FROM memories
|
||||
WHERE scope_kind = ?1 AND scope_key = ?2
|
||||
ORDER BY updated_at DESC, created_at DESC
|
||||
LIMIT ?3
|
||||
",
|
||||
)?;
|
||||
let rows = stmt.query_map(params![scope_kind, scope_key, limit], map_memory_record)?;
|
||||
for row in rows {
|
||||
memories.push(row?);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(memories)
|
||||
}
|
||||
|
||||
pub fn update_memory(
|
||||
&self,
|
||||
input: &MemoryUpsert,
|
||||
) -> Result<Option<MemoryRecord>, StorageError> {
|
||||
if self
|
||||
.get_memory(
|
||||
&input.scope_kind,
|
||||
&input.scope_key,
|
||||
&input.namespace,
|
||||
&input.memory_key,
|
||||
)?
|
||||
.is_none()
|
||||
{
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
self.put_memory(input).map(Some)
|
||||
}
|
||||
|
||||
pub fn delete_memory(
|
||||
&self,
|
||||
scope_kind: &str,
|
||||
scope_key: &str,
|
||||
namespace: &str,
|
||||
memory_key: &str,
|
||||
) -> Result<bool, StorageError> {
|
||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||
let changed = conn.execute(
|
||||
"
|
||||
DELETE FROM memories
|
||||
WHERE scope_kind = ?1 AND scope_key = ?2 AND namespace = ?3 AND memory_key = ?4
|
||||
",
|
||||
params![scope_kind, scope_key, namespace, memory_key],
|
||||
)?;
|
||||
Ok(changed > 0)
|
||||
}
|
||||
|
||||
pub fn search_memories(
|
||||
&self,
|
||||
scope_kind: &str,
|
||||
scope_key: &str,
|
||||
query: &str,
|
||||
namespace: Option<&str>,
|
||||
limit: usize,
|
||||
) -> Result<Vec<MemoryRecord>, StorageError> {
|
||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||
let limit = limit.max(1) as i64;
|
||||
let query = quote_fts_query(query);
|
||||
let mut memories = Vec::new();
|
||||
|
||||
if let Some(namespace) = namespace {
|
||||
let mut stmt = conn.prepare(
|
||||
"
|
||||
SELECT m.id, m.scope_kind, m.scope_key, m.namespace, m.memory_key, m.content,
|
||||
m.source_type, m.source_session_id, m.source_message_id, m.source_message_seq,
|
||||
m.source_channel_name, m.source_chat_id, m.created_at, m.updated_at
|
||||
FROM memories_fts f
|
||||
JOIN memories m ON m.rowid = f.rowid
|
||||
WHERE memories_fts MATCH ?1
|
||||
AND m.scope_kind = ?2
|
||||
AND m.scope_key = ?3
|
||||
AND m.namespace = ?4
|
||||
ORDER BY bm25(memories_fts), m.updated_at DESC
|
||||
LIMIT ?5
|
||||
",
|
||||
)?;
|
||||
let rows = stmt.query_map(params![query, scope_kind, scope_key, namespace, limit], map_memory_record)?;
|
||||
for row in rows {
|
||||
memories.push(row?);
|
||||
}
|
||||
} else {
|
||||
let mut stmt = conn.prepare(
|
||||
"
|
||||
SELECT m.id, m.scope_kind, m.scope_key, m.namespace, m.memory_key, m.content,
|
||||
m.source_type, m.source_session_id, m.source_message_id, m.source_message_seq,
|
||||
m.source_channel_name, m.source_chat_id, m.created_at, m.updated_at
|
||||
FROM memories_fts f
|
||||
JOIN memories m ON m.rowid = f.rowid
|
||||
WHERE memories_fts MATCH ?1
|
||||
AND m.scope_kind = ?2
|
||||
AND m.scope_key = ?3
|
||||
ORDER BY bm25(memories_fts), m.updated_at DESC
|
||||
LIMIT ?4
|
||||
",
|
||||
)?;
|
||||
let rows = stmt.query_map(params![query, scope_kind, scope_key, limit], map_memory_record)?;
|
||||
for row in rows {
|
||||
memories.push(row?);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(memories)
|
||||
}
|
||||
|
||||
pub fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError> {
|
||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||
let cutoff_seq = active_reset_cutoff(&conn, session_id)?;
|
||||
@ -479,6 +802,25 @@ fn map_skill_event_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<SkillEven
|
||||
})
|
||||
}
|
||||
|
||||
fn map_memory_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<MemoryRecord> {
|
||||
Ok(MemoryRecord {
|
||||
id: row.get(0)?,
|
||||
scope_kind: row.get(1)?,
|
||||
scope_key: row.get(2)?,
|
||||
namespace: row.get(3)?,
|
||||
memory_key: row.get(4)?,
|
||||
content: row.get(5)?,
|
||||
source_type: row.get(6)?,
|
||||
source_session_id: row.get(7)?,
|
||||
source_message_id: row.get(8)?,
|
||||
source_message_seq: row.get(9)?,
|
||||
source_channel_name: row.get(10)?,
|
||||
source_chat_id: row.get(11)?,
|
||||
created_at: row.get(12)?,
|
||||
updated_at: row.get(13)?,
|
||||
})
|
||||
}
|
||||
|
||||
fn ensure_sessions_schema(conn: &Connection) -> Result<(), StorageError> {
|
||||
if !has_column(conn, "sessions", "reset_cutoff_seq")? {
|
||||
conn.execute(
|
||||
@ -580,6 +922,10 @@ fn current_timestamp() -> i64 {
|
||||
.as_millis() as i64
|
||||
}
|
||||
|
||||
fn quote_fts_query(query: &str) -> String {
|
||||
format!("\"{}\"", query.replace('"', "\"\""))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@ -797,4 +1143,101 @@ mod tests {
|
||||
assert_eq!(session_events[0].skill_name.as_deref(), Some("code-review"));
|
||||
assert_eq!(session_events[0].payload["source"], "project");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_roundtrip_with_source_fields() {
|
||||
let store = SessionStore::in_memory().unwrap();
|
||||
|
||||
let saved = store
|
||||
.put_memory(&MemoryUpsert {
|
||||
scope_kind: "user".to_string(),
|
||||
scope_key: "feishu:user-1".to_string(),
|
||||
namespace: "profile".to_string(),
|
||||
memory_key: "language".to_string(),
|
||||
content: "Rust".to_string(),
|
||||
source_type: "message".to_string(),
|
||||
source_session_id: Some("feishu:chat-1".to_string()),
|
||||
source_message_id: Some("msg-1".to_string()),
|
||||
source_message_seq: Some(7),
|
||||
source_channel_name: Some("feishu".to_string()),
|
||||
source_chat_id: Some("chat-1".to_string()),
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(saved.content, "Rust");
|
||||
assert_eq!(saved.source_type, "message");
|
||||
assert_eq!(saved.source_session_id.as_deref(), Some("feishu:chat-1"));
|
||||
assert_eq!(saved.source_message_id.as_deref(), Some("msg-1"));
|
||||
assert_eq!(saved.source_message_seq, Some(7));
|
||||
|
||||
let fetched = store
|
||||
.get_memory("user", "feishu:user-1", "profile", "language")
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(fetched.id, saved.id);
|
||||
assert_eq!(fetched.source_chat_id.as_deref(), Some("chat-1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_fts_tracks_upsert_and_delete() {
|
||||
let store = SessionStore::in_memory().unwrap();
|
||||
|
||||
store
|
||||
.put_memory(&MemoryUpsert {
|
||||
scope_kind: "user".to_string(),
|
||||
scope_key: "feishu:user-1".to_string(),
|
||||
namespace: "preferences".to_string(),
|
||||
memory_key: "editor".to_string(),
|
||||
content: "Prefers rust-analyzer and cargo test output".to_string(),
|
||||
source_type: "message".to_string(),
|
||||
source_session_id: Some("feishu:chat-2".to_string()),
|
||||
source_message_id: Some("msg-2".to_string()),
|
||||
source_message_seq: Some(3),
|
||||
source_channel_name: Some("feishu".to_string()),
|
||||
source_chat_id: Some("chat-2".to_string()),
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let hits = store
|
||||
.search_memories("user", "feishu:user-1", "rust-analyzer", None, 10)
|
||||
.unwrap();
|
||||
assert_eq!(hits.len(), 1);
|
||||
assert_eq!(hits[0].memory_key, "editor");
|
||||
|
||||
store
|
||||
.put_memory(&MemoryUpsert {
|
||||
scope_kind: "user".to_string(),
|
||||
scope_key: "feishu:user-1".to_string(),
|
||||
namespace: "preferences".to_string(),
|
||||
memory_key: "editor".to_string(),
|
||||
content: "Prefers clippy diagnostics".to_string(),
|
||||
source_type: "message".to_string(),
|
||||
source_session_id: Some("feishu:chat-3".to_string()),
|
||||
source_message_id: Some("msg-3".to_string()),
|
||||
source_message_seq: Some(4),
|
||||
source_channel_name: Some("feishu".to_string()),
|
||||
source_chat_id: Some("chat-3".to_string()),
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let old_hits = store
|
||||
.search_memories("user", "feishu:user-1", "rust-analyzer", None, 10)
|
||||
.unwrap();
|
||||
assert!(old_hits.is_empty());
|
||||
|
||||
let new_hits = store
|
||||
.search_memories("user", "feishu:user-1", "clippy", None, 10)
|
||||
.unwrap();
|
||||
assert_eq!(new_hits.len(), 1);
|
||||
|
||||
let deleted = store
|
||||
.delete_memory("user", "feishu:user-1", "preferences", "editor")
|
||||
.unwrap();
|
||||
assert!(deleted);
|
||||
|
||||
let hits_after_delete = store
|
||||
.search_memories("user", "feishu:user-1", "clippy", None, 10)
|
||||
.unwrap();
|
||||
assert!(hits_after_delete.is_empty());
|
||||
}
|
||||
}
|
||||
313
src/tools/memory_manage.rs
Normal file
313
src/tools/memory_manage.rs
Normal file
@ -0,0 +1,313 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::storage::{MemoryRecord, MemoryUpsert, SessionStore};
|
||||
use crate::tools::traits::{Tool, ToolContext, ToolResult};
|
||||
|
||||
pub struct MemoryManageTool {
|
||||
store: Arc<SessionStore>,
|
||||
}
|
||||
|
||||
impl MemoryManageTool {
|
||||
pub fn new(store: Arc<SessionStore>) -> Self {
|
||||
Self { store }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for MemoryManageTool {
|
||||
fn name(&self) -> &str {
|
||||
"memory_manage"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Manage user memories stored in SQLite. Supports actions: list, get, put, update, delete. Memories are scoped to the current channel and sender, and record the originating session/message when available."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["list", "get", "put", "update", "delete"],
|
||||
"description": "Management action to perform"
|
||||
},
|
||||
"namespace": {
|
||||
"type": "string",
|
||||
"description": "Memory namespace, such as profile, preferences, or tasks"
|
||||
},
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": "Memory key within the namespace"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Memory content for put/update"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of memories to list",
|
||||
"minimum": 1,
|
||||
"default": 20
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, _args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
Ok(error_result("memory_manage requires tool context"))
|
||||
}
|
||||
|
||||
async fn execute_with_context(
|
||||
&self,
|
||||
context: &ToolContext,
|
||||
args: serde_json::Value,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
let action = match args.get("action").and_then(|value| value.as_str()) {
|
||||
Some(action) => action,
|
||||
None => return Ok(error_result("Missing required parameter: action")),
|
||||
};
|
||||
|
||||
let scope_key = match scope_key_from_context(context) {
|
||||
Ok(scope_key) => scope_key,
|
||||
Err(result) => return Ok(result),
|
||||
};
|
||||
|
||||
let namespace = args.get("namespace").and_then(|value| value.as_str());
|
||||
let key = args.get("key").and_then(|value| value.as_str());
|
||||
|
||||
let payload = match action {
|
||||
"list" => {
|
||||
let limit = args
|
||||
.get("limit")
|
||||
.and_then(|value| value.as_u64())
|
||||
.unwrap_or(20) as usize;
|
||||
let memories = self
|
||||
.store
|
||||
.list_memories("user", &scope_key, namespace, limit)?;
|
||||
json!({
|
||||
"count": memories.len(),
|
||||
"memories": memories.into_iter().map(memory_to_json).collect::<Vec<_>>()
|
||||
})
|
||||
}
|
||||
"get" => {
|
||||
let namespace = match namespace {
|
||||
Some(namespace) => namespace,
|
||||
None => return Ok(error_result("Missing required parameter: namespace")),
|
||||
};
|
||||
let key = match key {
|
||||
Some(key) => key,
|
||||
None => return Ok(error_result("Missing required parameter: key")),
|
||||
};
|
||||
|
||||
match self.store.get_memory("user", &scope_key, namespace, key)? {
|
||||
Some(memory) => memory_to_json(memory),
|
||||
None => return Ok(error_result(&format!("memory '{}.{}' not found", namespace, key))),
|
||||
}
|
||||
}
|
||||
"put" => {
|
||||
let input = match build_memory_upsert(context, &scope_key, &args, true) {
|
||||
Ok(input) => input,
|
||||
Err(result) => return Ok(result),
|
||||
};
|
||||
memory_to_json(self.store.put_memory(&input)?)
|
||||
}
|
||||
"update" => {
|
||||
let input = match build_memory_upsert(context, &scope_key, &args, false) {
|
||||
Ok(input) => input,
|
||||
Err(result) => return Ok(result),
|
||||
};
|
||||
|
||||
match self.store.update_memory(&input)? {
|
||||
Some(memory) => memory_to_json(memory),
|
||||
None => {
|
||||
return Ok(error_result(&format!(
|
||||
"memory '{}.{}' not found",
|
||||
input.namespace, input.memory_key
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
"delete" => {
|
||||
let namespace = match namespace {
|
||||
Some(namespace) => namespace,
|
||||
None => return Ok(error_result("Missing required parameter: namespace")),
|
||||
};
|
||||
let key = match key {
|
||||
Some(key) => key,
|
||||
None => return Ok(error_result("Missing required parameter: key")),
|
||||
};
|
||||
|
||||
let deleted = self.store.delete_memory("user", &scope_key, namespace, key)?;
|
||||
if !deleted {
|
||||
return Ok(error_result(&format!("memory '{}.{}' not found", namespace, key)));
|
||||
}
|
||||
|
||||
json!({
|
||||
"status": "deleted",
|
||||
"namespace": namespace,
|
||||
"key": key,
|
||||
})
|
||||
}
|
||||
_ => return Ok(error_result("Unsupported action")),
|
||||
};
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&payload)?,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn build_memory_upsert(
|
||||
context: &ToolContext,
|
||||
scope_key: &str,
|
||||
args: &serde_json::Value,
|
||||
allow_put: bool,
|
||||
) -> Result<MemoryUpsert, ToolResult> {
|
||||
let namespace = match args.get("namespace").and_then(|value| value.as_str()) {
|
||||
Some(namespace) => namespace,
|
||||
None => return Err(error_result("Missing required parameter: namespace")),
|
||||
};
|
||||
let key = match args.get("key").and_then(|value| value.as_str()) {
|
||||
Some(key) => key,
|
||||
None => return Err(error_result("Missing required parameter: key")),
|
||||
};
|
||||
let content = match args.get("content").and_then(|value| value.as_str()) {
|
||||
Some(content) => content,
|
||||
None => return Err(error_result("Missing required parameter: content")),
|
||||
};
|
||||
|
||||
let source_type = if context.message_id.is_some() {
|
||||
"message"
|
||||
} else if allow_put {
|
||||
"manual"
|
||||
} else {
|
||||
"session"
|
||||
};
|
||||
|
||||
Ok(MemoryUpsert {
|
||||
scope_kind: "user".to_string(),
|
||||
scope_key: scope_key.to_string(),
|
||||
namespace: namespace.to_string(),
|
||||
memory_key: key.to_string(),
|
||||
content: content.to_string(),
|
||||
source_type: source_type.to_string(),
|
||||
source_session_id: context.session_id.clone(),
|
||||
source_message_id: context.message_id.clone(),
|
||||
source_message_seq: context.message_seq,
|
||||
source_channel_name: context.channel_name.clone(),
|
||||
source_chat_id: context.chat_id.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn scope_key_from_context(context: &ToolContext) -> Result<String, ToolResult> {
|
||||
let channel_name = context
|
||||
.channel_name
|
||||
.as_deref()
|
||||
.ok_or_else(|| error_result("memory_manage requires channel_name in tool context"))?;
|
||||
let sender_id = context
|
||||
.sender_id
|
||||
.as_deref()
|
||||
.ok_or_else(|| error_result("memory_manage requires sender_id in tool context"))?;
|
||||
Ok(format!("{}:{}", channel_name, sender_id))
|
||||
}
|
||||
|
||||
fn memory_to_json(memory: MemoryRecord) -> serde_json::Value {
|
||||
json!({
|
||||
"id": memory.id,
|
||||
"scope_kind": memory.scope_kind,
|
||||
"scope_key": memory.scope_key,
|
||||
"namespace": memory.namespace,
|
||||
"key": memory.memory_key,
|
||||
"content": memory.content,
|
||||
"source_type": memory.source_type,
|
||||
"source_session_id": memory.source_session_id,
|
||||
"source_message_id": memory.source_message_id,
|
||||
"source_message_seq": memory.source_message_seq,
|
||||
"source_channel_name": memory.source_channel_name,
|
||||
"source_chat_id": memory.source_chat_id,
|
||||
"created_at": memory.created_at,
|
||||
"updated_at": memory.updated_at,
|
||||
})
|
||||
}
|
||||
|
||||
fn error_result(message: &str) -> ToolResult {
|
||||
ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(message.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_memory_manage_put_and_get() {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let tool = MemoryManageTool::new(store);
|
||||
let context = ToolContext {
|
||||
channel_name: Some("feishu".to_string()),
|
||||
sender_id: Some("user-1".to_string()),
|
||||
chat_id: Some("chat-1".to_string()),
|
||||
session_id: Some("feishu:chat-1".to_string()),
|
||||
message_id: Some("msg-1".to_string()),
|
||||
message_seq: Some(1),
|
||||
};
|
||||
|
||||
let put = tool
|
||||
.execute_with_context(
|
||||
&context,
|
||||
json!({
|
||||
"action": "put",
|
||||
"namespace": "profile",
|
||||
"key": "language",
|
||||
"content": "Rust"
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(put.success);
|
||||
|
||||
let get = tool
|
||||
.execute_with_context(
|
||||
&context,
|
||||
json!({
|
||||
"action": "get",
|
||||
"namespace": "profile",
|
||||
"key": "language"
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(get.success);
|
||||
assert!(get.output.contains("Rust"));
|
||||
assert!(get.output.contains("msg-1"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_memory_manage_requires_context() {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let tool = MemoryManageTool::new(store);
|
||||
|
||||
let result = tool
|
||||
.execute_with_context(
|
||||
&ToolContext::default(),
|
||||
json!({
|
||||
"action": "list"
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("channel_name"));
|
||||
}
|
||||
}
|
||||
@ -4,6 +4,7 @@ pub mod file_edit;
|
||||
pub mod file_read;
|
||||
pub mod file_write;
|
||||
pub mod http_request;
|
||||
pub mod memory_manage;
|
||||
pub mod registry;
|
||||
pub mod schema;
|
||||
pub mod skill_manage;
|
||||
@ -16,8 +17,9 @@ pub use file_edit::FileEditTool;
|
||||
pub use file_read::FileReadTool;
|
||||
pub use file_write::FileWriteTool;
|
||||
pub use http_request::HttpRequestTool;
|
||||
pub use memory_manage::MemoryManageTool;
|
||||
pub use registry::ToolRegistry;
|
||||
pub use schema::{CleaningStrategy, SchemaCleanr};
|
||||
pub use skill_manage::{SkillListTool, SkillManageTool};
|
||||
pub use traits::{Tool, ToolResult};
|
||||
pub use traits::{Tool, ToolContext, ToolResult};
|
||||
pub use web_fetch::WebFetchTool;
|
||||
|
||||
@ -7,6 +7,16 @@ pub struct ToolResult {
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ToolContext {
|
||||
pub channel_name: Option<String>,
|
||||
pub sender_id: Option<String>,
|
||||
pub chat_id: Option<String>,
|
||||
pub session_id: Option<String>,
|
||||
pub message_id: Option<String>,
|
||||
pub message_seq: Option<i64>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Tool: Send + Sync + 'static {
|
||||
fn name(&self) -> &str;
|
||||
@ -14,6 +24,14 @@ pub trait Tool: Send + Sync + 'static {
|
||||
fn parameters_schema(&self) -> serde_json::Value;
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult>;
|
||||
|
||||
async fn execute_with_context(
|
||||
&self,
|
||||
_context: &ToolContext,
|
||||
args: serde_json::Value,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
self.execute(args).await
|
||||
}
|
||||
|
||||
/// Whether this tool is side-effect free and safe to parallelize.
|
||||
fn read_only(&self) -> bool {
|
||||
false
|
||||
|
||||
@ -117,3 +117,55 @@ fn test_clear_history_with_session_id_serialization() {
|
||||
assert!(json.contains(r#""type":"clear_history""#));
|
||||
assert!(json.contains(r#""session_id":"session-1""#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_outbound_serialization() {
|
||||
let msg = WsOutbound::ToolCall {
|
||||
id: "msg-1".to_string(),
|
||||
tool_call_id: "call-1".to_string(),
|
||||
tool_name: "calculator".to_string(),
|
||||
arguments: serde_json::json!({"expression": "1 + 1"}),
|
||||
content: "调用工具: calculator".to_string(),
|
||||
role: "assistant".to_string(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
assert!(json.contains(r#""type":"tool_call""#));
|
||||
assert!(json.contains(r#""tool_name":"calculator""#));
|
||||
assert!(json.contains(r#""expression":"1 + 1""#));
|
||||
|
||||
let decoded: WsOutbound = serde_json::from_str(&json).unwrap();
|
||||
match decoded {
|
||||
WsOutbound::ToolCall { tool_call_id, tool_name, arguments, .. } => {
|
||||
assert_eq!(tool_call_id, "call-1");
|
||||
assert_eq!(tool_name, "calculator");
|
||||
assert_eq!(arguments["expression"], "1 + 1");
|
||||
}
|
||||
other => panic!("unexpected decoded variant: {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_result_outbound_serialization() {
|
||||
let msg = WsOutbound::ToolResult {
|
||||
id: "msg-2".to_string(),
|
||||
tool_call_id: "call-1".to_string(),
|
||||
tool_name: "calculator".to_string(),
|
||||
content: "工具结果: calculator\n\n2".to_string(),
|
||||
role: "tool".to_string(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
assert!(json.contains(r#""type":"tool_result""#));
|
||||
assert!(json.contains(r#""tool_name":"calculator""#));
|
||||
|
||||
let decoded: WsOutbound = serde_json::from_str(&json).unwrap();
|
||||
match decoded {
|
||||
WsOutbound::ToolResult { tool_call_id, tool_name, content, .. } => {
|
||||
assert_eq!(tool_call_id, "call-1");
|
||||
assert_eq!(tool_name, "calculator");
|
||||
assert!(content.contains('2'));
|
||||
}
|
||||
other => panic!("unexpected decoded variant: {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user