diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index 3b31ac1..3af02e4 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -1,11 +1,11 @@ use crate::bus::ChatMessage; -use crate::bus::message::ContentBlock; use crate::bus::message::ToolMessageState; use crate::config::LLMProviderConfig; +use crate::domain::messages::{ContentBlock, ToolCall}; use crate::observability::{ Observer, ObserverEvent, ToolExecutionOutcome, ToolExecutionState, truncate_args, }; -use crate::providers::{ChatCompletionRequest, LLMProvider, Message, ToolCall, create_provider}; +use crate::providers::{ChatCompletionRequest, LLMProvider, Message, create_provider}; use crate::skills::SkillRuntime; use crate::storage::SessionStore; use crate::text::{char_count, take_prefix_chars, take_suffix_chars}; diff --git a/src/bus/message.rs b/src/bus/message.rs index 75c03d8..cdf7ac3 100644 --- a/src/bus/message.rs +++ b/src/bus/message.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; -use crate::providers::ToolCall; +use crate::domain::messages::ToolCall; pub const SYSTEM_CONTEXT_AGENT_PROMPT: &str = "agent_prompt"; pub const SYSTEM_CONTEXT_SCHEDULED_PROMPT: &str = "scheduled_system_prompt"; @@ -14,38 +14,6 @@ pub enum ToolMessageState { PendingUserAction, } -// ============================================================================ -// ContentBlock - Multimodal content representation (OpenAI-style) -// ============================================================================ - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "type", rename_all = "snake_case")] -pub enum ContentBlock { - #[serde(rename = "text")] - Text { text: String }, - #[serde(rename = "image_url")] - ImageUrl { image_url: ImageUrlBlock }, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ImageUrlBlock { - pub url: String, -} - -impl ContentBlock { - pub fn text(content: impl Into) -> Self { - Self::Text { - text: content.into(), - } - } - - pub fn image_url(url: impl Into) -> Self { - Self::ImageUrl { - image_url: ImageUrlBlock { url: url.into() }, - } - } -} - // ============================================================================ // MediaItem - Media metadata for messages // ============================================================================ @@ -566,7 +534,7 @@ fn current_timestamp() -> i64 { #[cfg(test)] mod tests { use super::{ChatMessage, OutboundEventKind, OutboundMessage, ToolMessageState}; - use crate::providers::ToolCall; + use crate::domain::messages::ToolCall; use serde_json::json; use std::collections::HashMap; diff --git a/src/bus/mod.rs b/src/bus/mod.rs index 77e642c..c0325b0 100644 --- a/src/bus/mod.rs +++ b/src/bus/mod.rs @@ -1,11 +1,11 @@ pub mod dispatcher; pub mod message; +pub use crate::domain::messages::ContentBlock; pub use dispatcher::OutboundDispatcher; pub use message::{ - ChatMessage, ContentBlock, InboundMessage, MediaItem, OutboundMessage, - SYSTEM_CONTEXT_AGENT_PROMPT, SYSTEM_CONTEXT_HISTORY_COMPACTION, - SYSTEM_CONTEXT_SCHEDULED_PROMPT, + ChatMessage, InboundMessage, MediaItem, OutboundMessage, SYSTEM_CONTEXT_AGENT_PROMPT, + SYSTEM_CONTEXT_HISTORY_COMPACTION, SYSTEM_CONTEXT_SCHEDULED_PROMPT, }; use std::sync::Arc; diff --git a/src/domain/messages.rs b/src/domain/messages.rs new file mode 100644 index 0000000..174880b --- /dev/null +++ b/src/domain/messages.rs @@ -0,0 +1,36 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ContentBlock { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "image_url")] + ImageUrl { image_url: ImageUrlBlock }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ImageUrlBlock { + pub url: String, +} + +impl ContentBlock { + pub fn text(content: impl Into) -> Self { + Self::Text { + text: content.into(), + } + } + + pub fn image_url(url: impl Into) -> Self { + Self::ImageUrl { + image_url: ImageUrlBlock { url: url.into() }, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCall { + pub id: String, + pub name: String, + pub arguments: serde_json::Value, +} diff --git a/src/domain/mod.rs b/src/domain/mod.rs new file mode 100644 index 0000000..ba63992 --- /dev/null +++ b/src/domain/mod.rs @@ -0,0 +1 @@ +pub mod messages; diff --git a/src/gateway/agent_task_executor.rs b/src/gateway/agent_task_executor.rs index a77376b..1b2e7ce 100644 --- a/src/gateway/agent_task_executor.rs +++ b/src/gateway/agent_task_executor.rs @@ -1,8 +1,13 @@ use crate::agent::AgentError; use crate::bus::OutboundMessage; +use crate::scheduler::{ + AgentTaskExecutor as SchedulerAgentTaskExecutor, MaintenanceExecutor, MaintenanceRunSummary, + ScheduledAgentTaskOptions, +}; +use async_trait::async_trait; use super::memory_maintenance::MemoryMaintenanceScopeResult; -use super::session::{ScheduledAgentTaskOptions, SessionManager}; +use super::session::SessionManager; #[derive(Clone)] pub struct AgentTaskExecutor { @@ -14,7 +19,7 @@ impl AgentTaskExecutor { Self { session_manager } } - pub(crate) async fn execute( + async fn execute_agent_task( &self, channel_name: &str, chat_id: &str, @@ -27,6 +32,21 @@ impl AgentTaskExecutor { } } +#[async_trait] +impl SchedulerAgentTaskExecutor for AgentTaskExecutor { + async fn execute( + &self, + channel_name: &str, + chat_id: &str, + prompt: &str, + options: ScheduledAgentTaskOptions, + ) -> anyhow::Result> { + self.execute_agent_task(channel_name, chat_id, prompt, options) + .await + .map_err(|error| anyhow::anyhow!(error.to_string())) + } +} + #[derive(Clone)] pub struct SchedulerMaintenanceService { session_manager: SessionManager, @@ -37,11 +57,11 @@ impl SchedulerMaintenanceService { Self { session_manager } } - pub(crate) async fn cleanup_expired_sessions(&self) -> usize { + async fn cleanup_sessions(&self) -> usize { self.session_manager.cleanup_expired_sessions().await } - pub(crate) async fn run_memory_maintenance_for_all_scopes( + async fn run_memory_maintenance( &self, updated_since: Option, ) -> Result, AgentError> { @@ -50,3 +70,33 @@ impl SchedulerMaintenanceService { .await } } + +#[async_trait] +impl MaintenanceExecutor for SchedulerMaintenanceService { + async fn cleanup_expired_sessions(&self) -> usize { + self.cleanup_sessions().await + } + + async fn run_memory_maintenance_for_all_scopes( + &self, + updated_since: Option, + ) -> anyhow::Result> { + self.run_memory_maintenance(updated_since) + .await + .map(|results| { + results + .into_iter() + .map(|result| MaintenanceRunSummary { + scope_key: result.scope_key, + user_facts: result.output.user_facts.len(), + preferences: result.output.preferences.len(), + behavior_patterns: result.output.behavior_patterns.len(), + merges: result.output.merges.len(), + conflicts: result.output.conflicts.len(), + low_value: result.output.low_value_ids.len(), + }) + .collect() + }) + .map_err(|error| anyhow::anyhow!(error.to_string())) + } +} diff --git a/src/gateway/session.rs b/src/gateway/session.rs index 612119a..6e1435d 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -4,6 +4,7 @@ use crate::bus::SYSTEM_CONTEXT_SCHEDULED_PROMPT; use crate::bus::{ChatMessage, MessageBus, OutboundMessage}; use crate::config::LLMProviderConfig; use crate::protocol::WsOutbound; +use crate::scheduler::ScheduledAgentTaskOptions; use crate::skills::SkillRuntime; use crate::storage::{SessionRecord, SessionStore, persistent_session_id}; use crate::tools::ToolRegistry; @@ -66,15 +67,6 @@ pub struct BusToolCallEmitter { show_tool_results: bool, } -#[derive(Debug, Clone, Default)] -pub struct ScheduledAgentTaskOptions { - pub sender_id: Option, - pub fresh_session: bool, - pub system_prompt: Option, - pub metadata: HashMap, - pub agent: Option, -} - impl BusToolCallEmitter { pub fn new( bus: Arc, diff --git a/src/gateway/ws_adapter.rs b/src/gateway/ws_adapter.rs index c013a7b..c59a5fb 100644 --- a/src/gateway/ws_adapter.rs +++ b/src/gateway/ws_adapter.rs @@ -118,7 +118,7 @@ pub(crate) fn ws_outbound_from_outbound_message(message: &OutboundMessage) -> Ve #[cfg(test)] mod tests { use super::*; - use crate::providers::ToolCall; + use crate::domain::messages::ToolCall; use serde_json::json; #[test] diff --git a/src/lib.rs b/src/lib.rs index c99e424..9e5e8c6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,7 @@ pub mod channels; pub mod cli; pub mod client; pub mod config; +pub mod domain; pub mod gateway; pub mod logging; pub mod observability; diff --git a/src/providers/anthropic.rs b/src/providers/anthropic.rs index f5563f0..a2632fe 100644 --- a/src/providers/anthropic.rs +++ b/src/providers/anthropic.rs @@ -6,7 +6,7 @@ use std::time::Duration; use super::traits::Usage; use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall}; -use crate::bus::message::ContentBlock; +use crate::domain::messages::ContentBlock; fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String { let mut details = vec![error.to_string()]; diff --git a/src/providers/mod.rs b/src/providers/mod.rs index f4f0675..8e1f25a 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -6,9 +6,9 @@ pub use self::anthropic::AnthropicProvider; pub use self::openai::OpenAIProvider; use crate::config::LLMProviderConfig; +pub use crate::domain::messages::ToolCall; pub use traits::{ - ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall, - ToolFunction, Usage, + ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolFunction, Usage, }; pub fn create_provider(config: LLMProviderConfig) -> Result, ProviderError> { diff --git a/src/providers/openai.rs b/src/providers/openai.rs index 05a3f96..fa1329a 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -7,7 +7,7 @@ use std::time::Duration; use super::traits::Usage; use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall}; -use crate::bus::message::ContentBlock; +use crate::domain::messages::ContentBlock; const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content"]; diff --git a/src/providers/traits.rs b/src/providers/traits.rs index a1ace3c..9fd8ae8 100644 --- a/src/providers/traits.rs +++ b/src/providers/traits.rs @@ -1,4 +1,4 @@ -use crate::bus::message::ContentBlock; +use crate::domain::messages::{ContentBlock, ToolCall}; use async_trait::async_trait; use serde::{Deserialize, Serialize}; @@ -91,13 +91,6 @@ pub struct ToolFunction { pub parameters: serde_json::Value, } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToolCall { - pub id: String, - pub name: String, - pub arguments: serde_json::Value, -} - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatCompletionRequest { pub messages: Vec, diff --git a/src/scheduler/mod.rs b/src/scheduler/mod.rs index 2144bba..cf7d4ca 100644 --- a/src/scheduler/mod.rs +++ b/src/scheduler/mod.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use std::str::FromStr; use std::sync::Arc; +use async_trait::async_trait; use chrono::{DateTime, Duration as ChronoDuration, TimeZone, Utc}; use chrono_tz::Tz; use tokio::sync::watch; @@ -11,37 +12,80 @@ use crate::config::{ SchedulerConfig, SchedulerJobConfig, SchedulerJobKind, SchedulerJobTarget, SchedulerMisfirePolicy, SchedulerSchedule, }; -use crate::gateway::agent_task_executor::{AgentTaskExecutor, SchedulerMaintenanceService}; -use crate::gateway::session::ScheduledAgentTaskOptions; use crate::storage::{ SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus, SchedulerJobUpsert, SessionStore, }; +#[derive(Debug, Clone, Default)] +pub struct ScheduledAgentTaskOptions { + pub sender_id: Option, + pub fresh_session: bool, + pub system_prompt: Option, + pub metadata: HashMap, + pub agent: Option, +} + +#[derive(Debug, Clone)] +pub struct MaintenanceRunSummary { + pub scope_key: String, + pub user_facts: usize, + pub preferences: usize, + pub behavior_patterns: usize, + pub merges: usize, + pub conflicts: usize, + pub low_value: usize, +} + +#[async_trait] +pub trait AgentTaskExecutor: Send + Sync { + async fn execute( + &self, + channel_name: &str, + chat_id: &str, + prompt: &str, + options: ScheduledAgentTaskOptions, + ) -> anyhow::Result>; +} + +#[async_trait] +pub trait MaintenanceExecutor: Send + Sync { + async fn cleanup_expired_sessions(&self) -> usize; + + async fn run_memory_maintenance_for_all_scopes( + &self, + updated_since: Option, + ) -> anyhow::Result>; +} + pub struct Scheduler { bus: Arc, config: SchedulerConfig, timezone: Tz, store: Arc, - agent_task_executor: AgentTaskExecutor, - maintenance_service: SchedulerMaintenanceService, + agent_task_executor: Arc, + maintenance_executor: Arc, } impl Scheduler { - pub fn new( + pub fn new( bus: Arc, config: SchedulerConfig, timezone: Tz, store: Arc, - agent_task_executor: AgentTaskExecutor, - maintenance_service: SchedulerMaintenanceService, - ) -> Self { + agent_task_executor: A, + maintenance_executor: M, + ) -> Self + where + A: AgentTaskExecutor + 'static, + M: MaintenanceExecutor + 'static, + { Self { bus, config, timezone, store, - agent_task_executor, - maintenance_service, + agent_task_executor: Arc::new(agent_task_executor), + maintenance_executor: Arc::new(maintenance_executor), } } @@ -171,11 +215,11 @@ impl Scheduler { self.bus.publish_outbound(message).await?; } SchedulerJobKind::InternalEvent => { - execute_internal_event(&self.maintenance_service, job).await?; + execute_internal_event(self.maintenance_executor.as_ref(), job).await?; } SchedulerJobKind::AgentTask => { let outbound_messages = execute_agent_task( - &self.agent_task_executor, + self.agent_task_executor.as_ref(), job, required_notification_chat_id(job, "agent_task")?, ) @@ -187,7 +231,8 @@ impl Scheduler { SchedulerJobKind::SilentAgentTask => { let execution_chat_id = resolve_execution_chat_id(job)?; if let Err(error) = - execute_agent_task(&self.agent_task_executor, job, &execution_chat_id).await + execute_agent_task(self.agent_task_executor.as_ref(), job, &execution_chat_id) + .await { if let Err(notify_error) = self.notify_silent_agent_task_failure(job, &error).await @@ -590,7 +635,7 @@ fn build_outbound_message(job: &RuntimeJob) -> anyhow::Result { } async fn execute_internal_event( - maintenance_service: &SchedulerMaintenanceService, + maintenance_executor: &dyn MaintenanceExecutor, job: &RuntimeJob, ) -> anyhow::Result<()> { let event = job @@ -601,24 +646,24 @@ async fn execute_internal_event( match event { "session_cleanup" => { - let removed = maintenance_service.cleanup_expired_sessions().await; + let removed = maintenance_executor.cleanup_expired_sessions().await; tracing::info!(job_id = %job.id, removed, "Scheduler session cleanup completed"); Ok(()) } "memory_maintenance" => { - let results = maintenance_service + let results = maintenance_executor .run_memory_maintenance_for_all_scopes(job.last_fired_at) .await?; for result in &results { tracing::info!( job_id = %job.id, scope_key = %result.scope_key, - user_facts = result.output.user_facts.len(), - preferences = result.output.preferences.len(), - behavior_patterns = result.output.behavior_patterns.len(), - merges = result.output.merges.len(), - conflicts = result.output.conflicts.len(), - low_value = result.output.low_value_ids.len(), + user_facts = result.user_facts, + preferences = result.preferences, + behavior_patterns = result.behavior_patterns, + merges = result.merges, + conflicts = result.conflicts, + low_value = result.low_value, "Scheduler completed memory maintenance model run" ); } @@ -630,7 +675,7 @@ async fn execute_internal_event( } async fn execute_agent_task( - agent_task_executor: &AgentTaskExecutor, + agent_task_executor: &dyn AgentTaskExecutor, job: &RuntimeJob, execution_chat_id: &str, ) -> anyhow::Result> { @@ -649,7 +694,6 @@ async fn execute_agent_task( agent_task_executor .execute(channel_name, execution_chat_id, prompt, options) .await - .map_err(|error| anyhow::anyhow!(error.to_string())) } fn required_notification_chat_id<'a>( @@ -966,52 +1010,44 @@ impl TryFrom for SchedulerJobTarget { mod tests { use super::*; use crate::bus::MessageBus; - use crate::config::{BUILTIN_MEMORY_MAINTENANCE_JOB_ID, LLMProviderConfig}; - use crate::gateway::agent_task_executor::{AgentTaskExecutor, SchedulerMaintenanceService}; - use crate::gateway::session::SessionManager; - use crate::skills::SkillRuntime; + use crate::config::BUILTIN_MEMORY_MAINTENANCE_JOB_ID; use crate::storage::{SchedulerJobUpsert, SessionStore}; - use std::collections::HashMap; - fn test_provider_config() -> LLMProviderConfig { - LLMProviderConfig { - provider_type: "openai".to_string(), - name: "default".to_string(), - base_url: "http://localhost".to_string(), - api_key: "test-key".to_string(), - extra_headers: HashMap::new(), - llm_timeout_secs: 30, - model_id: "test-model".to_string(), - temperature: Some(0.0), - max_tokens: None, - context_window_tokens: None, - model_extra: HashMap::new(), - max_tool_iterations: 4, - tool_result_max_chars: 20_000, - context_tool_result_trim_chars: 20_000, + #[derive(Clone)] + struct TestAgentTaskExecutor; + + #[async_trait::async_trait] + impl AgentTaskExecutor for TestAgentTaskExecutor { + async fn execute( + &self, + _channel_name: &str, + _chat_id: &str, + _prompt: &str, + _options: ScheduledAgentTaskOptions, + ) -> anyhow::Result> { + Ok(Vec::new()) } } - fn test_session_manager() -> SessionManager { - let provider_config = test_provider_config(); - SessionManager::new( - 4, - 100, - false, - "Asia/Shanghai".to_string(), - provider_config.clone(), - HashMap::from([("default".to_string(), provider_config)]), - Arc::new(SkillRuntime::default()), - ) - .unwrap() + #[derive(Clone)] + struct TestMaintenanceExecutor; + + #[async_trait::async_trait] + impl MaintenanceExecutor for TestMaintenanceExecutor { + async fn cleanup_expired_sessions(&self) -> usize { + 0 + } + + async fn run_memory_maintenance_for_all_scopes( + &self, + _updated_since: Option, + ) -> anyhow::Result> { + Ok(Vec::new()) + } } - fn test_scheduler_services() -> (AgentTaskExecutor, SchedulerMaintenanceService) { - let session_manager = test_session_manager(); - ( - AgentTaskExecutor::new(session_manager.clone()), - SchedulerMaintenanceService::new(session_manager), - ) + fn test_scheduler_services() -> (TestAgentTaskExecutor, TestMaintenanceExecutor) { + (TestAgentTaskExecutor, TestMaintenanceExecutor) } #[test] diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 2bdde00..42a0dc7 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -1802,7 +1802,7 @@ fn quote_fts_or_query(queries: &[String]) -> String { mod tests { use super::*; use crate::bus::SYSTEM_CONTEXT_AGENT_PROMPT; - use crate::providers::ToolCall; + use crate::domain::messages::ToolCall; #[test] fn test_persistent_session_id_for_cli_and_channel() {