475 lines
19 KiB
Rust
475 lines
19 KiB
Rust
pub mod http;
|
|
pub mod ws;
|
|
|
|
use axum::{Router, routing};
|
|
use std::sync::Arc;
|
|
use tokio::net::TcpListener;
|
|
|
|
use crate::bus::{ControlMessage, MessageBus, OutboundDispatcher};
|
|
use crate::channels::base::{Channel, ChannelError};
|
|
use crate::channels::{ChannelManager, CliChatChannel};
|
|
use crate::config::{Config, ensure_workspace_dir, expand_path};
|
|
use crate::logging;
|
|
use crate::mcp;
|
|
use crate::memory::MemoryManager;
|
|
use crate::scheduler::Scheduler;
|
|
use crate::session::SessionManager;
|
|
|
|
pub struct GatewayState {
|
|
pub config: Config,
|
|
pub workspace_dir: std::path::PathBuf,
|
|
pub session_manager: Arc<SessionManager>,
|
|
pub channel_manager: ChannelManager,
|
|
pub storage: Arc<crate::storage::Storage>,
|
|
}
|
|
|
|
impl GatewayState {
|
|
pub async fn new() -> Result<Self, Box<dyn std::error::Error>> {
|
|
let config = Config::load_default()?;
|
|
|
|
// Initialize workspace directory: expand path and ensure it exists
|
|
let workspace_path = expand_path(&config.workspace_dir);
|
|
let workspace_path = ensure_workspace_dir(&workspace_path)?;
|
|
|
|
// Switch current working directory to workspace
|
|
std::env::set_current_dir(&workspace_path).map_err(|e| {
|
|
format!(
|
|
"Failed to switch to workspace directory {}: {}",
|
|
workspace_path.display(),
|
|
e
|
|
)
|
|
})?;
|
|
|
|
tracing::info!("Using workspace directory: {}", workspace_path.display());
|
|
|
|
// Release default AGENTS.md and USER.md to ~/.picobot/ if not exist
|
|
ensure_default_config_files();
|
|
|
|
// Get provider config for SessionManager
|
|
let mut provider_config = config.get_provider_config("default")?;
|
|
// Override workspace_dir with the ensured path
|
|
provider_config.workspace_dir = workspace_path.clone();
|
|
|
|
// Initialize Storage
|
|
let db_path = if let Some(ref path) = config.gateway.session_db_path {
|
|
std::path::PathBuf::from(path)
|
|
} else {
|
|
workspace_path.join("picobot.db")
|
|
};
|
|
let storage = Arc::new(
|
|
crate::storage::Storage::new(&db_path)
|
|
.await
|
|
.map_err(|e| format!("failed to initialize session storage: {}", e))?,
|
|
);
|
|
tracing::info!("Session storage: {}", db_path.display());
|
|
|
|
// Resolve consolidation provider/model with fallback to main agent config
|
|
let consolidation_provider = config
|
|
.memory
|
|
.resolve_consolidation_provider(&provider_config.name);
|
|
let consolidation_model = config
|
|
.memory
|
|
.resolve_consolidation_model(&provider_config.model_id);
|
|
let memory_manager = Arc::new(MemoryManager::new(
|
|
storage.clone(),
|
|
consolidation_provider,
|
|
consolidation_model,
|
|
));
|
|
tracing::info!(
|
|
consolidation_provider = %memory_manager.consolidation_provider,
|
|
consolidation_model = %memory_manager.consolidation_model,
|
|
"Memory system initialized"
|
|
);
|
|
|
|
// Create MessageBus first (shared by SessionManager and ChannelManager)
|
|
let bus = MessageBus::new(100);
|
|
|
|
let browser_config = if config.browser.enabled {
|
|
Some(config.browser.clone())
|
|
} else {
|
|
None
|
|
};
|
|
|
|
// Create SessionManager with bus injection
|
|
let session_manager = SessionManager::new(
|
|
provider_config.clone(),
|
|
storage.clone(),
|
|
bus.clone(),
|
|
memory_manager,
|
|
browser_config,
|
|
config.gateway.max_concurrent_background_tasks,
|
|
)?;
|
|
let session_manager = Arc::new(session_manager);
|
|
|
|
// Create ChannelManager and init channels
|
|
let cli_chat_channel = Arc::new(CliChatChannel::new());
|
|
let channel_manager = ChannelManager::with_bus(cli_chat_channel, bus);
|
|
channel_manager
|
|
.init(&config, workspace_path.clone())
|
|
.await
|
|
.map_err(|e| format!("Failed to init channels: {}", e))?;
|
|
|
|
// Register send_message tool with available channel names
|
|
let available_channels = channel_manager.list_channel_names().await;
|
|
let valid_channels = available_channels.clone();
|
|
session_manager.register_outbound_tool(available_channels);
|
|
|
|
// Register chat_manager tool
|
|
session_manager
|
|
.tools()
|
|
.register(crate::tools::ChatManagerTool::new(
|
|
storage.clone(),
|
|
valid_channels.clone(),
|
|
));
|
|
|
|
// Initialize MCP servers — connect and register discovered tools
|
|
if !config.mcp.servers.is_empty() {
|
|
let mcp_tools = mcp::connect_all(&config.mcp).await;
|
|
for tool_info in mcp_tools {
|
|
let wrapper = mcp::McpToolWrapper::new(
|
|
&tool_info.server_name,
|
|
tool_info.tool_name,
|
|
tool_info.description,
|
|
tool_info.schema,
|
|
tool_info.connection,
|
|
);
|
|
session_manager.tools().register(wrapper);
|
|
}
|
|
}
|
|
|
|
// Initialize scheduler if enabled in config
|
|
let scheduler_config = config.gateway.scheduler.clone().unwrap_or_default();
|
|
if scheduler_config.enabled {
|
|
// Register cron tools
|
|
session_manager
|
|
.tools()
|
|
.register(crate::tools::cron::CronAddTool::new(
|
|
storage.clone(),
|
|
valid_channels,
|
|
));
|
|
session_manager
|
|
.tools()
|
|
.register(crate::tools::cron::CronListTool::new(storage.clone()));
|
|
session_manager
|
|
.tools()
|
|
.register(crate::tools::cron::CronRemoveTool::new(storage.clone()));
|
|
session_manager
|
|
.tools()
|
|
.register(crate::tools::cron::CronEnableTool::new(storage.clone()));
|
|
session_manager
|
|
.tools()
|
|
.register(crate::tools::cron::CronDisableTool::new(storage.clone()));
|
|
session_manager
|
|
.tools()
|
|
.register(crate::tools::cron::CronUpdateTool::new(storage.clone()));
|
|
tracing::info!("Cron tools registered");
|
|
}
|
|
|
|
Ok(Self {
|
|
config,
|
|
workspace_dir: workspace_path,
|
|
session_manager: session_manager.clone(),
|
|
channel_manager,
|
|
storage,
|
|
})
|
|
}
|
|
|
|
/// Get a reference to the MessageBus
|
|
pub fn bus(&self) -> Arc<crate::bus::MessageBus> {
|
|
self.channel_manager.bus()
|
|
}
|
|
|
|
/// Get CLI chat channel for WebSocket handling
|
|
pub fn cli_chat_channel(&self) -> Arc<CliChatChannel> {
|
|
self.channel_manager.cli_chat_channel()
|
|
}
|
|
|
|
/// Start the message processing loops
|
|
pub async fn start_message_processing(&self) {
|
|
let bus = self.bus();
|
|
let bus_for_outbound = bus.clone();
|
|
let session_manager = self.session_manager.clone();
|
|
|
|
// Start CLI Chat Channel (it's already registered in ChannelManager)
|
|
let cli_chat_channel = self.cli_chat_channel();
|
|
if let Err(e) = cli_chat_channel.start(bus.clone()).await {
|
|
tracing::error!(error = %e, "Failed to start CLI chat channel");
|
|
}
|
|
|
|
// Spawn unified message processor
|
|
// This handles both inbound AI messages and control messages in one loop
|
|
tokio::spawn(async move {
|
|
tracing::info!("Message processor started");
|
|
|
|
loop {
|
|
tokio::select! {
|
|
// Inbound: AI message flow
|
|
inbound = bus.consume_inbound() => {
|
|
let Some(inbound) = inbound else {
|
|
tracing::warn!("Message processor stopping because inbound bus closed");
|
|
break;
|
|
};
|
|
match session_manager.handle_message(
|
|
&inbound.channel,
|
|
&inbound.sender_id,
|
|
&inbound.chat_id,
|
|
&inbound.content,
|
|
inbound.media,
|
|
).await {
|
|
Ok(crate::session::session::HandleResult::AgentResponse(content)) => {
|
|
let outbound = crate::bus::OutboundMessage {
|
|
channel: inbound.channel.clone(),
|
|
chat_id: inbound.chat_id.clone(),
|
|
content,
|
|
reply_to: None,
|
|
media: vec![],
|
|
metadata: inbound.forwarded_metadata,
|
|
};
|
|
if let Err(e) = bus.publish_outbound(outbound).await {
|
|
tracing::error!(error = %e, "Failed to publish outbound");
|
|
}
|
|
}
|
|
Ok(crate::session::session::HandleResult::CommandOutput(content)) => {
|
|
let outbound = crate::bus::OutboundMessage {
|
|
channel: inbound.channel.clone(),
|
|
chat_id: inbound.chat_id.clone(),
|
|
content,
|
|
reply_to: None,
|
|
media: vec![],
|
|
metadata: inbound.forwarded_metadata,
|
|
};
|
|
if let Err(e) = bus.publish_outbound(outbound).await {
|
|
tracing::error!(error = %e, "Failed to publish outbound");
|
|
}
|
|
}
|
|
Ok(crate::session::session::HandleResult::AgentProcessing) => {
|
|
// Agent is processing in background; response will be
|
|
// sent via bus directly from the spawned task.
|
|
// The select loop remains free to handle subsequent
|
|
// messages (including slash commands).
|
|
}
|
|
Err(e) => {
|
|
tracing::error!(error = %e, "Failed to handle message");
|
|
}
|
|
}
|
|
}
|
|
|
|
// Control: session management operations
|
|
msg = bus.consume_control() => {
|
|
let Some(msg) = msg else {
|
|
tracing::warn!("Message processor stopping because control bus closed");
|
|
break;
|
|
};
|
|
Self::handle_control_message(&session_manager, msg).await;
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
// Spawn outbound dispatcher
|
|
let dispatcher = OutboundDispatcher::new(bus_for_outbound, self.channel_manager.clone());
|
|
|
|
tokio::spawn(async move {
|
|
tracing::info!("Outbound dispatcher started");
|
|
dispatcher.run().await;
|
|
});
|
|
|
|
// Spawn scheduler background task if enabled
|
|
let scheduler_config = self.config.gateway.scheduler.clone().unwrap_or_default();
|
|
if scheduler_config.enabled {
|
|
let sched = Arc::new(Scheduler::new(
|
|
self.storage.clone(),
|
|
self.session_manager.clone(),
|
|
scheduler_config,
|
|
));
|
|
tokio::spawn(async move {
|
|
sched.run().await;
|
|
});
|
|
tracing::info!("Scheduler background task spawned");
|
|
}
|
|
}
|
|
|
|
/// Handle control messages (session management operations)
|
|
async fn handle_control_message(session_manager: &SessionManager, msg: ControlMessage) {
|
|
use crate::session::{SessionCommand::*, SessionEvent};
|
|
|
|
let reply_tx = msg.reply_tx;
|
|
let result: Result<SessionEvent, ChannelError> = match msg.op {
|
|
CreateDialog {
|
|
channel,
|
|
chat_id,
|
|
title,
|
|
} => session_manager
|
|
.create_dialog(&channel, &chat_id, title.as_deref())
|
|
.await
|
|
.map(|(session_id, title)| SessionEvent::DialogCreated { session_id, title })
|
|
.map_err(|e| ChannelError::Other(e.to_string())),
|
|
ListDialogs {
|
|
channel,
|
|
chat_id,
|
|
include_archived,
|
|
} => session_manager
|
|
.list_dialogs(&channel, &chat_id, include_archived)
|
|
.await
|
|
.map(|(dialogs, current_dialog_id)| SessionEvent::DialogList {
|
|
dialogs,
|
|
current_dialog_id,
|
|
})
|
|
.map_err(|e| ChannelError::Other(e.to_string())),
|
|
GetCurrentDialog { channel, chat_id } => session_manager
|
|
.get_current_dialog(&channel, &chat_id)
|
|
.await
|
|
.map(|session_id| SessionEvent::CurrentDialog { session_id })
|
|
.map_err(|e| ChannelError::Other(e.to_string())),
|
|
SwitchDialog {
|
|
channel,
|
|
chat_id,
|
|
dialog_id,
|
|
} => session_manager
|
|
.switch_dialog(&channel, &chat_id, &dialog_id)
|
|
.await
|
|
.map(|session_id| SessionEvent::DialogSwitched { session_id })
|
|
.map_err(|e| ChannelError::Other(e.to_string())),
|
|
RenameDialog { session_id, title } => session_manager
|
|
.rename_dialog(&session_id, &title)
|
|
.await
|
|
.map(|()| SessionEvent::DialogRenamed { session_id, title })
|
|
.map_err(|e| ChannelError::Other(e.to_string())),
|
|
ArchiveDialog { session_id } => session_manager
|
|
.archive_dialog(&session_id)
|
|
.await
|
|
.map(|()| SessionEvent::DialogArchived { session_id })
|
|
.map_err(|e| ChannelError::Other(e.to_string())),
|
|
DeleteDialog { session_id } => session_manager
|
|
.delete_dialog(&session_id)
|
|
.await
|
|
.map(|()| SessionEvent::DialogDeleted { session_id })
|
|
.map_err(|e| ChannelError::Other(e.to_string())),
|
|
ClearHistory { session_id } => session_manager
|
|
.clear_dialog_history(&session_id)
|
|
.await
|
|
.map(|()| SessionEvent::HistoryCleared { session_id })
|
|
.map_err(|e| ChannelError::Other(e.to_string())),
|
|
GetSlashCommands {
|
|
channel: _,
|
|
chat_id: _,
|
|
} => {
|
|
let commands = session_manager.get_slash_commands().to_vec();
|
|
Ok(SessionEvent::SlashCommandsList { commands })
|
|
}
|
|
ExecuteSlashCommand {
|
|
command,
|
|
args,
|
|
channel,
|
|
chat_id,
|
|
current_session_id,
|
|
} => session_manager
|
|
.execute_slash_command(
|
|
&command,
|
|
args.as_deref(),
|
|
&channel,
|
|
&chat_id,
|
|
current_session_id.as_ref(),
|
|
)
|
|
.await
|
|
.map(|(new_id, msg)| SessionEvent::SlashCommandExecuted {
|
|
new_session_id: new_id,
|
|
message: msg,
|
|
})
|
|
.map_err(|e| ChannelError::Other(e.to_string())),
|
|
};
|
|
|
|
let _ = reply_tx.send(result).await;
|
|
}
|
|
}
|
|
|
|
pub async fn run(
|
|
host: Option<String>,
|
|
port: Option<u16>,
|
|
) -> Result<(), Box<dyn std::error::Error>> {
|
|
// Initialize logging
|
|
logging::init_logging();
|
|
tracing::info!("Starting PicoBot Gateway");
|
|
|
|
let state = Arc::new(GatewayState::new().await?);
|
|
|
|
// Start all channels (init already done in GatewayState::new)
|
|
state.channel_manager.start_all().await?;
|
|
|
|
// Start message processing (inbound processor + control processor + outbound dispatcher)
|
|
state.start_message_processing().await;
|
|
|
|
// CLI args override config file values
|
|
let bind_host = host.unwrap_or_else(|| state.config.gateway.host.clone());
|
|
let bind_port = port.unwrap_or(state.config.gateway.port);
|
|
|
|
let app = Router::new()
|
|
.route("/health", routing::get(http::health))
|
|
.route("/ws", routing::get(ws::ws_handler))
|
|
.with_state(state.clone());
|
|
|
|
let addr = format!("{}:{}", bind_host, bind_port);
|
|
let listener = TcpListener::bind(&addr).await?;
|
|
tracing::info!(address = %addr, "Gateway listening");
|
|
|
|
// Graceful shutdown using oneshot channel
|
|
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
|
|
let channel_manager = state.channel_manager.clone();
|
|
|
|
// Spawn ctrl_c handler
|
|
tokio::spawn(async move {
|
|
tokio::signal::ctrl_c().await.ok();
|
|
tracing::info!("Shutdown signal received");
|
|
let _ = channel_manager.stop_all().await;
|
|
let _ = shutdown_tx.send(());
|
|
});
|
|
|
|
// Serve with graceful shutdown
|
|
axum::serve(listener, app)
|
|
.with_graceful_shutdown(async {
|
|
shutdown_rx.await.ok();
|
|
})
|
|
.await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Release default AGENTS.md and USER.md templates to ~/.picobot/ if not already present.
|
|
fn ensure_default_config_files() {
|
|
let picobot_dir = dirs::home_dir().unwrap_or_default().join(".picobot");
|
|
if let Err(e) = std::fs::create_dir_all(&picobot_dir) {
|
|
tracing::warn!(dir = %picobot_dir.display(), error = %e, "Failed to create ~/.picobot directory");
|
|
return;
|
|
}
|
|
|
|
let agents_path = picobot_dir.join("AGENTS.md");
|
|
if !agents_path.exists() {
|
|
let content = include_str!("../../resources/templates/AGENTS.md");
|
|
if let Err(e) = std::fs::write(&agents_path, content) {
|
|
tracing::warn!(path = %agents_path.display(), error = %e, "Failed to write AGENTS.md template");
|
|
} else {
|
|
tracing::info!(path = %agents_path.display(), "Released default AGENTS.md template");
|
|
}
|
|
}
|
|
|
|
let user_path = picobot_dir.join("USER.md");
|
|
if !user_path.exists() {
|
|
let content = include_str!("../../resources/templates/USER.md");
|
|
if let Err(e) = std::fs::write(&user_path, content) {
|
|
tracing::warn!(path = %user_path.display(), error = %e, "Failed to write USER.md template");
|
|
} else {
|
|
tracing::info!(path = %user_path.display(), "Released default USER.md template");
|
|
}
|
|
}
|
|
|
|
let config_example_path = picobot_dir.join("config.example.json");
|
|
if !config_example_path.exists() {
|
|
let content = include_str!("../../resources/templates/config.example.json");
|
|
if let Err(e) = std::fs::write(&config_example_path, content) {
|
|
tracing::warn!(path = %config_example_path.display(), error = %e, "Failed to write config.example.json template");
|
|
} else {
|
|
tracing::info!(path = %config_example_path.display(), "Released config.example.json template");
|
|
}
|
|
}
|
|
}
|