Compare commits
No commits in common. "65abf017a1ae4308e6fa263b148f730f2c8197d4" and "9cda2ab8d5a752d2fa9dafdf791c15f380c12554" have entirely different histories.
65abf017a1
...
9cda2ab8d5
@ -1,15 +1,15 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use crate::bus::message::ContentBlock;
|
use crate::bus::message::ContentBlock;
|
||||||
use crate::bus::ChatMessage;
|
use crate::bus::ChatMessage;
|
||||||
use crate::bus::message::ToolMessageState;
|
|
||||||
use crate::config::LLMProviderConfig;
|
use crate::config::LLMProviderConfig;
|
||||||
use crate::observability::{
|
use crate::observability::{
|
||||||
truncate_args, Observer, ObserverEvent, ToolExecutionOutcome, ToolExecutionState,
|
truncate_args, Observer, ObserverEvent, ToolExecutionOutcome,
|
||||||
};
|
};
|
||||||
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall};
|
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall};
|
||||||
use crate::skills::SkillRuntime;
|
use crate::skills::SkillRuntime;
|
||||||
use crate::storage::SessionStore;
|
use crate::storage::SessionStore;
|
||||||
use crate::tools::{ToolContext, ToolRegistry};
|
use crate::tools::{ToolContext, ToolRegistry};
|
||||||
|
use serde::Deserialize;
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
use std::hash::{Hash, Hasher};
|
use std::hash::{Hash, Hasher};
|
||||||
use std::io::Read;
|
use std::io::Read;
|
||||||
@ -21,9 +21,12 @@ use std::time::Instant;
|
|||||||
const MAX_TOOL_RESULT_CHARS: usize = 16_000;
|
const MAX_TOOL_RESULT_CHARS: usize = 16_000;
|
||||||
/// Minimum characters to keep when truncating
|
/// Minimum characters to keep when truncating
|
||||||
const TRUNCATION_SUFFIX_LEN: usize = 200;
|
const TRUNCATION_SUFFIX_LEN: usize = 200;
|
||||||
const MEMORY_TOOL_USAGE_SYSTEM_PROMPT: &str = "你可以在处理任务过程中使用长期记忆工具。读取记忆时,优先使用 memory_search:当你需要用户长期偏好、稳定事实、历史决策、持续任务上下文时,先 search;已知 namespace/key 时可用 get;需要浏览最近记忆时可用 list。写入或修改记忆时,再使用 memory_manage。仅在遇到高价值且未来仍有用的信息时写入记忆:用户长期偏好、稳定事实、用户对你的纠正、持续任务/项目上下文、明确决策。不要保存一次性工具结果、临时列表、敏感凭证或不确定推测。写入时优先使用规范 namespace:preferences、profile、tasks、decisions,并优先调用 memory_manage(action='put');同一 namespace/key 可直接覆盖更新。检索时应提供 queries 数组,尽量同时放入中文关键词、英文别名,以及可能的 snake_case memory_key 词,例如 queries=['email', '邮件', 'email_folder_preference']。";
|
const MEMORY_AUTOSAVE_SYSTEM_PROMPT: &str = "你可以在处理任务过程中使用 memory_manage 工具维护长期记忆。仅在遇到高价值且未来仍有用的信息时写入记忆:用户长期偏好、稳定事实、用户对你的纠正、持续任务/项目上下文、明确决策。不要保存一次性工具结果、临时列表、敏感凭证或不确定推测。写入时优先使用规范 namespace:preferences、profile、tasks、decisions。若需要写入,优先调用 memory_manage(action='put');同一 namespace/key 可直接覆盖更新。搜索记忆时,优先使用 memory_manage(action='search'),并尽量同时提供中文关键词、英文别名,以及可能的 snake_case memory_key 词,例如 email / 邮件 / email_folder_preference。";
|
||||||
const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
|
const MEMORY_EXTRACTION_SYSTEM_PROMPT: &str = "你负责从一小段对话中提取值得长期记忆的信息。只返回 JSON 数组,不要输出解释或 Markdown。数组元素格式为 {\"namespace\": string, \"key\": string, \"content\": string}。只有在内容属于长期偏好、稳定事实、用户纠正、持续任务上下文或明确决策时才输出;否则返回 []。namespace 只能使用 preferences、profile、tasks、decisions。key 用简短稳定的 snake_case 英文标识。不要保存一次性工具结果、临时列表、敏感信息或猜测。最多输出 2 条。";
|
||||||
const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str = "工具已经启动并进入等待用户操作的状态。请先完成外部操作,完成后直接告诉我继续。";
|
const MEMORY_EXTRACTION_REASONING_EFFORT: &str = "none";
|
||||||
|
const MEMORY_EXTRACTION_MAX_TOKENS: u32 = 192;
|
||||||
|
const MEMORY_EXTRACTION_CONTEXT_MESSAGES: usize = 4;
|
||||||
|
const MEMORY_EXTRACTION_MAX_CANDIDATES: usize = 2;
|
||||||
|
|
||||||
/// Build content blocks from text and media paths
|
/// Build content blocks from text and media paths
|
||||||
fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock> {
|
fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock> {
|
||||||
@ -69,36 +72,28 @@ fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error
|
|||||||
/// Truncate tool result if it exceeds MAX_TOOL_RESULT_CHARS.
|
/// Truncate tool result if it exceeds MAX_TOOL_RESULT_CHARS.
|
||||||
/// Preserves the end of the output as it often contains the conclusion/useful result.
|
/// Preserves the end of the output as it often contains the conclusion/useful result.
|
||||||
fn truncate_tool_result(output: &str) -> String {
|
fn truncate_tool_result(output: &str) -> String {
|
||||||
let char_count = output.chars().count();
|
if output.len() <= MAX_TOOL_RESULT_CHARS {
|
||||||
if char_count <= MAX_TOOL_RESULT_CHARS {
|
|
||||||
return output.to_string();
|
return output.to_string();
|
||||||
}
|
}
|
||||||
|
|
||||||
let truncated_start_len = char_count.saturating_sub(TRUNCATION_SUFFIX_LEN);
|
let truncated_start_len = output.len().saturating_sub(TRUNCATION_SUFFIX_LEN);
|
||||||
if truncated_start_len > MAX_TOOL_RESULT_CHARS {
|
if truncated_start_len > MAX_TOOL_RESULT_CHARS {
|
||||||
// Even after removing suffix, still too long - take from beginning
|
// Even after removing suffix, still too long - take from beginning
|
||||||
let head_len = MAX_TOOL_RESULT_CHARS - 100;
|
|
||||||
let head: String = output.chars().take(head_len).collect();
|
|
||||||
format!(
|
format!(
|
||||||
"{}...\n\n[Output truncated - {} characters removed]",
|
"{}...\n\n[Output truncated - {} characters removed]",
|
||||||
head,
|
&output[..MAX_TOOL_RESULT_CHARS - 100],
|
||||||
char_count - MAX_TOOL_RESULT_CHARS + 100
|
output.len() - MAX_TOOL_RESULT_CHARS + 100
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
// Keep most of the end which usually contains the useful result
|
// Keep most of the end which usually contains the useful result
|
||||||
let tail: String = output.chars().skip(truncated_start_len).collect();
|
|
||||||
format!(
|
format!(
|
||||||
"...\n\n[Output truncated - {} characters removed]\n\n{}",
|
"...\n\n[Output truncated - {} characters removed]\n\n{}",
|
||||||
truncated_start_len,
|
truncated_start_len,
|
||||||
tail
|
&output[truncated_start_len..]
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_pending_tool_output(output: &str) -> Option<String> {
|
|
||||||
output.strip_prefix(PENDING_USER_ACTION_MARKER).map(|rest| rest.trim().to_string())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Loop detection result.
|
/// Loop detection result.
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
enum LoopDetectionResult {
|
enum LoopDetectionResult {
|
||||||
@ -251,6 +246,19 @@ pub struct AgentProcessResult {
|
|||||||
pub emitted_messages: Vec<ChatMessage>,
|
pub emitted_messages: Vec<ChatMessage>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
|
||||||
|
struct MemoryCandidate {
|
||||||
|
namespace: String,
|
||||||
|
key: String,
|
||||||
|
content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct MemoryCandidateEnvelope {
|
||||||
|
#[serde(default)]
|
||||||
|
memories: Vec<MemoryCandidate>,
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait EmittedMessageHandler: Send + Sync + 'static {
|
pub trait EmittedMessageHandler: Send + Sync + 'static {
|
||||||
async fn handle(&self, message: ChatMessage);
|
async fn handle(&self, message: ChatMessage);
|
||||||
@ -358,6 +366,7 @@ impl AgentLoop {
|
|||||||
// Track tool calls for loop detection
|
// Track tool calls for loop detection
|
||||||
let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default());
|
let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default());
|
||||||
let mut emitted_messages = Vec::new();
|
let mut emitted_messages = Vec::new();
|
||||||
|
let mut memory_write_occurred = false;
|
||||||
|
|
||||||
for iteration in 0..self.max_iterations {
|
for iteration in 0..self.max_iterations {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
@ -368,7 +377,7 @@ impl AgentLoop {
|
|||||||
if let Some(skill_prompt) = self.skills.system_index_prompt() {
|
if let Some(skill_prompt) = self.skills.system_index_prompt() {
|
||||||
messages_for_llm.push(Message::system(skill_prompt));
|
messages_for_llm.push(Message::system(skill_prompt));
|
||||||
}
|
}
|
||||||
messages_for_llm.push(Message::system(MEMORY_TOOL_USAGE_SYSTEM_PROMPT));
|
messages_for_llm.push(Message::system(MEMORY_AUTOSAVE_SYSTEM_PROMPT));
|
||||||
messages_for_llm.extend(messages.iter().map(chat_message_to_llm_message));
|
messages_for_llm.extend(messages.iter().map(chat_message_to_llm_message));
|
||||||
|
|
||||||
// Build request
|
// Build request
|
||||||
@ -403,6 +412,9 @@ impl AgentLoop {
|
|||||||
// If no tool calls, this is the final response
|
// If no tool calls, this is the final response
|
||||||
if response.tool_calls.is_empty() {
|
if response.tool_calls.is_empty() {
|
||||||
let assistant_message = ChatMessage::assistant(response.content);
|
let assistant_message = ChatMessage::assistant(response.content);
|
||||||
|
if !memory_write_occurred {
|
||||||
|
self.maybe_extract_and_store_memories(&messages, &assistant_message).await?;
|
||||||
|
}
|
||||||
emitted_messages.push(assistant_message.clone());
|
emitted_messages.push(assistant_message.clone());
|
||||||
return Ok(AgentProcessResult {
|
return Ok(AgentProcessResult {
|
||||||
final_response: assistant_message,
|
final_response: assistant_message,
|
||||||
@ -424,6 +436,14 @@ impl AgentLoop {
|
|||||||
|
|
||||||
// Execute tools and add results to messages
|
// Execute tools and add results to messages
|
||||||
let tool_results = self.execute_tools(&response.tool_calls).await;
|
let tool_results = self.execute_tools(&response.tool_calls).await;
|
||||||
|
if response
|
||||||
|
.tool_calls
|
||||||
|
.iter()
|
||||||
|
.zip(tool_results.iter())
|
||||||
|
.any(|(tool_call, result)| did_successfully_write_memory(tool_call, result))
|
||||||
|
{
|
||||||
|
memory_write_occurred = true;
|
||||||
|
}
|
||||||
|
|
||||||
for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) {
|
for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) {
|
||||||
// Log function call with name and arguments
|
// Log function call with name and arguments
|
||||||
@ -447,29 +467,19 @@ impl AgentLoop {
|
|||||||
"Loop warning: {}",
|
"Loop warning: {}",
|
||||||
msg
|
msg
|
||||||
);
|
);
|
||||||
let tool_message = ChatMessage::tool_with_state(
|
let tool_message = ChatMessage::tool(
|
||||||
tool_call.id.clone(),
|
tool_call.id.clone(),
|
||||||
tool_call.name.clone(),
|
tool_call.name.clone(),
|
||||||
format!("{}\n\n[上一条结果]\n{}", msg, truncated_output),
|
format!("{}\n\n[上一条结果]\n{}", msg, truncated_output),
|
||||||
if result.state == ToolExecutionState::PendingUserAction {
|
|
||||||
ToolMessageState::PendingUserAction
|
|
||||||
} else {
|
|
||||||
ToolMessageState::Completed
|
|
||||||
},
|
|
||||||
);
|
);
|
||||||
messages.push(tool_message.clone());
|
messages.push(tool_message.clone());
|
||||||
emitted_messages.push(tool_message);
|
emitted_messages.push(tool_message);
|
||||||
}
|
}
|
||||||
LoopDetectionResult::Ok => {
|
LoopDetectionResult::Ok => {
|
||||||
let tool_message = ChatMessage::tool_with_state(
|
let tool_message = ChatMessage::tool(
|
||||||
tool_call.id.clone(),
|
tool_call.id.clone(),
|
||||||
tool_call.name.clone(),
|
tool_call.name.clone(),
|
||||||
truncated_output,
|
truncated_output,
|
||||||
if result.state == ToolExecutionState::PendingUserAction {
|
|
||||||
ToolMessageState::PendingUserAction
|
|
||||||
} else {
|
|
||||||
ToolMessageState::Completed
|
|
||||||
},
|
|
||||||
);
|
);
|
||||||
messages.push(tool_message.clone());
|
messages.push(tool_message.clone());
|
||||||
emitted_messages.push(tool_message);
|
emitted_messages.push(tool_message);
|
||||||
@ -477,29 +487,6 @@ impl AgentLoop {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some((tool_call, pending_result)) = response
|
|
||||||
.tool_calls
|
|
||||||
.iter()
|
|
||||||
.zip(tool_results.iter())
|
|
||||||
.find(|(_, result)| result.state == ToolExecutionState::PendingUserAction)
|
|
||||||
{
|
|
||||||
let assistant_message = ChatMessage::assistant(format!(
|
|
||||||
"{}\n\n当前等待中的工具: {}",
|
|
||||||
pending_result
|
|
||||||
.output
|
|
||||||
.lines()
|
|
||||||
.next()
|
|
||||||
.filter(|line| !line.trim().is_empty())
|
|
||||||
.unwrap_or(DEFAULT_PENDING_ASSISTANT_MESSAGE),
|
|
||||||
tool_call.name,
|
|
||||||
));
|
|
||||||
emitted_messages.push(assistant_message.clone());
|
|
||||||
return Ok(AgentProcessResult {
|
|
||||||
final_response: assistant_message,
|
|
||||||
emitted_messages,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Loop continues to next iteration with updated messages
|
// Loop continues to next iteration with updated messages
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(iteration, message_count = messages.len(), "Tool execution complete, continuing to next iteration");
|
tracing::debug!(iteration, message_count = messages.len(), "Tool execution complete, continuing to next iteration");
|
||||||
@ -520,6 +507,7 @@ impl AgentLoop {
|
|||||||
if let Some(skill_prompt) = self.skills.system_index_prompt() {
|
if let Some(skill_prompt) = self.skills.system_index_prompt() {
|
||||||
messages_for_llm.push(Message::system(skill_prompt));
|
messages_for_llm.push(Message::system(skill_prompt));
|
||||||
}
|
}
|
||||||
|
messages_for_llm.push(Message::system(MEMORY_AUTOSAVE_SYSTEM_PROMPT));
|
||||||
messages_for_llm.extend(messages.iter().map(chat_message_to_llm_message));
|
messages_for_llm.extend(messages.iter().map(chat_message_to_llm_message));
|
||||||
|
|
||||||
let request = ChatCompletionRequest {
|
let request = ChatCompletionRequest {
|
||||||
@ -532,6 +520,9 @@ impl AgentLoop {
|
|||||||
match (*self.provider).chat(request).await {
|
match (*self.provider).chat(request).await {
|
||||||
Ok(response) => {
|
Ok(response) => {
|
||||||
let assistant_message = ChatMessage::assistant(response.content);
|
let assistant_message = ChatMessage::assistant(response.content);
|
||||||
|
if !memory_write_occurred {
|
||||||
|
self.maybe_extract_and_store_memories(&messages, &assistant_message).await?;
|
||||||
|
}
|
||||||
emitted_messages.push(assistant_message.clone());
|
emitted_messages.push(assistant_message.clone());
|
||||||
Ok(AgentProcessResult {
|
Ok(AgentProcessResult {
|
||||||
final_response: assistant_message,
|
final_response: assistant_message,
|
||||||
@ -710,11 +701,7 @@ impl AgentLoop {
|
|||||||
match tool.execute_with_context(&self.tool_context, tool_call.arguments.clone()).await {
|
match tool.execute_with_context(&self.tool_context, tool_call.arguments.clone()).await {
|
||||||
Ok(result) => {
|
Ok(result) => {
|
||||||
if result.success {
|
if result.success {
|
||||||
if let Some(pending_output) = parse_pending_tool_output(&result.output) {
|
|
||||||
ToolExecutionOutcome::pending(pending_output)
|
|
||||||
} else {
|
|
||||||
ToolExecutionOutcome::success(result.output)
|
ToolExecutionOutcome::success(result.output)
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
let error = result.error.unwrap_or_default();
|
let error = result.error.unwrap_or_default();
|
||||||
ToolExecutionOutcome::failure(
|
ToolExecutionOutcome::failure(
|
||||||
@ -752,6 +739,227 @@ impl AgentLoop {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn did_successfully_write_memory(tool_call: &ToolCall, result: &ToolExecutionOutcome) -> bool {
|
||||||
|
if !result.success || tool_call.name != "memory_manage" {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
matches!(
|
||||||
|
tool_call.arguments.get("action").and_then(|value| value.as_str()),
|
||||||
|
Some("put") | Some("update")
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn lightweight_provider_config(provider_config: &LLMProviderConfig) -> LLMProviderConfig {
|
||||||
|
let mut config = provider_config.clone();
|
||||||
|
if config.provider_type == "openai" {
|
||||||
|
config.model_extra.insert(
|
||||||
|
"reasoning_effort".to_string(),
|
||||||
|
serde_json::Value::String(MEMORY_EXTRACTION_REASONING_EFFORT.to_string()),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
config
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_memory_extraction_context(messages: &[ChatMessage], final_response: &ChatMessage) -> String {
|
||||||
|
let recent_messages: Vec<&ChatMessage> = messages
|
||||||
|
.iter()
|
||||||
|
.rev()
|
||||||
|
.filter(|message| matches!(message.role.as_str(), "user" | "assistant"))
|
||||||
|
.take(MEMORY_EXTRACTION_CONTEXT_MESSAGES)
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.into_iter()
|
||||||
|
.rev()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let mut lines = vec!["请判断这轮对话中是否有值得长期记忆的信息:".to_string()];
|
||||||
|
for message in recent_messages {
|
||||||
|
let content = message.content.trim();
|
||||||
|
if content.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
lines.push(format!(
|
||||||
|
"{}: {}",
|
||||||
|
message.role,
|
||||||
|
truncate_text_for_memory_extraction(content)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let final_content = final_response.content.trim();
|
||||||
|
if !final_content.is_empty() {
|
||||||
|
lines.push(format!(
|
||||||
|
"assistant_final: {}",
|
||||||
|
truncate_text_for_memory_extraction(final_content)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
lines.join("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn truncate_text_for_memory_extraction(text: &str) -> String {
|
||||||
|
if text.chars().count() <= 240 {
|
||||||
|
return text.to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
let prefix: String = text.chars().take(240).collect();
|
||||||
|
format!("{}...", prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_memory_candidates(raw: &str) -> Vec<MemoryCandidate> {
|
||||||
|
let trimmed = raw.trim();
|
||||||
|
let json_payload = trimmed
|
||||||
|
.strip_prefix("```json")
|
||||||
|
.and_then(|value| value.strip_suffix("```"))
|
||||||
|
.map(str::trim)
|
||||||
|
.or_else(|| {
|
||||||
|
trimmed
|
||||||
|
.strip_prefix("```")
|
||||||
|
.and_then(|value| value.strip_suffix("```"))
|
||||||
|
.map(str::trim)
|
||||||
|
})
|
||||||
|
.unwrap_or(trimmed);
|
||||||
|
|
||||||
|
if let Ok(candidates) = serde_json::from_str::<Vec<MemoryCandidate>>(json_payload) {
|
||||||
|
return normalize_memory_candidates(candidates);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(envelope) = serde_json::from_str::<MemoryCandidateEnvelope>(json_payload) {
|
||||||
|
return normalize_memory_candidates(envelope.memories);
|
||||||
|
}
|
||||||
|
|
||||||
|
Vec::new()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalize_memory_candidates(candidates: Vec<MemoryCandidate>) -> Vec<MemoryCandidate> {
|
||||||
|
let mut normalized = Vec::new();
|
||||||
|
|
||||||
|
for candidate in candidates {
|
||||||
|
let Some(namespace) = canonical_memory_namespace(&candidate.namespace) else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
let key = canonical_memory_key(&candidate.key);
|
||||||
|
let content = candidate.content.trim().to_string();
|
||||||
|
if key.is_empty() || content.len() < 8 {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized.push(MemoryCandidate {
|
||||||
|
namespace,
|
||||||
|
key,
|
||||||
|
content,
|
||||||
|
});
|
||||||
|
|
||||||
|
if normalized.len() >= MEMORY_EXTRACTION_MAX_CANDIDATES {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized
|
||||||
|
}
|
||||||
|
|
||||||
|
fn canonical_memory_namespace(namespace: &str) -> Option<String> {
|
||||||
|
match namespace.trim().to_lowercase().as_str() {
|
||||||
|
"preferences" | "preference" => Some("preferences".to_string()),
|
||||||
|
"profile" | "profiles" => Some("profile".to_string()),
|
||||||
|
"tasks" | "task" => Some("tasks".to_string()),
|
||||||
|
"decisions" | "decision" => Some("decisions".to_string()),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn canonical_memory_key(key: &str) -> String {
|
||||||
|
let mut normalized = String::new();
|
||||||
|
let mut prev_is_sep = false;
|
||||||
|
|
||||||
|
for ch in key.trim().chars() {
|
||||||
|
let mapped = if ch.is_ascii_alphanumeric() {
|
||||||
|
prev_is_sep = false;
|
||||||
|
Some(ch.to_ascii_lowercase())
|
||||||
|
} else if matches!(ch, '_' | '-' | '/' | ' ') {
|
||||||
|
if prev_is_sep {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
prev_is_sep = true;
|
||||||
|
Some('_')
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(ch) = mapped {
|
||||||
|
normalized.push(ch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized.trim_matches('_').to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AgentLoop {
|
||||||
|
async fn maybe_extract_and_store_memories(
|
||||||
|
&self,
|
||||||
|
messages: &[ChatMessage],
|
||||||
|
final_response: &ChatMessage,
|
||||||
|
) -> Result<usize, AgentError> {
|
||||||
|
let Some(tool) = self.tools.get("memory_manage") else {
|
||||||
|
return Ok(0);
|
||||||
|
};
|
||||||
|
if self.tool_context.channel_name.is_none() || self.tool_context.sender_id.is_none() {
|
||||||
|
return Ok(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
let context = build_memory_extraction_context(messages, final_response);
|
||||||
|
let provider = create_provider(lightweight_provider_config(&self.provider_config))
|
||||||
|
.map_err(|err| AgentError::ProviderCreation(err.to_string()))?;
|
||||||
|
let request = ChatCompletionRequest {
|
||||||
|
messages: vec![
|
||||||
|
Message::system(MEMORY_EXTRACTION_SYSTEM_PROMPT),
|
||||||
|
Message::user(context),
|
||||||
|
],
|
||||||
|
temperature: Some(0.0),
|
||||||
|
max_tokens: Some(MEMORY_EXTRACTION_MAX_TOKENS),
|
||||||
|
tools: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = provider
|
||||||
|
.chat(request)
|
||||||
|
.await
|
||||||
|
.map_err(|err| AgentError::LlmError(err.to_string()))?;
|
||||||
|
let candidates = parse_memory_candidates(&response.content);
|
||||||
|
if candidates.is_empty() {
|
||||||
|
tracing::debug!("No auto-save memory candidates extracted from final response");
|
||||||
|
return Ok(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!(candidate_count = candidates.len(), candidates = ?candidates, "Auto-save memory candidates extracted");
|
||||||
|
|
||||||
|
let mut saved = 0;
|
||||||
|
for candidate in candidates {
|
||||||
|
let result = tool
|
||||||
|
.execute_with_context(
|
||||||
|
&self.tool_context,
|
||||||
|
serde_json::json!({
|
||||||
|
"action": "put",
|
||||||
|
"namespace": candidate.namespace,
|
||||||
|
"key": candidate.key,
|
||||||
|
"content": candidate.content,
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(|err| AgentError::Other(format!("auto-save memory tool error: {}", err)))?;
|
||||||
|
|
||||||
|
if result.success {
|
||||||
|
saved += 1;
|
||||||
|
} else {
|
||||||
|
tracing::warn!(error = result.error.as_deref().unwrap_or("unknown"), "Auto-save memory write failed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!(saved_count = saved, "Auto-save memory extraction completed");
|
||||||
|
Ok(saved)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@ -829,20 +1037,70 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_truncate_tool_result_handles_utf8_char_boundaries() {
|
fn test_did_successfully_write_memory_only_accepts_successful_put_or_update() {
|
||||||
let input = "范".repeat(MAX_TOOL_RESULT_CHARS + 500);
|
let tool_call = ToolCall {
|
||||||
|
id: "call_1".to_string(),
|
||||||
|
name: "memory_manage".to_string(),
|
||||||
|
arguments: serde_json::json!({ "action": "put" }),
|
||||||
|
};
|
||||||
|
|
||||||
let output = truncate_tool_result(&input);
|
assert!(did_successfully_write_memory(
|
||||||
|
&tool_call,
|
||||||
|
&ToolExecutionOutcome::success("ok".to_string())
|
||||||
|
));
|
||||||
|
|
||||||
assert!(output.contains("Output truncated"));
|
let failed = ToolExecutionOutcome::failure("err".to_string(), Some("boom".to_string()));
|
||||||
assert!(output.is_char_boundary(output.len()));
|
assert!(!did_successfully_write_memory(&tool_call, &failed));
|
||||||
|
|
||||||
|
let search_call = ToolCall {
|
||||||
|
id: "call_2".to_string(),
|
||||||
|
name: "memory_manage".to_string(),
|
||||||
|
arguments: serde_json::json!({ "action": "search" }),
|
||||||
|
};
|
||||||
|
assert!(!did_successfully_write_memory(
|
||||||
|
&search_call,
|
||||||
|
&ToolExecutionOutcome::success("ok".to_string())
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_parse_pending_tool_output() {
|
fn test_parse_memory_candidates_normalizes_and_limits_results() {
|
||||||
let output = parse_pending_tool_output("__PICOBOT_PENDING_USER_ACTION__\n请完成授权");
|
let raw = r#"```json
|
||||||
assert_eq!(output.as_deref(), Some("请完成授权"));
|
{"memories": [
|
||||||
assert!(parse_pending_tool_output("normal output").is_none());
|
{"namespace": "preference", "key": "Email Folder Preference", "content": "用户提到邮件时默认查看代收邮箱而不是收件箱。"},
|
||||||
|
{"namespace": "decision", "key": "mailbox strategy", "content": "后续默认先查看代收邮箱。"},
|
||||||
|
{"namespace": "tasks", "key": "short", "content": "太短"}
|
||||||
|
]}
|
||||||
|
```"#;
|
||||||
|
|
||||||
|
let candidates = parse_memory_candidates(raw);
|
||||||
|
assert_eq!(candidates.len(), 2);
|
||||||
|
assert_eq!(candidates[0].namespace, "preferences");
|
||||||
|
assert_eq!(candidates[0].key, "email_folder_preference");
|
||||||
|
assert_eq!(candidates[1].namespace, "decisions");
|
||||||
|
assert_eq!(candidates[1].key, "mailbox_strategy");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_build_memory_extraction_context_uses_recent_user_and_assistant_messages() {
|
||||||
|
let messages = vec![
|
||||||
|
ChatMessage::system("system"),
|
||||||
|
ChatMessage::tool("call-1", "calculator", "2"),
|
||||||
|
ChatMessage::user("first user"),
|
||||||
|
ChatMessage::assistant("first assistant"),
|
||||||
|
ChatMessage::user("second user"),
|
||||||
|
ChatMessage::assistant("second assistant"),
|
||||||
|
ChatMessage::user("third user"),
|
||||||
|
];
|
||||||
|
let final_response = ChatMessage::assistant("final assistant");
|
||||||
|
|
||||||
|
let context = build_memory_extraction_context(&messages, &final_response);
|
||||||
|
assert!(context.contains("user: first user") == false);
|
||||||
|
assert!(context.contains("user: second user"));
|
||||||
|
assert!(context.contains("assistant: second assistant"));
|
||||||
|
assert!(context.contains("user: third user"));
|
||||||
|
assert!(context.contains("assistant_final: final assistant"));
|
||||||
|
assert!(!context.contains("tool:"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -3,13 +3,6 @@ use serde::{Deserialize, Serialize};
|
|||||||
|
|
||||||
use crate::providers::ToolCall;
|
use crate::providers::ToolCall;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum ToolMessageState {
|
|
||||||
Completed,
|
|
||||||
PendingUserAction,
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// ContentBlock - Multimodal content representation (OpenAI-style)
|
// ContentBlock - Multimodal content representation (OpenAI-style)
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
@ -79,8 +72,6 @@ pub struct ChatMessage {
|
|||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub tool_name: Option<String>,
|
pub tool_name: Option<String>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub tool_state: Option<ToolMessageState>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub tool_calls: Option<Vec<ToolCall>>,
|
pub tool_calls: Option<Vec<ToolCall>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -94,7 +85,6 @@ impl ChatMessage {
|
|||||||
timestamp: current_timestamp(),
|
timestamp: current_timestamp(),
|
||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
tool_name: None,
|
tool_name: None,
|
||||||
tool_state: None,
|
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -108,7 +98,6 @@ impl ChatMessage {
|
|||||||
timestamp: current_timestamp(),
|
timestamp: current_timestamp(),
|
||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
tool_name: None,
|
tool_name: None,
|
||||||
tool_state: None,
|
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -122,7 +111,6 @@ impl ChatMessage {
|
|||||||
timestamp: current_timestamp(),
|
timestamp: current_timestamp(),
|
||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
tool_name: None,
|
tool_name: None,
|
||||||
tool_state: None,
|
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -136,7 +124,6 @@ impl ChatMessage {
|
|||||||
timestamp: current_timestamp(),
|
timestamp: current_timestamp(),
|
||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
tool_name: None,
|
tool_name: None,
|
||||||
tool_state: None,
|
|
||||||
tool_calls: Some(tool_calls),
|
tool_calls: Some(tool_calls),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -150,21 +137,11 @@ impl ChatMessage {
|
|||||||
timestamp: current_timestamp(),
|
timestamp: current_timestamp(),
|
||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
tool_name: None,
|
tool_name: None,
|
||||||
tool_state: None,
|
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self {
|
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self {
|
||||||
Self::tool_with_state(tool_call_id, tool_name, content, ToolMessageState::Completed)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn tool_with_state(
|
|
||||||
tool_call_id: impl Into<String>,
|
|
||||||
tool_name: impl Into<String>,
|
|
||||||
content: impl Into<String>,
|
|
||||||
tool_state: ToolMessageState,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
Self {
|
||||||
id: uuid::Uuid::new_v4().to_string(),
|
id: uuid::Uuid::new_v4().to_string(),
|
||||||
role: "tool".to_string(),
|
role: "tool".to_string(),
|
||||||
@ -173,7 +150,6 @@ impl ChatMessage {
|
|||||||
timestamp: current_timestamp(),
|
timestamp: current_timestamp(),
|
||||||
tool_call_id: Some(tool_call_id.into()),
|
tool_call_id: Some(tool_call_id.into()),
|
||||||
tool_name: Some(tool_name.into()),
|
tool_name: Some(tool_name.into()),
|
||||||
tool_state: Some(tool_state),
|
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -236,7 +212,6 @@ pub enum OutboundEventKind {
|
|||||||
AssistantResponse,
|
AssistantResponse,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolResult,
|
ToolResult,
|
||||||
ToolPending,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OutboundMessage {
|
impl OutboundMessage {
|
||||||
@ -319,33 +294,6 @@ impl OutboundMessage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tool_pending(
|
|
||||||
channel: impl Into<String>,
|
|
||||||
chat_id: impl Into<String>,
|
|
||||||
tool_call_id: impl Into<String>,
|
|
||||||
tool_name: impl Into<String>,
|
|
||||||
content: impl Into<String>,
|
|
||||||
reply_to: Option<String>,
|
|
||||||
metadata: HashMap<String, String>,
|
|
||||||
) -> Self {
|
|
||||||
let tool_name = tool_name.into();
|
|
||||||
let raw_content = content.into();
|
|
||||||
let content = format_tool_result_content(&tool_name, &raw_content);
|
|
||||||
Self {
|
|
||||||
channel: channel.into(),
|
|
||||||
chat_id: chat_id.into(),
|
|
||||||
content,
|
|
||||||
reply_to,
|
|
||||||
media: Vec::new(),
|
|
||||||
metadata,
|
|
||||||
event_kind: OutboundEventKind::ToolPending,
|
|
||||||
role: "tool".to_string(),
|
|
||||||
tool_call_id: Some(tool_call_id.into()),
|
|
||||||
tool_name: Some(tool_name),
|
|
||||||
tool_arguments: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn from_chat_message(
|
pub fn from_chat_message(
|
||||||
channel: &str,
|
channel: &str,
|
||||||
chat_id: &str,
|
chat_id: &str,
|
||||||
@ -380,26 +328,7 @@ impl OutboundMessage {
|
|||||||
)]
|
)]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"tool" => match message.tool_state.as_ref().unwrap_or(&ToolMessageState::Completed) {
|
"tool" => Vec::new(),
|
||||||
ToolMessageState::Completed => vec![Self::tool_result(
|
|
||||||
channel.to_string(),
|
|
||||||
chat_id.to_string(),
|
|
||||||
message.tool_call_id.clone().unwrap_or_default(),
|
|
||||||
message.tool_name.clone().unwrap_or_default(),
|
|
||||||
message.content.clone(),
|
|
||||||
reply_to,
|
|
||||||
metadata.clone(),
|
|
||||||
)],
|
|
||||||
ToolMessageState::PendingUserAction => vec![Self::tool_pending(
|
|
||||||
channel.to_string(),
|
|
||||||
chat_id.to_string(),
|
|
||||||
message.tool_call_id.clone().unwrap_or_default(),
|
|
||||||
message.tool_name.clone().unwrap_or_default(),
|
|
||||||
message.content.clone(),
|
|
||||||
reply_to,
|
|
||||||
metadata.clone(),
|
|
||||||
)],
|
|
||||||
},
|
|
||||||
_ => Vec::new(),
|
_ => Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -448,7 +377,7 @@ fn current_timestamp() -> i64 {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::{ChatMessage, OutboundEventKind, OutboundMessage, ToolMessageState};
|
use super::{ChatMessage, OutboundEventKind, OutboundMessage};
|
||||||
use crate::providers::ToolCall;
|
use crate::providers::ToolCall;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@ -489,7 +418,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_from_chat_message_includes_tool_result() {
|
fn test_from_chat_message_omits_tool_result() {
|
||||||
let message = ChatMessage::tool("call-9", "calculator", "2");
|
let message = ChatMessage::tool("call-9", "calculator", "2");
|
||||||
|
|
||||||
let outbound = OutboundMessage::from_chat_message(
|
let outbound = OutboundMessage::from_chat_message(
|
||||||
@ -500,28 +429,6 @@ mod tests {
|
|||||||
&message,
|
&message,
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(outbound.len(), 1);
|
assert!(outbound.is_empty());
|
||||||
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolResult);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_from_chat_message_includes_tool_pending() {
|
|
||||||
let message = ChatMessage::tool_with_state(
|
|
||||||
"call-9",
|
|
||||||
"bash",
|
|
||||||
"等待你完成浏览器授权后再继续。",
|
|
||||||
ToolMessageState::PendingUserAction,
|
|
||||||
);
|
|
||||||
|
|
||||||
let outbound = OutboundMessage::from_chat_message(
|
|
||||||
"feishu",
|
|
||||||
"chat-1",
|
|
||||||
None,
|
|
||||||
&HashMap::new(),
|
|
||||||
&message,
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(outbound.len(), 1);
|
|
||||||
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolPending);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -73,9 +73,6 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
WsOutbound::ToolResult { tool_name, content, .. } => {
|
WsOutbound::ToolResult { tool_name, content, .. } => {
|
||||||
input.write_output(&format!("Tool result: {}\n{}\n", tool_name, content)).await?;
|
input.write_output(&format!("Tool result: {}\n{}\n", tool_name, content)).await?;
|
||||||
}
|
}
|
||||||
WsOutbound::ToolPending { tool_name, content, resume_hint, .. } => {
|
|
||||||
input.write_output(&format!("Tool pending: {}\n{}\n{}\n", tool_name, content, resume_hint)).await?;
|
|
||||||
}
|
|
||||||
WsOutbound::Error { message, .. } => {
|
WsOutbound::Error { message, .. } => {
|
||||||
input.write_output(&format!("Error: {}", message)).await?;
|
input.write_output(&format!("Error: {}", message)).await?;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -134,8 +134,6 @@ pub struct GatewayConfig {
|
|||||||
pub host: String,
|
pub host: String,
|
||||||
#[serde(default = "default_gateway_port")]
|
#[serde(default = "default_gateway_port")]
|
||||||
pub port: u16,
|
pub port: u16,
|
||||||
#[serde(default)]
|
|
||||||
pub show_tool_results: bool,
|
|
||||||
#[serde(default, rename = "session_ttl_hours")]
|
#[serde(default, rename = "session_ttl_hours")]
|
||||||
pub session_ttl_hours: Option<u64>,
|
pub session_ttl_hours: Option<u64>,
|
||||||
#[serde(default = "default_agent_prompt_reinject_every", rename = "agent_prompt_reinject_every")]
|
#[serde(default = "default_agent_prompt_reinject_every", rename = "agent_prompt_reinject_every")]
|
||||||
@ -169,7 +167,6 @@ impl Default for GatewayConfig {
|
|||||||
Self {
|
Self {
|
||||||
host: default_gateway_host(),
|
host: default_gateway_host(),
|
||||||
port: default_gateway_port(),
|
port: default_gateway_port(),
|
||||||
show_tool_results: false,
|
|
||||||
session_ttl_hours: None,
|
session_ttl_hours: None,
|
||||||
agent_prompt_reinject_every: default_agent_prompt_reinject_every(),
|
agent_prompt_reinject_every: default_agent_prompt_reinject_every(),
|
||||||
}
|
}
|
||||||
@ -398,7 +395,6 @@ mod tests {
|
|||||||
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||||
assert_eq!(config.gateway.host, "0.0.0.0");
|
assert_eq!(config.gateway.host, "0.0.0.0");
|
||||||
assert_eq!(config.gateway.port, 19876);
|
assert_eq!(config.gateway.port, 19876);
|
||||||
assert!(!config.gateway.show_tool_results);
|
|
||||||
assert_eq!(config.gateway.agent_prompt_reinject_every, 120);
|
assert_eq!(config.gateway.agent_prompt_reinject_every, 120);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -432,43 +428,6 @@ mod tests {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||||
assert!(!config.gateway.show_tool_results);
|
|
||||||
assert_eq!(config.gateway.agent_prompt_reinject_every, 100);
|
assert_eq!(config.gateway.agent_prompt_reinject_every, 100);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_gateway_config_can_enable_tool_results() {
|
|
||||||
let file = tempfile::NamedTempFile::new().unwrap();
|
|
||||||
std::fs::write(
|
|
||||||
file.path(),
|
|
||||||
r#"{
|
|
||||||
"providers": {
|
|
||||||
"aliyun": {
|
|
||||||
"type": "openai",
|
|
||||||
"base_url": "https://example.invalid/v1",
|
|
||||||
"api_key": "test-key",
|
|
||||||
"extra_headers": {}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"models": {
|
|
||||||
"qwen-plus": {
|
|
||||||
"model_id": "qwen-plus"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"agents": {
|
|
||||||
"default": {
|
|
||||||
"provider": "aliyun",
|
|
||||||
"model": "qwen-plus"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"gateway": {
|
|
||||||
"show_tool_results": true
|
|
||||||
}
|
|
||||||
}"#,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
|
||||||
assert!(config.gateway.show_tool_results);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -30,14 +30,12 @@ impl GatewayState {
|
|||||||
// Session TTL from config (default 4 hours)
|
// Session TTL from config (default 4 hours)
|
||||||
let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4);
|
let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4);
|
||||||
let agent_prompt_reinject_every = config.gateway.agent_prompt_reinject_every;
|
let agent_prompt_reinject_every = config.gateway.agent_prompt_reinject_every;
|
||||||
let show_tool_results = config.gateway.show_tool_results;
|
|
||||||
|
|
||||||
let skills = Arc::new(SkillRuntime::from_config(config.skills.clone()));
|
let skills = Arc::new(SkillRuntime::from_config(config.skills.clone()));
|
||||||
|
|
||||||
let session_manager = SessionManager::new(
|
let session_manager = SessionManager::new(
|
||||||
session_ttl_hours,
|
session_ttl_hours,
|
||||||
agent_prompt_reinject_every,
|
agent_prompt_reinject_every,
|
||||||
show_tool_results,
|
|
||||||
provider_config,
|
provider_config,
|
||||||
skills,
|
skills,
|
||||||
)?;
|
)?;
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
@ -8,16 +8,21 @@ use uuid::Uuid;
|
|||||||
use crate::bus::{ChatMessage, MessageBus, OutboundMessage};
|
use crate::bus::{ChatMessage, MessageBus, OutboundMessage};
|
||||||
use crate::config::LLMProviderConfig;
|
use crate::config::LLMProviderConfig;
|
||||||
use crate::agent::{AgentLoop, AgentError, ContextCompressor, EmittedMessageHandler};
|
use crate::agent::{AgentLoop, AgentError, ContextCompressor, EmittedMessageHandler};
|
||||||
|
use crate::providers::{create_provider, ChatCompletionRequest, Message};
|
||||||
use crate::protocol::WsOutbound;
|
use crate::protocol::WsOutbound;
|
||||||
use crate::skills::SkillRuntime;
|
use crate::skills::SkillRuntime;
|
||||||
use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
|
use crate::storage::{MemoryRecord, SessionRecord, SessionStore, persistent_session_id};
|
||||||
use crate::tools::{
|
use crate::tools::{
|
||||||
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
|
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
|
||||||
HttpRequestTool, MemoryManageTool, MemorySearchTool, SkillListTool, SkillManageTool, ToolContext, ToolRegistry,
|
HttpRequestTool, MemoryManageTool, SkillListTool, SkillManageTool, ToolContext, ToolRegistry,
|
||||||
WebFetchTool,
|
WebFetchTool,
|
||||||
};
|
};
|
||||||
|
|
||||||
const DEFAULT_AGENT_PROMPT: &str = "# PicoBot 代理配置\n\n## 身份\n- 你是 PicoBot,一名务实、可靠的通用助理。\n- 你的目标是理解用户当下的真实需求,并给出清晰、可执行的帮助。\n\n## 工作方式\n- 优先理解意图,再给出回应或行动。\n- 保持简洁、准确、自然,不故作热情,也不空泛铺陈。\n- 能直接验证的内容尽量先验证,避免凭空猜测。\n- 当现有工具是完成任务的最直接方式时,优先使用工具。\n- 除非用户明确要求改变方向,否则保持用户原本目标不变。\n\n## 助理原则\n- 优先解决问题,而不是展示过程。\n- 输出要方便用户立即使用,结论尽量明确。\n- 对不确定的地方要直说,不把猜测包装成事实。\n- 复杂任务先收敛重点,简单任务直接给结果。\n- 避免不必要的重复、客套和冗长说明。\n\n## 回复规则\n- 除非用户另有要求,否则使用中文回复。\n- 默认短而清楚,按信息密度组织内容。\n- 如果任务涉及文件、命令、配置或下一步操作,优先给出最关键的那部分。\n- 如果存在限制、风险或前提条件,要直接说明。\n\n## 补充要求\n- 你是 PicoBot。\n- 回答应以帮助用户完成当前目标为中心。\n- 在信息不足时先补关键前提,在信息充分时直接执行。\n";
|
const DEFAULT_AGENT_PROMPT: &str = "# PicoBot 代理配置\n\n## 身份\n- 你是 PicoBot,一名务实、可靠的通用助理。\n- 你的目标是理解用户当下的真实需求,并给出清晰、可执行的帮助。\n\n## 工作方式\n- 优先理解意图,再给出回应或行动。\n- 保持简洁、准确、自然,不故作热情,也不空泛铺陈。\n- 能直接验证的内容尽量先验证,避免凭空猜测。\n- 当现有工具是完成任务的最直接方式时,优先使用工具。\n- 除非用户明确要求改变方向,否则保持用户原本目标不变。\n\n## 助理原则\n- 优先解决问题,而不是展示过程。\n- 输出要方便用户立即使用,结论尽量明确。\n- 对不确定的地方要直说,不把猜测包装成事实。\n- 复杂任务先收敛重点,简单任务直接给结果。\n- 避免不必要的重复、客套和冗长说明。\n\n## 回复规则\n- 除非用户另有要求,否则使用中文回复。\n- 默认短而清楚,按信息密度组织内容。\n- 如果任务涉及文件、命令、配置或下一步操作,优先给出最关键的那部分。\n- 如果存在限制、风险或前提条件,要直接说明。\n\n## 补充要求\n- 你是 PicoBot。\n- 回答应以帮助用户完成当前目标为中心。\n- 在信息不足时先补关键前提,在信息充分时直接执行。\n";
|
||||||
|
const MEMORY_KEYWORD_SYSTEM_PROMPT: &str = "你负责为长期记忆检索生成关键词。根据给定的最近会话,仅输出 JSON 数组字符串。关键词必须同时覆盖中文和英文:优先为同一主题同时给出中文关键词和对应英文关键词。关键词必须是短词语,优先使用最容易命中记忆的核心检索词,不要输出完整句子、解释或长描述。必要时优先保留实体名、产品名、偏好名、snake_case key 风格短词。数组元素总数控制在 2 到 6 个简短关键词或短语。不要输出解释,不要输出 Markdown。";
|
||||||
|
const RELATED_MEMORY_SYSTEM_PROMPT_PREFIX: &str = "找到相关的记忆。你必须优先参考这些记忆,并在后续推理中把它们当作当前会话的补充上下文;若与用户本轮明确要求冲突,以用户本轮要求为准。";
|
||||||
|
const MEMORY_KEYWORD_REASONING_EFFORT: &str = "none";
|
||||||
|
const MEMORY_KEYWORD_MAX_CHARS: usize = 32;
|
||||||
|
|
||||||
/// Session 按 channel 隔离,每个 channel 一个 Session
|
/// Session 按 channel 隔离,每个 channel 一个 Session
|
||||||
/// History 按 chat_id 隔离,由 Session 统一管理
|
/// History 按 chat_id 隔离,由 Session 统一管理
|
||||||
@ -357,7 +362,6 @@ pub struct SessionManager {
|
|||||||
skills: Arc<SkillRuntime>,
|
skills: Arc<SkillRuntime>,
|
||||||
store: Arc<SessionStore>,
|
store: Arc<SessionStore>,
|
||||||
agent_prompt_reinject_every: u64,
|
agent_prompt_reinject_every: u64,
|
||||||
show_tool_results: bool,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct SessionManagerInner {
|
struct SessionManagerInner {
|
||||||
@ -372,7 +376,6 @@ fn default_tools(skills: Arc<SkillRuntime>, store: Arc<SessionStore>) -> ToolReg
|
|||||||
registry.register(FileReadTool::new());
|
registry.register(FileReadTool::new());
|
||||||
registry.register(FileWriteTool::new());
|
registry.register(FileWriteTool::new());
|
||||||
registry.register(FileEditTool::new());
|
registry.register(FileEditTool::new());
|
||||||
registry.register(MemorySearchTool::new(store.clone()));
|
|
||||||
registry.register(MemoryManageTool::new(store));
|
registry.register(MemoryManageTool::new(store));
|
||||||
registry.register(SkillListTool::new(skills.clone()));
|
registry.register(SkillListTool::new(skills.clone()));
|
||||||
registry.register(SkillManageTool::new(skills));
|
registry.register(SkillManageTool::new(skills));
|
||||||
@ -417,7 +420,6 @@ impl SessionManager {
|
|||||||
pub fn new(
|
pub fn new(
|
||||||
session_ttl_hours: u64,
|
session_ttl_hours: u64,
|
||||||
agent_prompt_reinject_every: u64,
|
agent_prompt_reinject_every: u64,
|
||||||
show_tool_results: bool,
|
|
||||||
provider_config: LLMProviderConfig,
|
provider_config: LLMProviderConfig,
|
||||||
skills: Arc<SkillRuntime>,
|
skills: Arc<SkillRuntime>,
|
||||||
) -> Result<Self, AgentError> {
|
) -> Result<Self, AgentError> {
|
||||||
@ -441,7 +443,6 @@ impl SessionManager {
|
|||||||
skills,
|
skills,
|
||||||
store,
|
store,
|
||||||
agent_prompt_reinject_every,
|
agent_prompt_reinject_every,
|
||||||
show_tool_results,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -625,7 +626,43 @@ impl SessionManager {
|
|||||||
let user_message_id = user_message.id.clone();
|
let user_message_id = user_message.id.clone();
|
||||||
session_guard.append_persisted_message(chat_id, user_message)?;
|
session_guard.append_persisted_message(chat_id, user_message)?;
|
||||||
|
|
||||||
let history = session_guard.get_or_create_history(chat_id).clone();
|
// 获取完整历史
|
||||||
|
let mut history = session_guard.get_or_create_history(chat_id).clone();
|
||||||
|
tracing::info!(
|
||||||
|
channel = %channel_name,
|
||||||
|
chat_id = %chat_id,
|
||||||
|
sender_id = %sender_id,
|
||||||
|
history_len = history.len(),
|
||||||
|
"Starting synchronous related memory search"
|
||||||
|
);
|
||||||
|
if let Some(memory_prompt) = build_related_memory_prompt(
|
||||||
|
session_guard.provider_config().clone(),
|
||||||
|
self.store.clone(),
|
||||||
|
channel_name.to_string(),
|
||||||
|
sender_id.to_string(),
|
||||||
|
chat_id.to_string(),
|
||||||
|
history.clone(),
|
||||||
|
)
|
||||||
|
.await?
|
||||||
|
{
|
||||||
|
tracing::info!(
|
||||||
|
channel = %channel_name,
|
||||||
|
chat_id = %chat_id,
|
||||||
|
sender_id = %sender_id,
|
||||||
|
prompt_len = memory_prompt.len(),
|
||||||
|
"Injecting related memory system prompt before agent processing"
|
||||||
|
);
|
||||||
|
let memory_message = ChatMessage::system(memory_prompt);
|
||||||
|
session_guard.append_persisted_message(chat_id, memory_message.clone())?;
|
||||||
|
history.push(memory_message);
|
||||||
|
} else {
|
||||||
|
tracing::info!(
|
||||||
|
channel = %channel_name,
|
||||||
|
chat_id = %chat_id,
|
||||||
|
sender_id = %sender_id,
|
||||||
|
"No related memory prompt generated before agent processing"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
// 压缩历史(如果需要)
|
// 压缩历史(如果需要)
|
||||||
let history = session_guard.compressor
|
let history = session_guard.compressor
|
||||||
@ -647,10 +684,7 @@ impl SessionManager {
|
|||||||
result
|
result
|
||||||
.emitted_messages
|
.emitted_messages
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|message| {
|
.filter(|message| !message.is_assistant_tool_call_message() || live_emitter.is_none())
|
||||||
(!message.is_assistant_tool_call_message() || live_emitter.is_none())
|
|
||||||
&& should_display_message_to_user(self.show_tool_results, message)
|
|
||||||
})
|
|
||||||
.flat_map(|message| {
|
.flat_map(|message| {
|
||||||
OutboundMessage::from_chat_message(
|
OutboundMessage::from_chat_message(
|
||||||
channel_name,
|
channel_name,
|
||||||
@ -684,16 +718,231 @@ impl SessionManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn should_display_message_to_user(show_tool_results: bool, message: &ChatMessage) -> bool {
|
async fn build_related_memory_prompt(
|
||||||
if message.role != "tool" {
|
provider_config: LLMProviderConfig,
|
||||||
return true;
|
store: Arc<SessionStore>,
|
||||||
|
channel_name: String,
|
||||||
|
sender_id: String,
|
||||||
|
chat_id: String,
|
||||||
|
history: Vec<ChatMessage>,
|
||||||
|
) -> Result<Option<String>, AgentError> {
|
||||||
|
let keywords = generate_memory_search_keywords(provider_config, &history).await?;
|
||||||
|
tracing::info!(
|
||||||
|
channel = %channel_name,
|
||||||
|
chat_id = %chat_id,
|
||||||
|
sender_id = %sender_id,
|
||||||
|
keyword_count = keywords.len(),
|
||||||
|
keywords = ?keywords,
|
||||||
|
"Generated memory search keywords"
|
||||||
|
);
|
||||||
|
if keywords.is_empty() {
|
||||||
|
return Ok(None);
|
||||||
}
|
}
|
||||||
|
|
||||||
show_tool_results
|
let memories = search_related_memories(
|
||||||
|| matches!(
|
store,
|
||||||
message.tool_state.as_ref().unwrap_or(&crate::bus::message::ToolMessageState::Completed),
|
&channel_name,
|
||||||
crate::bus::message::ToolMessageState::PendingUserAction
|
&sender_id,
|
||||||
|
&chat_id,
|
||||||
|
&keywords,
|
||||||
)
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if memories.is_empty() {
|
||||||
|
tracing::info!(
|
||||||
|
channel = %channel_name,
|
||||||
|
chat_id = %chat_id,
|
||||||
|
sender_id = %sender_id,
|
||||||
|
keyword_count = keywords.len(),
|
||||||
|
"Related memory search returned no matches"
|
||||||
|
);
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
channel = %channel_name,
|
||||||
|
chat_id = %chat_id,
|
||||||
|
sender_id = %sender_id,
|
||||||
|
keyword_count = keywords.len(),
|
||||||
|
memory_count = memories.len(),
|
||||||
|
"Related memory search produced matches"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(Some(format_related_memory_system_prompt(&keywords, &memories)))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn generate_memory_search_keywords(
|
||||||
|
mut provider_config: LLMProviderConfig,
|
||||||
|
history: &[ChatMessage],
|
||||||
|
) -> Result<Vec<String>, AgentError> {
|
||||||
|
if provider_config.provider_type == "openai" {
|
||||||
|
provider_config.model_extra.insert(
|
||||||
|
"reasoning_effort".to_string(),
|
||||||
|
serde_json::Value::String(MEMORY_KEYWORD_REASONING_EFFORT.to_string()),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let provider = create_provider(provider_config)
|
||||||
|
.map_err(|err| AgentError::ProviderCreation(err.to_string()))?;
|
||||||
|
|
||||||
|
let request = ChatCompletionRequest {
|
||||||
|
messages: vec![
|
||||||
|
Message::system(MEMORY_KEYWORD_SYSTEM_PROMPT),
|
||||||
|
Message::user(build_memory_keyword_context(history)),
|
||||||
|
],
|
||||||
|
temperature: Some(0.0),
|
||||||
|
max_tokens: Some(128),
|
||||||
|
tools: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = provider
|
||||||
|
.chat(request)
|
||||||
|
.await
|
||||||
|
.map_err(|err| AgentError::LlmError(err.to_string()))?;
|
||||||
|
|
||||||
|
Ok(parse_memory_keywords(&response.content))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_memory_keyword_context(history: &[ChatMessage]) -> String {
|
||||||
|
let recent_messages: Vec<&ChatMessage> = history
|
||||||
|
.iter()
|
||||||
|
.rev()
|
||||||
|
.filter(|message| message.role != "system")
|
||||||
|
.take(8)
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.into_iter()
|
||||||
|
.rev()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let mut lines = vec!["请基于以下最近会话生成长期记忆搜索关键词:".to_string()];
|
||||||
|
for message in recent_messages {
|
||||||
|
let content = message.content.trim();
|
||||||
|
if content.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let compact = if content.chars().count() > 240 {
|
||||||
|
let prefix: String = content.chars().take(240).collect();
|
||||||
|
format!("{}...", prefix)
|
||||||
|
} else {
|
||||||
|
content.to_string()
|
||||||
|
};
|
||||||
|
lines.push(format!("{}: {}", message.role, compact));
|
||||||
|
}
|
||||||
|
|
||||||
|
lines.join("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_memory_keywords(raw: &str) -> Vec<String> {
|
||||||
|
if let Ok(keywords) = serde_json::from_str::<Vec<String>>(raw) {
|
||||||
|
return normalize_keywords(keywords);
|
||||||
|
}
|
||||||
|
|
||||||
|
normalize_keywords(
|
||||||
|
raw.split(|ch| matches!(ch, '\n' | ',' | ',' | ';' | ';'))
|
||||||
|
.map(str::trim)
|
||||||
|
.filter(|part| !part.is_empty())
|
||||||
|
.map(ToOwned::to_owned)
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalize_keywords(keywords: Vec<String>) -> Vec<String> {
|
||||||
|
let mut seen = HashSet::new();
|
||||||
|
let mut normalized = Vec::new();
|
||||||
|
|
||||||
|
for keyword in keywords {
|
||||||
|
let candidate = keyword
|
||||||
|
.trim()
|
||||||
|
.trim_matches('"')
|
||||||
|
.trim_matches('[')
|
||||||
|
.trim_matches(']')
|
||||||
|
.trim()
|
||||||
|
.to_string();
|
||||||
|
let candidate = compact_memory_keyword(&candidate);
|
||||||
|
if candidate.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let key = candidate.to_lowercase();
|
||||||
|
if seen.insert(key) {
|
||||||
|
normalized.push(candidate);
|
||||||
|
}
|
||||||
|
|
||||||
|
if normalized.len() >= 6 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compact_memory_keyword(candidate: &str) -> String {
|
||||||
|
let compact = candidate
|
||||||
|
.split_whitespace()
|
||||||
|
.next()
|
||||||
|
.unwrap_or(candidate)
|
||||||
|
.trim_matches(|ch: char| matches!(ch, '"' | '\'' | '[' | ']' | '。' | ',' | ',' | ';' | ';' | ':' | ':'))
|
||||||
|
.trim();
|
||||||
|
|
||||||
|
if compact.is_empty() {
|
||||||
|
return String::new();
|
||||||
|
}
|
||||||
|
|
||||||
|
compact.chars().take(MEMORY_KEYWORD_MAX_CHARS).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn search_related_memories(
|
||||||
|
store: Arc<SessionStore>,
|
||||||
|
channel_name: &str,
|
||||||
|
sender_id: &str,
|
||||||
|
chat_id: &str,
|
||||||
|
keywords: &[String],
|
||||||
|
) -> Result<Vec<MemoryRecord>, AgentError> {
|
||||||
|
tracing::debug!(
|
||||||
|
channel = %channel_name,
|
||||||
|
chat_id = %chat_id,
|
||||||
|
sender_id = %sender_id,
|
||||||
|
keyword_count = keywords.len(),
|
||||||
|
keywords = ?keywords,
|
||||||
|
"Searching related memories with a single batched FTS query"
|
||||||
|
);
|
||||||
|
|
||||||
|
let scope_key = format!("{}:{}", channel_name, sender_id);
|
||||||
|
let merged = store
|
||||||
|
.search_memories_any("user", &scope_key, keywords, None, 10)
|
||||||
|
.map_err(|err| AgentError::Other(format!("batched memory search error: {}", err)))?;
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
channel = %channel_name,
|
||||||
|
chat_id = %chat_id,
|
||||||
|
sender_id = %sender_id,
|
||||||
|
keyword_count = keywords.len(),
|
||||||
|
deduped_memory_count = merged.len(),
|
||||||
|
"Finished related memory search aggregation"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(merged)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn format_related_memory_system_prompt(keywords: &[String], memories: &[MemoryRecord]) -> String {
|
||||||
|
let mut lines = vec![
|
||||||
|
RELATED_MEMORY_SYSTEM_PROMPT_PREFIX.to_string(),
|
||||||
|
format!("检索关键词: {}", keywords.join(" / ")),
|
||||||
|
"相关记忆: ".to_string(),
|
||||||
|
];
|
||||||
|
|
||||||
|
for (index, memory) in memories.iter().take(8).enumerate() {
|
||||||
|
lines.push(format!(
|
||||||
|
"{}. [{} / {}] {}",
|
||||||
|
index + 1,
|
||||||
|
memory.namespace,
|
||||||
|
memory.memory_key,
|
||||||
|
memory.content.replace('\n', " ").trim()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
lines.join("\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@ -718,21 +967,6 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_should_display_message_to_user_hides_completed_tool_results_by_default() {
|
|
||||||
let completed = ChatMessage::tool("call-1", "calculator", "2");
|
|
||||||
let pending = ChatMessage::tool_with_state(
|
|
||||||
"call-2",
|
|
||||||
"bash",
|
|
||||||
"waiting",
|
|
||||||
crate::bus::message::ToolMessageState::PendingUserAction,
|
|
||||||
);
|
|
||||||
|
|
||||||
assert!(!should_display_message_to_user(false, &completed));
|
|
||||||
assert!(should_display_message_to_user(false, &pending));
|
|
||||||
assert!(should_display_message_to_user(true, &completed));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_parse_in_chat_command_aliases() {
|
fn test_parse_in_chat_command_aliases() {
|
||||||
assert_eq!(parse_in_chat_command("/new"), Some(InChatCommand::FreshConversation));
|
assert_eq!(parse_in_chat_command("/new"), Some(InChatCommand::FreshConversation));
|
||||||
@ -927,4 +1161,42 @@ mod tests {
|
|||||||
assert_eq!(history[0].role, "system");
|
assert_eq!(history[0].role, "system");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_memory_keywords_handles_json_and_dedup() {
|
||||||
|
let keywords = parse_memory_keywords("[\"Rust\", \"偏好\", \"rust\", \"自动化\"]");
|
||||||
|
assert_eq!(keywords, vec!["Rust", "偏好", "自动化"]);
|
||||||
|
|
||||||
|
let fallback = parse_memory_keywords("Rust, 偏好\n自动化");
|
||||||
|
assert_eq!(fallback, vec!["Rust", "偏好", "自动化"]);
|
||||||
|
|
||||||
|
let compacted = parse_memory_keywords("[\"用户 身份 信息 长描述\", \"email_folder_preference details\"]");
|
||||||
|
assert_eq!(compacted, vec!["用户", "email_folder_preference"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_format_related_memory_system_prompt_includes_prefix_keywords_and_memory_lines() {
|
||||||
|
let prompt = format_related_memory_system_prompt(
|
||||||
|
&["Rust 偏好".to_string(), "审批项目".to_string()],
|
||||||
|
&[MemoryRecord {
|
||||||
|
id: "memory-1".to_string(),
|
||||||
|
scope_kind: "user".to_string(),
|
||||||
|
scope_key: "feishu:user-1".to_string(),
|
||||||
|
namespace: "profile".to_string(),
|
||||||
|
memory_key: "language".to_string(),
|
||||||
|
content: "用户偏好 Rust 和自动化工具".to_string(),
|
||||||
|
source_type: "message".to_string(),
|
||||||
|
source_session_id: Some("feishu:chat-1".to_string()),
|
||||||
|
source_message_id: Some("msg-1".to_string()),
|
||||||
|
source_message_seq: Some(1),
|
||||||
|
source_channel_name: Some("feishu".to_string()),
|
||||||
|
source_chat_id: Some("chat-1".to_string()),
|
||||||
|
created_at: 1,
|
||||||
|
updated_at: 1,
|
||||||
|
}],
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(prompt.contains("找到相关的记忆"));
|
||||||
|
assert!(prompt.contains("Rust 偏好 / 审批项目"));
|
||||||
|
assert!(prompt.contains("[profile / language] 用户偏好 Rust 和自动化工具"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -6,7 +6,7 @@ use axum::response::Response;
|
|||||||
use futures_util::{SinkExt, StreamExt};
|
use futures_util::{SinkExt, StreamExt};
|
||||||
use tokio::sync::{mpsc, Mutex};
|
use tokio::sync::{mpsc, Mutex};
|
||||||
use crate::agent::EmittedMessageHandler;
|
use crate::agent::EmittedMessageHandler;
|
||||||
use crate::bus::message::{format_tool_call_content, ToolMessageState};
|
use crate::bus::message::format_tool_call_content;
|
||||||
use crate::bus::ChatMessage;
|
use crate::bus::ChatMessage;
|
||||||
use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound};
|
use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound};
|
||||||
use super::{GatewayState, session::{Session, handle_in_chat_command}};
|
use super::{GatewayState, session::{Session, handle_in_chat_command}};
|
||||||
@ -181,39 +181,11 @@ fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
|
|||||||
}]
|
}]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"tool" => match message.tool_state.as_ref().unwrap_or(&ToolMessageState::Completed) {
|
"tool" => Vec::new(),
|
||||||
ToolMessageState::Completed => vec![WsOutbound::ToolResult {
|
|
||||||
id: message.id.clone(),
|
|
||||||
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
|
|
||||||
tool_name: message.tool_name.clone().unwrap_or_default(),
|
|
||||||
content: message.content.clone(),
|
|
||||||
role: message.role.clone(),
|
|
||||||
}],
|
|
||||||
ToolMessageState::PendingUserAction => vec![WsOutbound::ToolPending {
|
|
||||||
id: message.id.clone(),
|
|
||||||
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
|
|
||||||
tool_name: message.tool_name.clone().unwrap_or_default(),
|
|
||||||
content: message.content.clone(),
|
|
||||||
role: message.role.clone(),
|
|
||||||
resume_hint: "完成外部操作后,直接发一条继续消息即可。".to_string(),
|
|
||||||
}],
|
|
||||||
},
|
|
||||||
_ => Vec::new(),
|
_ => Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn should_display_message_to_user(show_tool_results: bool, message: &ChatMessage) -> bool {
|
|
||||||
if message.role != "tool" {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
show_tool_results
|
|
||||||
|| matches!(
|
|
||||||
message.tool_state.as_ref().unwrap_or(&ToolMessageState::Completed),
|
|
||||||
ToolMessageState::PendingUserAction
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn handle_inbound(
|
async fn handle_inbound(
|
||||||
state: &Arc<GatewayState>,
|
state: &Arc<GatewayState>,
|
||||||
session: &Arc<Mutex<Session>>,
|
session: &Arc<Mutex<Session>>,
|
||||||
@ -272,10 +244,7 @@ async fn handle_inbound(
|
|||||||
for outbound in result
|
for outbound in result
|
||||||
.emitted_messages
|
.emitted_messages
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|message| {
|
.filter(|message| !message.is_assistant_tool_call_message())
|
||||||
!message.is_assistant_tool_call_message()
|
|
||||||
&& should_display_message_to_user(state.config.gateway.show_tool_results, message)
|
|
||||||
})
|
|
||||||
.flat_map(ws_outbound_from_chat_message)
|
.flat_map(ws_outbound_from_chat_message)
|
||||||
{
|
{
|
||||||
let _ = session_guard.send(outbound).await;
|
let _ = session_guard.send(outbound).await;
|
||||||
@ -420,9 +389,8 @@ async fn handle_inbound(
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::{should_display_message_to_user, ws_outbound_from_chat_message};
|
use super::ws_outbound_from_chat_message;
|
||||||
use crate::bus::ChatMessage;
|
use crate::bus::ChatMessage;
|
||||||
use crate::bus::message::ToolMessageState;
|
|
||||||
use crate::providers::ToolCall;
|
use crate::providers::ToolCall;
|
||||||
use crate::protocol::WsOutbound;
|
use crate::protocol::WsOutbound;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
@ -453,42 +421,11 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_ws_outbound_from_chat_message_includes_tool_results() {
|
fn test_ws_outbound_from_chat_message_omits_tool_results() {
|
||||||
let message = ChatMessage::tool("call-1", "calculator", "2");
|
let message = ChatMessage::tool("call-1", "calculator", "2");
|
||||||
|
|
||||||
let outbound = ws_outbound_from_chat_message(&message);
|
let outbound = ws_outbound_from_chat_message(&message);
|
||||||
|
|
||||||
assert_eq!(outbound.len(), 1);
|
assert!(outbound.is_empty());
|
||||||
assert!(matches!(outbound[0], WsOutbound::ToolResult { .. }));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_ws_outbound_from_chat_message_includes_tool_pending() {
|
|
||||||
let message = ChatMessage::tool_with_state(
|
|
||||||
"call-1",
|
|
||||||
"bash",
|
|
||||||
"等待你完成授权后再继续。",
|
|
||||||
ToolMessageState::PendingUserAction,
|
|
||||||
);
|
|
||||||
|
|
||||||
let outbound = ws_outbound_from_chat_message(&message);
|
|
||||||
|
|
||||||
assert_eq!(outbound.len(), 1);
|
|
||||||
assert!(matches!(outbound[0], WsOutbound::ToolPending { .. }));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_should_display_message_to_user_hides_completed_tool_results_by_default() {
|
|
||||||
let completed = ChatMessage::tool("call-1", "calculator", "2");
|
|
||||||
let pending = ChatMessage::tool_with_state(
|
|
||||||
"call-2",
|
|
||||||
"bash",
|
|
||||||
"waiting",
|
|
||||||
ToolMessageState::PendingUserAction,
|
|
||||||
);
|
|
||||||
|
|
||||||
assert!(!should_display_message_to_user(false, &completed));
|
|
||||||
assert!(should_display_message_to_user(false, &pending));
|
|
||||||
assert!(should_display_message_to_user(true, &completed));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,12 +5,6 @@
|
|||||||
|
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
||||||
pub enum ToolExecutionState {
|
|
||||||
Completed,
|
|
||||||
PendingUserAction,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Events emitted during agent and tool execution.
|
/// Events emitted during agent and tool execution.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum ObserverEvent {
|
pub enum ObserverEvent {
|
||||||
@ -66,8 +60,6 @@ pub struct ToolExecutionOutcome {
|
|||||||
pub error_reason: Option<String>,
|
pub error_reason: Option<String>,
|
||||||
/// How long the tool took to execute.
|
/// How long the tool took to execute.
|
||||||
pub duration: Duration,
|
pub duration: Duration,
|
||||||
/// Whether the tool completed or is waiting for external user action.
|
|
||||||
pub state: ToolExecutionState,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToolExecutionOutcome {
|
impl ToolExecutionOutcome {
|
||||||
@ -78,7 +70,6 @@ impl ToolExecutionOutcome {
|
|||||||
success: true,
|
success: true,
|
||||||
error_reason: None,
|
error_reason: None,
|
||||||
duration: Duration::ZERO,
|
duration: Duration::ZERO,
|
||||||
state: ToolExecutionState::Completed,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,18 +80,6 @@ impl ToolExecutionOutcome {
|
|||||||
success: true,
|
success: true,
|
||||||
error_reason: None,
|
error_reason: None,
|
||||||
duration,
|
duration,
|
||||||
state: ToolExecutionState::Completed,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a pending outcome with zero duration.
|
|
||||||
pub fn pending(output: String) -> Self {
|
|
||||||
Self {
|
|
||||||
output,
|
|
||||||
success: true,
|
|
||||||
error_reason: None,
|
|
||||||
duration: Duration::ZERO,
|
|
||||||
state: ToolExecutionState::PendingUserAction,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -111,7 +90,6 @@ impl ToolExecutionOutcome {
|
|||||||
success: false,
|
success: false,
|
||||||
error_reason,
|
error_reason,
|
||||||
duration: Duration::ZERO,
|
duration: Duration::ZERO,
|
||||||
state: ToolExecutionState::Completed,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -122,7 +100,6 @@ impl ToolExecutionOutcome {
|
|||||||
success: false,
|
success: false,
|
||||||
error_reason,
|
error_reason,
|
||||||
duration,
|
duration,
|
||||||
state: ToolExecutionState::Completed,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -224,7 +201,6 @@ mod tests {
|
|||||||
assert_eq!(outcome.output, "output content");
|
assert_eq!(outcome.output, "output content");
|
||||||
assert!(outcome.error_reason.is_none());
|
assert!(outcome.error_reason.is_none());
|
||||||
assert_eq!(outcome.duration, Duration::ZERO);
|
assert_eq!(outcome.duration, Duration::ZERO);
|
||||||
assert_eq!(outcome.state, ToolExecutionState::Completed);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -235,7 +211,6 @@ mod tests {
|
|||||||
);
|
);
|
||||||
assert!(outcome.success);
|
assert!(outcome.success);
|
||||||
assert_eq!(outcome.duration, Duration::from_millis(100));
|
assert_eq!(outcome.duration, Duration::from_millis(100));
|
||||||
assert_eq!(outcome.state, ToolExecutionState::Completed);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -248,14 +223,6 @@ mod tests {
|
|||||||
assert_eq!(outcome.output, "error output");
|
assert_eq!(outcome.output, "error output");
|
||||||
assert_eq!(outcome.error_reason, Some("error reason".to_string()));
|
assert_eq!(outcome.error_reason, Some("error reason".to_string()));
|
||||||
assert_eq!(outcome.duration, Duration::ZERO);
|
assert_eq!(outcome.duration, Duration::ZERO);
|
||||||
assert_eq!(outcome.state, ToolExecutionState::Completed);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_tool_execution_outcome_pending() {
|
|
||||||
let outcome = ToolExecutionOutcome::pending("waiting".to_string());
|
|
||||||
assert!(outcome.success);
|
|
||||||
assert_eq!(outcome.state, ToolExecutionState::PendingUserAction);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@ -88,15 +88,6 @@ pub enum WsOutbound {
|
|||||||
content: String,
|
content: String,
|
||||||
role: String,
|
role: String,
|
||||||
},
|
},
|
||||||
#[serde(rename = "tool_pending")]
|
|
||||||
ToolPending {
|
|
||||||
id: String,
|
|
||||||
tool_call_id: String,
|
|
||||||
tool_name: String,
|
|
||||||
content: String,
|
|
||||||
role: String,
|
|
||||||
resume_hint: String,
|
|
||||||
},
|
|
||||||
#[serde(rename = "error")]
|
#[serde(rename = "error")]
|
||||||
Error { code: String, message: String },
|
Error { code: String, message: String },
|
||||||
#[serde(rename = "session_established")]
|
#[serde(rename = "session_established")]
|
||||||
|
|||||||
@ -1030,7 +1030,6 @@ fn load_messages_after(
|
|||||||
timestamp: row.get(4)?,
|
timestamp: row.get(4)?,
|
||||||
tool_call_id: row.get(5)?,
|
tool_call_id: row.get(5)?,
|
||||||
tool_name: row.get(6)?,
|
tool_name: row.get(6)?,
|
||||||
tool_state: None,
|
|
||||||
tool_calls,
|
tool_calls,
|
||||||
})
|
})
|
||||||
})?;
|
})?;
|
||||||
|
|||||||
@ -1,21 +1,17 @@
|
|||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::process::Stdio;
|
use std::process::Stdio;
|
||||||
use std::sync::Arc;
|
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, BufReader};
|
use tokio::io::AsyncReadExt;
|
||||||
use tokio::process::Command;
|
use tokio::process::Command;
|
||||||
use tokio::sync::{Mutex, mpsc};
|
use tokio::time::timeout;
|
||||||
use tokio::time::{Instant, sleep_until};
|
|
||||||
|
|
||||||
use crate::tools::traits::{Tool, ToolResult};
|
use crate::tools::traits::{Tool, ToolResult};
|
||||||
|
|
||||||
const MAX_TIMEOUT_SECS: u64 = 600;
|
const MAX_TIMEOUT_SECS: u64 = 600;
|
||||||
const MAX_OUTPUT_CHARS: usize = 50_000;
|
const MAX_OUTPUT_CHARS: usize = 50_000;
|
||||||
const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
|
|
||||||
const USER_ACTION_HINT: &str = "该命令正在等待你完成外部操作。完成后请告诉我继续,或重新运行后续检查命令。";
|
|
||||||
|
|
||||||
pub struct BashTool {
|
pub struct BashTool {
|
||||||
timeout_secs: u64,
|
timeout_secs: u64,
|
||||||
@ -65,67 +61,18 @@ impl BashTool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn truncate_output(&self, output: &str) -> String {
|
fn truncate_output(&self, output: &str) -> String {
|
||||||
let char_count = output.chars().count();
|
if output.len() <= MAX_OUTPUT_CHARS {
|
||||||
if char_count <= MAX_OUTPUT_CHARS {
|
|
||||||
return output.to_string();
|
return output.to_string();
|
||||||
}
|
}
|
||||||
|
|
||||||
let half = MAX_OUTPUT_CHARS / 2;
|
let half = MAX_OUTPUT_CHARS / 2;
|
||||||
let head: String = output.chars().take(half).collect();
|
|
||||||
let tail: String = output
|
|
||||||
.chars()
|
|
||||||
.skip(char_count.saturating_sub(half))
|
|
||||||
.collect();
|
|
||||||
format!(
|
format!(
|
||||||
"{}...\n\n(... {} chars truncated ...)\n\n{}",
|
"{}...\n\n(... {} chars truncated ...)\n\n{}",
|
||||||
head,
|
&output[..half],
|
||||||
char_count - MAX_OUTPUT_CHARS,
|
output.len() - MAX_OUTPUT_CHARS,
|
||||||
tail
|
&output[output.len() - half..]
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn pending_output(&self, output: &str) -> String {
|
|
||||||
format!(
|
|
||||||
"{}\n{}\n\n{}",
|
|
||||||
PENDING_USER_ACTION_MARKER,
|
|
||||||
USER_ACTION_HINT,
|
|
||||||
self.truncate_output(output.trim())
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn should_return_pending(&self, interactive: bool, output: &str) -> bool {
|
|
||||||
let normalized = output.to_lowercase();
|
|
||||||
let has_auth_phrase = [
|
|
||||||
"等待用户授权",
|
|
||||||
"等待授权",
|
|
||||||
"等待你授权",
|
|
||||||
"在浏览器中打开以下链接进行认证",
|
|
||||||
"open the following link",
|
|
||||||
"waiting for authorization",
|
|
||||||
"waiting for user authorization",
|
|
||||||
"waiting for approval",
|
|
||||||
"device/verify",
|
|
||||||
"user_code=",
|
|
||||||
]
|
|
||||||
.iter()
|
|
||||||
.any(|pattern| normalized.contains(pattern));
|
|
||||||
|
|
||||||
has_auth_phrase || (interactive && !output.trim().is_empty())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn drain_available_chunks(
|
|
||||||
rx: &mut mpsc::UnboundedReceiver<(bool, String)>,
|
|
||||||
stdout_buf: &Arc<Mutex<String>>,
|
|
||||||
stderr_buf: &Arc<Mutex<String>>,
|
|
||||||
) {
|
|
||||||
while let Ok((is_stderr, chunk)) = rx.try_recv() {
|
|
||||||
if is_stderr {
|
|
||||||
stderr_buf.lock().await.push_str(&chunk);
|
|
||||||
} else {
|
|
||||||
stdout_buf.lock().await.push_str(&chunk);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for BashTool {
|
impl Default for BashTool {
|
||||||
@ -157,10 +104,6 @@ impl Tool for BashTool {
|
|||||||
"description": format!("Timeout in seconds (default {}, max {})", self.timeout_secs, MAX_TIMEOUT_SECS),
|
"description": format!("Timeout in seconds (default {}, max {})", self.timeout_secs, MAX_TIMEOUT_SECS),
|
||||||
"minimum": 1,
|
"minimum": 1,
|
||||||
"maximum": MAX_TIMEOUT_SECS
|
"maximum": MAX_TIMEOUT_SECS
|
||||||
},
|
|
||||||
"interactive": {
|
|
||||||
"type": "boolean",
|
|
||||||
"description": "Whether this command may enter a wait-for-user-action flow such as browser/device authentication"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": ["command"]
|
"required": ["command"]
|
||||||
@ -197,10 +140,6 @@ impl Tool for BashTool {
|
|||||||
.and_then(|v| v.as_u64())
|
.and_then(|v| v.as_u64())
|
||||||
.unwrap_or(self.timeout_secs)
|
.unwrap_or(self.timeout_secs)
|
||||||
.min(MAX_TIMEOUT_SECS);
|
.min(MAX_TIMEOUT_SECS);
|
||||||
let interactive = args
|
|
||||||
.get("interactive")
|
|
||||||
.and_then(|v| v.as_bool())
|
|
||||||
.unwrap_or(false);
|
|
||||||
|
|
||||||
let cwd = self
|
let cwd = self
|
||||||
.working_dir
|
.working_dir
|
||||||
@ -208,29 +147,37 @@ impl Tool for BashTool {
|
|||||||
.map(|d| Path::new(d))
|
.map(|d| Path::new(d))
|
||||||
.unwrap_or_else(|| Path::new("."));
|
.unwrap_or_else(|| Path::new("."));
|
||||||
|
|
||||||
match self.run_command(command, cwd, timeout_secs, interactive).await {
|
let result = timeout(
|
||||||
Ok(output) => Ok(ToolResult {
|
Duration::from_secs(timeout_secs),
|
||||||
|
self.run_command(command, cwd),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(Ok(output)) => Ok(ToolResult {
|
||||||
success: true,
|
success: true,
|
||||||
output,
|
output,
|
||||||
error: None,
|
error: None,
|
||||||
}),
|
}),
|
||||||
Err(e) => Ok(ToolResult {
|
Ok(Err(e)) => Ok(ToolResult {
|
||||||
success: false,
|
success: false,
|
||||||
output: String::new(),
|
output: String::new(),
|
||||||
error: Some(e),
|
error: Some(e),
|
||||||
}),
|
}),
|
||||||
|
Err(_) => Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(format!(
|
||||||
|
"Command timed out after {} seconds",
|
||||||
|
timeout_secs
|
||||||
|
)),
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BashTool {
|
impl BashTool {
|
||||||
async fn run_command(
|
async fn run_command(&self, command: &str, cwd: &Path) -> Result<String, String> {
|
||||||
&self,
|
|
||||||
command: &str,
|
|
||||||
cwd: &Path,
|
|
||||||
timeout_secs: u64,
|
|
||||||
interactive: bool,
|
|
||||||
) -> Result<String, String> {
|
|
||||||
let mut cmd = Command::new("bash");
|
let mut cmd = Command::new("bash");
|
||||||
cmd.args(["-c", command])
|
cmd.args(["-c", command])
|
||||||
.stdout(Stdio::piped())
|
.stdout(Stdio::piped())
|
||||||
@ -239,112 +186,48 @@ impl BashTool {
|
|||||||
|
|
||||||
let mut child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?;
|
let mut child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?;
|
||||||
|
|
||||||
let stdout = child.stdout.take();
|
let mut stdout = Vec::new();
|
||||||
let stderr = child.stderr.take();
|
let mut stderr = Vec::new();
|
||||||
let (tx, mut rx) = mpsc::unbounded_channel::<(bool, String)>();
|
|
||||||
|
|
||||||
if let Some(stdout) = stdout {
|
if let Some(ref mut out) = child.stdout {
|
||||||
tokio::spawn(read_stream(stdout, false, tx.clone()));
|
out.read_to_end(&mut stdout)
|
||||||
}
|
.await
|
||||||
if let Some(stderr) = stderr {
|
.map_err(|e| format!("Failed to read stdout: {}", e))?;
|
||||||
tokio::spawn(read_stream(stderr, true, tx.clone()));
|
|
||||||
}
|
|
||||||
drop(tx);
|
|
||||||
|
|
||||||
let stdout_buf = Arc::new(Mutex::new(String::new()));
|
|
||||||
let stderr_buf = Arc::new(Mutex::new(String::new()));
|
|
||||||
let deadline = Instant::now() + Duration::from_secs(timeout_secs);
|
|
||||||
|
|
||||||
loop {
|
|
||||||
tokio::select! {
|
|
||||||
status = child.wait() => {
|
|
||||||
let status = status.map_err(|e| format!("Failed to wait: {}", e))?;
|
|
||||||
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await;
|
|
||||||
while let Some((is_stderr, chunk)) = rx.recv().await {
|
|
||||||
if is_stderr {
|
|
||||||
stderr_buf.lock().await.push_str(&chunk);
|
|
||||||
} else {
|
|
||||||
stdout_buf.lock().await.push_str(&chunk);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let output = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, Some(status.code().unwrap_or(-1)));
|
|
||||||
return Ok(self.truncate_output(&output));
|
|
||||||
}
|
|
||||||
Some((is_stderr, chunk)) = rx.recv() => {
|
|
||||||
if is_stderr {
|
|
||||||
stderr_buf.lock().await.push_str(&chunk);
|
|
||||||
} else {
|
|
||||||
stdout_buf.lock().await.push_str(&chunk);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
|
if let Some(ref mut err) = child.stderr {
|
||||||
if self.should_return_pending(interactive, &combined) {
|
err.read_to_end(&mut stderr)
|
||||||
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await;
|
.await
|
||||||
let _ = child.start_kill();
|
.map_err(|e| format!("Failed to read stderr: {}", e))?;
|
||||||
let _ = child.wait().await;
|
|
||||||
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
|
|
||||||
return Ok(self.pending_output(&combined));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ = sleep_until(deadline) => {
|
|
||||||
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await;
|
|
||||||
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
|
|
||||||
let _ = child.start_kill();
|
|
||||||
let _ = child.wait().await;
|
|
||||||
if self.should_return_pending(interactive, &combined) {
|
|
||||||
return Ok(self.pending_output(&combined));
|
|
||||||
}
|
|
||||||
return Err(format!("Command timed out after {} seconds", timeout_secs));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn read_stream<R>(stream: R, is_stderr: bool, tx: mpsc::UnboundedSender<(bool, String)>)
|
let status = child
|
||||||
where
|
.wait()
|
||||||
R: AsyncRead + Unpin + Send + 'static,
|
.await
|
||||||
{
|
.map_err(|e| format!("Failed to wait: {}", e))?;
|
||||||
let mut reader = BufReader::new(stream);
|
|
||||||
let mut line = String::new();
|
|
||||||
|
|
||||||
loop {
|
|
||||||
line.clear();
|
|
||||||
match reader.read_line(&mut line).await {
|
|
||||||
Ok(0) => break,
|
|
||||||
Ok(_) => {
|
|
||||||
let _ = tx.send((is_stderr, line.clone()));
|
|
||||||
}
|
|
||||||
Err(_) => break,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut remainder = String::new();
|
|
||||||
if reader.read_to_string(&mut remainder).await.is_ok() && !remainder.is_empty() {
|
|
||||||
let _ = tx.send((is_stderr, remainder));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn format_command_output(stdout: &str, stderr: &str, exit_code: Option<i32>) -> String {
|
|
||||||
let mut output = String::new();
|
let mut output = String::new();
|
||||||
|
|
||||||
if !stdout.is_empty() {
|
if !stdout.is_empty() {
|
||||||
output.push_str(stdout);
|
let stdout_str = String::from_utf8_lossy(&stdout);
|
||||||
|
output.push_str(&stdout_str);
|
||||||
}
|
}
|
||||||
|
|
||||||
if !stderr.trim().is_empty() {
|
if !stderr.is_empty() {
|
||||||
|
let stderr_str = String::from_utf8_lossy(&stderr);
|
||||||
|
if !stderr_str.trim().is_empty() {
|
||||||
if !output.is_empty() {
|
if !output.is_empty() {
|
||||||
output.push_str("\n");
|
output.push_str("\n");
|
||||||
}
|
}
|
||||||
output.push_str("STDERR:\n");
|
output.push_str("STDERR:\n");
|
||||||
output.push_str(stderr);
|
output.push_str(&stderr_str);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(code) = exit_code {
|
output.push_str(&format!("\nExit code: {}", status.code().unwrap_or(-1)));
|
||||||
output.push_str(&format!("\nExit code: {}", code));
|
|
||||||
}
|
|
||||||
|
|
||||||
output
|
Ok(self.truncate_output(&output))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@ -429,31 +312,4 @@ mod tests {
|
|||||||
assert!(!result.success);
|
assert!(!result.success);
|
||||||
assert!(result.error.unwrap().contains("timed out"));
|
assert!(result.error.unwrap().contains("timed out"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_pending_user_action_detection() {
|
|
||||||
let tool = BashTool::new();
|
|
||||||
let result = tool
|
|
||||||
.execute(json!({
|
|
||||||
"command": "printf '在浏览器中打开以下链接进行认证:\n\nhttps://example.com/device/verify\n\n等待用户授权...\n'; sleep 10",
|
|
||||||
"timeout": 1
|
|
||||||
}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert!(result.success);
|
|
||||||
assert!(result.output.contains(PENDING_USER_ACTION_MARKER));
|
|
||||||
assert!(result.output.contains("等待用户授权"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_truncate_output_handles_utf8_char_boundaries() {
|
|
||||||
let tool = BashTool::new();
|
|
||||||
let input = "全".repeat(MAX_OUTPUT_CHARS + 100);
|
|
||||||
|
|
||||||
let output = tool.truncate_output(&input);
|
|
||||||
|
|
||||||
assert!(output.contains("chars truncated"));
|
|
||||||
assert!(output.is_char_boundary(output.len()));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -23,7 +23,7 @@ impl Tool for MemoryManageTool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn description(&self) -> &str {
|
fn description(&self) -> &str {
|
||||||
"Create, update, or delete long-term user memories stored in SQLite. Supports actions: put, update, delete. Use memory_search for all retrieval, including search, get, and list. Memories are scoped to the current channel and sender, and record the originating session/message when available."
|
"Manage user memories stored in SQLite. Supports actions: list, search, get, put, update, delete. Use search first when looking for user preferences, historical facts, prior decisions, or previously stored information. Search matches namespace, memory_key, and content. When searching, prefer bilingual queries that include both Chinese and English aliases, and include likely snake_case key terms when known. Memories are scoped to the current channel and sender, and record the originating session/message when available."
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parameters_schema(&self) -> serde_json::Value {
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
@ -32,13 +32,17 @@ impl Tool for MemoryManageTool {
|
|||||||
"properties": {
|
"properties": {
|
||||||
"action": {
|
"action": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": ["put", "update", "delete"],
|
"enum": ["list", "search", "get", "put", "update", "delete"],
|
||||||
"description": "Management action to perform. Use 'put' to create or overwrite, 'update' to modify an existing record, and 'delete' to remove one. Use memory_search for retrieval."
|
"description": "Management action to perform. Prefer 'search' for keyword lookup across stored memories, 'get' for an exact namespace/key lookup, and 'list' for browsing recent memories."
|
||||||
},
|
},
|
||||||
"namespace": {
|
"namespace": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Optional memory namespace filter, such as profile, preferences, or tasks"
|
"description": "Optional memory namespace filter, such as profile, preferences, or tasks"
|
||||||
},
|
},
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Keyword query for full-text memory search across namespace, memory_key, and content. Prefer concise bilingual keywords when possible, for example Chinese plus English aliases and likely snake_case key terms."
|
||||||
|
},
|
||||||
"key": {
|
"key": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Exact memory key within the namespace"
|
"description": "Exact memory key within the namespace"
|
||||||
@ -46,6 +50,12 @@ impl Tool for MemoryManageTool {
|
|||||||
"content": {
|
"content": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Memory content for put/update"
|
"description": "Memory content for put/update"
|
||||||
|
},
|
||||||
|
"limit": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Maximum number of memories to return",
|
||||||
|
"minimum": 1,
|
||||||
|
"default": 20
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": ["action"]
|
"required": ["action"]
|
||||||
@ -72,9 +82,56 @@ impl Tool for MemoryManageTool {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let namespace = args.get("namespace").and_then(|value| value.as_str());
|
let namespace = args.get("namespace").and_then(|value| value.as_str());
|
||||||
|
let query = args.get("query").and_then(|value| value.as_str());
|
||||||
let key = args.get("key").and_then(|value| value.as_str());
|
let key = args.get("key").and_then(|value| value.as_str());
|
||||||
|
|
||||||
let payload = match action {
|
let payload = match action {
|
||||||
|
"list" => {
|
||||||
|
let limit = args
|
||||||
|
.get("limit")
|
||||||
|
.and_then(|value| value.as_u64())
|
||||||
|
.unwrap_or(20) as usize;
|
||||||
|
let memories = self
|
||||||
|
.store
|
||||||
|
.list_memories("user", &scope_key, namespace, limit)?;
|
||||||
|
json!({
|
||||||
|
"count": memories.len(),
|
||||||
|
"memories": memories.into_iter().map(memory_to_json).collect::<Vec<_>>()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
"search" => {
|
||||||
|
let query = match query {
|
||||||
|
Some(query) if !query.trim().is_empty() => query,
|
||||||
|
_ => return Ok(error_result("Missing required parameter: query")),
|
||||||
|
};
|
||||||
|
let limit = args
|
||||||
|
.get("limit")
|
||||||
|
.and_then(|value| value.as_u64())
|
||||||
|
.unwrap_or(20) as usize;
|
||||||
|
let memories = self
|
||||||
|
.store
|
||||||
|
.search_memories("user", &scope_key, query, namespace, limit)?;
|
||||||
|
json!({
|
||||||
|
"query": query,
|
||||||
|
"count": memories.len(),
|
||||||
|
"memories": memories.into_iter().map(memory_to_json).collect::<Vec<_>>()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
"get" => {
|
||||||
|
let namespace = match namespace {
|
||||||
|
Some(namespace) => namespace,
|
||||||
|
None => return Ok(error_result("Missing required parameter: namespace")),
|
||||||
|
};
|
||||||
|
let key = match key {
|
||||||
|
Some(key) => key,
|
||||||
|
None => return Ok(error_result("Missing required parameter: key")),
|
||||||
|
};
|
||||||
|
|
||||||
|
match self.store.get_memory("user", &scope_key, namespace, key)? {
|
||||||
|
Some(memory) => memory_to_json(memory),
|
||||||
|
None => return Ok(error_result(&format!("memory '{}.{}' not found", namespace, key))),
|
||||||
|
}
|
||||||
|
}
|
||||||
"put" => {
|
"put" => {
|
||||||
let input = match build_memory_upsert(context, &scope_key, &args, true) {
|
let input = match build_memory_upsert(context, &scope_key, &args, true) {
|
||||||
Ok(input) => input,
|
Ok(input) => input,
|
||||||
@ -216,7 +273,7 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_memory_manage_put_returns_saved_memory() {
|
async fn test_memory_manage_put_and_get() {
|
||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
let tool = MemoryManageTool::new(store);
|
let tool = MemoryManageTool::new(store);
|
||||||
let context = ToolContext {
|
let context = ToolContext {
|
||||||
@ -241,8 +298,64 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert!(put.success);
|
assert!(put.success);
|
||||||
assert!(put.output.contains("Rust"));
|
|
||||||
assert!(put.output.contains("msg-1"));
|
let get = tool
|
||||||
|
.execute_with_context(
|
||||||
|
&context,
|
||||||
|
json!({
|
||||||
|
"action": "get",
|
||||||
|
"namespace": "profile",
|
||||||
|
"key": "language"
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(get.success);
|
||||||
|
assert!(get.output.contains("Rust"));
|
||||||
|
assert!(get.output.contains("msg-1"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_memory_manage_search() {
|
||||||
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
|
let tool = MemoryManageTool::new(store);
|
||||||
|
let context = ToolContext {
|
||||||
|
channel_name: Some("feishu".to_string()),
|
||||||
|
sender_id: Some("user-1".to_string()),
|
||||||
|
chat_id: Some("chat-1".to_string()),
|
||||||
|
session_id: Some("feishu:chat-1".to_string()),
|
||||||
|
message_id: Some("msg-1".to_string()),
|
||||||
|
message_seq: Some(1),
|
||||||
|
};
|
||||||
|
|
||||||
|
let put = tool
|
||||||
|
.execute_with_context(
|
||||||
|
&context,
|
||||||
|
json!({
|
||||||
|
"action": "put",
|
||||||
|
"namespace": "profile",
|
||||||
|
"key": "editor",
|
||||||
|
"content": "Prefers rust-analyzer over clippy hints"
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(put.success);
|
||||||
|
|
||||||
|
let search = tool
|
||||||
|
.execute_with_context(
|
||||||
|
&context,
|
||||||
|
json!({
|
||||||
|
"action": "search",
|
||||||
|
"query": "rust-analyzer",
|
||||||
|
"limit": 5
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(search.success);
|
||||||
|
assert!(search.output.contains("rust-analyzer"));
|
||||||
|
assert!(search.output.contains("editor"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@ -263,30 +376,4 @@ mod tests {
|
|||||||
assert!(!result.success);
|
assert!(!result.success);
|
||||||
assert!(result.error.unwrap().contains("channel_name"));
|
assert!(result.error.unwrap().contains("channel_name"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_memory_manage_rejects_read_actions() {
|
|
||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
|
||||||
let tool = MemoryManageTool::new(store);
|
|
||||||
let context = ToolContext {
|
|
||||||
channel_name: Some("feishu".to_string()),
|
|
||||||
sender_id: Some("user-1".to_string()),
|
|
||||||
..ToolContext::default()
|
|
||||||
};
|
|
||||||
|
|
||||||
let result = tool
|
|
||||||
.execute_with_context(
|
|
||||||
&context,
|
|
||||||
json!({
|
|
||||||
"action": "get",
|
|
||||||
"namespace": "profile",
|
|
||||||
"key": "language"
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert!(!result.success);
|
|
||||||
assert!(result.error.unwrap().contains("Unsupported action"));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
@ -1,289 +0,0 @@
|
|||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use serde_json::json;
|
|
||||||
|
|
||||||
use crate::storage::{MemoryRecord, SessionStore};
|
|
||||||
use crate::tools::traits::{Tool, ToolContext, ToolResult};
|
|
||||||
|
|
||||||
pub struct MemorySearchTool {
|
|
||||||
store: Arc<SessionStore>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MemorySearchTool {
|
|
||||||
pub fn new(store: Arc<SessionStore>) -> Self {
|
|
||||||
Self { store }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl Tool for MemorySearchTool {
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
"memory_search"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn description(&self) -> &str {
|
|
||||||
"Search and read long-term user memories stored in SQLite. Use this tool when you need prior preferences, stable facts, historical decisions, or ongoing task context. This tool is read-only and supports three actions: search for multi-keyword recall, get for exact namespace/key lookup, and list for browsing recent memories. Prefer this tool over memory_manage when you only need to retrieve memory."
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parameters_schema(&self) -> serde_json::Value {
|
|
||||||
json!({
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"action": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": ["search", "get", "list"],
|
|
||||||
"description": "Retrieval action. Use 'search' for multi-keyword recall, 'get' for an exact namespace/key read, and 'list' to browse recent memories."
|
|
||||||
},
|
|
||||||
"namespace": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Optional namespace filter, such as profile, preferences, tasks, or decisions. Required for get."
|
|
||||||
},
|
|
||||||
"queries": {
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"description": "Keyword queries for memory search. Provide multiple concise bilingual keywords, English aliases, and likely snake_case memory_key terms when known. Search matches any of the provided entries. Required for search.",
|
|
||||||
"minItems": 1
|
|
||||||
},
|
|
||||||
"key": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Exact memory key within the namespace. Required for get."
|
|
||||||
},
|
|
||||||
"limit": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "Maximum number of memories to return",
|
|
||||||
"minimum": 1,
|
|
||||||
"default": 10
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["action"]
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn execute(&self, _args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
|
||||||
Ok(error_result("memory_search requires tool context"))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn execute_with_context(
|
|
||||||
&self,
|
|
||||||
context: &ToolContext,
|
|
||||||
args: serde_json::Value,
|
|
||||||
) -> anyhow::Result<ToolResult> {
|
|
||||||
let action = match args.get("action").and_then(|value| value.as_str()) {
|
|
||||||
Some(action) => action,
|
|
||||||
None => return Ok(error_result("Missing required parameter: action")),
|
|
||||||
};
|
|
||||||
|
|
||||||
let scope_key = match scope_key_from_context(context) {
|
|
||||||
Ok(scope_key) => scope_key,
|
|
||||||
Err(result) => return Ok(result),
|
|
||||||
};
|
|
||||||
|
|
||||||
let namespace = args.get("namespace").and_then(|value| value.as_str());
|
|
||||||
let key = args.get("key").and_then(|value| value.as_str());
|
|
||||||
|
|
||||||
let payload = match action {
|
|
||||||
"list" => {
|
|
||||||
let limit = args
|
|
||||||
.get("limit")
|
|
||||||
.and_then(|value| value.as_u64())
|
|
||||||
.unwrap_or(10) as usize;
|
|
||||||
let memories = self.store.list_memories("user", &scope_key, namespace, limit)?;
|
|
||||||
json!({
|
|
||||||
"count": memories.len(),
|
|
||||||
"memories": memories.into_iter().map(memory_to_json).collect::<Vec<_>>()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
"search" => {
|
|
||||||
let queries = match args.get("queries").and_then(|value| value.as_array()) {
|
|
||||||
Some(queries) => queries
|
|
||||||
.iter()
|
|
||||||
.filter_map(|value| value.as_str())
|
|
||||||
.map(str::trim)
|
|
||||||
.filter(|value| !value.is_empty())
|
|
||||||
.map(ToOwned::to_owned)
|
|
||||||
.collect::<Vec<_>>(),
|
|
||||||
None => return Ok(error_result("Missing required parameter: queries")),
|
|
||||||
};
|
|
||||||
if queries.is_empty() {
|
|
||||||
return Ok(error_result("Missing required parameter: queries"));
|
|
||||||
}
|
|
||||||
let limit = args
|
|
||||||
.get("limit")
|
|
||||||
.and_then(|value| value.as_u64())
|
|
||||||
.unwrap_or(10) as usize;
|
|
||||||
let memories = self
|
|
||||||
.store
|
|
||||||
.search_memories_any("user", &scope_key, &queries, namespace, limit)?;
|
|
||||||
json!({
|
|
||||||
"queries": queries,
|
|
||||||
"count": memories.len(),
|
|
||||||
"memories": memories.into_iter().map(memory_to_json).collect::<Vec<_>>()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
"get" => {
|
|
||||||
let namespace = match namespace {
|
|
||||||
Some(namespace) => namespace,
|
|
||||||
None => return Ok(error_result("Missing required parameter: namespace")),
|
|
||||||
};
|
|
||||||
let key = match key {
|
|
||||||
Some(key) => key,
|
|
||||||
None => return Ok(error_result("Missing required parameter: key")),
|
|
||||||
};
|
|
||||||
|
|
||||||
match self.store.get_memory("user", &scope_key, namespace, key)? {
|
|
||||||
Some(memory) => memory_to_json(memory),
|
|
||||||
None => return Ok(error_result(&format!("memory '{}.{}' not found", namespace, key))),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => return Ok(error_result("Unsupported action")),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(ToolResult {
|
|
||||||
success: true,
|
|
||||||
output: serde_json::to_string_pretty(&payload)?,
|
|
||||||
error: None,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn read_only(&self) -> bool {
|
|
||||||
true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn scope_key_from_context(context: &ToolContext) -> Result<String, ToolResult> {
|
|
||||||
let channel_name = context
|
|
||||||
.channel_name
|
|
||||||
.as_deref()
|
|
||||||
.ok_or_else(|| error_result("memory_search requires channel_name in tool context"))?;
|
|
||||||
let sender_id = context
|
|
||||||
.sender_id
|
|
||||||
.as_deref()
|
|
||||||
.ok_or_else(|| error_result("memory_search requires sender_id in tool context"))?;
|
|
||||||
Ok(format!("{}:{}", channel_name, sender_id))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn memory_to_json(memory: MemoryRecord) -> serde_json::Value {
|
|
||||||
json!({
|
|
||||||
"id": memory.id,
|
|
||||||
"scope_kind": memory.scope_kind,
|
|
||||||
"scope_key": memory.scope_key,
|
|
||||||
"namespace": memory.namespace,
|
|
||||||
"key": memory.memory_key,
|
|
||||||
"content": memory.content,
|
|
||||||
"source_type": memory.source_type,
|
|
||||||
"source_session_id": memory.source_session_id,
|
|
||||||
"source_message_id": memory.source_message_id,
|
|
||||||
"source_message_seq": memory.source_message_seq,
|
|
||||||
"source_channel_name": memory.source_channel_name,
|
|
||||||
"source_chat_id": memory.source_chat_id,
|
|
||||||
"created_at": memory.created_at,
|
|
||||||
"updated_at": memory.updated_at,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn error_result(message: &str) -> ToolResult {
|
|
||||||
ToolResult {
|
|
||||||
success: false,
|
|
||||||
output: String::new(),
|
|
||||||
error: Some(message.to_string()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_memory_search_search_and_get() {
|
|
||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
|
||||||
store
|
|
||||||
.put_memory(&crate::storage::MemoryUpsert {
|
|
||||||
scope_kind: "user".to_string(),
|
|
||||||
scope_key: "feishu:user-1".to_string(),
|
|
||||||
namespace: "preferences".to_string(),
|
|
||||||
memory_key: "language".to_string(),
|
|
||||||
content: "User prefers Chinese responses".to_string(),
|
|
||||||
source_type: "message".to_string(),
|
|
||||||
source_session_id: Some("feishu:chat-1".to_string()),
|
|
||||||
source_message_id: Some("msg-1".to_string()),
|
|
||||||
source_message_seq: Some(1),
|
|
||||||
source_channel_name: Some("feishu".to_string()),
|
|
||||||
source_chat_id: Some("chat-1".to_string()),
|
|
||||||
})
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let tool = MemorySearchTool::new(store);
|
|
||||||
let context = ToolContext {
|
|
||||||
channel_name: Some("feishu".to_string()),
|
|
||||||
sender_id: Some("user-1".to_string()),
|
|
||||||
chat_id: Some("chat-1".to_string()),
|
|
||||||
session_id: Some("feishu:chat-1".to_string()),
|
|
||||||
message_id: Some("msg-2".to_string()),
|
|
||||||
message_seq: Some(2),
|
|
||||||
};
|
|
||||||
|
|
||||||
let search = tool
|
|
||||||
.execute_with_context(
|
|
||||||
&context,
|
|
||||||
json!({
|
|
||||||
"action": "search",
|
|
||||||
"queries": ["Chinese", "language"],
|
|
||||||
"limit": 5
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(search.success);
|
|
||||||
assert!(search.output.contains("language"));
|
|
||||||
|
|
||||||
let get = tool
|
|
||||||
.execute_with_context(
|
|
||||||
&context,
|
|
||||||
json!({
|
|
||||||
"action": "get",
|
|
||||||
"namespace": "preferences",
|
|
||||||
"key": "language"
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(get.success);
|
|
||||||
assert!(get.output.contains("Chinese"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_memory_search_is_read_only_and_requires_context() {
|
|
||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
|
||||||
let tool = MemorySearchTool::new(store);
|
|
||||||
|
|
||||||
assert!(tool.read_only());
|
|
||||||
|
|
||||||
let result = tool
|
|
||||||
.execute_with_context(&ToolContext::default(), json!({ "action": "list" }))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(!result.success);
|
|
||||||
assert!(result.error.unwrap().contains("channel_name"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_memory_search_search_requires_queries() {
|
|
||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
|
||||||
let tool = MemorySearchTool::new(store);
|
|
||||||
let context = ToolContext {
|
|
||||||
channel_name: Some("feishu".to_string()),
|
|
||||||
sender_id: Some("user-1".to_string()),
|
|
||||||
..ToolContext::default()
|
|
||||||
};
|
|
||||||
|
|
||||||
let result = tool
|
|
||||||
.execute_with_context(&context, json!({ "action": "search", "queries": [] }))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(!result.success);
|
|
||||||
assert!(result.error.unwrap().contains("queries"));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -5,7 +5,6 @@ pub mod file_read;
|
|||||||
pub mod file_write;
|
pub mod file_write;
|
||||||
pub mod http_request;
|
pub mod http_request;
|
||||||
pub mod memory_manage;
|
pub mod memory_manage;
|
||||||
pub mod memory_search;
|
|
||||||
pub mod registry;
|
pub mod registry;
|
||||||
pub mod schema;
|
pub mod schema;
|
||||||
pub mod skill_manage;
|
pub mod skill_manage;
|
||||||
@ -19,7 +18,6 @@ pub use file_read::FileReadTool;
|
|||||||
pub use file_write::FileWriteTool;
|
pub use file_write::FileWriteTool;
|
||||||
pub use http_request::HttpRequestTool;
|
pub use http_request::HttpRequestTool;
|
||||||
pub use memory_manage::MemoryManageTool;
|
pub use memory_manage::MemoryManageTool;
|
||||||
pub use memory_search::MemorySearchTool;
|
|
||||||
pub use registry::ToolRegistry;
|
pub use registry::ToolRegistry;
|
||||||
pub use schema::{CleaningStrategy, SchemaCleanr};
|
pub use schema::{CleaningStrategy, SchemaCleanr};
|
||||||
pub use skill_manage::{SkillListTool, SkillManageTool};
|
pub use skill_manage::{SkillListTool, SkillManageTool};
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user