211 lines
7.7 KiB
Rust
211 lines
7.7 KiB
Rust
pub mod http;
|
|
pub mod session;
|
|
pub mod ws;
|
|
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use axum::{routing, Router};
|
|
use tokio::net::TcpListener;
|
|
|
|
use crate::bus::{MessageBus, OutboundDispatcher};
|
|
use crate::channels::ChannelManager;
|
|
use crate::config::Config;
|
|
use crate::config::LLMProviderConfig;
|
|
use crate::logging;
|
|
use crate::scheduler::Scheduler;
|
|
use crate::skills::SkillRuntime;
|
|
use session::{BusToolCallEmitter, SessionManager};
|
|
|
|
pub struct GatewayState {
|
|
pub config: Config,
|
|
pub session_manager: SessionManager,
|
|
pub channel_manager: ChannelManager,
|
|
pub bus: Arc<MessageBus>,
|
|
}
|
|
|
|
impl GatewayState {
|
|
pub fn from_config(config: Config) -> Result<Self, Box<dyn std::error::Error>> {
|
|
// Get provider config for SessionManager
|
|
let provider_config = config.get_provider_config("default")?;
|
|
let mut provider_configs = HashMap::<String, LLMProviderConfig>::new();
|
|
for agent_name in config.agents.keys() {
|
|
provider_configs.insert(agent_name.clone(), config.get_provider_config(agent_name)?);
|
|
}
|
|
|
|
// Session TTL from config (default 4 hours)
|
|
let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4);
|
|
let agent_prompt_reinject_every = config.gateway.agent_prompt_reinject_every;
|
|
let show_tool_results = config.gateway.show_tool_results;
|
|
|
|
let skills = Arc::new(SkillRuntime::from_config(config.skills.clone()));
|
|
|
|
let session_manager = SessionManager::new(
|
|
session_ttl_hours,
|
|
agent_prompt_reinject_every,
|
|
show_tool_results,
|
|
config.time.timezone.clone(),
|
|
provider_config,
|
|
provider_configs,
|
|
skills,
|
|
)?;
|
|
let channel_manager = ChannelManager::new();
|
|
let bus = channel_manager.bus();
|
|
|
|
Ok(Self {
|
|
config,
|
|
session_manager,
|
|
channel_manager,
|
|
bus,
|
|
})
|
|
}
|
|
|
|
/// Start the message processing loops
|
|
pub async fn start_message_processing(&self) {
|
|
let bus_for_inbound = self.bus.clone();
|
|
let bus_for_outbound = self.bus.clone();
|
|
let session_manager = self.session_manager.clone();
|
|
|
|
// Spawn inbound message processor
|
|
// This consumes from bus.inbound, processes via SessionManager, publishes to bus.outbound
|
|
tokio::spawn(async move {
|
|
tracing::info!("Inbound processor started");
|
|
loop {
|
|
let inbound = bus_for_inbound.consume_inbound().await;
|
|
#[cfg(debug_assertions)]
|
|
{
|
|
tracing::debug!(
|
|
channel = %inbound.channel,
|
|
chat_id = %inbound.chat_id,
|
|
sender = %inbound.sender_id,
|
|
content = %inbound.content,
|
|
media_count = %inbound.media.len(),
|
|
"Processing inbound message"
|
|
);
|
|
if !inbound.media.is_empty() {
|
|
for (i, m) in inbound.media.iter().enumerate() {
|
|
tracing::debug!(media_index = i, media_type = %m.media_type, path = %m.path, "Media item");
|
|
}
|
|
}
|
|
}
|
|
|
|
// Process via session manager
|
|
let live_emitter = Arc::new(BusToolCallEmitter::new(
|
|
bus_for_inbound.clone(),
|
|
inbound.channel.clone(),
|
|
inbound.chat_id.clone(),
|
|
inbound.forwarded_metadata.clone(),
|
|
session_manager.show_tool_results(),
|
|
));
|
|
match session_manager.handle_message(
|
|
&inbound.channel,
|
|
&inbound.sender_id,
|
|
&inbound.chat_id,
|
|
&inbound.content,
|
|
inbound.media,
|
|
Some(live_emitter),
|
|
).await {
|
|
Ok(outbound_messages) => {
|
|
// Forward channel-specific metadata from inbound to outbound.
|
|
// This allows channels to propagate context (e.g. feishu message_id for reaction cleanup)
|
|
// without gateway needing channel-specific code.
|
|
for mut outbound in outbound_messages {
|
|
outbound.metadata.extend(inbound.forwarded_metadata.clone());
|
|
if let Err(e) = bus_for_inbound.publish_outbound(outbound).await {
|
|
tracing::error!(error = %e, "Failed to publish outbound");
|
|
}
|
|
}
|
|
}
|
|
Err(e) => {
|
|
tracing::error!(error = %e, "Failed to handle message");
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
// Spawn outbound dispatcher
|
|
let dispatcher = OutboundDispatcher::new(bus_for_outbound);
|
|
let channel_manager = self.channel_manager.clone();
|
|
|
|
// Register channels with dispatcher
|
|
if let Some(channel) = channel_manager.get_channel("feishu").await {
|
|
dispatcher.register_channel("feishu", channel).await;
|
|
}
|
|
|
|
tokio::spawn(async move {
|
|
tracing::info!("Outbound dispatcher started");
|
|
dispatcher.run().await;
|
|
});
|
|
}
|
|
}
|
|
|
|
pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn std::error::Error>> {
|
|
let config = Config::load_default()?;
|
|
let timezone = config.time.parse_timezone()?;
|
|
|
|
// Initialize logging
|
|
logging::init_logging(timezone);
|
|
tracing::info!("Starting PicoBot Gateway");
|
|
|
|
let state = Arc::new(GatewayState::from_config(config)?);
|
|
|
|
// Get provider config for channels
|
|
let provider_config = state.config.get_provider_config("default")?;
|
|
|
|
// Initialize and start channels
|
|
state.channel_manager.init(&state.config, provider_config.clone()).await?;
|
|
state.channel_manager.start_all().await?;
|
|
|
|
// Start message processing (inbound processor + outbound dispatcher)
|
|
state.start_message_processing().await;
|
|
|
|
let (scheduler_shutdown_tx, scheduler_shutdown_rx) = tokio::sync::watch::channel(false);
|
|
if state.config.scheduler.enabled {
|
|
let scheduler = Scheduler::new(
|
|
state.bus.clone(),
|
|
state.config.scheduler.clone(),
|
|
timezone,
|
|
state.session_manager.store(),
|
|
state.session_manager.clone(),
|
|
);
|
|
|
|
tokio::spawn(async move {
|
|
scheduler.run(scheduler_shutdown_rx).await;
|
|
});
|
|
}
|
|
|
|
// CLI args override config file values
|
|
let bind_host = host.unwrap_or_else(|| state.config.gateway.host.clone());
|
|
let bind_port = port.unwrap_or(state.config.gateway.port);
|
|
|
|
let app = Router::new()
|
|
.route("/health", routing::get(http::health))
|
|
.route("/ws", routing::get(ws::ws_handler))
|
|
.with_state(state.clone());
|
|
|
|
let addr = format!("{}:{}", bind_host, bind_port);
|
|
let listener = TcpListener::bind(&addr).await?;
|
|
tracing::info!(address = %addr, "Gateway listening");
|
|
|
|
// Graceful shutdown using oneshot channel
|
|
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
|
|
let channel_manager = state.channel_manager.clone();
|
|
|
|
// Spawn ctrl_c handler
|
|
tokio::spawn(async move {
|
|
tokio::signal::ctrl_c().await.ok();
|
|
tracing::info!("Shutdown signal received");
|
|
let _ = scheduler_shutdown_tx.send(true);
|
|
let _ = channel_manager.stop_all().await;
|
|
let _ = shutdown_tx.send(());
|
|
});
|
|
|
|
// Serve with graceful shutdown
|
|
axum::serve(listener, app)
|
|
.with_graceful_shutdown(async {
|
|
shutdown_rx.await.ok();
|
|
})
|
|
.await?;
|
|
|
|
Ok(())
|
|
}
|