Refactor agent and storage components to introduce SkillProvider and repository patterns
- Introduced `SkillProvider` trait to abstract skill-related functionalities. - Replaced `SkillRuntime` with `EmptySkillProvider` in `AgentLoop` for default behavior. - Updated `AgentFactory` to accept `SkillProvider` instead of `SkillRuntime`. - Created `SessionHistory` struct to manage chat histories and interactions. - Added `MemoryRepository`, `SchedulerJobRepository`, and `SkillEventRepository` traits for better storage abstraction. - Refactored tools to use new repository traits for memory and scheduler management. - Cleaned up session management logic by consolidating chat history handling into `SessionHistory`. Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
6756a3d0ae
commit
396504dffb
@ -2,11 +2,11 @@ use crate::bus::ChatMessage;
|
||||
use crate::bus::message::ToolMessageState;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::domain::messages::{ContentBlock, ToolCall};
|
||||
use crate::domain::tools::Tool;
|
||||
use crate::observability::{
|
||||
Observer, ObserverEvent, ToolExecutionOutcome, ToolExecutionState, truncate_args,
|
||||
};
|
||||
use crate::providers::{ChatCompletionRequest, LLMProvider, Message, create_provider};
|
||||
use crate::skills::SkillRuntime;
|
||||
use crate::text::{char_count, take_prefix_chars, take_suffix_chars};
|
||||
use crate::tools::{ToolContext, ToolRegistry};
|
||||
use async_trait::async_trait;
|
||||
@ -296,7 +296,7 @@ pub struct AgentLoop {
|
||||
provider_config: LLMProviderConfig,
|
||||
provider: Box<dyn LLMProvider>,
|
||||
tools: Arc<ToolRegistry>,
|
||||
skills: Arc<SkillRuntime>,
|
||||
skills: Arc<dyn SkillProvider>,
|
||||
skill_event_sink: Option<Arc<dyn SkillEventSink>>,
|
||||
tool_context: ToolContext,
|
||||
observer: Option<Arc<dyn Observer>>,
|
||||
@ -326,6 +326,37 @@ pub trait SkillEventSink: Send + Sync + 'static {
|
||||
fn record_skill_event(&self, event: SkillEvent);
|
||||
}
|
||||
|
||||
pub trait SkillProvider: Send + Sync + 'static {
|
||||
fn system_index_prompt(&self) -> Option<String>;
|
||||
|
||||
fn skill_tool_definition(&self) -> Option<Tool>;
|
||||
|
||||
fn activation_payload(&self, name: &str) -> Result<String, String>;
|
||||
|
||||
fn activation_event_payload(&self, name: &str) -> Result<serde_json::Value, String>;
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct EmptySkillProvider;
|
||||
|
||||
impl SkillProvider for EmptySkillProvider {
|
||||
fn system_index_prompt(&self) -> Option<String> {
|
||||
None
|
||||
}
|
||||
|
||||
fn skill_tool_definition(&self) -> Option<Tool> {
|
||||
None
|
||||
}
|
||||
|
||||
fn activation_payload(&self, name: &str) -> Result<String, String> {
|
||||
Err(format!("skill '{}' not found", name))
|
||||
}
|
||||
|
||||
fn activation_event_payload(&self, name: &str) -> Result<serde_json::Value, String> {
|
||||
Err(format!("skill '{}' not found", name))
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentLoop {
|
||||
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
|
||||
let max_iterations = provider_config.max_tool_iterations;
|
||||
@ -336,7 +367,7 @@ impl AgentLoop {
|
||||
provider_config,
|
||||
provider,
|
||||
tools: Arc::new(ToolRegistry::new()),
|
||||
skills: Arc::new(SkillRuntime::default()),
|
||||
skills: Arc::new(EmptySkillProvider),
|
||||
skill_event_sink: None,
|
||||
tool_context: ToolContext::default(),
|
||||
observer: None,
|
||||
@ -357,7 +388,7 @@ impl AgentLoop {
|
||||
provider_config,
|
||||
provider,
|
||||
tools,
|
||||
skills: Arc::new(SkillRuntime::default()),
|
||||
skills: Arc::new(EmptySkillProvider),
|
||||
skill_event_sink: None,
|
||||
tool_context: ToolContext::default(),
|
||||
observer: None,
|
||||
@ -366,10 +397,10 @@ impl AgentLoop {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn with_tools_and_skills(
|
||||
pub fn with_tools_and_skill_provider(
|
||||
provider_config: LLMProviderConfig,
|
||||
tools: Arc<ToolRegistry>,
|
||||
skills: Arc<SkillRuntime>,
|
||||
skills: Arc<dyn SkillProvider>,
|
||||
) -> Result<Self, AgentError> {
|
||||
let max_iterations = provider_config.max_tool_iterations;
|
||||
let provider = create_provider(provider_config.clone())
|
||||
|
||||
@ -3,5 +3,6 @@ pub mod context_compressor;
|
||||
|
||||
pub use agent_loop::{
|
||||
AgentError, AgentLoop, AgentProcessResult, EmittedMessageHandler, SkillEvent, SkillEventSink,
|
||||
SkillProvider,
|
||||
};
|
||||
pub use context_compressor::ContextCompressor;
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::agent::{AgentError, AgentLoop};
|
||||
use crate::agent::{AgentError, AgentLoop, SkillProvider};
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::skills::SkillRuntime;
|
||||
use crate::storage::{SessionStore, persistent_session_id};
|
||||
use crate::tools::{ToolContext, ToolRegistry};
|
||||
|
||||
@ -11,7 +10,7 @@ use super::skill_event_sink::PersistentSkillEventSink;
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct AgentFactory {
|
||||
tools: Arc<ToolRegistry>,
|
||||
skills: Arc<SkillRuntime>,
|
||||
skills: Arc<dyn SkillProvider>,
|
||||
store: Arc<SessionStore>,
|
||||
}
|
||||
|
||||
@ -26,7 +25,7 @@ pub(crate) struct AgentBuildRequest<'a> {
|
||||
impl AgentFactory {
|
||||
pub(crate) fn new(
|
||||
tools: Arc<ToolRegistry>,
|
||||
skills: Arc<SkillRuntime>,
|
||||
skills: Arc<dyn SkillProvider>,
|
||||
store: Arc<SessionStore>,
|
||||
) -> Self {
|
||||
Self {
|
||||
@ -38,7 +37,7 @@ impl AgentFactory {
|
||||
|
||||
pub(crate) fn create(&self, request: AgentBuildRequest<'_>) -> Result<AgentLoop, AgentError> {
|
||||
let session_id = persistent_session_id(request.channel_name, request.chat_id);
|
||||
AgentLoop::with_tools_and_skills(
|
||||
AgentLoop::with_tools_and_skill_provider(
|
||||
request.provider_config,
|
||||
self.tools.clone(),
|
||||
self.skills.clone(),
|
||||
|
||||
@ -16,6 +16,7 @@ pub mod provider_config_service;
|
||||
pub mod scheduled_agent_task_service;
|
||||
pub mod session;
|
||||
pub mod session_factory;
|
||||
pub mod session_history;
|
||||
pub mod session_lifecycle;
|
||||
pub mod session_message_service;
|
||||
pub mod session_pool;
|
||||
|
||||
@ -6,7 +6,7 @@ use crate::config::LLMProviderConfig;
|
||||
use crate::protocol::WsOutbound;
|
||||
use crate::scheduler::ScheduledAgentTaskOptions;
|
||||
use crate::skills::SkillRuntime;
|
||||
use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
|
||||
use crate::storage::{SessionRecord, SessionStore};
|
||||
use crate::tools::ToolRegistry;
|
||||
use async_trait::async_trait;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
@ -29,33 +29,22 @@ use super::prompt_injector::PromptInjector;
|
||||
use super::provider_config_service::ProviderConfigService;
|
||||
use super::scheduled_agent_task_service::ScheduledAgentTaskService;
|
||||
use super::session_factory::SessionFactory;
|
||||
use super::session_history::SessionHistory;
|
||||
use super::session_lifecycle::SessionLifecycleService;
|
||||
use super::session_message_service::SessionMessageService;
|
||||
use super::tool_registry_factory::ToolRegistryFactory;
|
||||
|
||||
fn preview_text(content: &str, max_chars: usize) -> String {
|
||||
let mut preview = content.chars().take(max_chars).collect::<String>();
|
||||
if content.chars().count() > max_chars {
|
||||
preview.push_str("...");
|
||||
}
|
||||
preview.replace('\n', "\\n")
|
||||
}
|
||||
|
||||
/// Session 按 channel 隔离,每个 channel 一个 Session
|
||||
/// History 按 chat_id 隔离,由 Session 统一管理
|
||||
pub struct Session {
|
||||
pub id: Uuid,
|
||||
pub channel_name: String,
|
||||
/// 按 chat_id 路由到不同会话历史,支持多用户多会话
|
||||
chat_histories: HashMap<String, Vec<ChatMessage>>,
|
||||
compression_in_flight: HashSet<String>,
|
||||
pub user_tx: mpsc::Sender<WsOutbound>,
|
||||
provider_config: LLMProviderConfig,
|
||||
skills: Arc<SkillRuntime>,
|
||||
agent_factory: AgentFactory,
|
||||
prompt_injector: PromptInjector,
|
||||
compressor: ContextCompressor,
|
||||
store: Arc<SessionStore>,
|
||||
history: SessionHistory,
|
||||
}
|
||||
|
||||
pub struct BusToolCallEmitter {
|
||||
@ -140,103 +129,61 @@ impl Session {
|
||||
) -> Result<Self, AgentError> {
|
||||
Ok(Self {
|
||||
id: Uuid::new_v4(),
|
||||
channel_name,
|
||||
chat_histories: HashMap::new(),
|
||||
compression_in_flight: HashSet::new(),
|
||||
channel_name: channel_name.clone(),
|
||||
user_tx,
|
||||
provider_config: provider_config.clone(),
|
||||
skills,
|
||||
agent_factory,
|
||||
prompt_injector,
|
||||
compressor: ContextCompressor::from_provider_config(&provider_config),
|
||||
store,
|
||||
history: SessionHistory::new(channel_name, prompt_injector, store),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn persistent_session_id(&self, chat_id: &str) -> String {
|
||||
persistent_session_id(&self.channel_name, chat_id)
|
||||
self.history.persistent_session_id(chat_id)
|
||||
}
|
||||
|
||||
pub fn ensure_persistent_session(&self, chat_id: &str) -> Result<SessionRecord, AgentError> {
|
||||
self.store
|
||||
.ensure_channel_session(&self.channel_name, chat_id)
|
||||
.map_err(|err| AgentError::Other(format!("session persistence error: {}", err)))
|
||||
self.history.ensure_persistent_session(chat_id)
|
||||
}
|
||||
|
||||
pub fn ensure_chat_loaded(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||
if self.chat_histories.contains_key(chat_id) {
|
||||
return self.ensure_initial_agent_prompt(chat_id);
|
||||
}
|
||||
|
||||
let history = self
|
||||
.store
|
||||
.load_messages(&self.persistent_session_id(chat_id))
|
||||
.map_err(|err| AgentError::Other(format!("session history load error: {}", err)))?;
|
||||
self.chat_histories.insert(chat_id.to_string(), history);
|
||||
self.ensure_initial_agent_prompt(chat_id)?;
|
||||
Ok(())
|
||||
self.history.ensure_chat_loaded(chat_id)
|
||||
}
|
||||
|
||||
pub fn ensure_agent_prompt_before_user_message(
|
||||
&mut self,
|
||||
chat_id: &str,
|
||||
) -> Result<(), AgentError> {
|
||||
self.ensure_chat_loaded(chat_id)?;
|
||||
|
||||
let session_id = self.persistent_session_id(chat_id);
|
||||
let prompt_injector = self.prompt_injector.clone();
|
||||
prompt_injector.ensure_reinjected_prompt(&session_id, |message| {
|
||||
self.append_persisted_message(chat_id, message)
|
||||
})
|
||||
self.history
|
||||
.ensure_agent_prompt_before_user_message(chat_id)
|
||||
}
|
||||
|
||||
/// 获取或创建指定 chat_id 的会话历史
|
||||
pub fn get_or_create_history(&mut self, chat_id: &str) -> &mut Vec<ChatMessage> {
|
||||
self.chat_histories
|
||||
.entry(chat_id.to_string())
|
||||
.or_insert_with(Vec::new)
|
||||
self.history.get_or_create_history(chat_id)
|
||||
}
|
||||
|
||||
/// 获取指定 chat_id 的会话历史(不创建)
|
||||
pub fn get_history(&self, chat_id: &str) -> Option<&Vec<ChatMessage>> {
|
||||
self.chat_histories.get(chat_id)
|
||||
self.history.get_history(chat_id)
|
||||
}
|
||||
|
||||
/// 使用完整消息追加到历史
|
||||
pub fn add_message(&mut self, chat_id: &str, message: ChatMessage) {
|
||||
let history = self.get_or_create_history(chat_id);
|
||||
history.push(message);
|
||||
self.history.add_message(chat_id, message);
|
||||
}
|
||||
|
||||
pub fn remove_history(&mut self, chat_id: &str) {
|
||||
self.chat_histories.remove(chat_id);
|
||||
self.compression_in_flight.remove(chat_id);
|
||||
self.history.remove_history(chat_id);
|
||||
}
|
||||
|
||||
pub fn clear_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||
if let Some(history) = self.chat_histories.get_mut(chat_id) {
|
||||
let len = history.len();
|
||||
history.clear();
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history cleared");
|
||||
}
|
||||
|
||||
self.store
|
||||
.clear_messages(&self.persistent_session_id(chat_id))
|
||||
.map_err(|err| AgentError::Other(format!("clear history persistence error: {}", err)))
|
||||
self.history.clear_chat_history(chat_id)
|
||||
}
|
||||
|
||||
pub fn reset_chat_context(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||
if let Some(history) = self.chat_histories.get_mut(chat_id) {
|
||||
let len = history.len();
|
||||
history.clear();
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history reset in memory");
|
||||
}
|
||||
|
||||
self.store
|
||||
.reset_session(&self.persistent_session_id(chat_id))
|
||||
.map_err(|err| AgentError::Other(format!("reset history persistence error: {}", err)))
|
||||
self.history.reset_chat_context(chat_id)
|
||||
}
|
||||
|
||||
/// 将消息写入内存与持久化层
|
||||
@ -245,14 +192,7 @@ impl Session {
|
||||
chat_id: &str,
|
||||
message: ChatMessage,
|
||||
) -> Result<(), AgentError> {
|
||||
let session_id = self.persistent_session_id(chat_id);
|
||||
self.store
|
||||
.append_message(&session_id, &message)
|
||||
.map_err(|err| {
|
||||
AgentError::Other(format!("append message persistence error: {}", err))
|
||||
})?;
|
||||
self.add_message(chat_id, message);
|
||||
Ok(())
|
||||
self.history.append_persisted_message(chat_id, message)
|
||||
}
|
||||
|
||||
pub fn append_persisted_messages<I>(
|
||||
@ -263,10 +203,7 @@ impl Session {
|
||||
where
|
||||
I: IntoIterator<Item = ChatMessage>,
|
||||
{
|
||||
for message in messages {
|
||||
self.append_persisted_message(chat_id, message)?;
|
||||
}
|
||||
Ok(())
|
||||
self.history.append_persisted_messages(chat_id, messages)
|
||||
}
|
||||
|
||||
pub fn create_user_message(&self, content: &str, media_refs: Vec<String>) -> ChatMessage {
|
||||
@ -283,9 +220,9 @@ impl Session {
|
||||
.map(|message| message.id.as_str())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn latest_user_message(&self, chat_id: &str) -> Option<&ChatMessage> {
|
||||
self.get_history(chat_id)
|
||||
.and_then(|history| history.iter().rev().find(|message| message.role == "user"))
|
||||
self.history.latest_user_message(chat_id)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@ -296,55 +233,19 @@ impl Session {
|
||||
}
|
||||
|
||||
pub(crate) fn matches_current_user_turn(&self, chat_id: &str, message: &ChatMessage) -> bool {
|
||||
self.latest_user_message(chat_id)
|
||||
.map(|current| {
|
||||
current.id == message.id
|
||||
|| (current.content == message.content
|
||||
&& current.timestamp == message.timestamp
|
||||
&& current.media_refs == message.media_refs)
|
||||
})
|
||||
.unwrap_or(false)
|
||||
self.history.matches_current_user_turn(chat_id, message)
|
||||
}
|
||||
|
||||
pub(crate) fn stale_result_diagnostics(
|
||||
&self,
|
||||
chat_id: &str,
|
||||
) -> (Option<&str>, Option<String>, bool, usize) {
|
||||
let latest_user = self.latest_user_message(chat_id);
|
||||
let latest_user_id = latest_user.map(|message| message.id.as_str());
|
||||
let latest_user_preview = latest_user.map(|message| preview_text(&message.content, 80));
|
||||
let compression_in_flight = self.compression_in_flight.contains(chat_id);
|
||||
let history_len = self
|
||||
.get_history(chat_id)
|
||||
.map(|history| history.len())
|
||||
.unwrap_or(0);
|
||||
|
||||
(
|
||||
latest_user_id,
|
||||
latest_user_preview,
|
||||
compression_in_flight,
|
||||
history_len,
|
||||
)
|
||||
self.history.stale_result_diagnostics(chat_id)
|
||||
}
|
||||
|
||||
/// 清除所有历史
|
||||
pub fn clear_all_history(&mut self) -> Result<(), AgentError> {
|
||||
let chat_ids: Vec<String> = self.chat_histories.keys().cloned().collect();
|
||||
let total: usize = self.chat_histories.values().map(|h| h.len()).sum();
|
||||
self.chat_histories.clear();
|
||||
self.compression_in_flight.clear();
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(previous_total = total, "All chat histories cleared");
|
||||
|
||||
for chat_id in chat_ids {
|
||||
self.store
|
||||
.clear_messages(&self.persistent_session_id(&chat_id))
|
||||
.map_err(|err| {
|
||||
AgentError::Other(format!("clear history persistence error: {}", err))
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
self.history.clear_all_history()
|
||||
}
|
||||
|
||||
pub async fn send(&self, msg: WsOutbound) {
|
||||
@ -362,24 +263,19 @@ impl Session {
|
||||
}
|
||||
|
||||
pub(crate) fn try_start_background_compaction(&mut self, chat_id: &str) -> bool {
|
||||
self.compression_in_flight.insert(chat_id.to_string())
|
||||
self.history.try_start_background_compaction(chat_id)
|
||||
}
|
||||
|
||||
pub(crate) fn finish_background_compaction(&mut self, chat_id: &str) {
|
||||
self.compression_in_flight.remove(chat_id);
|
||||
self.history.finish_background_compaction(chat_id);
|
||||
}
|
||||
|
||||
pub(crate) fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||
let history = self
|
||||
.store
|
||||
.load_messages(&self.persistent_session_id(chat_id))
|
||||
.map_err(|err| AgentError::Other(format!("session history reload error: {}", err)))?;
|
||||
self.chat_histories.insert(chat_id.to_string(), history);
|
||||
Ok(())
|
||||
self.history.reload_chat_history(chat_id)
|
||||
}
|
||||
|
||||
pub(crate) fn store(&self) -> Arc<SessionStore> {
|
||||
self.store.clone()
|
||||
self.history.store()
|
||||
}
|
||||
|
||||
pub fn record_skill_offer(&self, chat_id: &str) -> Result<(), AgentError> {
|
||||
@ -387,14 +283,12 @@ impl Session {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.store
|
||||
.append_skill_event(
|
||||
Some(&self.persistent_session_id(chat_id)),
|
||||
self.history.append_skill_event(
|
||||
chat_id,
|
||||
"offered",
|
||||
None,
|
||||
&self.skills.offered_event_payload(),
|
||||
)
|
||||
.map_err(|err| AgentError::Other(format!("append skill event error: {}", err)))
|
||||
}
|
||||
|
||||
/// 创建一个临时的 AgentLoop 实例来处理消息
|
||||
@ -427,22 +321,6 @@ impl Session {
|
||||
provider_config,
|
||||
})
|
||||
}
|
||||
|
||||
fn ensure_initial_agent_prompt(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||
let history_is_empty = self
|
||||
.get_history(chat_id)
|
||||
.map(|history| history.is_empty())
|
||||
.unwrap_or(true);
|
||||
|
||||
if !history_is_empty {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let prompt_injector = self.prompt_injector.clone();
|
||||
prompt_injector.ensure_initial_prompt(history_is_empty, |message| {
|
||||
self.append_persisted_message(chat_id, message)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// SessionManager 管理所有 Session,按 channel_name 路由
|
||||
@ -1049,7 +927,7 @@ mod tests {
|
||||
let session = session_manager.get("feishu").await.unwrap();
|
||||
let session_guard = session.lock().await;
|
||||
let persisted_messages = session_guard
|
||||
.store
|
||||
.store()
|
||||
.load_messages(&session_guard.persistent_session_id("chat-guard"))
|
||||
.unwrap();
|
||||
|
||||
|
||||
262
src/gateway/session_history.rs
Normal file
262
src/gateway/session_history.rs
Normal file
@ -0,0 +1,262 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::agent::AgentError;
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
|
||||
|
||||
use super::prompt_injector::PromptInjector;
|
||||
|
||||
fn preview_text(content: &str, max_chars: usize) -> String {
|
||||
let mut preview = content.chars().take(max_chars).collect::<String>();
|
||||
if content.chars().count() > max_chars {
|
||||
preview.push_str("...");
|
||||
}
|
||||
preview.replace('\n', "\\n")
|
||||
}
|
||||
|
||||
pub(crate) struct SessionHistory {
|
||||
channel_name: String,
|
||||
chat_histories: HashMap<String, Vec<ChatMessage>>,
|
||||
compression_in_flight: HashSet<String>,
|
||||
prompt_injector: PromptInjector,
|
||||
store: Arc<SessionStore>,
|
||||
}
|
||||
|
||||
impl SessionHistory {
|
||||
pub(crate) fn new(
|
||||
channel_name: impl Into<String>,
|
||||
prompt_injector: PromptInjector,
|
||||
store: Arc<SessionStore>,
|
||||
) -> Self {
|
||||
Self {
|
||||
channel_name: channel_name.into(),
|
||||
chat_histories: HashMap::new(),
|
||||
compression_in_flight: HashSet::new(),
|
||||
prompt_injector,
|
||||
store,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn persistent_session_id(&self, chat_id: &str) -> String {
|
||||
persistent_session_id(&self.channel_name, chat_id)
|
||||
}
|
||||
|
||||
pub(crate) fn ensure_persistent_session(
|
||||
&self,
|
||||
chat_id: &str,
|
||||
) -> Result<SessionRecord, AgentError> {
|
||||
self.store
|
||||
.ensure_channel_session(&self.channel_name, chat_id)
|
||||
.map_err(|err| AgentError::Other(format!("session persistence error: {}", err)))
|
||||
}
|
||||
|
||||
pub(crate) fn ensure_chat_loaded(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||
if self.chat_histories.contains_key(chat_id) {
|
||||
return self.ensure_initial_agent_prompt(chat_id);
|
||||
}
|
||||
|
||||
let history = self
|
||||
.store
|
||||
.load_messages(&self.persistent_session_id(chat_id))
|
||||
.map_err(|err| AgentError::Other(format!("session history load error: {}", err)))?;
|
||||
self.chat_histories.insert(chat_id.to_string(), history);
|
||||
self.ensure_initial_agent_prompt(chat_id)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn ensure_agent_prompt_before_user_message(
|
||||
&mut self,
|
||||
chat_id: &str,
|
||||
) -> Result<(), AgentError> {
|
||||
self.ensure_chat_loaded(chat_id)?;
|
||||
|
||||
let session_id = self.persistent_session_id(chat_id);
|
||||
let prompt_injector = self.prompt_injector.clone();
|
||||
prompt_injector.ensure_reinjected_prompt(&session_id, |message| {
|
||||
self.append_persisted_message(chat_id, message)
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn get_or_create_history(&mut self, chat_id: &str) -> &mut Vec<ChatMessage> {
|
||||
self.chat_histories.entry(chat_id.to_string()).or_default()
|
||||
}
|
||||
|
||||
pub(crate) fn get_history(&self, chat_id: &str) -> Option<&Vec<ChatMessage>> {
|
||||
self.chat_histories.get(chat_id)
|
||||
}
|
||||
|
||||
pub(crate) fn add_message(&mut self, chat_id: &str, message: ChatMessage) {
|
||||
self.get_or_create_history(chat_id).push(message);
|
||||
}
|
||||
|
||||
pub(crate) fn remove_history(&mut self, chat_id: &str) {
|
||||
self.chat_histories.remove(chat_id);
|
||||
self.compression_in_flight.remove(chat_id);
|
||||
}
|
||||
|
||||
pub(crate) fn clear_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||
if let Some(history) = self.chat_histories.get_mut(chat_id) {
|
||||
let len = history.len();
|
||||
history.clear();
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history cleared");
|
||||
}
|
||||
|
||||
self.store
|
||||
.clear_messages(&self.persistent_session_id(chat_id))
|
||||
.map_err(|err| AgentError::Other(format!("clear history persistence error: {}", err)))
|
||||
}
|
||||
|
||||
pub(crate) fn reset_chat_context(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||
if let Some(history) = self.chat_histories.get_mut(chat_id) {
|
||||
let len = history.len();
|
||||
history.clear();
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history reset in memory");
|
||||
}
|
||||
|
||||
self.store
|
||||
.reset_session(&self.persistent_session_id(chat_id))
|
||||
.map_err(|err| AgentError::Other(format!("reset history persistence error: {}", err)))
|
||||
}
|
||||
|
||||
pub(crate) fn append_persisted_message(
|
||||
&mut self,
|
||||
chat_id: &str,
|
||||
message: ChatMessage,
|
||||
) -> Result<(), AgentError> {
|
||||
let session_id = self.persistent_session_id(chat_id);
|
||||
self.store
|
||||
.append_message(&session_id, &message)
|
||||
.map_err(|err| {
|
||||
AgentError::Other(format!("append message persistence error: {}", err))
|
||||
})?;
|
||||
self.add_message(chat_id, message);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn append_persisted_messages<I>(
|
||||
&mut self,
|
||||
chat_id: &str,
|
||||
messages: I,
|
||||
) -> Result<(), AgentError>
|
||||
where
|
||||
I: IntoIterator<Item = ChatMessage>,
|
||||
{
|
||||
for message in messages {
|
||||
self.append_persisted_message(chat_id, message)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn latest_user_message(&self, chat_id: &str) -> Option<&ChatMessage> {
|
||||
self.get_history(chat_id)
|
||||
.and_then(|history| history.iter().rev().find(|message| message.role == "user"))
|
||||
}
|
||||
|
||||
pub(crate) fn matches_current_user_turn(&self, chat_id: &str, message: &ChatMessage) -> bool {
|
||||
self.latest_user_message(chat_id)
|
||||
.map(|current| {
|
||||
current.id == message.id
|
||||
|| (current.content == message.content
|
||||
&& current.timestamp == message.timestamp
|
||||
&& current.media_refs == message.media_refs)
|
||||
})
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
pub(crate) fn stale_result_diagnostics(
|
||||
&self,
|
||||
chat_id: &str,
|
||||
) -> (Option<&str>, Option<String>, bool, usize) {
|
||||
let latest_user = self.latest_user_message(chat_id);
|
||||
let latest_user_id = latest_user.map(|message| message.id.as_str());
|
||||
let latest_user_preview = latest_user.map(|message| preview_text(&message.content, 80));
|
||||
let compression_in_flight = self.compression_in_flight.contains(chat_id);
|
||||
let history_len = self
|
||||
.get_history(chat_id)
|
||||
.map(|history| history.len())
|
||||
.unwrap_or(0);
|
||||
|
||||
(
|
||||
latest_user_id,
|
||||
latest_user_preview,
|
||||
compression_in_flight,
|
||||
history_len,
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn clear_all_history(&mut self) -> Result<(), AgentError> {
|
||||
let chat_ids: Vec<String> = self.chat_histories.keys().cloned().collect();
|
||||
let total: usize = self.chat_histories.values().map(|h| h.len()).sum();
|
||||
self.chat_histories.clear();
|
||||
self.compression_in_flight.clear();
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(previous_total = total, "All chat histories cleared");
|
||||
|
||||
for chat_id in chat_ids {
|
||||
self.store
|
||||
.clear_messages(&self.persistent_session_id(&chat_id))
|
||||
.map_err(|err| {
|
||||
AgentError::Other(format!("clear history persistence error: {}", err))
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn try_start_background_compaction(&mut self, chat_id: &str) -> bool {
|
||||
self.compression_in_flight.insert(chat_id.to_string())
|
||||
}
|
||||
|
||||
pub(crate) fn finish_background_compaction(&mut self, chat_id: &str) {
|
||||
self.compression_in_flight.remove(chat_id);
|
||||
}
|
||||
|
||||
pub(crate) fn reload_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||
let history = self
|
||||
.store
|
||||
.load_messages(&self.persistent_session_id(chat_id))
|
||||
.map_err(|err| AgentError::Other(format!("session history reload error: {}", err)))?;
|
||||
self.chat_histories.insert(chat_id.to_string(), history);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn store(&self) -> Arc<SessionStore> {
|
||||
self.store.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn append_skill_event(
|
||||
&self,
|
||||
chat_id: &str,
|
||||
event_type: &str,
|
||||
skill_name: Option<&str>,
|
||||
payload: &serde_json::Value,
|
||||
) -> Result<(), AgentError> {
|
||||
self.store
|
||||
.append_skill_event(
|
||||
Some(&self.persistent_session_id(chat_id)),
|
||||
event_type,
|
||||
skill_name,
|
||||
payload,
|
||||
)
|
||||
.map_err(|err| AgentError::Other(format!("append skill event error: {}", err)))
|
||||
}
|
||||
|
||||
fn ensure_initial_agent_prompt(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||||
let history_is_empty = self
|
||||
.get_history(chat_id)
|
||||
.map(|history| history.is_empty())
|
||||
.unwrap_or(true);
|
||||
|
||||
if !history_is_empty {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let prompt_injector = self.prompt_injector.clone();
|
||||
prompt_injector.ensure_initial_prompt(history_is_empty, |message| {
|
||||
self.append_persisted_message(chat_id, message)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -1,22 +1,22 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::agent::{SkillEvent, SkillEventSink};
|
||||
use crate::storage::SessionStore;
|
||||
use crate::storage::SkillEventRepository;
|
||||
|
||||
pub(crate) struct PersistentSkillEventSink {
|
||||
store: Arc<SessionStore>,
|
||||
events: Arc<dyn SkillEventRepository>,
|
||||
session_id: String,
|
||||
}
|
||||
|
||||
impl PersistentSkillEventSink {
|
||||
pub(crate) fn new(store: Arc<SessionStore>, session_id: String) -> Self {
|
||||
Self { store, session_id }
|
||||
pub(crate) fn new(events: Arc<dyn SkillEventRepository>, session_id: String) -> Self {
|
||||
Self { events, session_id }
|
||||
}
|
||||
}
|
||||
|
||||
impl SkillEventSink for PersistentSkillEventSink {
|
||||
fn record_skill_event(&self, event: SkillEvent) {
|
||||
if let Err(err) = self.store.append_skill_event(
|
||||
if let Err(err) = self.events.append_skill_event(
|
||||
Some(&self.session_id),
|
||||
&event.event_type,
|
||||
event.skill_name.as_deref(),
|
||||
|
||||
@ -230,6 +230,24 @@ impl SkillRuntime {
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::agent::SkillProvider for SkillRuntime {
|
||||
fn system_index_prompt(&self) -> Option<String> {
|
||||
SkillRuntime::system_index_prompt(self)
|
||||
}
|
||||
|
||||
fn skill_tool_definition(&self) -> Option<Tool> {
|
||||
SkillRuntime::skill_tool_definition(self)
|
||||
}
|
||||
|
||||
fn activation_payload(&self, name: &str) -> Result<String, String> {
|
||||
SkillRuntime::activation_payload(self, name)
|
||||
}
|
||||
|
||||
fn activation_event_payload(&self, name: &str) -> Result<serde_json::Value, String> {
|
||||
SkillRuntime::activation_event_payload(self, name)
|
||||
}
|
||||
}
|
||||
|
||||
impl SkillSource {
|
||||
fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
|
||||
@ -7,9 +7,11 @@ use rusqlite::{Connection, OptionalExtension, params};
|
||||
use crate::bus::ChatMessage;
|
||||
|
||||
pub mod error;
|
||||
pub mod ports;
|
||||
pub mod records;
|
||||
|
||||
pub use error::StorageError;
|
||||
pub use ports::{MemoryRepository, SchedulerJobRepository, SkillEventRepository};
|
||||
pub use records::{
|
||||
MemoryRecord, MemoryUpsert, SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus,
|
||||
SchedulerJobUpsert, SessionRecord, SkillEventRecord,
|
||||
|
||||
170
src/storage/ports.rs
Normal file
170
src/storage/ports.rs
Normal file
@ -0,0 +1,170 @@
|
||||
use super::{
|
||||
MemoryRecord, MemoryUpsert, SchedulerJobRecord, SchedulerJobUpsert, SkillEventRecord,
|
||||
StorageError,
|
||||
};
|
||||
|
||||
pub trait MemoryRepository: Send + Sync + 'static {
|
||||
fn put_memory(&self, input: &MemoryUpsert) -> Result<MemoryRecord, StorageError>;
|
||||
|
||||
fn update_memory(&self, input: &MemoryUpsert) -> Result<Option<MemoryRecord>, StorageError>;
|
||||
|
||||
fn delete_memory(
|
||||
&self,
|
||||
scope_kind: &str,
|
||||
scope_key: &str,
|
||||
namespace: &str,
|
||||
memory_key: &str,
|
||||
) -> Result<bool, StorageError>;
|
||||
|
||||
fn get_memory(
|
||||
&self,
|
||||
scope_kind: &str,
|
||||
scope_key: &str,
|
||||
namespace: &str,
|
||||
memory_key: &str,
|
||||
) -> Result<Option<MemoryRecord>, StorageError>;
|
||||
|
||||
fn list_memories(
|
||||
&self,
|
||||
scope_kind: &str,
|
||||
scope_key: &str,
|
||||
namespace: Option<&str>,
|
||||
limit: usize,
|
||||
) -> Result<Vec<MemoryRecord>, StorageError>;
|
||||
|
||||
fn search_memories_any(
|
||||
&self,
|
||||
scope_kind: &str,
|
||||
scope_key: &str,
|
||||
queries: &[String],
|
||||
namespace: Option<&str>,
|
||||
limit: usize,
|
||||
) -> Result<Vec<MemoryRecord>, StorageError>;
|
||||
}
|
||||
|
||||
pub trait SchedulerJobRepository: Send + Sync + 'static {
|
||||
fn upsert_scheduler_job(
|
||||
&self,
|
||||
input: &SchedulerJobUpsert,
|
||||
) -> Result<SchedulerJobRecord, StorageError>;
|
||||
|
||||
fn get_scheduler_job(&self, job_id: &str) -> Result<Option<SchedulerJobRecord>, StorageError>;
|
||||
|
||||
fn list_scheduler_jobs(
|
||||
&self,
|
||||
enabled_only: bool,
|
||||
) -> Result<Vec<SchedulerJobRecord>, StorageError>;
|
||||
|
||||
fn delete_scheduler_job(&self, job_id: &str) -> Result<(), StorageError>;
|
||||
}
|
||||
|
||||
pub trait SkillEventRepository: Send + Sync + 'static {
|
||||
fn append_skill_event(
|
||||
&self,
|
||||
session_id: Option<&str>,
|
||||
event_type: &str,
|
||||
skill_name: Option<&str>,
|
||||
payload: &serde_json::Value,
|
||||
) -> Result<(), StorageError>;
|
||||
|
||||
fn list_skill_events(
|
||||
&self,
|
||||
session_id: Option<&str>,
|
||||
) -> Result<Vec<SkillEventRecord>, StorageError>;
|
||||
}
|
||||
|
||||
impl MemoryRepository for super::SessionStore {
|
||||
fn put_memory(&self, input: &MemoryUpsert) -> Result<MemoryRecord, StorageError> {
|
||||
super::SessionStore::put_memory(self, input)
|
||||
}
|
||||
|
||||
fn update_memory(&self, input: &MemoryUpsert) -> Result<Option<MemoryRecord>, StorageError> {
|
||||
super::SessionStore::update_memory(self, input)
|
||||
}
|
||||
|
||||
fn delete_memory(
|
||||
&self,
|
||||
scope_kind: &str,
|
||||
scope_key: &str,
|
||||
namespace: &str,
|
||||
memory_key: &str,
|
||||
) -> Result<bool, StorageError> {
|
||||
super::SessionStore::delete_memory(self, scope_kind, scope_key, namespace, memory_key)
|
||||
}
|
||||
|
||||
fn get_memory(
|
||||
&self,
|
||||
scope_kind: &str,
|
||||
scope_key: &str,
|
||||
namespace: &str,
|
||||
memory_key: &str,
|
||||
) -> Result<Option<MemoryRecord>, StorageError> {
|
||||
super::SessionStore::get_memory(self, scope_kind, scope_key, namespace, memory_key)
|
||||
}
|
||||
|
||||
fn list_memories(
|
||||
&self,
|
||||
scope_kind: &str,
|
||||
scope_key: &str,
|
||||
namespace: Option<&str>,
|
||||
limit: usize,
|
||||
) -> Result<Vec<MemoryRecord>, StorageError> {
|
||||
super::SessionStore::list_memories(self, scope_kind, scope_key, namespace, limit)
|
||||
}
|
||||
|
||||
fn search_memories_any(
|
||||
&self,
|
||||
scope_kind: &str,
|
||||
scope_key: &str,
|
||||
queries: &[String],
|
||||
namespace: Option<&str>,
|
||||
limit: usize,
|
||||
) -> Result<Vec<MemoryRecord>, StorageError> {
|
||||
super::SessionStore::search_memories_any(
|
||||
self, scope_kind, scope_key, queries, namespace, limit,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl SchedulerJobRepository for super::SessionStore {
|
||||
fn upsert_scheduler_job(
|
||||
&self,
|
||||
input: &SchedulerJobUpsert,
|
||||
) -> Result<SchedulerJobRecord, StorageError> {
|
||||
super::SessionStore::upsert_scheduler_job(self, input)
|
||||
}
|
||||
|
||||
fn get_scheduler_job(&self, job_id: &str) -> Result<Option<SchedulerJobRecord>, StorageError> {
|
||||
super::SessionStore::get_scheduler_job(self, job_id)
|
||||
}
|
||||
|
||||
fn list_scheduler_jobs(
|
||||
&self,
|
||||
enabled_only: bool,
|
||||
) -> Result<Vec<SchedulerJobRecord>, StorageError> {
|
||||
super::SessionStore::list_scheduler_jobs(self, enabled_only)
|
||||
}
|
||||
|
||||
fn delete_scheduler_job(&self, job_id: &str) -> Result<(), StorageError> {
|
||||
super::SessionStore::delete_scheduler_job(self, job_id)
|
||||
}
|
||||
}
|
||||
|
||||
impl SkillEventRepository for super::SessionStore {
|
||||
fn append_skill_event(
|
||||
&self,
|
||||
session_id: Option<&str>,
|
||||
event_type: &str,
|
||||
skill_name: Option<&str>,
|
||||
payload: &serde_json::Value,
|
||||
) -> Result<(), StorageError> {
|
||||
super::SessionStore::append_skill_event(self, session_id, event_type, skill_name, payload)
|
||||
}
|
||||
|
||||
fn list_skill_events(
|
||||
&self,
|
||||
session_id: Option<&str>,
|
||||
) -> Result<Vec<SkillEventRecord>, StorageError> {
|
||||
super::SessionStore::list_skill_events(self, session_id)
|
||||
}
|
||||
}
|
||||
@ -3,16 +3,16 @@ use std::sync::Arc;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::storage::{MemoryRecord, MemoryUpsert, SessionStore};
|
||||
use crate::storage::{MemoryRecord, MemoryRepository, MemoryUpsert};
|
||||
use crate::tools::traits::{Tool, ToolContext, ToolResult};
|
||||
|
||||
pub struct MemoryManageTool {
|
||||
store: Arc<SessionStore>,
|
||||
memories: Arc<dyn MemoryRepository>,
|
||||
}
|
||||
|
||||
impl MemoryManageTool {
|
||||
pub fn new(store: Arc<SessionStore>) -> Self {
|
||||
Self { store }
|
||||
pub fn new(memories: Arc<dyn MemoryRepository>) -> Self {
|
||||
Self { memories }
|
||||
}
|
||||
}
|
||||
|
||||
@ -23,7 +23,7 @@ impl Tool for MemoryManageTool {
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Create, update, or delete long-term user memories stored in SQLite. Supports actions: put, update, delete. Use memory_search as the default retrieval path before answering most requests, and use memory_search for all retrieval actions including search, get, and list. Only call this tool when you have determined that a high-value long-term memory should be created, overwritten, updated, or deleted. Memories are scoped to the current channel and sender, and record the originating session/message when available."
|
||||
"Create, update, or delete long-term user memories in the configured memory repository. Supports actions: put, update, delete. Use memory_search as the default retrieval path before answering most requests, and use memory_search for all retrieval actions including search, get, and list. Only call this tool when you have determined that a high-value long-term memory should be created, overwritten, updated, or deleted. Memories are scoped to the current channel and sender, and record the originating session/message when available."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
@ -80,7 +80,7 @@ impl Tool for MemoryManageTool {
|
||||
Ok(input) => input,
|
||||
Err(result) => return Ok(result),
|
||||
};
|
||||
memory_to_json(self.store.put_memory(&input)?)
|
||||
memory_to_json(self.memories.put_memory(&input)?)
|
||||
}
|
||||
"update" => {
|
||||
let input = match build_memory_upsert(context, &scope_key, &args, false) {
|
||||
@ -88,7 +88,7 @@ impl Tool for MemoryManageTool {
|
||||
Err(result) => return Ok(result),
|
||||
};
|
||||
|
||||
match self.store.update_memory(&input)? {
|
||||
match self.memories.update_memory(&input)? {
|
||||
Some(memory) => memory_to_json(memory),
|
||||
None => {
|
||||
return Ok(error_result(&format!(
|
||||
@ -109,7 +109,7 @@ impl Tool for MemoryManageTool {
|
||||
};
|
||||
|
||||
let deleted = self
|
||||
.store
|
||||
.memories
|
||||
.delete_memory("user", &scope_key, namespace, key)?;
|
||||
if !deleted {
|
||||
return Ok(error_result(&format!(
|
||||
@ -219,6 +219,7 @@ fn error_result(message: &str) -> ToolResult {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::storage::SessionStore;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_memory_manage_put_returns_saved_memory() {
|
||||
|
||||
@ -3,16 +3,16 @@ use std::sync::Arc;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::storage::{MemoryRecord, SessionStore};
|
||||
use crate::storage::{MemoryRecord, MemoryRepository};
|
||||
use crate::tools::traits::{Tool, ToolContext, ToolResult};
|
||||
|
||||
pub struct MemorySearchTool {
|
||||
store: Arc<SessionStore>,
|
||||
memories: Arc<dyn MemoryRepository>,
|
||||
}
|
||||
|
||||
impl MemorySearchTool {
|
||||
pub fn new(store: Arc<SessionStore>) -> Self {
|
||||
Self { store }
|
||||
pub fn new(memories: Arc<dyn MemoryRepository>) -> Self {
|
||||
Self { memories }
|
||||
}
|
||||
}
|
||||
|
||||
@ -23,7 +23,7 @@ impl Tool for MemorySearchTool {
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Search and read long-term user memories stored in SQLite. This is the default entry point for memory retrieval and should usually be the first memory tool you call at the start of a request, unless the request is clearly a simple greeting, a one-off calculation, or a direct fact question that does not depend on user history. Use it to recall prior preferences, stable facts, historical decisions, and ongoing task context. If the request also needs other independent read-only tools, you may call memory_search in the same round alongside them. This tool is read-only and supports three actions: search for multi-keyword recall, get for exact namespace/key lookup, and list for browsing recent memories. Prefer this tool over memory_manage whenever you only need to retrieve memory."
|
||||
"Search and read long-term user memories from the configured memory repository. This is the default entry point for memory retrieval and should usually be the first memory tool you call at the start of a request, unless the request is clearly a simple greeting, a one-off calculation, or a direct fact question that does not depend on user history. Use it to recall prior preferences, stable facts, historical decisions, and ongoing task context. If the request also needs other independent read-only tools, you may call memory_search in the same round alongside them. This tool is read-only and supports three actions: search for multi-keyword recall, get for exact namespace/key lookup, and list for browsing recent memories. Prefer this tool over memory_manage whenever you only need to retrieve memory."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
@ -91,7 +91,7 @@ impl Tool for MemorySearchTool {
|
||||
.and_then(|value| value.as_u64())
|
||||
.unwrap_or(10) as usize;
|
||||
let memories = self
|
||||
.store
|
||||
.memories
|
||||
.list_memories("user", &scope_key, namespace, limit)?;
|
||||
json!({
|
||||
"count": memories.len(),
|
||||
@ -117,7 +117,7 @@ impl Tool for MemorySearchTool {
|
||||
.and_then(|value| value.as_u64())
|
||||
.unwrap_or(10) as usize;
|
||||
let memories = self
|
||||
.store
|
||||
.memories
|
||||
.search_memories_any("user", &scope_key, &queries, namespace, limit)?;
|
||||
json!({
|
||||
"queries": queries,
|
||||
@ -135,7 +135,10 @@ impl Tool for MemorySearchTool {
|
||||
None => return Ok(error_result("Missing required parameter: key")),
|
||||
};
|
||||
|
||||
match self.store.get_memory("user", &scope_key, namespace, key)? {
|
||||
match self
|
||||
.memories
|
||||
.get_memory("user", &scope_key, namespace, key)?
|
||||
{
|
||||
Some(memory) => memory_to_json(memory),
|
||||
None => {
|
||||
return Ok(error_result(&format!(
|
||||
@ -202,6 +205,7 @@ fn error_result(message: &str) -> ToolResult {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::storage::SessionStore;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_memory_search_search_and_get() {
|
||||
|
||||
@ -5,18 +5,20 @@ use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::config::SchedulerSchedule;
|
||||
use crate::storage::{SchedulerJobRecord, SchedulerJobState, SchedulerJobUpsert, SessionStore};
|
||||
use crate::storage::{
|
||||
SchedulerJobRecord, SchedulerJobRepository, SchedulerJobState, SchedulerJobUpsert,
|
||||
};
|
||||
use crate::tools::traits::{Tool, ToolResult};
|
||||
|
||||
pub struct SchedulerManageTool {
|
||||
store: Arc<SessionStore>,
|
||||
jobs: Arc<dyn SchedulerJobRepository>,
|
||||
known_agents: Arc<HashSet<String>>,
|
||||
}
|
||||
|
||||
impl SchedulerManageTool {
|
||||
pub fn new(store: Arc<SessionStore>, known_agents: HashSet<String>) -> Self {
|
||||
pub fn new(jobs: Arc<dyn SchedulerJobRepository>, known_agents: HashSet<String>) -> Self {
|
||||
Self {
|
||||
store,
|
||||
jobs,
|
||||
known_agents: Arc::new(known_agents),
|
||||
}
|
||||
}
|
||||
@ -29,7 +31,7 @@ impl Tool for SchedulerManageTool {
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Manage DB-backed scheduled jobs. Supports actions: list, get, put, delete, pause, resume. Jobs persist in SQLite and are executed by the scheduler runtime. When creating agent_task or silent_agent_task jobs, keep prompt/system_prompt focused on the work to perform; do not restate execution times unless the task logic truly depends on them, because the trigger already controls timing."
|
||||
"Manage repository-backed scheduled jobs. Supports actions: list, get, put, delete, pause, resume. Jobs are persisted by the configured scheduler job repository and executed by the scheduler runtime. When creating agent_task or silent_agent_task jobs, keep prompt/system_prompt focused on the work to perform; do not restate execution times unless the task logic truly depends on them, because the trigger already controls timing."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
@ -116,30 +118,30 @@ impl Tool for SchedulerManageTool {
|
||||
.get("enabled_only")
|
||||
.and_then(|value| value.as_bool())
|
||||
.unwrap_or(false);
|
||||
let jobs = self.store.list_scheduler_jobs(enabled_only)?;
|
||||
let jobs = self.jobs.list_scheduler_jobs(enabled_only)?;
|
||||
json!(jobs.iter().map(record_to_json).collect::<Vec<_>>())
|
||||
}
|
||||
"get" => {
|
||||
let id = require_str(&args, "id")?;
|
||||
match self.store.get_scheduler_job(id)? {
|
||||
match self.jobs.get_scheduler_job(id)? {
|
||||
Some(record) => record_to_json(&record),
|
||||
None => return Ok(error_result(&format!("scheduler job '{}' not found", id))),
|
||||
}
|
||||
}
|
||||
"put" => {
|
||||
let input = build_upsert(context, &args, &self.known_agents)?;
|
||||
let record = self.store.upsert_scheduler_job(&input)?;
|
||||
let record = self.jobs.upsert_scheduler_job(&input)?;
|
||||
record_to_json(&record)
|
||||
}
|
||||
"delete" => {
|
||||
let id = require_str(&args, "id")?;
|
||||
self.store.delete_scheduler_job(id)?;
|
||||
self.jobs.delete_scheduler_job(id)?;
|
||||
json!({"status": "deleted", "id": id})
|
||||
}
|
||||
"pause" => {
|
||||
let id = require_str(&args, "id")?;
|
||||
let record = self
|
||||
.store
|
||||
.jobs
|
||||
.get_scheduler_job(id)?
|
||||
.ok_or_else(|| anyhow::anyhow!("scheduler job '{}' not found", id))?;
|
||||
let mut input = record_to_upsert(&record);
|
||||
@ -147,13 +149,13 @@ impl Tool for SchedulerManageTool {
|
||||
input.state = SchedulerJobState::Paused;
|
||||
input.paused_at = Some(current_timestamp());
|
||||
input.next_fire_at = None;
|
||||
let saved = self.store.upsert_scheduler_job(&input)?;
|
||||
let saved = self.jobs.upsert_scheduler_job(&input)?;
|
||||
record_to_json(&saved)
|
||||
}
|
||||
"resume" => {
|
||||
let id = require_str(&args, "id")?;
|
||||
let record = self
|
||||
.store
|
||||
.jobs
|
||||
.get_scheduler_job(id)?
|
||||
.ok_or_else(|| anyhow::anyhow!("scheduler job '{}' not found", id))?;
|
||||
let mut input = record_to_upsert(&record);
|
||||
@ -162,7 +164,7 @@ impl Tool for SchedulerManageTool {
|
||||
input.paused_at = None;
|
||||
input.completed_at = None;
|
||||
input.next_fire_at = None;
|
||||
let saved = self.store.upsert_scheduler_job(&input)?;
|
||||
let saved = self.jobs.upsert_scheduler_job(&input)?;
|
||||
record_to_json(&saved)
|
||||
}
|
||||
_ => return Ok(error_result("Unsupported action")),
|
||||
@ -431,6 +433,7 @@ fn current_timestamp() -> i64 {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::storage::SessionStore;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_scheduler_manage_put_and_get() {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user