PicoBot/src/gateway/processor.rs

233 lines
8.6 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::sync::Arc;
use tokio::sync::Semaphore;
use crate::agent::{AgentError, CompositeSystemPromptProvider};
use crate::bus::{InboundMessage, MessageBus, OutboundMessage};
use crate::command::adapter::InputAdapter;
use crate::command::adapters::channel::ChannelInputAdapter;
use crate::command::handler::CommandRouter;
use crate::command::handlers::save_session::SaveSessionCommandHandler;
use crate::command::handlers::session::SessionCommandHandler;
use crate::command::handlers::session_query::SessionQueryCommandHandler;
use crate::command::Command;
use crate::config::LLMProviderConfig;
use crate::gateway::agent_prompt_provider::AgentPromptProvider;
use crate::skills::SkillPromptProvider;
use super::session::{BusToolCallEmitter, SessionManager};
#[derive(Clone)]
pub struct InboundProcessor {
bus: Arc<MessageBus>,
session_manager: SessionManager,
semaphore: Arc<Semaphore>,
_provider_config: LLMProviderConfig,
command_router: Arc<CommandRouter>,
}
impl InboundProcessor {
pub fn new(
bus: Arc<MessageBus>,
session_manager: SessionManager,
semaphore: Arc<Semaphore>,
provider_config: LLMProviderConfig,
) -> Self {
// 创建命令路由器并注册处理器
let mut command_router = CommandRouter::new();
// 注册 Session 处理器
let cli_sessions = session_manager.cli_sessions();
command_router.register(Box::new(SessionCommandHandler::new(cli_sessions.clone())));
// 注册 session_query 处理器
command_router.register(Box::new(SessionQueryCommandHandler::new(cli_sessions)));
// 注册 save_session 处理器
let store = session_manager.store();
let skills = session_manager.skills();
let prompt_repository = session_manager.store().clone();
let system_prompt_provider: Arc<dyn crate::agent::SystemPromptProvider> = Arc::new(CompositeSystemPromptProvider::new(vec![
Box::new(AgentPromptProvider::new(
0, // save_session 不需要 reinject 逻辑
provider_config.clone(),
prompt_repository,
)),
Box::new(SkillPromptProvider::new(skills)),
]));
command_router.register(Box::new(SaveSessionCommandHandler::new(
store,
system_prompt_provider,
)));
Self {
bus,
session_manager,
semaphore,
_provider_config: provider_config,
command_router: Arc::new(command_router),
}
}
pub async fn run(self) {
let max_concurrent = self.semaphore.available_permits();
tracing::info!(
max_concurrent_requests = max_concurrent,
"Inbound processor started"
);
loop {
// 1. 消费消息
let inbound = self.bus.consume_inbound().await;
#[cfg(debug_assertions)]
{
tracing::debug!(
channel = %inbound.channel,
chat_id = %inbound.chat_id,
sender = %inbound.sender_id,
content_len = %inbound.content.len(),
media_count = %inbound.media.len(),
"Processing inbound message"
);
}
// 2. 获取 semaphore permit控制并发
let permit = match self.semaphore.clone().acquire_owned().await {
Ok(permit) => permit,
Err(_) => {
tracing::error!("Semaphore closed, stopping inbound processor");
break;
}
};
// 3. 克隆 processor 用于新任务
let processor = self.clone();
// 4. 独立任务处理(包含 permit任务完成自动释放
tokio::spawn(async move {
let _permit = permit; // 持有 permit 直到任务完成
if let Err(e) = processor.process_one(inbound).await {
tracing::error!(error = %e, "Message processing failed");
}
});
}
}
async fn process_one(&self, inbound: InboundMessage) -> Result<(), AgentError> {
// 使用 ChannelInputAdapter 尝试解析命令
let adapter = ChannelInputAdapter::new();
let ctx = crate::command::context::AdapterContext::new(&inbound.channel)
.with_session_id(&inbound.chat_id);
if let Ok(Some(cmd)) = adapter.try_parse(&inbound.content, ctx) {
// 使用命令路由器处理
let cmd_ctx = crate::command::context::CommandContext::new(&inbound.channel, &inbound.channel)
.with_session_id(&inbound.chat_id);
// 记录是否是创建会话命令(用于后续自动切换)
let is_create_session = matches!(cmd, Command::CreateSession { .. });
let response = self.command_router.dispatch_with_response(cmd, cmd_ctx).await;
// 发送响应给用户
if response.success {
// 如果是创建会话,更新 chat_id 到新会话
let target_chat_id = if let Some(session_id) = response.metadata.get("session_id") {
if is_create_session {
// 自动切换到新会话
session_id.clone()
} else {
inbound.chat_id.clone()
}
} else {
inbound.chat_id.clone()
};
// 提取响应消息
for msg in &response.messages {
if let Err(error) = self
.bus
.publish_outbound(OutboundMessage::assistant(
inbound.channel.clone(),
target_chat_id.clone(),
msg.content.clone(),
None,
inbound.forwarded_metadata.clone(),
))
.await
{
tracing::error!(error = %error, "Failed to publish command response");
}
}
} else if let Some(error) = response.error {
if let Err(e) = self
.bus
.publish_outbound(OutboundMessage::assistant(
inbound.channel.clone(),
inbound.chat_id.clone(),
format!("Error [{}]: {}", error.code, error.message),
None,
inbound.forwarded_metadata.clone(),
))
.await
{
tracing::error!(error = %e, "Failed to publish error response");
}
}
return Ok(());
}
// 普通消息进入 AgentLoop
let live_emitter = Arc::new(BusToolCallEmitter::new(
self.bus.clone(),
inbound.channel.clone(),
inbound.chat_id.clone(),
inbound.forwarded_metadata.clone(),
self.session_manager.show_tool_results(),
));
match self
.session_manager
.handle_message(
&inbound.channel,
&inbound.sender_id,
&inbound.chat_id,
&inbound.content,
inbound.media,
Some(live_emitter),
)
.await
{
Ok(outbound_messages) => {
for mut outbound in outbound_messages {
outbound.metadata.extend(inbound.forwarded_metadata.clone());
if let Err(error) = self.bus.publish_outbound(outbound).await {
tracing::error!(error = %error, "Failed to publish outbound");
}
}
}
Err(error) => {
tracing::error!(error = %error, "Failed to handle message");
let mut metadata = inbound.forwarded_metadata.clone();
metadata.insert("error_kind".to_string(), "agent_execution".to_string());
if let Err(publish_error) = self
.bus
.publish_outbound(OutboundMessage::error_notification(
inbound.channel,
inbound.chat_id,
error.to_string(),
None,
metadata,
))
.await
{
tracing::error!(error = %publish_error, "Failed to publish execution error outbound");
}
}
}
Ok(())
}
}