Compare commits

...

4 Commits

11 changed files with 1388 additions and 33 deletions

1
.gitignore vendored
View File

@ -6,3 +6,4 @@ AGENTS.md
CLAUDE.md
Cargo.lock
.playwright-cli/
.venv

View File

@ -2,6 +2,23 @@ PicoBot
Skills (initial implementation)
Agent profile injection
PicoBot maintains a persistent agent profile file at ~/.picobot/agent/AGENT.md.
Behavior:
- The directory ~/.picobot/agent is created automatically when needed.
- If AGENT.md does not exist yet, PicoBot creates it with a default profile.
- When the active conversation is empty or has just been reset, AGENT.md is loaded as the first system message in the active history.
- After every configured number of user turns, PicoBot injects the latest AGENT.md content again before the next user message is appended.
Config:
- Set gateway.agent_prompt_reinject_every in ~/.picobot/config.json.
- Default value is 100.
- Set it to 0 to disable periodic reinjection.
This profile is persisted in session history, while the skills index system prompt is still injected dynamically by AgentLoop.
PicoBot now supports filesystem skills.
Skill discovery locations:

View File

@ -48,9 +48,11 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
2. `SessionManager` 定位到对应 channel 的运行时 `Session`
3. `Session::ensure_persistent_session(chat_id)` 确保数据库里有对应会话记录。
4. `Session::ensure_chat_loaded(chat_id)` 在内存中没有历史时,从 `messages` 表加载该会话全部历史。
5. 新的用户消息先写入 `messages`,再放入内存历史。
6. Agent 执行后产生的 assistant/tool 消息按实际顺序继续写入 `messages`
7. 下次进程重启或 session 过期后,可从数据库完整恢复上下文。
5. 如果当前活动段历史为空,系统会从 `~/.picobot/agent/AGENT.md` 读取 Agent 基本设定,并先写入一条 `system` 消息。
6. 在新的用户消息进入前,系统会检查当前活动段的 `user_turn_count` 是否刚跨过配置项 `gateway.agent_prompt_reinject_every` 指定的下一轮阈值;如果跨过,就再次把 `AGENT.md` 写入一条新的 `system` 消息。
7. 新的用户消息先写入 `messages`,再放入内存历史。
8. Agent 执行后产生的 assistant/tool 消息按实际顺序继续写入 `messages`
9. 下次进程重启或 session 过期后,可从数据库完整恢复上下文。
## 3. 会话标识规则
@ -86,6 +88,8 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
| `deleted_at` | `INTEGER` | 删除时间 | 预留字段,当前读取逻辑会过滤该字段,但当前删除实现是物理删除 |
| `message_count` | `INTEGER NOT NULL DEFAULT 0` | 消息数 | 追加消息时自增,清空历史时重置 |
| `reset_cutoff_seq` | `INTEGER NOT NULL DEFAULT 0` | 逻辑重置切点 | `/reset` 后默认只恢复 `seq > reset_cutoff_seq` 的活动段 |
| `user_turn_count` | `INTEGER NOT NULL DEFAULT 0` | 当前活动段用户轮次数 | 只在追加 `role = user` 消息时递增,清空历史和 `/reset` 时归零 |
| `agent_prompt_reinjection_count` | `INTEGER NOT NULL DEFAULT 0` | AGENT.md 周期重注入次数 | 每完成一次“达到配置阈值后的下一轮前注入”就递增,清空历史和 `/reset` 时归零 |
索引:
@ -171,6 +175,8 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
其中第 6 步很重要:归档会话一旦收到新消息,会自动恢复为活跃态。
另外,只有 `role = user` 的消息会递增 `user_turn_count``system``assistant``tool` 消息不会影响周期注入阈值的判定。
### 6.3 读取历史
`load_messages(session_id)` 会按 `seq ASC` 读取当前活动段历史,并把 JSON 字段反序列化回 `ChatMessage`。活动段的定义是:
@ -236,6 +242,8 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
- 删除该会话在 `messages` 中的所有记录
- 将 `sessions.message_count` 重置为 0
- 将 `sessions.reset_cutoff_seq` 重置为 0
- 将 `sessions.user_turn_count` 重置为 0
- 将 `sessions.agent_prompt_reinjection_count` 重置为 0
- 更新 `updated_at``last_active_at`
- 保留会话本身
@ -247,11 +255,15 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
- 不删除 `messages` 中的任何记录
- 将当前会话的 `MAX(seq)` 写入 `sessions.reset_cutoff_seq`
- 将 `sessions.user_turn_count` 重置为 0
- 将 `sessions.agent_prompt_reinjection_count` 重置为 0
- 更新 `updated_at``last_active_at`
- 后续默认恢复和发给模型的历史,只包含这次重置之后新增的消息
这适合“开始新对话,但保留完整历史以便审计或未来检索”的场景。
由于 AGENT.md 注入消息也会持久化,`/reset` 前的 Agent 设定消息仍会保留在完整历史中,但不会继续出现在新的活动段。下一次活动段首次加载时,系统会重新读取当前版本的 `~/.picobot/agent/AGENT.md`,并把它作为新的首条系统消息写入活动段。
### 8.5 删除会话
`delete_session(session_id)`

View File

@ -9,6 +9,7 @@ use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Mess
use crate::skills::SkillRuntime;
use crate::storage::SessionStore;
use crate::tools::{ToolContext, ToolRegistry};
use serde::Deserialize;
use std::collections::VecDeque;
use std::hash::{Hash, Hasher};
use std::io::Read;
@ -20,6 +21,12 @@ use std::time::Instant;
const MAX_TOOL_RESULT_CHARS: usize = 16_000;
/// Minimum characters to keep when truncating
const TRUNCATION_SUFFIX_LEN: usize = 200;
const MEMORY_AUTOSAVE_SYSTEM_PROMPT: &str = "你可以在处理任务过程中使用 memory_manage 工具维护长期记忆。仅在遇到高价值且未来仍有用的信息时写入记忆:用户长期偏好、稳定事实、用户对你的纠正、持续任务/项目上下文、明确决策。不要保存一次性工具结果、临时列表、敏感凭证或不确定推测。写入时优先使用规范 namespacepreferences、profile、tasks、decisions。若需要写入优先调用 memory_manage(action='put');同一 namespace/key 可直接覆盖更新。搜索记忆时,优先使用 memory_manage(action='search'),并尽量同时提供中文关键词、英文别名,以及可能的 snake_case memory_key 词,例如 email / 邮件 / email_folder_preference。";
const MEMORY_EXTRACTION_SYSTEM_PROMPT: &str = "你负责从一小段对话中提取值得长期记忆的信息。只返回 JSON 数组,不要输出解释或 Markdown。数组元素格式为 {\"namespace\": string, \"key\": string, \"content\": string}。只有在内容属于长期偏好、稳定事实、用户纠正、持续任务上下文或明确决策时才输出;否则返回 []。namespace 只能使用 preferences、profile、tasks、decisions。key 用简短稳定的 snake_case 英文标识。不要保存一次性工具结果、临时列表、敏感信息或猜测。最多输出 2 条。";
const MEMORY_EXTRACTION_REASONING_EFFORT: &str = "none";
const MEMORY_EXTRACTION_MAX_TOKENS: u32 = 192;
const MEMORY_EXTRACTION_CONTEXT_MESSAGES: usize = 4;
const MEMORY_EXTRACTION_MAX_CANDIDATES: usize = 2;
/// Build content blocks from text and media paths
fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock> {
@ -221,6 +228,7 @@ fn chat_message_to_llm_message(m: &ChatMessage) -> Message {
/// AgentLoop - Stateless agent that processes messages with tool calling support.
/// History is managed externally by SessionManager.
pub struct AgentLoop {
provider_config: LLMProviderConfig,
provider: Box<dyn LLMProvider>,
tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>,
@ -238,6 +246,19 @@ pub struct AgentProcessResult {
pub emitted_messages: Vec<ChatMessage>,
}
#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
struct MemoryCandidate {
namespace: String,
key: String,
content: String,
}
#[derive(Debug, Deserialize)]
struct MemoryCandidateEnvelope {
#[serde(default)]
memories: Vec<MemoryCandidate>,
}
#[async_trait]
pub trait EmittedMessageHandler: Send + Sync + 'static {
async fn handle(&self, message: ChatMessage);
@ -246,10 +267,11 @@ pub trait EmittedMessageHandler: Send + Sync + 'static {
impl AgentLoop {
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
let max_iterations = provider_config.max_tool_iterations;
let provider = create_provider(provider_config)
let provider = create_provider(provider_config.clone())
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
Ok(Self {
provider_config,
provider,
tools: Arc::new(ToolRegistry::new()),
skills: Arc::new(SkillRuntime::default()),
@ -264,10 +286,11 @@ impl AgentLoop {
pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>) -> Result<Self, AgentError> {
let max_iterations = provider_config.max_tool_iterations;
let provider = create_provider(provider_config)
let provider = create_provider(provider_config.clone())
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
Ok(Self {
provider_config,
provider,
tools,
skills: Arc::new(SkillRuntime::default()),
@ -286,10 +309,11 @@ impl AgentLoop {
skills: Arc<SkillRuntime>,
) -> Result<Self, AgentError> {
let max_iterations = provider_config.max_tool_iterations;
let provider = create_provider(provider_config)
let provider = create_provider(provider_config.clone())
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
Ok(Self {
provider_config,
provider,
tools,
skills,
@ -342,6 +366,7 @@ impl AgentLoop {
// Track tool calls for loop detection
let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default());
let mut emitted_messages = Vec::new();
let mut memory_write_occurred = false;
for iteration in 0..self.max_iterations {
#[cfg(debug_assertions)]
@ -352,6 +377,7 @@ impl AgentLoop {
if let Some(skill_prompt) = self.skills.system_index_prompt() {
messages_for_llm.push(Message::system(skill_prompt));
}
messages_for_llm.push(Message::system(MEMORY_AUTOSAVE_SYSTEM_PROMPT));
messages_for_llm.extend(messages.iter().map(chat_message_to_llm_message));
// Build request
@ -386,6 +412,9 @@ impl AgentLoop {
// If no tool calls, this is the final response
if response.tool_calls.is_empty() {
let assistant_message = ChatMessage::assistant(response.content);
if !memory_write_occurred {
self.maybe_extract_and_store_memories(&messages, &assistant_message).await?;
}
emitted_messages.push(assistant_message.clone());
return Ok(AgentProcessResult {
final_response: assistant_message,
@ -407,6 +436,14 @@ impl AgentLoop {
// Execute tools and add results to messages
let tool_results = self.execute_tools(&response.tool_calls).await;
if response
.tool_calls
.iter()
.zip(tool_results.iter())
.any(|(tool_call, result)| did_successfully_write_memory(tool_call, result))
{
memory_write_occurred = true;
}
for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) {
// Log function call with name and arguments
@ -470,6 +507,7 @@ impl AgentLoop {
if let Some(skill_prompt) = self.skills.system_index_prompt() {
messages_for_llm.push(Message::system(skill_prompt));
}
messages_for_llm.push(Message::system(MEMORY_AUTOSAVE_SYSTEM_PROMPT));
messages_for_llm.extend(messages.iter().map(chat_message_to_llm_message));
let request = ChatCompletionRequest {
@ -482,6 +520,9 @@ impl AgentLoop {
match (*self.provider).chat(request).await {
Ok(response) => {
let assistant_message = ChatMessage::assistant(response.content);
if !memory_write_occurred {
self.maybe_extract_and_store_memories(&messages, &assistant_message).await?;
}
emitted_messages.push(assistant_message.clone());
Ok(AgentProcessResult {
final_response: assistant_message,
@ -698,6 +739,227 @@ impl AgentLoop {
}
}
fn did_successfully_write_memory(tool_call: &ToolCall, result: &ToolExecutionOutcome) -> bool {
if !result.success || tool_call.name != "memory_manage" {
return false;
}
matches!(
tool_call.arguments.get("action").and_then(|value| value.as_str()),
Some("put") | Some("update")
)
}
fn lightweight_provider_config(provider_config: &LLMProviderConfig) -> LLMProviderConfig {
let mut config = provider_config.clone();
if config.provider_type == "openai" {
config.model_extra.insert(
"reasoning_effort".to_string(),
serde_json::Value::String(MEMORY_EXTRACTION_REASONING_EFFORT.to_string()),
);
}
config
}
fn build_memory_extraction_context(messages: &[ChatMessage], final_response: &ChatMessage) -> String {
let recent_messages: Vec<&ChatMessage> = messages
.iter()
.rev()
.filter(|message| matches!(message.role.as_str(), "user" | "assistant"))
.take(MEMORY_EXTRACTION_CONTEXT_MESSAGES)
.collect::<Vec<_>>()
.into_iter()
.rev()
.collect();
let mut lines = vec!["请判断这轮对话中是否有值得长期记忆的信息:".to_string()];
for message in recent_messages {
let content = message.content.trim();
if content.is_empty() {
continue;
}
lines.push(format!(
"{}: {}",
message.role,
truncate_text_for_memory_extraction(content)
));
}
let final_content = final_response.content.trim();
if !final_content.is_empty() {
lines.push(format!(
"assistant_final: {}",
truncate_text_for_memory_extraction(final_content)
));
}
lines.join("\n")
}
fn truncate_text_for_memory_extraction(text: &str) -> String {
if text.chars().count() <= 240 {
return text.to_string();
}
let prefix: String = text.chars().take(240).collect();
format!("{}...", prefix)
}
fn parse_memory_candidates(raw: &str) -> Vec<MemoryCandidate> {
let trimmed = raw.trim();
let json_payload = trimmed
.strip_prefix("```json")
.and_then(|value| value.strip_suffix("```"))
.map(str::trim)
.or_else(|| {
trimmed
.strip_prefix("```")
.and_then(|value| value.strip_suffix("```"))
.map(str::trim)
})
.unwrap_or(trimmed);
if let Ok(candidates) = serde_json::from_str::<Vec<MemoryCandidate>>(json_payload) {
return normalize_memory_candidates(candidates);
}
if let Ok(envelope) = serde_json::from_str::<MemoryCandidateEnvelope>(json_payload) {
return normalize_memory_candidates(envelope.memories);
}
Vec::new()
}
fn normalize_memory_candidates(candidates: Vec<MemoryCandidate>) -> Vec<MemoryCandidate> {
let mut normalized = Vec::new();
for candidate in candidates {
let Some(namespace) = canonical_memory_namespace(&candidate.namespace) else {
continue;
};
let key = canonical_memory_key(&candidate.key);
let content = candidate.content.trim().to_string();
if key.is_empty() || content.len() < 8 {
continue;
}
normalized.push(MemoryCandidate {
namespace,
key,
content,
});
if normalized.len() >= MEMORY_EXTRACTION_MAX_CANDIDATES {
break;
}
}
normalized
}
fn canonical_memory_namespace(namespace: &str) -> Option<String> {
match namespace.trim().to_lowercase().as_str() {
"preferences" | "preference" => Some("preferences".to_string()),
"profile" | "profiles" => Some("profile".to_string()),
"tasks" | "task" => Some("tasks".to_string()),
"decisions" | "decision" => Some("decisions".to_string()),
_ => None,
}
}
fn canonical_memory_key(key: &str) -> String {
let mut normalized = String::new();
let mut prev_is_sep = false;
for ch in key.trim().chars() {
let mapped = if ch.is_ascii_alphanumeric() {
prev_is_sep = false;
Some(ch.to_ascii_lowercase())
} else if matches!(ch, '_' | '-' | '/' | ' ') {
if prev_is_sep {
None
} else {
prev_is_sep = true;
Some('_')
}
} else {
None
};
if let Some(ch) = mapped {
normalized.push(ch);
}
}
normalized.trim_matches('_').to_string()
}
impl AgentLoop {
async fn maybe_extract_and_store_memories(
&self,
messages: &[ChatMessage],
final_response: &ChatMessage,
) -> Result<usize, AgentError> {
let Some(tool) = self.tools.get("memory_manage") else {
return Ok(0);
};
if self.tool_context.channel_name.is_none() || self.tool_context.sender_id.is_none() {
return Ok(0);
}
let context = build_memory_extraction_context(messages, final_response);
let provider = create_provider(lightweight_provider_config(&self.provider_config))
.map_err(|err| AgentError::ProviderCreation(err.to_string()))?;
let request = ChatCompletionRequest {
messages: vec![
Message::system(MEMORY_EXTRACTION_SYSTEM_PROMPT),
Message::user(context),
],
temperature: Some(0.0),
max_tokens: Some(MEMORY_EXTRACTION_MAX_TOKENS),
tools: None,
};
let response = provider
.chat(request)
.await
.map_err(|err| AgentError::LlmError(err.to_string()))?;
let candidates = parse_memory_candidates(&response.content);
if candidates.is_empty() {
tracing::debug!("No auto-save memory candidates extracted from final response");
return Ok(0);
}
tracing::info!(candidate_count = candidates.len(), candidates = ?candidates, "Auto-save memory candidates extracted");
let mut saved = 0;
for candidate in candidates {
let result = tool
.execute_with_context(
&self.tool_context,
serde_json::json!({
"action": "put",
"namespace": candidate.namespace,
"key": candidate.key,
"content": candidate.content,
}),
)
.await
.map_err(|err| AgentError::Other(format!("auto-save memory tool error: {}", err)))?;
if result.success {
saved += 1;
} else {
tracing::warn!(error = result.error.as_deref().unwrap_or("unknown"), "Auto-save memory write failed");
}
}
tracing::info!(saved_count = saved, "Auto-save memory extraction completed");
Ok(saved)
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -773,6 +1035,73 @@ mod tests {
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].id, "call_1");
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].name, "calculator");
}
#[test]
fn test_did_successfully_write_memory_only_accepts_successful_put_or_update() {
let tool_call = ToolCall {
id: "call_1".to_string(),
name: "memory_manage".to_string(),
arguments: serde_json::json!({ "action": "put" }),
};
assert!(did_successfully_write_memory(
&tool_call,
&ToolExecutionOutcome::success("ok".to_string())
));
let failed = ToolExecutionOutcome::failure("err".to_string(), Some("boom".to_string()));
assert!(!did_successfully_write_memory(&tool_call, &failed));
let search_call = ToolCall {
id: "call_2".to_string(),
name: "memory_manage".to_string(),
arguments: serde_json::json!({ "action": "search" }),
};
assert!(!did_successfully_write_memory(
&search_call,
&ToolExecutionOutcome::success("ok".to_string())
));
}
#[test]
fn test_parse_memory_candidates_normalizes_and_limits_results() {
let raw = r#"```json
{"memories": [
{"namespace": "preference", "key": "Email Folder Preference", "content": "用户提到邮件时默认查看代收邮箱而不是收件箱。"},
{"namespace": "decision", "key": "mailbox strategy", "content": "后续默认先查看代收邮箱。"},
{"namespace": "tasks", "key": "short", "content": "太短"}
]}
```"#;
let candidates = parse_memory_candidates(raw);
assert_eq!(candidates.len(), 2);
assert_eq!(candidates[0].namespace, "preferences");
assert_eq!(candidates[0].key, "email_folder_preference");
assert_eq!(candidates[1].namespace, "decisions");
assert_eq!(candidates[1].key, "mailbox_strategy");
}
#[test]
fn test_build_memory_extraction_context_uses_recent_user_and_assistant_messages() {
let messages = vec![
ChatMessage::system("system"),
ChatMessage::tool("call-1", "calculator", "2"),
ChatMessage::user("first user"),
ChatMessage::assistant("first assistant"),
ChatMessage::user("second user"),
ChatMessage::assistant("second assistant"),
ChatMessage::user("third user"),
];
let final_response = ChatMessage::assistant("final assistant");
let context = build_memory_extraction_context(&messages, &final_response);
assert!(context.contains("user: first user") == false);
assert!(context.contains("user: second user"));
assert!(context.contains("assistant: second assistant"));
assert!(context.contains("user: third user"));
assert!(context.contains("assistant_final: final assistant"));
assert!(!context.contains("tool:"));
}
}
#[derive(Debug)]

View File

@ -412,16 +412,19 @@ impl FeishuChannel {
return Err(ChannelError::Other(format!("File download failed {}: {}", status, error_text)));
}
let response_headers = resp.headers().clone();
let data = resp.bytes().await
.map_err(|e| ChannelError::Other(format!("Failed to read file data: {}", e)))?
.to_vec();
let extension = match file_type {
"audio" => "mp3",
"video" => "mp4",
_ => "bin",
};
let filename = format!("{}_{}.{}", message_id, &file_key[..8.min(file_key.len())], extension);
let filename = infer_download_filename(
content_json,
&response_headers,
message_id,
file_key,
file_type,
);
let file_path = media_dir.join(&filename);
tokio::fs::write(&file_path, &data).await
@ -437,6 +440,15 @@ impl FeishuChannel {
Ok((format!("[{}: {}]", file_type, filename), Some(media_item)))
}
fn fallback_download_filename(message_id: &str, file_key: &str, file_type: &str) -> String {
let extension = match file_type {
"audio" => "mp3",
"video" => "mp4",
_ => "bin",
};
format!("{}_{}.{}", message_id, &file_key[..8.min(file_key.len())], extension)
}
/// Upload image to Feishu and return the image_key
async fn upload_image(&self, file_path: &str) -> Result<String, ChannelError> {
let token = self.get_tenant_access_token().await?;
@ -1920,9 +1932,83 @@ impl FeishuChannel {
}
}
fn infer_download_filename(
content_json: &serde_json::Value,
headers: &reqwest::header::HeaderMap,
message_id: &str,
file_key: &str,
file_type: &str,
) -> String {
if let Some(file_name) = extract_original_file_name(content_json, headers) {
let sanitized = sanitize_download_file_name(&file_name);
if !sanitized.is_empty() {
return format!("{}_{}", message_id, sanitized);
}
}
FeishuChannel::fallback_download_filename(message_id, file_key, file_type)
}
fn extract_original_file_name(
content_json: &serde_json::Value,
headers: &reqwest::header::HeaderMap,
) -> Option<String> {
let content_name = ["file_name", "filename", "name"]
.into_iter()
.find_map(|key| content_json.get(key).and_then(|value| value.as_str()))
.map(str::trim)
.filter(|value| !value.is_empty())
.map(ToOwned::to_owned);
if content_name.is_some() {
return content_name;
}
extract_file_name_from_content_disposition(headers)
}
fn extract_file_name_from_content_disposition(
headers: &reqwest::header::HeaderMap,
) -> Option<String> {
let header = headers
.get(reqwest::header::CONTENT_DISPOSITION)
.and_then(|value| value.to_str().ok())?;
for segment in header.split(';').map(str::trim) {
if let Some(value) = segment.strip_prefix("filename*=") {
let decoded = value.split("''").last().unwrap_or(value).trim_matches('"');
if !decoded.is_empty() {
return Some(decoded.to_string());
}
}
if let Some(value) = segment.strip_prefix("filename=") {
let cleaned = value.trim_matches('"').trim();
if !cleaned.is_empty() {
return Some(cleaned.to_string());
}
}
}
None
}
fn sanitize_download_file_name(file_name: &str) -> String {
file_name
.chars()
.map(|ch| match ch {
'/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_',
_ => ch,
})
.collect::<String>()
.trim_matches('.')
.trim()
.to_string()
}
#[cfg(test)]
mod tests {
use super::{FeishuChannel, MsgFormat};
use super::{extract_file_name_from_content_disposition, infer_download_filename, sanitize_download_file_name, FeishuChannel, MsgFormat};
#[test]
fn markdown_post_uses_md_tag() {
@ -1945,6 +2031,65 @@ mod tests {
let content = "intro\n## heading";
assert_eq!(FeishuChannel::detect_msg_format(content), MsgFormat::Interactive);
}
#[test]
fn infer_download_filename_prefers_original_file_name() {
let content = serde_json::json!({
"file_key": "file_key_123",
"file_name": "demo-archive.zip"
});
let headers = reqwest::header::HeaderMap::new();
let filename = infer_download_filename(&content, &headers, "om_123", "file_key_123", "file");
assert_eq!(filename, "om_123_demo-archive.zip");
}
#[test]
fn infer_download_filename_uses_content_disposition_when_message_lacks_name() {
let content = serde_json::json!({
"file_key": "file_key_123"
});
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_DISPOSITION,
reqwest::header::HeaderValue::from_static("attachment; filename=meeting-notes.zip"),
);
let filename = infer_download_filename(&content, &headers, "om_123", "file_key_123", "file");
assert_eq!(filename, "om_123_meeting-notes.zip");
}
#[test]
fn infer_download_filename_falls_back_to_bin_without_name() {
let content = serde_json::json!({
"file_key": "file_key_123"
});
let headers = reqwest::header::HeaderMap::new();
let filename = infer_download_filename(&content, &headers, "om_123", "file_key_123", "file");
assert_eq!(filename, "om_123_file_key.bin");
}
#[test]
fn sanitize_download_file_name_replaces_path_separators() {
let sanitized = sanitize_download_file_name("../../demo/archive.zip");
assert_eq!(sanitized, "_.._demo_archive.zip");
}
#[test]
fn extract_file_name_from_content_disposition_supports_filename_star() {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_DISPOSITION,
reqwest::header::HeaderValue::from_static("attachment; filename*=UTF-8''archive.zip"),
);
let file_name = extract_file_name_from_content_disposition(&headers);
assert_eq!(file_name.as_deref(), Some("archive.zip"));
}
}
#[async_trait]

View File

@ -136,6 +136,8 @@ pub struct GatewayConfig {
pub port: u16,
#[serde(default, rename = "session_ttl_hours")]
pub session_ttl_hours: Option<u64>,
#[serde(default = "default_agent_prompt_reinject_every", rename = "agent_prompt_reinject_every")]
pub agent_prompt_reinject_every: u64,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
@ -156,12 +158,17 @@ fn default_gateway_url() -> String {
"ws://127.0.0.1:19876/ws".to_string()
}
fn default_agent_prompt_reinject_every() -> u64 {
100
}
impl Default for GatewayConfig {
fn default() -> Self {
Self {
host: default_gateway_host(),
port: default_gateway_port(),
session_ttl_hours: None,
agent_prompt_reinject_every: default_agent_prompt_reinject_every(),
}
}
}
@ -344,7 +351,8 @@ mod tests {
},
"gateway": {
"host": "0.0.0.0",
"port": 19876
"port": 19876,
"agent_prompt_reinject_every": 120
}
}"#,
)
@ -387,5 +395,39 @@ mod tests {
let config = Config::load(file.path().to_str().unwrap()).unwrap();
assert_eq!(config.gateway.host, "0.0.0.0");
assert_eq!(config.gateway.port, 19876);
assert_eq!(config.gateway.agent_prompt_reinject_every, 120);
}
#[test]
fn test_gateway_config_defaults_agent_prompt_reinject_every() {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {}
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus"
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
}
}"#,
)
.unwrap();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
assert_eq!(config.gateway.agent_prompt_reinject_every, 100);
}
}

View File

@ -29,10 +29,16 @@ impl GatewayState {
// Session TTL from config (default 4 hours)
let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4);
let agent_prompt_reinject_every = config.gateway.agent_prompt_reinject_every;
let skills = Arc::new(SkillRuntime::from_config(config.skills.clone()));
let session_manager = SessionManager::new(session_ttl_hours, provider_config, skills)?;
let session_manager = SessionManager::new(
session_ttl_hours,
agent_prompt_reinject_every,
provider_config,
skills,
)?;
let channel_manager = ChannelManager::new();
let bus = channel_manager.bus();

View File

@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::fs;
use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
@ -7,15 +8,22 @@ use uuid::Uuid;
use crate::bus::{ChatMessage, MessageBus, OutboundMessage};
use crate::config::LLMProviderConfig;
use crate::agent::{AgentLoop, AgentError, ContextCompressor, EmittedMessageHandler};
use crate::providers::{create_provider, ChatCompletionRequest, Message};
use crate::protocol::WsOutbound;
use crate::skills::SkillRuntime;
use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
use crate::storage::{MemoryRecord, SessionRecord, SessionStore, persistent_session_id};
use crate::tools::{
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
HttpRequestTool, MemoryManageTool, SkillListTool, SkillManageTool, ToolContext, ToolRegistry,
WebFetchTool,
};
const DEFAULT_AGENT_PROMPT: &str = "# PicoBot 代理配置\n\n## 身份\n- 你是 PicoBot一名务实、可靠的通用助理。\n- 你的目标是理解用户当下的真实需求,并给出清晰、可执行的帮助。\n\n## 工作方式\n- 优先理解意图,再给出回应或行动。\n- 保持简洁、准确、自然,不故作热情,也不空泛铺陈。\n- 能直接验证的内容尽量先验证,避免凭空猜测。\n- 当现有工具是完成任务的最直接方式时,优先使用工具。\n- 除非用户明确要求改变方向,否则保持用户原本目标不变。\n\n## 助理原则\n- 优先解决问题,而不是展示过程。\n- 输出要方便用户立即使用,结论尽量明确。\n- 对不确定的地方要直说,不把猜测包装成事实。\n- 复杂任务先收敛重点,简单任务直接给结果。\n- 避免不必要的重复、客套和冗长说明。\n\n## 回复规则\n- 除非用户另有要求,否则使用中文回复。\n- 默认短而清楚,按信息密度组织内容。\n- 如果任务涉及文件、命令、配置或下一步操作,优先给出最关键的那部分。\n- 如果存在限制、风险或前提条件,要直接说明。\n\n## 补充要求\n- 你是 PicoBot。\n- 回答应以帮助用户完成当前目标为中心。\n- 在信息不足时先补关键前提,在信息充分时直接执行。\n";
const MEMORY_KEYWORD_SYSTEM_PROMPT: &str = "你负责为长期记忆检索生成关键词。根据给定的最近会话,仅输出 JSON 数组字符串。关键词必须同时覆盖中文和英文优先为同一主题同时给出中文关键词和对应英文关键词。关键词必须是短词语优先使用最容易命中记忆的核心检索词不要输出完整句子、解释或长描述。必要时优先保留实体名、产品名、偏好名、snake_case key 风格短词。数组元素总数控制在 2 到 6 个简短关键词或短语。不要输出解释,不要输出 Markdown。";
const RELATED_MEMORY_SYSTEM_PROMPT_PREFIX: &str = "找到相关的记忆。你必须优先参考这些记忆,并在后续推理中把它们当作当前会话的补充上下文;若与用户本轮明确要求冲突,以用户本轮要求为准。";
const MEMORY_KEYWORD_REASONING_EFFORT: &str = "none";
const MEMORY_KEYWORD_MAX_CHARS: usize = 32;
/// Session 按 channel 隔离,每个 channel 一个 Session
/// History 按 chat_id 隔离,由 Session 统一管理
pub struct Session {
@ -29,6 +37,7 @@ pub struct Session {
skills: Arc<SkillRuntime>,
compressor: ContextCompressor,
store: Arc<SessionStore>,
agent_prompt_reinject_every: i64,
}
pub struct BusToolCallEmitter {
@ -79,6 +88,7 @@ impl Session {
tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>,
store: Arc<SessionStore>,
agent_prompt_reinject_every: u64,
) -> Result<Self, AgentError> {
Ok(Self {
id: Uuid::new_v4(),
@ -90,6 +100,7 @@ impl Session {
skills,
compressor: ContextCompressor::new(provider_config.token_limit),
store,
agent_prompt_reinject_every: agent_prompt_reinject_every as i64,
})
}
@ -105,7 +116,7 @@ impl Session {
pub fn ensure_chat_loaded(&mut self, chat_id: &str) -> Result<(), AgentError> {
if self.chat_histories.contains_key(chat_id) {
return Ok(());
return self.ensure_initial_agent_prompt(chat_id);
}
let history = self
@ -113,6 +124,37 @@ impl Session {
.load_messages(&self.persistent_session_id(chat_id))
.map_err(|err| AgentError::Other(format!("session history load error: {}", err)))?;
self.chat_histories.insert(chat_id.to_string(), history);
self.ensure_initial_agent_prompt(chat_id)?;
Ok(())
}
pub fn ensure_agent_prompt_before_user_message(&mut self, chat_id: &str) -> Result<(), AgentError> {
self.ensure_chat_loaded(chat_id)?;
let session_id = self.persistent_session_id(chat_id);
let session_record = self
.store
.get_session(&session_id)
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))?
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
let active_user_turns = self
.store
.count_active_user_messages(&session_id)
.map_err(|err| AgentError::Other(format!("count active user messages error: {}", err)))?;
if self.agent_prompt_reinject_every > 0
&& active_user_turns > 0
&& active_user_turns / self.agent_prompt_reinject_every
> session_record.agent_prompt_reinjection_count
{
if let Some(agent_prompt) = load_agent_prompt()? {
self.append_persisted_message(chat_id, ChatMessage::system(agent_prompt))?;
self.store
.mark_agent_prompt_reinjected(&session_id)
.map_err(|err| AgentError::Other(format!("mark agent prompt reinjection error: {}", err)))?;
}
}
Ok(())
}
@ -264,6 +306,51 @@ impl Session {
})
})
}
fn ensure_initial_agent_prompt(&mut self, chat_id: &str) -> Result<(), AgentError> {
let history_is_empty = self
.get_history(chat_id)
.map(|history| history.is_empty())
.unwrap_or(true);
if !history_is_empty {
return Ok(());
}
if let Some(agent_prompt) = load_agent_prompt()? {
self.append_persisted_message(chat_id, ChatMessage::system(agent_prompt))?;
}
Ok(())
}
}
fn load_agent_prompt() -> Result<Option<String>, AgentError> {
let path = agent_prompt_path()?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.map_err(|err| AgentError::Other(format!("create agent prompt dir error: {}", err)))?;
}
if !path.exists() {
fs::write(&path, DEFAULT_AGENT_PROMPT)
.map_err(|err| AgentError::Other(format!("create agent prompt file error: {}", err)))?;
}
let content = fs::read_to_string(&path)
.map_err(|err| AgentError::Other(format!("read agent prompt file error: {}", err)))?;
let trimmed = content.trim();
if trimmed.is_empty() {
return Ok(None);
}
Ok(Some(trimmed.to_string()))
}
fn agent_prompt_path() -> Result<std::path::PathBuf, AgentError> {
let home = dirs::home_dir()
.ok_or_else(|| AgentError::Other("home directory not found".to_string()))?;
Ok(home.join(".picobot").join("agent").join("AGENT.md"))
}
/// SessionManager 管理所有 Session按 channel_name 路由
@ -274,6 +361,7 @@ pub struct SessionManager {
tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>,
store: Arc<SessionStore>,
agent_prompt_reinject_every: u64,
}
struct SessionManagerInner {
@ -331,6 +419,7 @@ pub(crate) fn handle_in_chat_command(
impl SessionManager {
pub fn new(
session_ttl_hours: u64,
agent_prompt_reinject_every: u64,
provider_config: LLMProviderConfig,
skills: Arc<SkillRuntime>,
) -> Result<Self, AgentError> {
@ -353,6 +442,7 @@ impl SessionManager {
tools: Arc::new(default_tools(skills.clone(), store.clone())),
skills,
store,
agent_prompt_reinject_every,
})
}
@ -447,6 +537,7 @@ impl SessionManager {
self.tools.clone(),
self.skills.clone(),
self.store.clone(),
self.agent_prompt_reinject_every,
)
.await?;
let arc = Arc::new(Mutex::new(session));
@ -523,6 +614,8 @@ impl SessionManager {
)]);
}
session_guard.ensure_agent_prompt_before_user_message(chat_id)?;
// 添加用户消息到历史
let media_refs: Vec<String> = media.iter().map(|m| m.path.clone()).collect();
#[cfg(debug_assertions)]
@ -534,7 +627,42 @@ impl SessionManager {
session_guard.append_persisted_message(chat_id, user_message)?;
// 获取完整历史
let history = session_guard.get_or_create_history(chat_id).clone();
let mut history = session_guard.get_or_create_history(chat_id).clone();
tracing::info!(
channel = %channel_name,
chat_id = %chat_id,
sender_id = %sender_id,
history_len = history.len(),
"Starting synchronous related memory search"
);
if let Some(memory_prompt) = build_related_memory_prompt(
session_guard.provider_config().clone(),
self.store.clone(),
channel_name.to_string(),
sender_id.to_string(),
chat_id.to_string(),
history.clone(),
)
.await?
{
tracing::info!(
channel = %channel_name,
chat_id = %chat_id,
sender_id = %sender_id,
prompt_len = memory_prompt.len(),
"Injecting related memory system prompt before agent processing"
);
let memory_message = ChatMessage::system(memory_prompt);
session_guard.append_persisted_message(chat_id, memory_message.clone())?;
history.push(memory_message);
} else {
tracing::info!(
channel = %channel_name,
chat_id = %chat_id,
sender_id = %sender_id,
"No related memory prompt generated before agent processing"
);
}
// 压缩历史(如果需要)
let history = session_guard.compressor
@ -590,6 +718,233 @@ impl SessionManager {
}
}
async fn build_related_memory_prompt(
provider_config: LLMProviderConfig,
store: Arc<SessionStore>,
channel_name: String,
sender_id: String,
chat_id: String,
history: Vec<ChatMessage>,
) -> Result<Option<String>, AgentError> {
let keywords = generate_memory_search_keywords(provider_config, &history).await?;
tracing::info!(
channel = %channel_name,
chat_id = %chat_id,
sender_id = %sender_id,
keyword_count = keywords.len(),
keywords = ?keywords,
"Generated memory search keywords"
);
if keywords.is_empty() {
return Ok(None);
}
let memories = search_related_memories(
store,
&channel_name,
&sender_id,
&chat_id,
&keywords,
)
.await?;
if memories.is_empty() {
tracing::info!(
channel = %channel_name,
chat_id = %chat_id,
sender_id = %sender_id,
keyword_count = keywords.len(),
"Related memory search returned no matches"
);
return Ok(None);
}
tracing::info!(
channel = %channel_name,
chat_id = %chat_id,
sender_id = %sender_id,
keyword_count = keywords.len(),
memory_count = memories.len(),
"Related memory search produced matches"
);
Ok(Some(format_related_memory_system_prompt(&keywords, &memories)))
}
async fn generate_memory_search_keywords(
mut provider_config: LLMProviderConfig,
history: &[ChatMessage],
) -> Result<Vec<String>, AgentError> {
if provider_config.provider_type == "openai" {
provider_config.model_extra.insert(
"reasoning_effort".to_string(),
serde_json::Value::String(MEMORY_KEYWORD_REASONING_EFFORT.to_string()),
);
}
let provider = create_provider(provider_config)
.map_err(|err| AgentError::ProviderCreation(err.to_string()))?;
let request = ChatCompletionRequest {
messages: vec![
Message::system(MEMORY_KEYWORD_SYSTEM_PROMPT),
Message::user(build_memory_keyword_context(history)),
],
temperature: Some(0.0),
max_tokens: Some(128),
tools: None,
};
let response = provider
.chat(request)
.await
.map_err(|err| AgentError::LlmError(err.to_string()))?;
Ok(parse_memory_keywords(&response.content))
}
fn build_memory_keyword_context(history: &[ChatMessage]) -> String {
let recent_messages: Vec<&ChatMessage> = history
.iter()
.rev()
.filter(|message| message.role != "system")
.take(8)
.collect::<Vec<_>>()
.into_iter()
.rev()
.collect();
let mut lines = vec!["请基于以下最近会话生成长期记忆搜索关键词:".to_string()];
for message in recent_messages {
let content = message.content.trim();
if content.is_empty() {
continue;
}
let compact = if content.chars().count() > 240 {
let prefix: String = content.chars().take(240).collect();
format!("{}...", prefix)
} else {
content.to_string()
};
lines.push(format!("{}: {}", message.role, compact));
}
lines.join("\n")
}
fn parse_memory_keywords(raw: &str) -> Vec<String> {
if let Ok(keywords) = serde_json::from_str::<Vec<String>>(raw) {
return normalize_keywords(keywords);
}
normalize_keywords(
raw.split(|ch| matches!(ch, '\n' | ',' | '' | ';' | ''))
.map(str::trim)
.filter(|part| !part.is_empty())
.map(ToOwned::to_owned)
.collect(),
)
}
fn normalize_keywords(keywords: Vec<String>) -> Vec<String> {
let mut seen = HashSet::new();
let mut normalized = Vec::new();
for keyword in keywords {
let candidate = keyword
.trim()
.trim_matches('"')
.trim_matches('[')
.trim_matches(']')
.trim()
.to_string();
let candidate = compact_memory_keyword(&candidate);
if candidate.is_empty() {
continue;
}
let key = candidate.to_lowercase();
if seen.insert(key) {
normalized.push(candidate);
}
if normalized.len() >= 6 {
break;
}
}
normalized
}
fn compact_memory_keyword(candidate: &str) -> String {
let compact = candidate
.split_whitespace()
.next()
.unwrap_or(candidate)
.trim_matches(|ch: char| matches!(ch, '"' | '\'' | '[' | ']' | '。' | '' | ',' | '' | ';' | ':' | ''))
.trim();
if compact.is_empty() {
return String::new();
}
compact.chars().take(MEMORY_KEYWORD_MAX_CHARS).collect()
}
async fn search_related_memories(
store: Arc<SessionStore>,
channel_name: &str,
sender_id: &str,
chat_id: &str,
keywords: &[String],
) -> Result<Vec<MemoryRecord>, AgentError> {
tracing::debug!(
channel = %channel_name,
chat_id = %chat_id,
sender_id = %sender_id,
keyword_count = keywords.len(),
keywords = ?keywords,
"Searching related memories with a single batched FTS query"
);
let scope_key = format!("{}:{}", channel_name, sender_id);
let merged = store
.search_memories_any("user", &scope_key, keywords, None, 10)
.map_err(|err| AgentError::Other(format!("batched memory search error: {}", err)))?;
tracing::info!(
channel = %channel_name,
chat_id = %chat_id,
sender_id = %sender_id,
keyword_count = keywords.len(),
deduped_memory_count = merged.len(),
"Finished related memory search aggregation"
);
Ok(merged)
}
fn format_related_memory_system_prompt(keywords: &[String], memories: &[MemoryRecord]) -> String {
let mut lines = vec![
RELATED_MEMORY_SYSTEM_PROMPT_PREFIX.to_string(),
format!("检索关键词: {}", keywords.join(" / ")),
"相关记忆: ".to_string(),
];
for (index, memory) in memories.iter().take(8).enumerate() {
lines.push(format!(
"{}. [{} / {}] {}",
index + 1,
memory.namespace,
memory.memory_key,
memory.content.replace('\n', " ").trim()
));
}
lines.join("\n")
}
#[cfg(test)]
mod tests {
use super::*;
@ -633,6 +988,7 @@ mod tests {
tools,
skills,
store.clone(),
100,
)
.await
.unwrap();
@ -658,7 +1014,189 @@ mod tests {
.load_all_messages(&session.persistent_session_id("chat-1"))
.unwrap()
.len(),
1,
2,
);
session.ensure_chat_loaded("chat-1").unwrap();
let history = session.get_history("chat-1").unwrap();
assert_eq!(history.len(), 1);
assert_eq!(history[0].role, "system");
}
#[tokio::test]
async fn test_ensure_chat_loaded_injects_agent_prompt_as_first_message() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
let mut session = Session::new(
"feishu".to_string(),
test_provider_config(),
user_tx,
tools,
skills,
store.clone(),
100,
)
.await
.unwrap();
session.ensure_persistent_session("chat-1").unwrap();
session.ensure_chat_loaded("chat-1").unwrap();
let history = session.get_history("chat-1").unwrap();
assert_eq!(history.len(), 1);
assert_eq!(history[0].role, "system");
assert!(history[0].content.contains("PicoBot 代理配置"));
}
#[tokio::test]
async fn test_agent_prompt_reinjected_after_each_hundred_user_turns() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
let mut session = Session::new(
"feishu".to_string(),
test_provider_config(),
user_tx,
tools,
skills,
store.clone(),
100,
)
.await
.unwrap();
session.ensure_persistent_session("chat-1").unwrap();
session.ensure_chat_loaded("chat-1").unwrap();
for turn in 0..100 {
session
.append_persisted_message("chat-1", ChatMessage::user(format!("user-{turn}")))
.unwrap();
}
session.ensure_agent_prompt_before_user_message("chat-1").unwrap();
let history = session.get_history("chat-1").unwrap();
let system_messages = history.iter().filter(|message| message.role == "system").count();
assert_eq!(system_messages, 2);
let stored = store
.get_session(&session.persistent_session_id("chat-1"))
.unwrap()
.unwrap();
assert_eq!(stored.agent_prompt_reinjection_count, 1);
session.ensure_agent_prompt_before_user_message("chat-1").unwrap();
let history = session.get_history("chat-1").unwrap();
let system_messages = history.iter().filter(|message| message.role == "system").count();
assert_eq!(system_messages, 2);
}
#[tokio::test]
async fn test_agent_prompt_reinjection_can_be_disabled_by_config() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
let mut session = Session::new(
"feishu".to_string(),
test_provider_config(),
user_tx,
tools,
skills,
store.clone(),
0,
)
.await
.unwrap();
session.ensure_persistent_session("chat-1").unwrap();
session.ensure_chat_loaded("chat-1").unwrap();
for turn in 0..100 {
session
.append_persisted_message("chat-1", ChatMessage::user(format!("user-{turn}")))
.unwrap();
}
session.ensure_agent_prompt_before_user_message("chat-1").unwrap();
let history = session.get_history("chat-1").unwrap();
let system_messages = history.iter().filter(|message| message.role == "system").count();
assert_eq!(system_messages, 1);
}
#[tokio::test]
async fn test_reset_reinjects_agent_prompt_before_next_user_message() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
let mut session = Session::new(
"feishu".to_string(),
test_provider_config(),
user_tx,
tools,
skills,
store.clone(),
100,
)
.await
.unwrap();
session.ensure_persistent_session("chat-1").unwrap();
session.ensure_chat_loaded("chat-1").unwrap();
session
.append_persisted_message("chat-1", ChatMessage::user("hello"))
.unwrap();
handle_in_chat_command(&mut session, "chat-1", "/reset").unwrap();
session.ensure_agent_prompt_before_user_message("chat-1").unwrap();
let history = session.get_history("chat-1").unwrap();
assert_eq!(history.len(), 1);
assert_eq!(history[0].role, "system");
}
#[test]
fn test_parse_memory_keywords_handles_json_and_dedup() {
let keywords = parse_memory_keywords("[\"Rust\", \"偏好\", \"rust\", \"自动化\"]");
assert_eq!(keywords, vec!["Rust", "偏好", "自动化"]);
let fallback = parse_memory_keywords("Rust, 偏好\n自动化");
assert_eq!(fallback, vec!["Rust", "偏好", "自动化"]);
let compacted = parse_memory_keywords("[\"用户 身份 信息 长描述\", \"email_folder_preference details\"]");
assert_eq!(compacted, vec!["用户", "email_folder_preference"]);
}
#[test]
fn test_format_related_memory_system_prompt_includes_prefix_keywords_and_memory_lines() {
let prompt = format_related_memory_system_prompt(
&["Rust 偏好".to_string(), "审批项目".to_string()],
&[MemoryRecord {
id: "memory-1".to_string(),
scope_kind: "user".to_string(),
scope_key: "feishu:user-1".to_string(),
namespace: "profile".to_string(),
memory_key: "language".to_string(),
content: "用户偏好 Rust 和自动化工具".to_string(),
source_type: "message".to_string(),
source_session_id: Some("feishu:chat-1".to_string()),
source_message_id: Some("msg-1".to_string()),
source_message_seq: Some(1),
source_channel_name: Some("feishu".to_string()),
source_chat_id: Some("chat-1".to_string()),
created_at: 1,
updated_at: 1,
}],
);
assert!(prompt.contains("找到相关的记忆"));
assert!(prompt.contains("Rust 偏好 / 审批项目"));
assert!(prompt.contains("[profile / language] 用户偏好 Rust 和自动化工具"));
}
}

View File

@ -59,6 +59,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
state.session_manager.tools(),
state.session_manager.skills(),
state.session_manager.store(),
state.config.gateway.agent_prompt_reinject_every,
)
.await
{
@ -210,6 +211,8 @@ async fn handle_inbound(
return Ok(());
}
session_guard.ensure_agent_prompt_before_user_message(&chat_id)?;
let user_message = session_guard.create_user_message(&content, Vec::new());
let user_message_id = user_message.id.clone();
session_guard.append_persisted_message(&chat_id, user_message)?;

View File

@ -40,6 +40,8 @@ pub struct SessionRecord {
pub deleted_at: Option<i64>,
pub message_count: i64,
pub reset_cutoff_seq: i64,
pub user_turn_count: i64,
pub agent_prompt_reinjection_count: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -113,7 +115,9 @@ impl SessionStore {
archived_at INTEGER,
deleted_at INTEGER,
message_count INTEGER NOT NULL DEFAULT 0,
reset_cutoff_seq INTEGER NOT NULL DEFAULT 0
reset_cutoff_seq INTEGER NOT NULL DEFAULT 0,
user_turn_count INTEGER NOT NULL DEFAULT 0,
agent_prompt_reinjection_count INTEGER NOT NULL DEFAULT 0
);
CREATE INDEX IF NOT EXISTS idx_sessions_channel_archived
@ -234,8 +238,9 @@ impl SessionStore {
"
INSERT INTO sessions (
id, title, channel_name, chat_id, summary,
created_at, updated_at, last_active_at, archived_at, deleted_at, message_count
) VALUES (?1, ?2, 'cli', ?3, NULL, ?4, ?4, ?4, NULL, NULL, 0)
created_at, updated_at, last_active_at, archived_at, deleted_at, message_count,
reset_cutoff_seq, user_turn_count, agent_prompt_reinjection_count
) VALUES (?1, ?2, 'cli', ?3, NULL, ?4, ?4, ?4, NULL, NULL, 0, 0, 0, 0)
",
params![id, title, id, now],
)?;
@ -261,8 +266,9 @@ impl SessionStore {
"
INSERT INTO sessions (
id, title, channel_name, chat_id, summary,
created_at, updated_at, last_active_at, archived_at, deleted_at, message_count
) VALUES (?1, ?2, ?3, ?4, NULL, ?5, ?5, ?5, NULL, NULL, 0)
created_at, updated_at, last_active_at, archived_at, deleted_at, message_count,
reset_cutoff_seq, user_turn_count, agent_prompt_reinjection_count
) VALUES (?1, ?2, ?3, ?4, NULL, ?5, ?5, ?5, NULL, NULL, 0, 0, 0, 0)
",
params![session_id, title, channel_name, chat_id, now],
)?;
@ -277,7 +283,8 @@ impl SessionStore {
"
SELECT id, title, channel_name, chat_id, summary,
created_at, updated_at, last_active_at,
archived_at, deleted_at, message_count, reset_cutoff_seq
archived_at, deleted_at, message_count, reset_cutoff_seq,
user_turn_count, agent_prompt_reinjection_count
FROM sessions
WHERE id = ?1 AND deleted_at IS NULL
",
@ -298,7 +305,8 @@ impl SessionStore {
"
SELECT id, title, channel_name, chat_id, summary,
created_at, updated_at, last_active_at,
archived_at, deleted_at, message_count, reset_cutoff_seq
archived_at, deleted_at, message_count, reset_cutoff_seq,
user_turn_count, agent_prompt_reinjection_count
FROM sessions
WHERE channel_name = ?1
AND deleted_at IS NULL
@ -354,7 +362,12 @@ impl SessionStore {
conn.execute(
"
UPDATE sessions
SET message_count = 0, updated_at = ?2, last_active_at = ?2, reset_cutoff_seq = 0
SET message_count = 0,
updated_at = ?2,
last_active_at = ?2,
reset_cutoff_seq = 0,
user_turn_count = 0,
agent_prompt_reinjection_count = 0
WHERE id = ?1 AND deleted_at IS NULL
",
params![session_id, now],
@ -379,7 +392,9 @@ impl SessionStore {
SET reset_cutoff_seq = ?2,
updated_at = ?3,
last_active_at = ?3,
archived_at = NULL
archived_at = NULL,
user_turn_count = 0,
agent_prompt_reinjection_count = 0
WHERE id = ?1 AND deleted_at IS NULL
",
params![session_id, cutoff_seq, now],
@ -423,10 +438,31 @@ impl SessionStore {
)?;
let now = current_timestamp();
let is_user_message = message.role == "user";
tx.execute(
"
UPDATE sessions
SET message_count = message_count + 1,
user_turn_count = user_turn_count + ?3,
updated_at = ?2,
last_active_at = ?2,
archived_at = NULL
WHERE id = ?1 AND deleted_at IS NULL
",
params![session_id, now, if is_user_message { 1 } else { 0 }],
)?;
tx.commit()?;
Ok(())
}
pub fn mark_agent_prompt_reinjected(&self, session_id: &str) -> Result<(), StorageError> {
let now = current_timestamp();
let conn = self.conn.lock().expect("session db mutex poisoned");
conn.execute(
"
UPDATE sessions
SET agent_prompt_reinjection_count = agent_prompt_reinjection_count + 1,
updated_at = ?2,
last_active_at = ?2,
archived_at = NULL
@ -434,8 +470,6 @@ impl SessionStore {
",
params![session_id, now],
)?;
tx.commit()?;
Ok(())
}
@ -740,6 +774,67 @@ impl SessionStore {
Ok(memories)
}
pub fn search_memories_any(
&self,
scope_kind: &str,
scope_key: &str,
queries: &[String],
namespace: Option<&str>,
limit: usize,
) -> Result<Vec<MemoryRecord>, StorageError> {
let conn = self.conn.lock().expect("session db mutex poisoned");
let limit = limit.max(1) as i64;
let query = quote_fts_or_query(queries);
if query.is_empty() {
return Ok(Vec::new());
}
let mut memories = Vec::new();
if let Some(namespace) = namespace {
let mut stmt = conn.prepare(
"
SELECT m.id, m.scope_kind, m.scope_key, m.namespace, m.memory_key, m.content,
m.source_type, m.source_session_id, m.source_message_id, m.source_message_seq,
m.source_channel_name, m.source_chat_id, m.created_at, m.updated_at
FROM memories_fts f
JOIN memories m ON m.rowid = f.rowid
WHERE memories_fts MATCH ?1
AND m.scope_kind = ?2
AND m.scope_key = ?3
AND m.namespace = ?4
ORDER BY bm25(memories_fts), m.updated_at DESC
LIMIT ?5
",
)?;
let rows = stmt.query_map(params![query, scope_kind, scope_key, namespace, limit], map_memory_record)?;
for row in rows {
memories.push(row?);
}
} else {
let mut stmt = conn.prepare(
"
SELECT m.id, m.scope_kind, m.scope_key, m.namespace, m.memory_key, m.content,
m.source_type, m.source_session_id, m.source_message_id, m.source_message_seq,
m.source_channel_name, m.source_chat_id, m.created_at, m.updated_at
FROM memories_fts f
JOIN memories m ON m.rowid = f.rowid
WHERE memories_fts MATCH ?1
AND m.scope_kind = ?2
AND m.scope_key = ?3
ORDER BY bm25(memories_fts), m.updated_at DESC
LIMIT ?4
",
)?;
let rows = stmt.query_map(params![query, scope_kind, scope_key, limit], map_memory_record)?;
for row in rows {
memories.push(row?);
}
}
Ok(memories)
}
pub fn load_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, StorageError> {
let conn = self.conn.lock().expect("session db mutex poisoned");
let cutoff_seq = active_reset_cutoff(&conn, session_id)?;
@ -750,6 +845,21 @@ impl SessionStore {
let conn = self.conn.lock().expect("session db mutex poisoned");
load_messages_after(&conn, session_id, 0)
}
pub fn count_active_user_messages(&self, session_id: &str) -> Result<i64, StorageError> {
let conn = self.conn.lock().expect("session db mutex poisoned");
let cutoff_seq = active_reset_cutoff(&conn, session_id)?;
conn.query_row(
"
SELECT COUNT(*)
FROM messages
WHERE session_id = ?1 AND seq > ?2 AND role = 'user'
",
params![session_id, cutoff_seq],
|row| row.get(0),
)
.map_err(StorageError::from)
}
}
pub fn persistent_session_id(channel_name: &str, chat_id: &str) -> String {
@ -779,6 +889,8 @@ fn map_session_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<SessionRecord
deleted_at: row.get(9)?,
message_count: row.get(10)?,
reset_cutoff_seq: row.get(11)?,
user_turn_count: row.get(12)?,
agent_prompt_reinjection_count: row.get(13)?,
})
}
@ -829,6 +941,20 @@ fn ensure_sessions_schema(conn: &Connection) -> Result<(), StorageError> {
)?;
}
if !has_column(conn, "sessions", "user_turn_count")? {
conn.execute(
"ALTER TABLE sessions ADD COLUMN user_turn_count INTEGER NOT NULL DEFAULT 0",
[],
)?;
}
if !has_column(conn, "sessions", "agent_prompt_reinjection_count")? {
conn.execute(
"ALTER TABLE sessions ADD COLUMN agent_prompt_reinjection_count INTEGER NOT NULL DEFAULT 0",
[],
)?;
}
Ok(())
}
@ -926,6 +1052,16 @@ fn quote_fts_query(query: &str) -> String {
format!("\"{}\"", query.replace('"', "\"\""))
}
fn quote_fts_or_query(queries: &[String]) -> String {
queries
.iter()
.map(|query| query.trim())
.filter(|query| !query.is_empty())
.map(quote_fts_query)
.collect::<Vec<_>>()
.join(" OR ")
}
#[cfg(test)]
mod tests {
use super::*;
@ -947,6 +1083,8 @@ mod tests {
assert_eq!(session.chat_id, session.id);
assert_eq!(session.message_count, 0);
assert_eq!(session.reset_cutoff_seq, 0);
assert_eq!(session.user_turn_count, 0);
assert_eq!(session.agent_prompt_reinjection_count, 0);
let first = ChatMessage::user("hello");
let second = ChatMessage::assistant("world");
@ -957,6 +1095,8 @@ mod tests {
assert_eq!(stored.message_count, 2);
assert!(stored.archived_at.is_none());
assert_eq!(stored.reset_cutoff_seq, 0);
assert_eq!(stored.user_turn_count, 1);
assert_eq!(stored.agent_prompt_reinjection_count, 0);
let messages = store.load_messages(&session.id).unwrap();
assert_eq!(messages.len(), 2);
@ -984,6 +1124,8 @@ mod tests {
assert!(cleared.is_empty());
let cleared_session = store.get_session(&session.id).unwrap().unwrap();
assert_eq!(cleared_session.message_count, 0);
assert_eq!(cleared_session.user_turn_count, 0);
assert_eq!(cleared_session.agent_prompt_reinjection_count, 0);
store.delete_session(&session.id).unwrap();
assert!(store.get_session(&session.id).unwrap().is_none());
@ -1036,6 +1178,8 @@ mod tests {
let stored = store.get_session(&session.id).unwrap().unwrap();
assert_eq!(stored.reset_cutoff_seq, 2);
assert_eq!(stored.user_turn_count, 0);
assert_eq!(stored.agent_prompt_reinjection_count, 0);
let active_messages = store.load_messages(&session.id).unwrap();
assert!(active_messages.is_empty());
@ -1049,6 +1193,9 @@ mod tests {
let active_messages = store.load_messages(&session.id).unwrap();
assert_eq!(active_messages.len(), 1);
assert_eq!(active_messages[0].content, "after");
let stored = store.get_session(&session.id).unwrap().unwrap();
assert_eq!(stored.user_turn_count, 1);
}
#[test]
@ -1091,6 +1238,42 @@ mod tests {
let store = SessionStore::from_connection(conn).unwrap();
let session = store.create_cli_session(Some("migrated")).unwrap();
assert_eq!(session.reset_cutoff_seq, 0);
assert_eq!(session.user_turn_count, 0);
assert_eq!(session.agent_prompt_reinjection_count, 0);
}
#[test]
fn test_count_active_user_messages_respects_reset_cutoff_seq() {
let store = SessionStore::in_memory().unwrap();
let session = store.create_cli_session(Some("count-users")).unwrap();
store.append_message(&session.id, &ChatMessage::system("agent")).unwrap();
store.append_message(&session.id, &ChatMessage::user("u1")).unwrap();
store.append_message(&session.id, &ChatMessage::assistant("a1")).unwrap();
store.append_message(&session.id, &ChatMessage::user("u2")).unwrap();
assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 2);
store.reset_session(&session.id).unwrap();
assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 0);
store.append_message(&session.id, &ChatMessage::system("agent-again")).unwrap();
store.append_message(&session.id, &ChatMessage::user("u3")).unwrap();
store.append_message(&session.id, &ChatMessage::user("u4")).unwrap();
assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 2);
}
#[test]
fn test_mark_agent_prompt_reinjected_increments_counter() {
let store = SessionStore::in_memory().unwrap();
let session = store.create_cli_session(Some("prompt")).unwrap();
store.mark_agent_prompt_reinjected(&session.id).unwrap();
store.mark_agent_prompt_reinjected(&session.id).unwrap();
let stored = store.get_session(&session.id).unwrap().unwrap();
assert_eq!(stored.agent_prompt_reinjection_count, 2);
}
#[test]
@ -1240,4 +1423,83 @@ mod tests {
.unwrap();
assert!(hits_after_delete.is_empty());
}
#[test]
fn test_memory_search_matches_memory_key_field() {
let store = SessionStore::in_memory().unwrap();
store
.put_memory(&MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: "feishu:user-1".to_string(),
namespace: "preferences".to_string(),
memory_key: "email_folder_preference".to_string(),
content: "用户提到邮件时默认查看代收邮箱。".to_string(),
source_type: "message".to_string(),
source_session_id: Some("feishu:chat-8".to_string()),
source_message_id: Some("msg-8".to_string()),
source_message_seq: Some(8),
source_channel_name: Some("feishu".to_string()),
source_chat_id: Some("chat-8".to_string()),
})
.unwrap();
let hits = store
.search_memories("user", "feishu:user-1", "email_folder_preference", None, 10)
.unwrap();
assert_eq!(hits.len(), 1);
assert_eq!(hits[0].memory_key, "email_folder_preference");
}
#[test]
fn test_search_memories_any_matches_multiple_keywords_once() {
let store = SessionStore::in_memory().unwrap();
store
.put_memory(&MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: "feishu:user-1".to_string(),
namespace: "preferences".to_string(),
memory_key: "editor".to_string(),
content: "Prefers rust-analyzer and cargo test output".to_string(),
source_type: "message".to_string(),
source_session_id: Some("feishu:chat-2".to_string()),
source_message_id: Some("msg-2".to_string()),
source_message_seq: Some(3),
source_channel_name: Some("feishu".to_string()),
source_chat_id: Some("chat-2".to_string()),
})
.unwrap();
store
.put_memory(&MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: "feishu:user-1".to_string(),
namespace: "tasks".to_string(),
memory_key: "quality".to_string(),
content: "Tracks clippy warnings before release".to_string(),
source_type: "message".to_string(),
source_session_id: Some("feishu:chat-3".to_string()),
source_message_id: Some("msg-3".to_string()),
source_message_seq: Some(4),
source_channel_name: Some("feishu".to_string()),
source_chat_id: Some("chat-3".to_string()),
})
.unwrap();
let hits = store
.search_memories_any(
"user",
"feishu:user-1",
&["rust-analyzer".to_string(), "clippy".to_string()],
None,
10,
)
.unwrap();
assert_eq!(hits.len(), 2);
assert!(hits.iter().any(|memory| memory.memory_key == "editor"));
assert!(hits.iter().any(|memory| memory.memory_key == "quality"));
}
}

View File

@ -23,7 +23,7 @@ impl Tool for MemoryManageTool {
}
fn description(&self) -> &str {
"Manage user memories stored in SQLite. Supports actions: list, search, get, put, update, delete. Use search first when looking for user preferences, historical facts, prior decisions, or previously stored information by keyword. Memories are scoped to the current channel and sender, and record the originating session/message when available."
"Manage user memories stored in SQLite. Supports actions: list, search, get, put, update, delete. Use search first when looking for user preferences, historical facts, prior decisions, or previously stored information. Search matches namespace, memory_key, and content. When searching, prefer bilingual queries that include both Chinese and English aliases, and include likely snake_case key terms when known. Memories are scoped to the current channel and sender, and record the originating session/message when available."
}
fn parameters_schema(&self) -> serde_json::Value {
@ -41,7 +41,7 @@ impl Tool for MemoryManageTool {
},
"query": {
"type": "string",
"description": "Keyword query for full-text memory search, such as a preference, fact, name, topic, or prior decision"
"description": "Keyword query for full-text memory search across namespace, memory_key, and content. Prefer concise bilingual keywords when possible, for example Chinese plus English aliases and likely snake_case key terms."
},
"key": {
"type": "string",