diff --git a/src/gateway/execution.rs b/src/gateway/execution.rs index 1988bbc..171c994 100644 --- a/src/gateway/execution.rs +++ b/src/gateway/execution.rs @@ -27,19 +27,6 @@ pub(crate) fn compose_scheduled_task_system_prompt(system_prompt: Option<&str>) } } -pub(crate) fn select_provider_config( - default_provider_config: &LLMProviderConfig, - provider_configs: &HashMap, - agent_name: Option<&str>, -) -> Result { - match agent_name.map(str::trim).filter(|value| !value.is_empty()) { - None | Some("default") => Ok(default_provider_config.clone()), - Some(agent_name) => provider_configs.get(agent_name).cloned().ok_or_else(|| { - AgentError::Other(format!("Scheduled agent '{}' not found", agent_name)) - }), - } -} - pub(crate) struct AgentExecutionService { show_tool_results: bool, } @@ -325,50 +312,6 @@ mod tests { use super::*; use crate::bus::ChatMessage; - fn test_provider_config_named(name: &str, model_id: &str) -> LLMProviderConfig { - LLMProviderConfig { - provider_type: "openai".to_string(), - name: name.to_string(), - base_url: "http://localhost".to_string(), - api_key: "test-key".to_string(), - extra_headers: HashMap::new(), - llm_timeout_secs: 120, - model_id: model_id.to_string(), - temperature: Some(0.0), - max_tokens: Some(32), - context_window_tokens: None, - model_extra: HashMap::new(), - max_tool_iterations: 1, - tool_result_max_chars: 20_000, - context_tool_result_trim_chars: 20_000, - } - } - - #[test] - fn test_select_provider_config_uses_named_agent_override() { - let default_provider = test_provider_config_named("default-provider", "default-model"); - let provider_configs = HashMap::from([( - "planner".to_string(), - test_provider_config_named("planner-provider", "planner-model"), - )]); - - let selected = - select_provider_config(&default_provider, &provider_configs, Some("planner")).unwrap(); - assert_eq!(selected.name, "planner-provider"); - assert_eq!(selected.model_id, "planner-model"); - } - - #[test] - fn test_select_provider_config_falls_back_to_default() { - let default_provider = test_provider_config_named("default-provider", "default-model"); - let provider_configs = HashMap::new(); - - let selected = - select_provider_config(&default_provider, &provider_configs, Some("default")).unwrap(); - assert_eq!(selected.name, "default-provider"); - assert_eq!(selected.model_id, "default-model"); - } - #[test] fn test_compose_scheduled_task_system_prompt_appends_task_specific_prompt() { let prompt = compose_scheduled_task_system_prompt(Some(" 只汇报异常 ")); diff --git a/src/gateway/memory_maintenance_coordinator.rs b/src/gateway/memory_maintenance_coordinator.rs new file mode 100644 index 0000000..7e22c5f --- /dev/null +++ b/src/gateway/memory_maintenance_coordinator.rs @@ -0,0 +1,46 @@ +use std::sync::Arc; + +use crate::agent::AgentError; +use crate::storage::SessionStore; + +use super::memory_maintenance::{ + MemoryMaintenanceModelOutput, MemoryMaintenanceScopeResult, MemoryMaintenanceService, +}; +use super::provider_config_service::ProviderConfigService; + +#[derive(Clone)] +pub(crate) struct MemoryMaintenanceCoordinator { + store: Arc, + provider_configs: ProviderConfigService, +} + +impl MemoryMaintenanceCoordinator { + pub(crate) fn new(store: Arc, provider_configs: ProviderConfigService) -> Self { + Self { + store, + provider_configs, + } + } + + #[cfg_attr(not(test), allow(dead_code))] + pub(crate) async fn summarize_for_scope( + &self, + scope_key: &str, + ) -> Result, AgentError> { + self.service()?.summarize_for_scope(scope_key).await + } + + pub(crate) async fn run_for_all_scopes( + &self, + updated_since: Option, + ) -> Result, AgentError> { + self.service()?.run_for_all_scopes(updated_since).await + } + + fn service(&self) -> Result { + Ok(MemoryMaintenanceService::new( + self.store.clone(), + self.provider_configs.default_provider_config(), + )) + } +} diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index b131307..1d55299 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -6,13 +6,17 @@ pub mod compaction; pub mod execution; pub mod http; pub mod memory_maintenance; +pub mod memory_maintenance_coordinator; pub mod message_prepare; pub mod processor; pub mod prompt; pub mod prompt_injector; +pub mod provider_config_service; +pub mod scheduled_agent_task_service; pub mod session; pub mod session_factory; pub mod session_lifecycle; +pub mod session_message_service; pub mod session_pool; pub mod tool_registry_factory; pub mod ws; diff --git a/src/gateway/provider_config_service.rs b/src/gateway/provider_config_service.rs new file mode 100644 index 0000000..757a913 --- /dev/null +++ b/src/gateway/provider_config_service.rs @@ -0,0 +1,90 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use crate::agent::AgentError; +use crate::config::LLMProviderConfig; + +#[derive(Clone)] +pub(crate) struct ProviderConfigService { + default_provider_config: LLMProviderConfig, + provider_configs: Arc>, +} + +impl ProviderConfigService { + pub(crate) fn new( + default_provider_config: LLMProviderConfig, + provider_configs: HashMap, + ) -> Self { + Self { + default_provider_config, + provider_configs: Arc::new(provider_configs), + } + } + + pub(crate) fn select(&self, agent_name: Option<&str>) -> Result { + match agent_name.map(str::trim).filter(|value| !value.is_empty()) { + None | Some("default") => Ok(self.default_provider_config.clone()), + Some(agent_name) => self + .provider_configs + .get(agent_name) + .cloned() + .ok_or_else(|| { + AgentError::Other(format!("Scheduled agent '{}' not found", agent_name)) + }), + } + } + + pub(crate) fn default_provider_config(&self) -> LLMProviderConfig { + self.default_provider_config.clone() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_provider_config_named(name: &str, model_id: &str) -> LLMProviderConfig { + LLMProviderConfig { + provider_type: "openai".to_string(), + name: name.to_string(), + base_url: "http://localhost".to_string(), + api_key: "test-key".to_string(), + extra_headers: HashMap::new(), + llm_timeout_secs: 120, + model_id: model_id.to_string(), + temperature: Some(0.0), + max_tokens: Some(32), + context_window_tokens: None, + model_extra: HashMap::new(), + max_tool_iterations: 1, + tool_result_max_chars: 20_000, + context_tool_result_trim_chars: 20_000, + } + } + + #[test] + fn test_select_uses_named_agent_override() { + let default_provider = test_provider_config_named("default-provider", "default-model"); + let service = ProviderConfigService::new( + default_provider, + HashMap::from([( + "planner".to_string(), + test_provider_config_named("planner-provider", "planner-model"), + )]), + ); + + let selected = service.select(Some("planner")).unwrap(); + assert_eq!(selected.name, "planner-provider"); + assert_eq!(selected.model_id, "planner-model"); + } + + #[test] + fn test_select_falls_back_to_default() { + let default_provider = test_provider_config_named("default-provider", "default-model"); + let service = ProviderConfigService::new(default_provider, HashMap::new()); + + let selected = service.select(Some("default")).unwrap(); + assert_eq!(selected.name, "default-provider"); + assert_eq!(selected.model_id, "default-model"); + } +} diff --git a/src/gateway/scheduled_agent_task_service.rs b/src/gateway/scheduled_agent_task_service.rs new file mode 100644 index 0000000..733a195 --- /dev/null +++ b/src/gateway/scheduled_agent_task_service.rs @@ -0,0 +1,57 @@ +use crate::agent::AgentError; +use crate::bus::OutboundMessage; +use crate::scheduler::ScheduledAgentTaskOptions; + +use super::execution::{AgentExecutionService, ScheduledExecutionRequest}; +use super::provider_config_service::ProviderConfigService; +use super::session_lifecycle::SessionLifecycleService; + +#[derive(Clone)] +pub(crate) struct ScheduledAgentTaskService { + lifecycle: SessionLifecycleService, + provider_configs: ProviderConfigService, + show_tool_results: bool, +} + +impl ScheduledAgentTaskService { + pub(crate) fn new( + lifecycle: SessionLifecycleService, + provider_configs: ProviderConfigService, + show_tool_results: bool, + ) -> Self { + Self { + lifecycle, + provider_configs, + show_tool_results, + } + } + + pub(crate) async fn run( + &self, + channel_name: &str, + chat_id: &str, + prompt: &str, + options: ScheduledAgentTaskOptions, + ) -> Result, AgentError> { + let session = self.lifecycle.active_session(channel_name).await?; + let sender_id = options + .sender_id + .clone() + .unwrap_or_else(|| "scheduler".to_string()); + let provider_config = self.provider_configs.select(options.agent.as_deref())?; + + AgentExecutionService::new(self.show_tool_results) + .prepare_and_execute_scheduled_task(ScheduledExecutionRequest { + session, + channel_name, + chat_id, + prompt, + sender_id: &sender_id, + provider_config, + fresh_session: options.fresh_session, + system_prompt: options.system_prompt.as_deref(), + metadata: &options.metadata, + }) + .await + } +} diff --git a/src/gateway/session.rs b/src/gateway/session.rs index 09b9d07..8ba98bc 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -16,22 +16,21 @@ use uuid::Uuid; use super::agent_factory::{AgentBuildRequest, AgentFactory}; use super::cli_session::CliSessionService; -use super::execution::{ - AgentExecutionService, MessageExecutionRequest, ScheduledExecutionRequest, - select_provider_config, should_display_message_to_user, -}; +use super::execution::should_display_message_to_user; #[cfg(test)] use super::memory_maintenance::{ MemoryMaintenanceMerge, apply_memory_maintenance_output, build_memory_maintenance_plan, combine_managed_memory_markdown, extract_json_object, is_recoverable_maintenance_llm_error, strip_json_code_fence, }; -use super::memory_maintenance::{ - MemoryMaintenanceModelOutput, MemoryMaintenanceScopeResult, MemoryMaintenanceService, -}; +use super::memory_maintenance::{MemoryMaintenanceModelOutput, MemoryMaintenanceScopeResult}; +use super::memory_maintenance_coordinator::MemoryMaintenanceCoordinator; use super::prompt_injector::PromptInjector; +use super::provider_config_service::ProviderConfigService; +use super::scheduled_agent_task_service::ScheduledAgentTaskService; use super::session_factory::SessionFactory; use super::session_lifecycle::SessionLifecycleService; +use super::session_message_service::SessionMessageService; use super::tool_registry_factory::ToolRegistryFactory; fn preview_text(content: &str, max_chars: usize) -> String { @@ -449,14 +448,15 @@ impl Session { /// SessionManager 管理所有 Session,按 channel_name 路由 #[derive(Clone)] pub struct SessionManager { - provider_config: LLMProviderConfig, - provider_configs: Arc>, tools: Arc, skills: Arc, store: Arc, show_tool_results: bool, lifecycle: SessionLifecycleService, cli_sessions: CliSessionService, + messages: SessionMessageService, + scheduled_tasks: ScheduledAgentTaskService, + memory_maintenance: MemoryMaintenanceCoordinator, } impl SessionManager { @@ -474,6 +474,8 @@ impl SessionManager { .map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?, ); let known_agents = provider_configs.keys().cloned().collect::>(); + let provider_configs = + ProviderConfigService::new(provider_config.clone(), provider_configs); if let Err(err) = store.append_skill_event(None, "discovered", None, &skills.discovery_event_payload()) @@ -501,16 +503,25 @@ impl SessionManager { ); let lifecycle = SessionLifecycleService::new(session_ttl_hours, session_factory); let cli_sessions = CliSessionService::new(store.clone()); + let messages = SessionMessageService::new(lifecycle.clone(), show_tool_results); + let scheduled_tasks = ScheduledAgentTaskService::new( + lifecycle.clone(), + provider_configs.clone(), + show_tool_results, + ); + let memory_maintenance = + MemoryMaintenanceCoordinator::new(store.clone(), provider_configs.clone()); Ok(Self { - provider_config, - provider_configs: Arc::new(provider_configs), tools, skills, store, show_tool_results, lifecycle, cli_sessions, + messages, + scheduled_tasks, + memory_maintenance, }) } @@ -539,34 +550,18 @@ impl SessionManager { &self, scope_key: &str, ) -> Result, AgentError> { - self.memory_maintenance_service()? - .summarize_for_scope(scope_key) - .await + self.memory_maintenance.summarize_for_scope(scope_key).await } pub(crate) async fn run_memory_maintenance_for_all_scopes( &self, updated_since: Option, ) -> Result, AgentError> { - self.memory_maintenance_service()? + self.memory_maintenance .run_for_all_scopes(updated_since) .await } - fn memory_maintenance_service(&self) -> Result { - Ok(MemoryMaintenanceService::new( - self.store.clone(), - self.provider_config_for_agent(None)?, - )) - } - - pub fn provider_config_for_agent( - &self, - agent_name: Option<&str>, - ) -> Result { - select_provider_config(&self.provider_config, &self.provider_configs, agent_name) - } - /// 确保 session 存在且未超时,超时则重建 pub async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> { self.lifecycle.ensure_session(channel_name).await @@ -596,43 +591,16 @@ impl SessionManager { media: Vec, live_emitter: Option>, ) -> Result, AgentError> { - #[cfg(debug_assertions)] - { - tracing::debug!( - channel = %channel_name, - chat_id = %chat_id, - content_len = content.len(), - media_count = %media.len(), - "Routing message to agent" - ); - for (i, m) in media.iter().enumerate() { - tracing::debug!(media_index = i, media_type = %m.media_type, path = %m.path, "Media in handle_message"); - } - } - - let session = self.lifecycle.active_session(channel_name).await?; - - let outbound_messages = AgentExecutionService::new(self.show_tool_results) - .prepare_and_execute_message(MessageExecutionRequest { - session: session.clone(), + self.messages + .handle_message( channel_name, sender_id, chat_id, content, media, live_emitter, - }) - .await?; - - #[cfg(debug_assertions)] - tracing::debug!( - channel = %channel_name, - chat_id = %chat_id, - outbound_count = outbound_messages.len(), - "Agent response sequence received" - ); - - Ok(outbound_messages) + ) + .await } pub async fn run_scheduled_agent_task( @@ -642,26 +610,8 @@ impl SessionManager { prompt: &str, options: ScheduledAgentTaskOptions, ) -> Result, AgentError> { - let session = self.lifecycle.active_session(channel_name).await?; - - let sender_id = options - .sender_id - .clone() - .unwrap_or_else(|| "scheduler".to_string()); - let provider_config = self.provider_config_for_agent(options.agent.as_deref())?; - - AgentExecutionService::new(self.show_tool_results) - .prepare_and_execute_scheduled_task(ScheduledExecutionRequest { - session: session.clone(), - channel_name, - chat_id, - prompt, - sender_id: &sender_id, - provider_config, - fresh_session: options.fresh_session, - system_prompt: options.system_prompt.as_deref(), - metadata: &options.metadata, - }) + self.scheduled_tasks + .run(channel_name, chat_id, prompt, options) .await } diff --git a/src/gateway/session_message_service.rs b/src/gateway/session_message_service.rs new file mode 100644 index 0000000..45473df --- /dev/null +++ b/src/gateway/session_message_service.rs @@ -0,0 +1,69 @@ +use std::sync::Arc; + +use crate::agent::{AgentError, EmittedMessageHandler}; +use crate::bus::{MediaItem, OutboundMessage}; + +use super::execution::{AgentExecutionService, MessageExecutionRequest}; +use super::session_lifecycle::SessionLifecycleService; + +#[derive(Clone)] +pub(crate) struct SessionMessageService { + lifecycle: SessionLifecycleService, + show_tool_results: bool, +} + +impl SessionMessageService { + pub(crate) fn new(lifecycle: SessionLifecycleService, show_tool_results: bool) -> Self { + Self { + lifecycle, + show_tool_results, + } + } + + pub(crate) async fn handle_message( + &self, + channel_name: &str, + sender_id: &str, + chat_id: &str, + content: &str, + media: Vec, + live_emitter: Option>, + ) -> Result, AgentError> { + #[cfg(debug_assertions)] + { + tracing::debug!( + channel = %channel_name, + chat_id = %chat_id, + content_len = content.len(), + media_count = %media.len(), + "Routing message to agent" + ); + for (i, m) in media.iter().enumerate() { + tracing::debug!(media_index = i, media_type = %m.media_type, path = %m.path, "Media in handle_message"); + } + } + + let session = self.lifecycle.active_session(channel_name).await?; + let outbound_messages = AgentExecutionService::new(self.show_tool_results) + .prepare_and_execute_message(MessageExecutionRequest { + session, + channel_name, + sender_id, + chat_id, + content, + media, + live_emitter, + }) + .await?; + + #[cfg(debug_assertions)] + tracing::debug!( + channel = %channel_name, + chat_id = %chat_id, + outbound_count = outbound_messages.len(), + "Agent response sequence received" + ); + + Ok(outbound_messages) + } +}