2026-04-28 12:08:34 +08:00

168 lines
5.2 KiB
Rust

pub mod cli_session;
pub mod execution;
pub mod http;
pub mod memory_maintenance;
pub mod processor;
pub mod prompt;
pub mod session;
pub mod session_factory;
pub mod session_pool;
pub mod ws;
use axum::{Router, routing};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::net::TcpListener;
use crate::bus::{MessageBus, OutboundDispatcher};
use crate::channels::ChannelManager;
use crate::config::Config;
use crate::config::LLMProviderConfig;
use crate::logging;
use crate::scheduler::Scheduler;
use crate::skills::SkillRuntime;
use processor::InboundProcessor;
use session::SessionManager;
pub struct GatewayState {
pub config: Config,
pub session_manager: SessionManager,
pub channel_manager: ChannelManager,
pub bus: Arc<MessageBus>,
}
impl GatewayState {
pub fn from_config(config: Config) -> Result<Self, Box<dyn std::error::Error>> {
// Get provider config for SessionManager
let provider_config = config.get_provider_config("default")?;
let mut provider_configs = HashMap::<String, LLMProviderConfig>::new();
for agent_name in config.agents.keys() {
provider_configs.insert(agent_name.clone(), config.get_provider_config(agent_name)?);
}
// Session TTL from config (default 4 hours)
let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4);
let agent_prompt_reinject_every = config.gateway.agent_prompt_reinject_every;
let show_tool_results = config.gateway.show_tool_results;
let skills = Arc::new(SkillRuntime::from_config(config.skills.clone()));
let session_manager = SessionManager::new(
session_ttl_hours,
agent_prompt_reinject_every,
show_tool_results,
config.time.timezone.clone(),
provider_config,
provider_configs,
skills,
)?;
let channel_manager = ChannelManager::new();
let bus = channel_manager.bus();
Ok(Self {
config,
session_manager,
channel_manager,
bus,
})
}
/// Start the message processing loops
pub async fn start_message_processing(&self) {
let bus_for_outbound = self.bus.clone();
let inbound_processor =
InboundProcessor::new(self.bus.clone(), self.session_manager.clone());
tokio::spawn(inbound_processor.run());
// Spawn outbound dispatcher
let dispatcher = OutboundDispatcher::new(bus_for_outbound);
let channel_manager = self.channel_manager.clone();
for (name, channel) in channel_manager.channels().await {
dispatcher.register_channel(&name, channel).await;
}
tokio::spawn(async move {
tracing::info!("Outbound dispatcher started");
dispatcher.run().await;
});
}
}
pub async fn run(
host: Option<String>,
port: Option<u16>,
) -> Result<(), Box<dyn std::error::Error>> {
let config = Config::load_default()?;
let timezone = config.time.parse_timezone()?;
// Initialize logging
logging::init_logging(timezone);
tracing::info!("Starting PicoBot Gateway");
let state = Arc::new(GatewayState::from_config(config)?);
// Get provider config for channels
let provider_config = state.config.get_provider_config("default")?;
// Initialize and start channels
state
.channel_manager
.init(&state.config, provider_config.clone())
.await?;
state.channel_manager.start_all().await?;
// Start message processing (inbound processor + outbound dispatcher)
state.start_message_processing().await;
let (scheduler_shutdown_tx, scheduler_shutdown_rx) = tokio::sync::watch::channel(false);
if state.config.scheduler.enabled {
let scheduler = Scheduler::new(
state.bus.clone(),
state.config.scheduler.clone(),
timezone,
state.session_manager.store(),
state.session_manager.clone(),
);
tokio::spawn(async move {
scheduler.run(scheduler_shutdown_rx).await;
});
}
// CLI args override config file values
let bind_host = host.unwrap_or_else(|| state.config.gateway.host.clone());
let bind_port = port.unwrap_or(state.config.gateway.port);
let app = Router::new()
.route("/health", routing::get(http::health))
.route("/ws", routing::get(ws::ws_handler))
.with_state(state.clone());
let addr = format!("{}:{}", bind_host, bind_port);
let listener = TcpListener::bind(&addr).await?;
tracing::info!(address = %addr, "Gateway listening");
// Graceful shutdown using oneshot channel
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
let channel_manager = state.channel_manager.clone();
// Spawn ctrl_c handler
tokio::spawn(async move {
tokio::signal::ctrl_c().await.ok();
tracing::info!("Shutdown signal received");
let _ = scheduler_shutdown_tx.send(true);
let _ = channel_manager.stop_all().await;
let _ = shutdown_tx.send(());
});
// Serve with graceful shutdown
axum::serve(listener, app)
.with_graceful_shutdown(async {
shutdown_rx.await.ok();
})
.await?;
Ok(())
}