feat: 实现 InChat 命令处理器和路由器,支持聊天中直接输入的命令;添加保存会话功能

This commit is contained in:
ooodc 2026-05-13 22:40:41 +08:00
parent d4c15e0478
commit 08172dcf9c
4 changed files with 278 additions and 47 deletions

View File

@ -1,6 +1,8 @@
use crate::bus::InboundMessage;
use crate::command::context::CommandContext; use crate::command::context::CommandContext;
use crate::command::response::{CommandError, CommandResponse}; use crate::command::response::{CommandError, CommandResponse};
use crate::command::Command; use crate::command::Command;
use crate::agent::AgentError;
use async_trait::async_trait; use async_trait::async_trait;
/// 命令处理器 trait /// 命令处理器 trait
@ -28,6 +30,31 @@ pub trait CommandHandler: Send + Sync {
) -> Result<CommandResponse, CommandError>; ) -> Result<CommandResponse, CommandError>;
} }
/// InChat 命令处理器 trait
///
/// 用于处理在聊天中直接输入的命令(如 Feishu/WeChat 等通道)
/// 接收 InboundMessage 而不是 CommandContext
#[async_trait]
pub trait InChatCommandHandler: Send + Sync {
/// 是否可以处理此命令
fn can_handle(&self, cmd: &Command) -> bool;
/// 执行命令
///
/// # Arguments
/// * `cmd` - 要执行的命令
/// * `inbound` - 入站消息(包含通道信息)
///
/// # Returns
/// * `Ok(())` - 命令执行成功
/// * `Err(AgentError)` - 命令执行失败
async fn handle(
&self,
cmd: Command,
inbound: &InboundMessage,
) -> Result<(), AgentError>;
}
/// 命令路由器 /// 命令路由器
/// ///
/// 负责将命令分发到合适的处理器 /// 负责将命令分发到合适的处理器
@ -102,6 +129,63 @@ impl Default for CommandRouter {
} }
} }
/// InChat 命令路由器
///
/// 负责将在聊天中输入的命令分发到合适的处理器
pub struct InChatCommandRouter {
handlers: Vec<Box<dyn InChatCommandHandler>>,
}
impl InChatCommandRouter {
/// 创建新的 InChat 命令路由器
pub fn new() -> Self {
Self {
handlers: Vec::new(),
}
}
/// 注册 InChat 命令处理器
///
/// # Arguments
/// * `handler` - 要注册的处理器
pub fn register(&mut self, handler: Box<dyn InChatCommandHandler>) {
self.handlers.push(handler);
}
/// 分发命令到合适的处理器
///
/// # Arguments
/// * `cmd` - 要执行的命令
/// * `inbound` - 入站消息
///
/// # Returns
/// * `Ok(true)` - 命令被处理
/// * `Ok(false)` - 没有合适的处理器
/// * `Err(AgentError)` - 执行失败
pub async fn dispatch(
&self,
cmd: Command,
inbound: &InboundMessage,
) -> Result<bool, AgentError> {
// 查找能处理此命令的处理器
for handler in &self.handlers {
if handler.can_handle(&cmd) {
handler.handle(cmd, inbound).await?;
return Ok(true);
}
}
// 没有找到合适的处理器
Ok(false)
}
}
impl Default for InChatCommandRouter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View File

@ -1,16 +1,73 @@
use crate::agent::{SystemPrompt, SystemPromptContext, SystemPromptProvider}; use crate::agent::{SystemPrompt, SystemPromptContext, SystemPromptProvider};
use crate::bus::InboundMessage;
use crate::command::context::CommandContext; use crate::command::context::CommandContext;
use crate::command::handler::CommandHandler; use crate::command::handler::{CommandHandler, InChatCommandHandler};
use crate::command::response::{CommandError, CommandResponse, MessageKind}; use crate::command::response::{CommandError, CommandResponse, MessageKind};
use crate::command::Command; use crate::command::Command;
use crate::config::LLMProviderConfig; use crate::config::LLMProviderConfig;
use crate::gateway::agent_prompt_provider::SimpleAgentPromptProvider; use crate::gateway::agent_prompt_provider::SimpleAgentPromptProvider;
use crate::storage::{SessionRecord, SessionStore}; use crate::storage::{SessionRecord, SessionStore};
use crate::agent::AgentError;
use crate::bus::OutboundMessage;
use async_trait::async_trait; use async_trait::async_trait;
use chrono::{Local, TimeZone}; use chrono::{Local, TimeZone};
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
/// 保存会话到文件(公共函数,可被命令处理器和其他模块复用)
///
/// # Arguments
/// * `session_id` - 会话ID
/// * `filepath` - 可选的文件路径
/// * `store` - 会话存储
/// * `provider_config` - LLM提供者配置
///
/// # Returns
/// 返回保存的文件路径
pub async fn save_session_to_file(
session_id: &str,
filepath: Option<String>,
store: &SessionStore,
provider_config: &LLMProviderConfig,
) -> Result<PathBuf, String> {
// 获取会话记录
let record = store
.get_session(session_id)
.map_err(|e| format!("Failed to get session: {}", e))?
.ok_or_else(|| "Session not found".to_string())?;
// 获取所有消息(包括历史)
let messages = store
.load_all_messages(session_id)
.map_err(|e| format!("Failed to load messages: {}", e))?;
// 计算用户消息数(用于系统提示词构建)
let user_message_count = messages.iter().filter(|m| m.role == "user").count();
// 构建系统提示词
let system_prompt = build_system_prompt(provider_config, &record, user_message_count);
// 生成 Markdown 内容
let markdown = generate_markdown(&record, &system_prompt, &messages);
// 确定输出路径
let output_path = resolve_filepath(filepath, &record);
// 创建父目录
if let Some(parent) = output_path.parent() {
if !parent.as_os_str().is_empty() && !parent.exists() {
std::fs::create_dir_all(parent)
.map_err(|e| format!("Failed to create directory: {}", e))?;
}
}
// 写入文件
std::fs::write(&output_path, markdown)
.map_err(|e| format!("Failed to write file: {}", e))?;
Ok(output_path)
}
/// 保存会话命令处理器 /// 保存会话命令处理器
/// ///
/// 将当前会话内容(系统提示词和消息历史)保存到 Markdown 文件 /// 将当前会话内容(系统提示词和消息历史)保存到 Markdown 文件
@ -70,50 +127,22 @@ async fn handle_save_session(
.as_deref() .as_deref()
.ok_or_else(|| CommandError::new("NO_SESSION", "No active session".to_string()))?; .ok_or_else(|| CommandError::new("NO_SESSION", "No active session".to_string()))?;
// 获取会话记录 // 调用公共函数
let record = handler let output_path = save_session_to_file(
.store session_id,
.get_session(session_id) filepath,
.map_err(|e| CommandError::new("SESSION_ERROR", e.to_string()))? &*handler.store,
.ok_or_else(|| CommandError::new("SESSION_NOT_FOUND", "Session not found".to_string()))?; &handler.provider_config,
)
.await
.map_err(|e| CommandError::new("SAVE_ERROR", e))?;
// 获取所有消息(包括历史) // 获取消息数量用于返回
let messages = handler let message_count = handler
.store .store
.load_all_messages(session_id) .load_all_messages(session_id)
.map_err(|e| CommandError::new("LOAD_MESSAGES_ERROR", e.to_string()))?; .map_err(|e| CommandError::new("LOAD_MESSAGES_ERROR", e.to_string()))?
.len();
// 计算用户消息数(用于系统提示词构建)
let user_message_count = messages.iter().filter(|m| m.role == "user").count();
// 构建系统提示词
let system_prompt = build_system_prompt(&handler.provider_config, &record, user_message_count);
// 生成 Markdown 内容
let markdown = generate_markdown(&record, &system_prompt, &messages);
// 确定输出路径
let output_path = resolve_filepath(filepath, &record);
// 创建父目录
if let Some(parent) = output_path.parent() {
if !parent.as_os_str().is_empty() && !parent.exists() {
std::fs::create_dir_all(parent).map_err(|e| {
CommandError::new(
"CREATE_DIR_ERROR",
format!("Failed to create directory: {}", e),
)
})?;
}
}
// 写入文件
std::fs::write(&output_path, markdown).map_err(|e| {
CommandError::new(
"WRITE_FILE_ERROR",
format!("Failed to write file: {}", e),
)
})?;
Ok(CommandResponse::success(ctx.request_id) Ok(CommandResponse::success(ctx.request_id)
.with_message( .with_message(
@ -121,7 +150,7 @@ async fn handle_save_session(
&format!("Session saved to: {}", output_path.display()), &format!("Session saved to: {}", output_path.display()),
) )
.with_metadata("filepath", output_path.to_string_lossy().as_ref()) .with_metadata("filepath", output_path.to_string_lossy().as_ref())
.with_metadata("message_count", &messages.len().to_string())) .with_metadata("message_count", &message_count.to_string()))
} }
/// 构建系统提示词 /// 构建系统提示词
@ -140,7 +169,7 @@ fn build_system_prompt(
} }
/// 生成 Markdown 内容 /// 生成 Markdown 内容
fn generate_markdown( pub fn generate_markdown(
record: &SessionRecord, record: &SessionRecord,
system_prompt: &Option<SystemPrompt>, system_prompt: &Option<SystemPrompt>,
messages: &[crate::bus::ChatMessage], messages: &[crate::bus::ChatMessage],
@ -273,7 +302,7 @@ fn escape_yaml_string(s: &str) -> String {
} }
/// 格式化时间戳 /// 格式化时间戳
fn format_timestamp(ts: i64) -> String { pub fn format_timestamp(ts: i64) -> String {
Local Local
.timestamp_millis_opt(ts) .timestamp_millis_opt(ts)
.single() .single()
@ -284,7 +313,7 @@ fn format_timestamp(ts: i64) -> String {
/// 解析文件路径 /// 解析文件路径
/// ///
/// 如果未提供路径,自动生成基于会话标题和时间戳的文件名 /// 如果未提供路径,自动生成基于会话标题和时间戳的文件名
fn resolve_filepath(filepath: Option<String>, record: &SessionRecord) -> PathBuf { pub fn resolve_filepath(filepath: Option<String>, record: &SessionRecord) -> PathBuf {
match filepath { match filepath {
Some(path) => PathBuf::from(path), Some(path) => PathBuf::from(path),
None => { None => {
@ -318,6 +347,66 @@ fn resolve_filepath(filepath: Option<String>, record: &SessionRecord) -> PathBuf
} }
} }
/// InChat 保存会话命令处理器
///
/// 用于处理 Feishu/WeChat 等通道中直接输入的 /save 命令
pub struct SaveSessionInChatHandler {
store: Arc<SessionStore>,
provider_config: LLMProviderConfig,
}
impl SaveSessionInChatHandler {
/// 创建新的 InChat 保存会话命令处理器
pub fn new(store: Arc<SessionStore>, provider_config: LLMProviderConfig) -> Self {
Self {
store,
provider_config,
}
}
}
#[async_trait]
impl InChatCommandHandler for SaveSessionInChatHandler {
fn can_handle(&self, cmd: &Command) -> bool {
matches!(cmd, Command::SaveSession { .. })
}
async fn handle(
&self,
cmd: Command,
inbound: &InboundMessage,
) -> Result<(), AgentError> {
let Command::SaveSession { filepath } = cmd else {
return Ok(());
};
// 使用聊天ID作为会话ID
let session_id = &inbound.chat_id;
// 调用公共函数
let result = save_session_to_file(
session_id,
filepath,
&*self.store,
&self.provider_config,
)
.await;
// 结果通过返回 Ok(()) 表示成功
// 实际输出由调用者通过消息总线发送
match result {
Ok(output_path) => {
tracing::info!("Session saved to: {}", output_path.display());
Ok(())
}
Err(error) => {
tracing::error!("Failed to save session: {}", error);
Ok(())
}
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View File

@ -99,8 +99,15 @@ impl GatewayState {
let semaphore = Arc::new(Semaphore::new(max_concurrent)); let semaphore = Arc::new(Semaphore::new(max_concurrent));
// Spawn inbound processor with semaphore-controlled concurrency // Spawn inbound processor with semaphore-controlled concurrency
let provider_config = match self.config.get_provider_config("default") {
Ok(config) => config,
Err(e) => {
tracing::error!(error = %e, "Failed to get provider config");
return;
}
};
let inbound_processor = let inbound_processor =
InboundProcessor::new(self.bus.clone(), self.session_manager.clone(), semaphore); InboundProcessor::new(self.bus.clone(), self.session_manager.clone(), semaphore, provider_config);
tokio::spawn(inbound_processor.run()); tokio::spawn(inbound_processor.run());
// Spawn outbound dispatcher // Spawn outbound dispatcher

View File

@ -4,6 +4,9 @@ use tokio::sync::Semaphore;
use crate::agent::AgentError; use crate::agent::AgentError;
use crate::bus::{InboundMessage, MessageBus, OutboundMessage}; use crate::bus::{InboundMessage, MessageBus, OutboundMessage};
use crate::command::handler::{InChatCommandHandler, InChatCommandRouter};
use crate::command::Command;
use crate::config::LLMProviderConfig;
use super::session::{BusToolCallEmitter, SessionManager}; use super::session::{BusToolCallEmitter, SessionManager};
@ -12,6 +15,8 @@ pub struct InboundProcessor {
bus: Arc<MessageBus>, bus: Arc<MessageBus>,
session_manager: SessionManager, session_manager: SessionManager,
semaphore: Arc<Semaphore>, semaphore: Arc<Semaphore>,
provider_config: LLMProviderConfig,
command_router: Arc<InChatCommandRouter>,
} }
impl InboundProcessor { impl InboundProcessor {
@ -19,11 +24,24 @@ impl InboundProcessor {
bus: Arc<MessageBus>, bus: Arc<MessageBus>,
session_manager: SessionManager, session_manager: SessionManager,
semaphore: Arc<Semaphore>, semaphore: Arc<Semaphore>,
provider_config: LLMProviderConfig,
) -> Self { ) -> Self {
// 创建命令路由器并注册处理器
let mut command_router = InChatCommandRouter::new();
// 注册 save_session 处理器
let store = session_manager.store();
command_router.register(Box::new(crate::command::handlers::save_session::SaveSessionInChatHandler::new(
store,
provider_config.clone(),
)));
Self { Self {
bus, bus,
session_manager, session_manager,
semaphore, semaphore,
provider_config,
command_router: Arc::new(command_router),
} }
} }
@ -73,6 +91,16 @@ impl InboundProcessor {
} }
async fn process_one(&self, inbound: InboundMessage) -> Result<(), AgentError> { async fn process_one(&self, inbound: InboundMessage) -> Result<(), AgentError> {
// 尝试解析为命令
if let Some(cmd) = parse_in_chat_command(&inbound.content) {
// 使用命令路由器处理
let handled = self.command_router.dispatch(cmd, &inbound).await?;
if handled {
return Ok(());
}
}
// 普通消息进入 AgentLoop
let live_emitter = Arc::new(BusToolCallEmitter::new( let live_emitter = Arc::new(BusToolCallEmitter::new(
self.bus.clone(), self.bus.clone(),
inbound.channel.clone(), inbound.channel.clone(),
@ -124,3 +152,26 @@ impl InboundProcessor {
Ok(()) Ok(())
} }
} }
/// 解析聊天中的命令
///
/// 支持格式:
/// - `/save [filepath]` - 保存会话
///
/// 返回 Some(Command) 如果是命令
/// 返回 None 如果不是命令
fn parse_in_chat_command(content: &str) -> Option<Command> {
let trimmed = content.trim();
if trimmed.starts_with("/save") {
let path = trimmed[5..].trim();
let filepath = if path.is_empty() {
None
} else {
Some(path.to_string())
};
Some(Command::SaveSession { filepath })
} else {
None
}
}