212 lines
6.9 KiB
Rust
212 lines
6.9 KiB
Rust
pub mod agent_factory;
|
|
pub mod agent_prompt_provider;
|
|
pub mod agent_task_executor;
|
|
pub mod cli_session;
|
|
pub mod command;
|
|
pub mod compaction;
|
|
pub mod execution;
|
|
pub mod http;
|
|
pub mod memory_maintenance;
|
|
pub mod memory_maintenance_coordinator;
|
|
pub mod message_prepare;
|
|
pub mod outbound_dispatcher;
|
|
pub mod processor;
|
|
pub mod prompt;
|
|
pub mod provider_config_service;
|
|
pub mod runtime;
|
|
pub mod scheduled_agent_task_service;
|
|
pub mod session;
|
|
pub mod session_factory;
|
|
pub mod session_history;
|
|
pub mod session_lifecycle;
|
|
pub mod session_message_sender;
|
|
pub mod session_message_service;
|
|
pub mod session_pool;
|
|
pub mod tool_registry_factory;
|
|
pub mod ws;
|
|
|
|
use axum::{Router, routing};
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use tokio::net::TcpListener;
|
|
use tokio::sync::Semaphore;
|
|
|
|
use crate::bus::MessageBus;
|
|
use crate::channels::ChannelManager;
|
|
use crate::config::Config;
|
|
use crate::config::LLMProviderConfig;
|
|
use crate::logging;
|
|
use crate::scheduler::Scheduler;
|
|
use crate::skills::SkillRuntime;
|
|
use crate::tools::task::repository::TaskRepository;
|
|
use agent_task_executor::{AgentTaskExecutor, SchedulerMaintenanceService};
|
|
use outbound_dispatcher::OutboundDispatcher;
|
|
use processor::InboundProcessor;
|
|
use runtime::build_session_manager_with_sender;
|
|
use session_message_sender::BusSessionMessageSender;
|
|
use session::SessionManager;
|
|
|
|
pub struct GatewayState {
|
|
pub config: Config,
|
|
pub session_manager: SessionManager,
|
|
pub channel_manager: ChannelManager,
|
|
pub bus: Arc<MessageBus>,
|
|
pub task_repository: Arc<dyn TaskRepository>,
|
|
}
|
|
|
|
impl GatewayState {
|
|
pub fn from_config(config: Config) -> Result<Self, Box<dyn std::error::Error>> {
|
|
// Get provider config for SessionManager
|
|
let provider_config = config.get_provider_config("default")?;
|
|
let mut provider_configs = HashMap::<String, LLMProviderConfig>::new();
|
|
for agent_name in config.agents.keys() {
|
|
provider_configs.insert(agent_name.clone(), config.get_provider_config(agent_name)?);
|
|
}
|
|
|
|
// Chat history TTL from config (default 4 hours)
|
|
let chat_history_ttl_hours = config.gateway.chat_history_ttl_hours;
|
|
let agent_prompt_reinject_every = config.gateway.agent_prompt_reinject_every;
|
|
let show_tool_results = config.gateway.show_tool_results;
|
|
|
|
let session_ttl_hours = config.gateway.session_ttl_hours;
|
|
|
|
let skills = Arc::new(SkillRuntime::from_config(config.skills.clone()));
|
|
let channel_manager = ChannelManager::new();
|
|
let bus = channel_manager.bus();
|
|
|
|
let (session_manager, task_repository) = build_session_manager_with_sender(
|
|
agent_prompt_reinject_every,
|
|
show_tool_results,
|
|
config.time.timezone.clone(),
|
|
provider_config,
|
|
provider_configs,
|
|
skills,
|
|
Arc::new(BusSessionMessageSender::new(bus.clone())),
|
|
std::collections::HashSet::new(),
|
|
config.tools.task.clone(),
|
|
chat_history_ttl_hours,
|
|
session_ttl_hours,
|
|
)?;
|
|
|
|
Ok(Self {
|
|
config,
|
|
session_manager,
|
|
channel_manager,
|
|
bus,
|
|
task_repository,
|
|
})
|
|
}
|
|
|
|
/// Start the message processing loops
|
|
pub async fn start_message_processing(&self) {
|
|
let bus_for_outbound = self.bus.clone();
|
|
|
|
// Create semaphore for controlling concurrent requests
|
|
let max_concurrent = self.config.gateway.max_concurrent_requests;
|
|
let semaphore = Arc::new(Semaphore::new(max_concurrent));
|
|
|
|
// Spawn inbound processor with semaphore-controlled concurrency
|
|
let provider_config = match self.config.get_provider_config("default") {
|
|
Ok(config) => config,
|
|
Err(e) => {
|
|
tracing::error!(error = %e, "Failed to get provider config");
|
|
return;
|
|
}
|
|
};
|
|
let inbound_processor =
|
|
InboundProcessor::new(self.bus.clone(), self.session_manager.clone(), semaphore, provider_config);
|
|
tokio::spawn(inbound_processor.run());
|
|
|
|
// Spawn outbound dispatcher
|
|
let dispatcher = OutboundDispatcher::new(bus_for_outbound);
|
|
let channel_manager = self.channel_manager.clone();
|
|
|
|
for (name, channel) in channel_manager.channels().await {
|
|
dispatcher.register_channel(&name, channel).await;
|
|
}
|
|
|
|
tokio::spawn(async move {
|
|
tracing::info!("Outbound dispatcher started");
|
|
dispatcher.run().await;
|
|
});
|
|
}
|
|
}
|
|
|
|
pub async fn run(
|
|
host: Option<String>,
|
|
port: Option<u16>,
|
|
) -> Result<(), Box<dyn std::error::Error>> {
|
|
let config = Config::load_default()?;
|
|
let timezone = config.time.parse_timezone()?;
|
|
|
|
// Initialize logging
|
|
logging::init_logging(timezone);
|
|
tracing::info!("Starting PicoBot Gateway");
|
|
|
|
let state = Arc::new(GatewayState::from_config(config)?);
|
|
|
|
// Get provider config for channels
|
|
let provider_config = state.config.get_provider_config("default")?;
|
|
|
|
// Initialize and start channels
|
|
state
|
|
.channel_manager
|
|
.init(&state.config, provider_config.clone())
|
|
.await?;
|
|
state.channel_manager.start_all().await?;
|
|
|
|
// Start message processing (inbound processor + outbound dispatcher)
|
|
state.start_message_processing().await;
|
|
|
|
let (scheduler_shutdown_tx, scheduler_shutdown_rx) = tokio::sync::watch::channel(false);
|
|
if state.config.scheduler.enabled {
|
|
let scheduler = Scheduler::new(
|
|
state.bus.clone(),
|
|
state.config.scheduler.clone(),
|
|
timezone,
|
|
state.session_manager.store(),
|
|
AgentTaskExecutor::new(state.session_manager.clone()),
|
|
SchedulerMaintenanceService::new(state.session_manager.clone()),
|
|
);
|
|
|
|
tokio::spawn(async move {
|
|
scheduler.run(scheduler_shutdown_rx).await;
|
|
});
|
|
}
|
|
|
|
// CLI args override config file values
|
|
let bind_host = host.unwrap_or_else(|| state.config.gateway.host.clone());
|
|
let bind_port = port.unwrap_or(state.config.gateway.port);
|
|
|
|
let app = Router::new()
|
|
.route("/health", routing::get(http::health))
|
|
.route("/ws", routing::get(ws::ws_handler))
|
|
.with_state(state.clone());
|
|
|
|
let addr = format!("{}:{}", bind_host, bind_port);
|
|
let listener = TcpListener::bind(&addr).await?;
|
|
tracing::info!(address = %addr, "Gateway listening");
|
|
|
|
// Graceful shutdown using oneshot channel
|
|
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
|
|
let channel_manager = state.channel_manager.clone();
|
|
|
|
// Spawn ctrl_c handler
|
|
tokio::spawn(async move {
|
|
tokio::signal::ctrl_c().await.ok();
|
|
tracing::info!("Shutdown signal received");
|
|
let _ = scheduler_shutdown_tx.send(true);
|
|
let _ = channel_manager.stop_all().await;
|
|
let _ = shutdown_tx.send(());
|
|
});
|
|
|
|
// Serve with graceful shutdown
|
|
axum::serve(listener, app)
|
|
.with_graceful_shutdown(async {
|
|
shutdown_rx.await.ok();
|
|
})
|
|
.await?;
|
|
|
|
Ok(())
|
|
}
|