211 lines
7.7 KiB
Rust

pub mod http;
pub mod session;
pub mod ws;
use std::collections::HashMap;
use std::sync::Arc;
use axum::{routing, Router};
use tokio::net::TcpListener;
use crate::bus::{MessageBus, OutboundDispatcher};
use crate::channels::ChannelManager;
use crate::config::Config;
use crate::config::LLMProviderConfig;
use crate::logging;
use crate::scheduler::Scheduler;
use crate::skills::SkillRuntime;
use session::{BusToolCallEmitter, SessionManager};
pub struct GatewayState {
pub config: Config,
pub session_manager: SessionManager,
pub channel_manager: ChannelManager,
pub bus: Arc<MessageBus>,
}
impl GatewayState {
pub fn from_config(config: Config) -> Result<Self, Box<dyn std::error::Error>> {
// Get provider config for SessionManager
let provider_config = config.get_provider_config("default")?;
let mut provider_configs = HashMap::<String, LLMProviderConfig>::new();
for agent_name in config.agents.keys() {
provider_configs.insert(agent_name.clone(), config.get_provider_config(agent_name)?);
}
// Session TTL from config (default 4 hours)
let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4);
let agent_prompt_reinject_every = config.gateway.agent_prompt_reinject_every;
let show_tool_results = config.gateway.show_tool_results;
let skills = Arc::new(SkillRuntime::from_config(config.skills.clone()));
let session_manager = SessionManager::new(
session_ttl_hours,
agent_prompt_reinject_every,
show_tool_results,
config.time.timezone.clone(),
provider_config,
provider_configs,
skills,
)?;
let channel_manager = ChannelManager::new();
let bus = channel_manager.bus();
Ok(Self {
config,
session_manager,
channel_manager,
bus,
})
}
/// Start the message processing loops
pub async fn start_message_processing(&self) {
let bus_for_inbound = self.bus.clone();
let bus_for_outbound = self.bus.clone();
let session_manager = self.session_manager.clone();
// Spawn inbound message processor
// This consumes from bus.inbound, processes via SessionManager, publishes to bus.outbound
tokio::spawn(async move {
tracing::info!("Inbound processor started");
loop {
let inbound = bus_for_inbound.consume_inbound().await;
#[cfg(debug_assertions)]
{
tracing::debug!(
channel = %inbound.channel,
chat_id = %inbound.chat_id,
sender = %inbound.sender_id,
content = %inbound.content,
media_count = %inbound.media.len(),
"Processing inbound message"
);
if !inbound.media.is_empty() {
for (i, m) in inbound.media.iter().enumerate() {
tracing::debug!(media_index = i, media_type = %m.media_type, path = %m.path, "Media item");
}
}
}
// Process via session manager
let live_emitter = Arc::new(BusToolCallEmitter::new(
bus_for_inbound.clone(),
inbound.channel.clone(),
inbound.chat_id.clone(),
inbound.forwarded_metadata.clone(),
session_manager.show_tool_results(),
));
match session_manager.handle_message(
&inbound.channel,
&inbound.sender_id,
&inbound.chat_id,
&inbound.content,
inbound.media,
Some(live_emitter),
).await {
Ok(outbound_messages) => {
// Forward channel-specific metadata from inbound to outbound.
// This allows channels to propagate context (e.g. feishu message_id for reaction cleanup)
// without gateway needing channel-specific code.
for mut outbound in outbound_messages {
outbound.metadata.extend(inbound.forwarded_metadata.clone());
if let Err(e) = bus_for_inbound.publish_outbound(outbound).await {
tracing::error!(error = %e, "Failed to publish outbound");
}
}
}
Err(e) => {
tracing::error!(error = %e, "Failed to handle message");
}
}
}
});
// Spawn outbound dispatcher
let dispatcher = OutboundDispatcher::new(bus_for_outbound);
let channel_manager = self.channel_manager.clone();
// Register channels with dispatcher
if let Some(channel) = channel_manager.get_channel("feishu").await {
dispatcher.register_channel("feishu", channel).await;
}
tokio::spawn(async move {
tracing::info!("Outbound dispatcher started");
dispatcher.run().await;
});
}
}
pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn std::error::Error>> {
let config = Config::load_default()?;
let timezone = config.time.parse_timezone()?;
// Initialize logging
logging::init_logging(timezone);
tracing::info!("Starting PicoBot Gateway");
let state = Arc::new(GatewayState::from_config(config)?);
// Get provider config for channels
let provider_config = state.config.get_provider_config("default")?;
// Initialize and start channels
state.channel_manager.init(&state.config, provider_config.clone()).await?;
state.channel_manager.start_all().await?;
// Start message processing (inbound processor + outbound dispatcher)
state.start_message_processing().await;
let (scheduler_shutdown_tx, scheduler_shutdown_rx) = tokio::sync::watch::channel(false);
if state.config.scheduler.enabled {
let scheduler = Scheduler::new(
state.bus.clone(),
state.config.scheduler.clone(),
timezone,
state.session_manager.store(),
state.session_manager.clone(),
);
tokio::spawn(async move {
scheduler.run(scheduler_shutdown_rx).await;
});
}
// CLI args override config file values
let bind_host = host.unwrap_or_else(|| state.config.gateway.host.clone());
let bind_port = port.unwrap_or(state.config.gateway.port);
let app = Router::new()
.route("/health", routing::get(http::health))
.route("/ws", routing::get(ws::ws_handler))
.with_state(state.clone());
let addr = format!("{}:{}", bind_host, bind_port);
let listener = TcpListener::bind(&addr).await?;
tracing::info!(address = %addr, "Gateway listening");
// Graceful shutdown using oneshot channel
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
let channel_manager = state.channel_manager.clone();
// Spawn ctrl_c handler
tokio::spawn(async move {
tokio::signal::ctrl_c().await.ok();
tracing::info!("Shutdown signal received");
let _ = scheduler_shutdown_tx.send(true);
let _ = channel_manager.stop_all().await;
let _ = shutdown_tx.send(());
});
// Serve with graceful shutdown
axum::serve(listener, app)
.with_graceful_shutdown(async {
shutdown_rx.await.ok();
})
.await?;
Ok(())
}