191 lines
6.1 KiB
Rust

pub mod agent_factory;
pub mod agent_prompt_provider;
pub mod agent_task_executor;
pub mod cli_session;
pub mod command;
pub mod compaction;
pub mod execution;
pub mod http;
pub mod memory_maintenance;
pub mod memory_maintenance_coordinator;
pub mod message_prepare;
pub mod outbound_dispatcher;
pub mod processor;
pub mod prompt;
pub mod provider_config_service;
pub mod runtime;
pub mod scheduled_agent_task_service;
pub mod session;
pub mod session_factory;
pub mod session_history;
pub mod session_lifecycle;
pub mod session_message_sender;
pub mod session_message_service;
pub mod session_pool;
pub mod tool_registry_factory;
pub mod ws;
use axum::{Router, routing};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::net::TcpListener;
use crate::bus::MessageBus;
use crate::channels::ChannelManager;
use crate::config::Config;
use crate::config::LLMProviderConfig;
use crate::logging;
use crate::scheduler::Scheduler;
use crate::skills::SkillRuntime;
use agent_task_executor::{AgentTaskExecutor, SchedulerMaintenanceService};
use outbound_dispatcher::OutboundDispatcher;
use processor::InboundProcessor;
use runtime::build_session_manager_with_sender;
use session_message_sender::BusSessionMessageSender;
use session::SessionManager;
pub struct GatewayState {
pub config: Config,
pub session_manager: SessionManager,
pub channel_manager: ChannelManager,
pub bus: Arc<MessageBus>,
}
impl GatewayState {
pub fn from_config(config: Config) -> Result<Self, Box<dyn std::error::Error>> {
// Get provider config for SessionManager
let provider_config = config.get_provider_config("default")?;
let mut provider_configs = HashMap::<String, LLMProviderConfig>::new();
for agent_name in config.agents.keys() {
provider_configs.insert(agent_name.clone(), config.get_provider_config(agent_name)?);
}
// Chat history TTL from config (default 4 hours)
let chat_history_ttl_hours = config.gateway.chat_history_ttl_hours;
let agent_prompt_reinject_every = config.gateway.agent_prompt_reinject_every;
let show_tool_results = config.gateway.show_tool_results;
let skills = Arc::new(SkillRuntime::from_config(config.skills.clone()));
let channel_manager = ChannelManager::new();
let bus = channel_manager.bus();
let session_manager = build_session_manager_with_sender(
agent_prompt_reinject_every,
show_tool_results,
config.time.timezone.clone(),
provider_config,
provider_configs,
skills,
Arc::new(BusSessionMessageSender::new(bus.clone())),
std::collections::HashSet::new(),
chat_history_ttl_hours,
)?;
Ok(Self {
config,
session_manager,
channel_manager,
bus,
})
}
/// Start the message processing loops
pub async fn start_message_processing(&self) {
let bus_for_outbound = self.bus.clone();
let inbound_processor =
InboundProcessor::new(self.bus.clone(), self.session_manager.clone());
tokio::spawn(inbound_processor.run());
// Spawn outbound dispatcher
let dispatcher = OutboundDispatcher::new(bus_for_outbound);
let channel_manager = self.channel_manager.clone();
for (name, channel) in channel_manager.channels().await {
dispatcher.register_channel(&name, channel).await;
}
tokio::spawn(async move {
tracing::info!("Outbound dispatcher started");
dispatcher.run().await;
});
}
}
pub async fn run(
host: Option<String>,
port: Option<u16>,
) -> Result<(), Box<dyn std::error::Error>> {
let config = Config::load_default()?;
let timezone = config.time.parse_timezone()?;
// Initialize logging
logging::init_logging(timezone);
tracing::info!("Starting PicoBot Gateway");
let state = Arc::new(GatewayState::from_config(config)?);
// Get provider config for channels
let provider_config = state.config.get_provider_config("default")?;
// Initialize and start channels
state
.channel_manager
.init(&state.config, provider_config.clone())
.await?;
state.channel_manager.start_all().await?;
// Start message processing (inbound processor + outbound dispatcher)
state.start_message_processing().await;
let (scheduler_shutdown_tx, scheduler_shutdown_rx) = tokio::sync::watch::channel(false);
if state.config.scheduler.enabled {
let scheduler = Scheduler::new(
state.bus.clone(),
state.config.scheduler.clone(),
timezone,
state.session_manager.store(),
AgentTaskExecutor::new(state.session_manager.clone()),
SchedulerMaintenanceService::new(state.session_manager.clone()),
);
tokio::spawn(async move {
scheduler.run(scheduler_shutdown_rx).await;
});
}
// CLI args override config file values
let bind_host = host.unwrap_or_else(|| state.config.gateway.host.clone());
let bind_port = port.unwrap_or(state.config.gateway.port);
let app = Router::new()
.route("/health", routing::get(http::health))
.route("/ws", routing::get(ws::ws_handler))
.with_state(state.clone());
let addr = format!("{}:{}", bind_host, bind_port);
let listener = TcpListener::bind(&addr).await?;
tracing::info!(address = %addr, "Gateway listening");
// Graceful shutdown using oneshot channel
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
let channel_manager = state.channel_manager.clone();
// Spawn ctrl_c handler
tokio::spawn(async move {
tokio::signal::ctrl_c().await.ok();
tracing::info!("Shutdown signal received");
let _ = scheduler_shutdown_tx.send(true);
let _ = channel_manager.stop_all().await;
let _ = shutdown_tx.send(());
});
// Serve with graceful shutdown
axum::serve(listener, app)
.with_graceful_shutdown(async {
shutdown_rx.await.ok();
})
.await?;
Ok(())
}