PicoBot/src/gateway/processor.rs

304 lines
13 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::sync::Arc;
use tokio::sync::Semaphore;
use crate::agent::{AgentError, CompositeSystemPromptProvider};
use crate::bus::{InboundMessage, MessageBus, OutboundMessage};
use crate::command::adapter::InputAdapter;
use crate::command::adapters::channel::ChannelInputAdapter;
use crate::command::handler::CommandRouter;
use crate::command::handlers::get_current::GetCurrentSessionCommandHandler;
use crate::command::handlers::help::HelpCommandHandler;
use crate::command::handlers::list_sessions::ListSessionsCommandHandler;
use crate::command::handlers::load_session::LoadSessionCommandHandler;
use crate::command::handlers::save_session::SaveSessionCommandHandler;
use crate::command::handlers::save_topic::SaveTopicCommandHandler;
use crate::command::handlers::session::SessionCommandHandler;
use crate::command::handlers::switch_session::SwitchSessionCommandHandler;
use crate::config::LLMProviderConfig;
use crate::gateway::agent_prompt_provider::AgentPromptProvider;
use crate::providers::{create_provider, ProviderRuntimeConfig};
use crate::skills::SkillPromptProvider;
use crate::storage::persistent_session_id;
use crate::topic_description::generate_topic_description;
use super::session::{BusToolCallEmitter, SessionManager};
#[derive(Clone)]
pub struct InboundProcessor {
bus: Arc<MessageBus>,
session_manager: SessionManager,
semaphore: Arc<Semaphore>,
provider_config: LLMProviderConfig,
command_router: Arc<CommandRouter>,
}
impl InboundProcessor {
pub fn new(
bus: Arc<MessageBus>,
session_manager: SessionManager,
semaphore: Arc<Semaphore>,
provider_config: LLMProviderConfig,
) -> Self {
// 创建命令路由器并注册处理器
let mut command_router = CommandRouter::new();
let store = session_manager.store();
// 注册 Session 处理器
let session_handler = SessionCommandHandler::new(store.clone())
.with_session_manager(session_manager.clone());
command_router.register(Box::new(session_handler));
// 注册 list_sessions 处理器
command_router.register(Box::new(ListSessionsCommandHandler::new(store.clone())));
// 注册 switch_session 处理器
let switch_handler = SwitchSessionCommandHandler::new(store.clone())
.with_session_manager(session_manager.clone());
command_router.register(Box::new(switch_handler));
// 创建 system_prompt_provider用于 save_session, save_topic, get_current
let skills = session_manager.skills();
let prompt_repository = session_manager.store().clone();
let system_prompt_provider: Arc<dyn crate::agent::SystemPromptProvider> = Arc::new(CompositeSystemPromptProvider::new(vec![
Box::new(AgentPromptProvider::new(
0, // 不需要 reinject 逻辑
provider_config.clone(),
prompt_repository,
)),
Box::new(SkillPromptProvider::new(skills)),
]));
// 注册 get_current 处理器
command_router.register(Box::new(
GetCurrentSessionCommandHandler::new(store.clone())
.with_session_manager(session_manager.clone())
.with_system_prompt_provider(system_prompt_provider.clone())
));
// 注册 load_session 处理器
command_router.register(Box::new(LoadSessionCommandHandler::new(store.clone())));
// 注册 save_session 处理器
command_router.register(Box::new(SaveSessionCommandHandler::new(
store.clone(),
session_manager.task_repository(),
system_prompt_provider.clone(),
)));
// 注册 save_topic 处理器
command_router.register(Box::new(SaveTopicCommandHandler::new(
store.clone(),
session_manager.task_repository(),
system_prompt_provider,
).with_session_manager(session_manager.clone())));
// 注册 help 处理器(最后注册,获取所有已注册命令的元数据)
let metadata = command_router.metadata_arc();
command_router.register(Box::new(HelpCommandHandler::new(metadata)));
Self {
bus,
session_manager,
semaphore,
provider_config,
command_router: Arc::new(command_router),
}
}
pub async fn run(self) {
let max_concurrent = self.semaphore.available_permits();
tracing::info!(
max_concurrent_requests = max_concurrent,
"Inbound processor started"
);
loop {
// 1. 消费消息
let inbound = self.bus.consume_inbound().await;
#[cfg(debug_assertions)]
{
tracing::debug!(
channel = %inbound.channel,
chat_id = %inbound.chat_id,
sender = %inbound.sender_id,
content_len = %inbound.content.len(),
media_count = %inbound.media.len(),
"Processing inbound message"
);
}
// 2. 获取 semaphore permit控制并发
let permit = match self.semaphore.clone().acquire_owned().await {
Ok(permit) => permit,
Err(_) => {
tracing::error!("Semaphore closed, stopping inbound processor");
break;
}
};
// 3. 克隆 processor 用于新任务
let processor = self.clone();
// 4. 独立任务处理(包含 permit任务完成自动释放
tokio::spawn(async move {
let _permit = permit; // 持有 permit 直到任务完成
if let Err(e) = processor.process_one(inbound).await {
tracing::error!(error = %e, "Message processing failed");
}
});
}
}
async fn process_one(&self, inbound: InboundMessage) -> Result<(), AgentError> {
// 计算正确的 session_id根据 channel_name 和 chat_id
let session_id = persistent_session_id(&inbound.channel, &inbound.chat_id);
// 获取当前话题(封装了 session 创建逻辑)
let current_topic = self.session_manager
.get_current_topic(&inbound.channel, &inbound.chat_id)
.await?;
// 使用 ChannelInputAdapter 尝试解析命令
let adapter = ChannelInputAdapter::new();
let ctx = crate::command::context::AdapterContext::new(&inbound.channel)
.with_session_id(&session_id);
if let Ok(Some(cmd)) = adapter.try_parse(&inbound.content, ctx) {
// 使用命令路由器处理
let mut cmd_ctx = crate::command::context::CommandContext::new(&inbound.channel, &inbound.channel)
.with_session_id(&session_id)
.with_chat_id(&inbound.chat_id);
// 只在有话题时才设置 topic_id
if let Some(ref topic_id) = current_topic {
cmd_ctx = cmd_ctx.with_topic_id(topic_id.as_str());
}
let response = self.command_router.dispatch_with_response(cmd, cmd_ctx).await;
// 发送响应给用户
if response.success {
// 提取响应消息
// chat_id 保持为 inbound.chat_id飞书 open_id
// session_id 放入 metadata 用于会话管理
for msg in &response.messages {
if let Err(error) = self
.bus
.publish_outbound(OutboundMessage::assistant(
inbound.channel.clone(),
inbound.chat_id.clone(),
response.metadata.get("session_id").cloned(),
msg.content.clone(),
None,
inbound.forwarded_metadata.clone(),
))
.await
{
tracing::error!(error = %error, "Failed to publish command response");
}
}
} else if let Some(error) = response.error {
if let Err(e) = self
.bus
.publish_outbound(OutboundMessage::assistant(
inbound.channel.clone(),
inbound.chat_id.clone(),
response.metadata.get("session_id").cloned(),
format!("Error [{}]: {}", error.code, error.message),
None,
inbound.forwarded_metadata.clone(),
))
.await
{
tracing::error!(error = %e, "Failed to publish error response");
}
}
return Ok(());
}
// 普通消息进入 AgentLoop
let live_emitter = Arc::new(BusToolCallEmitter::new(
self.bus.clone(),
inbound.channel.clone(),
inbound.chat_id.clone(),
inbound.forwarded_metadata.clone(),
self.session_manager.show_tool_results(),
));
match self
.session_manager
.handle_message(
&inbound.channel,
&inbound.sender_id,
&inbound.chat_id,
&inbound.content,
inbound.media,
Some(live_emitter),
)
.await
{
Ok(outbound_messages) => {
for mut outbound in outbound_messages {
outbound.metadata.extend(inbound.forwarded_metadata.clone());
if let Err(error) = self.bus.publish_outbound(outbound).await {
tracing::error!(error = %error, "Failed to publish outbound");
}
}
// 异步生成 topic 描述(仅第一条消息后触发一次)
if let Some(ref topic_id) = current_topic {
let store = self.session_manager.store();
if let Ok(Some(topic)) = store.get_topic(topic_id) {
if topic.description.is_none() || topic.description.as_ref().map(|d| d.is_empty()).unwrap_or(true) {
let provider_config = self.provider_config.clone();
let topic_id_clone = topic_id.clone();
let first_message = inbound.content.clone();
let store_clone = store.clone();
tokio::spawn(async move {
let runtime_config: ProviderRuntimeConfig = provider_config.into();
if let Ok(provider) = create_provider(runtime_config) {
match generate_topic_description(provider.as_ref(), &first_message).await {
Ok(description) => {
if let Err(e) = store_clone.update_topic_description(&topic_id_clone, &description) {
tracing::error!(error = %e, topic_id = %topic_id_clone, "Failed to update topic description");
} else {
tracing::info!(topic_id = %topic_id_clone, description = %description, "Topic description generated");
}
}
Err(e) => {
tracing::error!(error = %e, topic_id = %topic_id_clone, "Failed to generate topic description");
}
}
}
});
}
}
}
}
Err(error) => {
tracing::error!(error = %error, "Failed to handle message");
let mut metadata = inbound.forwarded_metadata.clone();
metadata.insert("error_kind".to_string(), "agent_execution".to_string());
if let Err(publish_error) = self
.bus
.publish_outbound(OutboundMessage::error_notification(
inbound.channel,
inbound.chat_id,
None, // session_id
error.to_string(),
None,
metadata,
))
.await
{
tracing::error!(error = %publish_error, "Failed to publish execution error outbound");
}
}
}
Ok(())
}
}