2026-06-16 22:56:01 +08:00

475 lines
19 KiB
Rust

pub mod http;
pub mod ws;
use axum::{Router, routing};
use std::sync::Arc;
use tokio::net::TcpListener;
use crate::bus::{ControlMessage, MessageBus, OutboundDispatcher};
use crate::channels::base::{Channel, ChannelError};
use crate::channels::{ChannelManager, CliChatChannel};
use crate::config::{Config, ensure_workspace_dir, expand_path};
use crate::logging;
use crate::mcp;
use crate::memory::MemoryManager;
use crate::scheduler::Scheduler;
use crate::session::SessionManager;
pub struct GatewayState {
pub config: Config,
pub workspace_dir: std::path::PathBuf,
pub session_manager: Arc<SessionManager>,
pub channel_manager: ChannelManager,
pub storage: Arc<crate::storage::Storage>,
}
impl GatewayState {
pub async fn new() -> Result<Self, Box<dyn std::error::Error>> {
let config = Config::load_default()?;
// Initialize workspace directory: expand path and ensure it exists
let workspace_path = expand_path(&config.workspace_dir);
let workspace_path = ensure_workspace_dir(&workspace_path)?;
// Switch current working directory to workspace
std::env::set_current_dir(&workspace_path).map_err(|e| {
format!(
"Failed to switch to workspace directory {}: {}",
workspace_path.display(),
e
)
})?;
tracing::info!("Using workspace directory: {}", workspace_path.display());
// Release default AGENTS.md and USER.md to ~/.picobot/ if not exist
ensure_default_config_files();
// Get provider config for SessionManager
let mut provider_config = config.get_provider_config("default")?;
// Override workspace_dir with the ensured path
provider_config.workspace_dir = workspace_path.clone();
// Initialize Storage
let db_path = if let Some(ref path) = config.gateway.session_db_path {
std::path::PathBuf::from(path)
} else {
workspace_path.join("picobot.db")
};
let storage = Arc::new(
crate::storage::Storage::new(&db_path)
.await
.map_err(|e| format!("failed to initialize session storage: {}", e))?,
);
tracing::info!("Session storage: {}", db_path.display());
// Resolve consolidation provider/model with fallback to main agent config
let consolidation_provider = config
.memory
.resolve_consolidation_provider(&provider_config.name);
let consolidation_model = config
.memory
.resolve_consolidation_model(&provider_config.model_id);
let memory_manager = Arc::new(MemoryManager::new(
storage.clone(),
consolidation_provider,
consolidation_model,
));
tracing::info!(
consolidation_provider = %memory_manager.consolidation_provider,
consolidation_model = %memory_manager.consolidation_model,
"Memory system initialized"
);
// Create MessageBus first (shared by SessionManager and ChannelManager)
let bus = MessageBus::new(100);
let browser_config = if config.browser.enabled {
Some(config.browser.clone())
} else {
None
};
// Create SessionManager with bus injection
let session_manager = SessionManager::new(
provider_config.clone(),
storage.clone(),
bus.clone(),
memory_manager,
browser_config,
config.gateway.max_concurrent_background_tasks,
)?;
let session_manager = Arc::new(session_manager);
// Create ChannelManager and init channels
let cli_chat_channel = Arc::new(CliChatChannel::new());
let channel_manager = ChannelManager::with_bus(cli_chat_channel, bus);
channel_manager
.init(&config, workspace_path.clone())
.await
.map_err(|e| format!("Failed to init channels: {}", e))?;
// Register send_message tool with available channel names
let available_channels = channel_manager.list_channel_names().await;
let valid_channels = available_channels.clone();
session_manager.register_outbound_tool(available_channels);
// Register chat_manager tool
session_manager
.tools()
.register(crate::tools::ChatManagerTool::new(
storage.clone(),
valid_channels.clone(),
));
// Initialize MCP servers — connect and register discovered tools
if !config.mcp.servers.is_empty() {
let mcp_tools = mcp::connect_all(&config.mcp).await;
for tool_info in mcp_tools {
let wrapper = mcp::McpToolWrapper::new(
&tool_info.server_name,
tool_info.tool_name,
tool_info.description,
tool_info.schema,
tool_info.connection,
);
session_manager.tools().register(wrapper);
}
}
// Initialize scheduler if enabled in config
let scheduler_config = config.gateway.scheduler.clone().unwrap_or_default();
if scheduler_config.enabled {
// Register cron tools
session_manager
.tools()
.register(crate::tools::cron::CronAddTool::new(
storage.clone(),
valid_channels,
));
session_manager
.tools()
.register(crate::tools::cron::CronListTool::new(storage.clone()));
session_manager
.tools()
.register(crate::tools::cron::CronRemoveTool::new(storage.clone()));
session_manager
.tools()
.register(crate::tools::cron::CronEnableTool::new(storage.clone()));
session_manager
.tools()
.register(crate::tools::cron::CronDisableTool::new(storage.clone()));
session_manager
.tools()
.register(crate::tools::cron::CronUpdateTool::new(storage.clone()));
tracing::info!("Cron tools registered");
}
Ok(Self {
config,
workspace_dir: workspace_path,
session_manager: session_manager.clone(),
channel_manager,
storage,
})
}
/// Get a reference to the MessageBus
pub fn bus(&self) -> Arc<crate::bus::MessageBus> {
self.channel_manager.bus()
}
/// Get CLI chat channel for WebSocket handling
pub fn cli_chat_channel(&self) -> Arc<CliChatChannel> {
self.channel_manager.cli_chat_channel()
}
/// Start the message processing loops
pub async fn start_message_processing(&self) {
let bus = self.bus();
let bus_for_outbound = bus.clone();
let session_manager = self.session_manager.clone();
// Start CLI Chat Channel (it's already registered in ChannelManager)
let cli_chat_channel = self.cli_chat_channel();
if let Err(e) = cli_chat_channel.start(bus.clone()).await {
tracing::error!(error = %e, "Failed to start CLI chat channel");
}
// Spawn unified message processor
// This handles both inbound AI messages and control messages in one loop
tokio::spawn(async move {
tracing::info!("Message processor started");
loop {
tokio::select! {
// Inbound: AI message flow
inbound = bus.consume_inbound() => {
let Some(inbound) = inbound else {
tracing::warn!("Message processor stopping because inbound bus closed");
break;
};
match session_manager.handle_message(
&inbound.channel,
&inbound.sender_id,
&inbound.chat_id,
&inbound.content,
inbound.media,
).await {
Ok(crate::session::session::HandleResult::AgentResponse(content)) => {
let outbound = crate::bus::OutboundMessage {
channel: inbound.channel.clone(),
chat_id: inbound.chat_id.clone(),
content,
reply_to: None,
media: vec![],
metadata: inbound.forwarded_metadata,
};
if let Err(e) = bus.publish_outbound(outbound).await {
tracing::error!(error = %e, "Failed to publish outbound");
}
}
Ok(crate::session::session::HandleResult::CommandOutput(content)) => {
let outbound = crate::bus::OutboundMessage {
channel: inbound.channel.clone(),
chat_id: inbound.chat_id.clone(),
content,
reply_to: None,
media: vec![],
metadata: inbound.forwarded_metadata,
};
if let Err(e) = bus.publish_outbound(outbound).await {
tracing::error!(error = %e, "Failed to publish outbound");
}
}
Ok(crate::session::session::HandleResult::AgentProcessing) => {
// Agent is processing in background; response will be
// sent via bus directly from the spawned task.
// The select loop remains free to handle subsequent
// messages (including slash commands).
}
Err(e) => {
tracing::error!(error = %e, "Failed to handle message");
}
}
}
// Control: session management operations
msg = bus.consume_control() => {
let Some(msg) = msg else {
tracing::warn!("Message processor stopping because control bus closed");
break;
};
Self::handle_control_message(&session_manager, msg).await;
}
}
}
});
// Spawn outbound dispatcher
let dispatcher = OutboundDispatcher::new(bus_for_outbound, self.channel_manager.clone());
tokio::spawn(async move {
tracing::info!("Outbound dispatcher started");
dispatcher.run().await;
});
// Spawn scheduler background task if enabled
let scheduler_config = self.config.gateway.scheduler.clone().unwrap_or_default();
if scheduler_config.enabled {
let sched = Arc::new(Scheduler::new(
self.storage.clone(),
self.session_manager.clone(),
scheduler_config,
));
tokio::spawn(async move {
sched.run().await;
});
tracing::info!("Scheduler background task spawned");
}
}
/// Handle control messages (session management operations)
async fn handle_control_message(session_manager: &SessionManager, msg: ControlMessage) {
use crate::session::{SessionCommand::*, SessionEvent};
let reply_tx = msg.reply_tx;
let result: Result<SessionEvent, ChannelError> = match msg.op {
CreateDialog {
channel,
chat_id,
title,
} => session_manager
.create_dialog(&channel, &chat_id, title.as_deref())
.await
.map(|(session_id, title)| SessionEvent::DialogCreated { session_id, title })
.map_err(|e| ChannelError::Other(e.to_string())),
ListDialogs {
channel,
chat_id,
include_archived,
} => session_manager
.list_dialogs(&channel, &chat_id, include_archived)
.await
.map(|(dialogs, current_dialog_id)| SessionEvent::DialogList {
dialogs,
current_dialog_id,
})
.map_err(|e| ChannelError::Other(e.to_string())),
GetCurrentDialog { channel, chat_id } => session_manager
.get_current_dialog(&channel, &chat_id)
.await
.map(|session_id| SessionEvent::CurrentDialog { session_id })
.map_err(|e| ChannelError::Other(e.to_string())),
SwitchDialog {
channel,
chat_id,
dialog_id,
} => session_manager
.switch_dialog(&channel, &chat_id, &dialog_id)
.await
.map(|session_id| SessionEvent::DialogSwitched { session_id })
.map_err(|e| ChannelError::Other(e.to_string())),
RenameDialog { session_id, title } => session_manager
.rename_dialog(&session_id, &title)
.await
.map(|()| SessionEvent::DialogRenamed { session_id, title })
.map_err(|e| ChannelError::Other(e.to_string())),
ArchiveDialog { session_id } => session_manager
.archive_dialog(&session_id)
.await
.map(|()| SessionEvent::DialogArchived { session_id })
.map_err(|e| ChannelError::Other(e.to_string())),
DeleteDialog { session_id } => session_manager
.delete_dialog(&session_id)
.await
.map(|()| SessionEvent::DialogDeleted { session_id })
.map_err(|e| ChannelError::Other(e.to_string())),
ClearHistory { session_id } => session_manager
.clear_dialog_history(&session_id)
.await
.map(|()| SessionEvent::HistoryCleared { session_id })
.map_err(|e| ChannelError::Other(e.to_string())),
GetSlashCommands {
channel: _,
chat_id: _,
} => {
let commands = session_manager.get_slash_commands().to_vec();
Ok(SessionEvent::SlashCommandsList { commands })
}
ExecuteSlashCommand {
command,
args,
channel,
chat_id,
current_session_id,
} => session_manager
.execute_slash_command(
&command,
args.as_deref(),
&channel,
&chat_id,
current_session_id.as_ref(),
)
.await
.map(|(new_id, msg)| SessionEvent::SlashCommandExecuted {
new_session_id: new_id,
message: msg,
})
.map_err(|e| ChannelError::Other(e.to_string())),
};
let _ = reply_tx.send(result).await;
}
}
pub async fn run(
host: Option<String>,
port: Option<u16>,
) -> Result<(), Box<dyn std::error::Error>> {
// Initialize logging
logging::init_logging();
tracing::info!("Starting PicoBot Gateway");
let state = Arc::new(GatewayState::new().await?);
// Start all channels (init already done in GatewayState::new)
state.channel_manager.start_all().await?;
// Start message processing (inbound processor + control processor + outbound dispatcher)
state.start_message_processing().await;
// CLI args override config file values
let bind_host = host.unwrap_or_else(|| state.config.gateway.host.clone());
let bind_port = port.unwrap_or(state.config.gateway.port);
let app = Router::new()
.route("/health", routing::get(http::health))
.route("/ws", routing::get(ws::ws_handler))
.with_state(state.clone());
let addr = format!("{}:{}", bind_host, bind_port);
let listener = TcpListener::bind(&addr).await?;
tracing::info!(address = %addr, "Gateway listening");
// Graceful shutdown using oneshot channel
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
let channel_manager = state.channel_manager.clone();
// Spawn ctrl_c handler
tokio::spawn(async move {
tokio::signal::ctrl_c().await.ok();
tracing::info!("Shutdown signal received");
let _ = channel_manager.stop_all().await;
let _ = shutdown_tx.send(());
});
// Serve with graceful shutdown
axum::serve(listener, app)
.with_graceful_shutdown(async {
shutdown_rx.await.ok();
})
.await?;
Ok(())
}
/// Release default AGENTS.md and USER.md templates to ~/.picobot/ if not already present.
fn ensure_default_config_files() {
let picobot_dir = dirs::home_dir().unwrap_or_default().join(".picobot");
if let Err(e) = std::fs::create_dir_all(&picobot_dir) {
tracing::warn!(dir = %picobot_dir.display(), error = %e, "Failed to create ~/.picobot directory");
return;
}
let agents_path = picobot_dir.join("AGENTS.md");
if !agents_path.exists() {
let content = include_str!("../../resources/templates/AGENTS.md");
if let Err(e) = std::fs::write(&agents_path, content) {
tracing::warn!(path = %agents_path.display(), error = %e, "Failed to write AGENTS.md template");
} else {
tracing::info!(path = %agents_path.display(), "Released default AGENTS.md template");
}
}
let user_path = picobot_dir.join("USER.md");
if !user_path.exists() {
let content = include_str!("../../resources/templates/USER.md");
if let Err(e) = std::fs::write(&user_path, content) {
tracing::warn!(path = %user_path.display(), error = %e, "Failed to write USER.md template");
} else {
tracing::info!(path = %user_path.display(), "Released default USER.md template");
}
}
let config_example_path = picobot_dir.join("config.example.json");
if !config_example_path.exists() {
let content = include_str!("../../resources/templates/config.example.json");
if let Err(e) = std::fs::write(&config_example_path, content) {
tracing::warn!(path = %config_example_path.display(), error = %e, "Failed to write config.example.json template");
} else {
tracing::info!(path = %config_example_path.display(), "Released config.example.json template");
}
}
}