feat: 实现 InChat 命令处理器和路由器,支持聊天中直接输入的命令;添加保存会话功能
This commit is contained in:
parent
d4c15e0478
commit
08172dcf9c
@ -1,6 +1,8 @@
|
||||
use crate::bus::InboundMessage;
|
||||
use crate::command::context::CommandContext;
|
||||
use crate::command::response::{CommandError, CommandResponse};
|
||||
use crate::command::Command;
|
||||
use crate::agent::AgentError;
|
||||
use async_trait::async_trait;
|
||||
|
||||
/// 命令处理器 trait
|
||||
@ -28,6 +30,31 @@ pub trait CommandHandler: Send + Sync {
|
||||
) -> Result<CommandResponse, CommandError>;
|
||||
}
|
||||
|
||||
/// InChat 命令处理器 trait
|
||||
///
|
||||
/// 用于处理在聊天中直接输入的命令(如 Feishu/WeChat 等通道)
|
||||
/// 接收 InboundMessage 而不是 CommandContext
|
||||
#[async_trait]
|
||||
pub trait InChatCommandHandler: Send + Sync {
|
||||
/// 是否可以处理此命令
|
||||
fn can_handle(&self, cmd: &Command) -> bool;
|
||||
|
||||
/// 执行命令
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `cmd` - 要执行的命令
|
||||
/// * `inbound` - 入站消息(包含通道信息)
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok(())` - 命令执行成功
|
||||
/// * `Err(AgentError)` - 命令执行失败
|
||||
async fn handle(
|
||||
&self,
|
||||
cmd: Command,
|
||||
inbound: &InboundMessage,
|
||||
) -> Result<(), AgentError>;
|
||||
}
|
||||
|
||||
/// 命令路由器
|
||||
///
|
||||
/// 负责将命令分发到合适的处理器
|
||||
@ -102,6 +129,63 @@ impl Default for CommandRouter {
|
||||
}
|
||||
}
|
||||
|
||||
/// InChat 命令路由器
|
||||
///
|
||||
/// 负责将在聊天中输入的命令分发到合适的处理器
|
||||
pub struct InChatCommandRouter {
|
||||
handlers: Vec<Box<dyn InChatCommandHandler>>,
|
||||
}
|
||||
|
||||
impl InChatCommandRouter {
|
||||
/// 创建新的 InChat 命令路由器
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
handlers: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// 注册 InChat 命令处理器
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `handler` - 要注册的处理器
|
||||
pub fn register(&mut self, handler: Box<dyn InChatCommandHandler>) {
|
||||
self.handlers.push(handler);
|
||||
}
|
||||
|
||||
/// 分发命令到合适的处理器
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `cmd` - 要执行的命令
|
||||
/// * `inbound` - 入站消息
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok(true)` - 命令被处理
|
||||
/// * `Ok(false)` - 没有合适的处理器
|
||||
/// * `Err(AgentError)` - 执行失败
|
||||
pub async fn dispatch(
|
||||
&self,
|
||||
cmd: Command,
|
||||
inbound: &InboundMessage,
|
||||
) -> Result<bool, AgentError> {
|
||||
// 查找能处理此命令的处理器
|
||||
for handler in &self.handlers {
|
||||
if handler.can_handle(&cmd) {
|
||||
handler.handle(cmd, inbound).await?;
|
||||
return Ok(true);
|
||||
}
|
||||
}
|
||||
|
||||
// 没有找到合适的处理器
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for InChatCommandRouter {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@ -1,16 +1,73 @@
|
||||
use crate::agent::{SystemPrompt, SystemPromptContext, SystemPromptProvider};
|
||||
use crate::bus::InboundMessage;
|
||||
use crate::command::context::CommandContext;
|
||||
use crate::command::handler::CommandHandler;
|
||||
use crate::command::handler::{CommandHandler, InChatCommandHandler};
|
||||
use crate::command::response::{CommandError, CommandResponse, MessageKind};
|
||||
use crate::command::Command;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::gateway::agent_prompt_provider::SimpleAgentPromptProvider;
|
||||
use crate::storage::{SessionRecord, SessionStore};
|
||||
use crate::agent::AgentError;
|
||||
use crate::bus::OutboundMessage;
|
||||
use async_trait::async_trait;
|
||||
use chrono::{Local, TimeZone};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// 保存会话到文件(公共函数,可被命令处理器和其他模块复用)
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `session_id` - 会话ID
|
||||
/// * `filepath` - 可选的文件路径
|
||||
/// * `store` - 会话存储
|
||||
/// * `provider_config` - LLM提供者配置
|
||||
///
|
||||
/// # Returns
|
||||
/// 返回保存的文件路径
|
||||
pub async fn save_session_to_file(
|
||||
session_id: &str,
|
||||
filepath: Option<String>,
|
||||
store: &SessionStore,
|
||||
provider_config: &LLMProviderConfig,
|
||||
) -> Result<PathBuf, String> {
|
||||
// 获取会话记录
|
||||
let record = store
|
||||
.get_session(session_id)
|
||||
.map_err(|e| format!("Failed to get session: {}", e))?
|
||||
.ok_or_else(|| "Session not found".to_string())?;
|
||||
|
||||
// 获取所有消息(包括历史)
|
||||
let messages = store
|
||||
.load_all_messages(session_id)
|
||||
.map_err(|e| format!("Failed to load messages: {}", e))?;
|
||||
|
||||
// 计算用户消息数(用于系统提示词构建)
|
||||
let user_message_count = messages.iter().filter(|m| m.role == "user").count();
|
||||
|
||||
// 构建系统提示词
|
||||
let system_prompt = build_system_prompt(provider_config, &record, user_message_count);
|
||||
|
||||
// 生成 Markdown 内容
|
||||
let markdown = generate_markdown(&record, &system_prompt, &messages);
|
||||
|
||||
// 确定输出路径
|
||||
let output_path = resolve_filepath(filepath, &record);
|
||||
|
||||
// 创建父目录
|
||||
if let Some(parent) = output_path.parent() {
|
||||
if !parent.as_os_str().is_empty() && !parent.exists() {
|
||||
std::fs::create_dir_all(parent)
|
||||
.map_err(|e| format!("Failed to create directory: {}", e))?;
|
||||
}
|
||||
}
|
||||
|
||||
// 写入文件
|
||||
std::fs::write(&output_path, markdown)
|
||||
.map_err(|e| format!("Failed to write file: {}", e))?;
|
||||
|
||||
Ok(output_path)
|
||||
}
|
||||
|
||||
/// 保存会话命令处理器
|
||||
///
|
||||
/// 将当前会话内容(系统提示词和消息历史)保存到 Markdown 文件
|
||||
@ -70,50 +127,22 @@ async fn handle_save_session(
|
||||
.as_deref()
|
||||
.ok_or_else(|| CommandError::new("NO_SESSION", "No active session".to_string()))?;
|
||||
|
||||
// 获取会话记录
|
||||
let record = handler
|
||||
.store
|
||||
.get_session(session_id)
|
||||
.map_err(|e| CommandError::new("SESSION_ERROR", e.to_string()))?
|
||||
.ok_or_else(|| CommandError::new("SESSION_NOT_FOUND", "Session not found".to_string()))?;
|
||||
// 调用公共函数
|
||||
let output_path = save_session_to_file(
|
||||
session_id,
|
||||
filepath,
|
||||
&*handler.store,
|
||||
&handler.provider_config,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| CommandError::new("SAVE_ERROR", e))?;
|
||||
|
||||
// 获取所有消息(包括历史)
|
||||
let messages = handler
|
||||
// 获取消息数量用于返回
|
||||
let message_count = handler
|
||||
.store
|
||||
.load_all_messages(session_id)
|
||||
.map_err(|e| CommandError::new("LOAD_MESSAGES_ERROR", e.to_string()))?;
|
||||
|
||||
// 计算用户消息数(用于系统提示词构建)
|
||||
let user_message_count = messages.iter().filter(|m| m.role == "user").count();
|
||||
|
||||
// 构建系统提示词
|
||||
let system_prompt = build_system_prompt(&handler.provider_config, &record, user_message_count);
|
||||
|
||||
// 生成 Markdown 内容
|
||||
let markdown = generate_markdown(&record, &system_prompt, &messages);
|
||||
|
||||
// 确定输出路径
|
||||
let output_path = resolve_filepath(filepath, &record);
|
||||
|
||||
// 创建父目录
|
||||
if let Some(parent) = output_path.parent() {
|
||||
if !parent.as_os_str().is_empty() && !parent.exists() {
|
||||
std::fs::create_dir_all(parent).map_err(|e| {
|
||||
CommandError::new(
|
||||
"CREATE_DIR_ERROR",
|
||||
format!("Failed to create directory: {}", e),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
// 写入文件
|
||||
std::fs::write(&output_path, markdown).map_err(|e| {
|
||||
CommandError::new(
|
||||
"WRITE_FILE_ERROR",
|
||||
format!("Failed to write file: {}", e),
|
||||
)
|
||||
})?;
|
||||
.map_err(|e| CommandError::new("LOAD_MESSAGES_ERROR", e.to_string()))?
|
||||
.len();
|
||||
|
||||
Ok(CommandResponse::success(ctx.request_id)
|
||||
.with_message(
|
||||
@ -121,7 +150,7 @@ async fn handle_save_session(
|
||||
&format!("Session saved to: {}", output_path.display()),
|
||||
)
|
||||
.with_metadata("filepath", output_path.to_string_lossy().as_ref())
|
||||
.with_metadata("message_count", &messages.len().to_string()))
|
||||
.with_metadata("message_count", &message_count.to_string()))
|
||||
}
|
||||
|
||||
/// 构建系统提示词
|
||||
@ -140,7 +169,7 @@ fn build_system_prompt(
|
||||
}
|
||||
|
||||
/// 生成 Markdown 内容
|
||||
fn generate_markdown(
|
||||
pub fn generate_markdown(
|
||||
record: &SessionRecord,
|
||||
system_prompt: &Option<SystemPrompt>,
|
||||
messages: &[crate::bus::ChatMessage],
|
||||
@ -273,7 +302,7 @@ fn escape_yaml_string(s: &str) -> String {
|
||||
}
|
||||
|
||||
/// 格式化时间戳
|
||||
fn format_timestamp(ts: i64) -> String {
|
||||
pub fn format_timestamp(ts: i64) -> String {
|
||||
Local
|
||||
.timestamp_millis_opt(ts)
|
||||
.single()
|
||||
@ -284,7 +313,7 @@ fn format_timestamp(ts: i64) -> String {
|
||||
/// 解析文件路径
|
||||
///
|
||||
/// 如果未提供路径,自动生成基于会话标题和时间戳的文件名
|
||||
fn resolve_filepath(filepath: Option<String>, record: &SessionRecord) -> PathBuf {
|
||||
pub fn resolve_filepath(filepath: Option<String>, record: &SessionRecord) -> PathBuf {
|
||||
match filepath {
|
||||
Some(path) => PathBuf::from(path),
|
||||
None => {
|
||||
@ -318,6 +347,66 @@ fn resolve_filepath(filepath: Option<String>, record: &SessionRecord) -> PathBuf
|
||||
}
|
||||
}
|
||||
|
||||
/// InChat 保存会话命令处理器
|
||||
///
|
||||
/// 用于处理 Feishu/WeChat 等通道中直接输入的 /save 命令
|
||||
pub struct SaveSessionInChatHandler {
|
||||
store: Arc<SessionStore>,
|
||||
provider_config: LLMProviderConfig,
|
||||
}
|
||||
|
||||
impl SaveSessionInChatHandler {
|
||||
/// 创建新的 InChat 保存会话命令处理器
|
||||
pub fn new(store: Arc<SessionStore>, provider_config: LLMProviderConfig) -> Self {
|
||||
Self {
|
||||
store,
|
||||
provider_config,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl InChatCommandHandler for SaveSessionInChatHandler {
|
||||
fn can_handle(&self, cmd: &Command) -> bool {
|
||||
matches!(cmd, Command::SaveSession { .. })
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
cmd: Command,
|
||||
inbound: &InboundMessage,
|
||||
) -> Result<(), AgentError> {
|
||||
let Command::SaveSession { filepath } = cmd else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
// 使用聊天ID作为会话ID
|
||||
let session_id = &inbound.chat_id;
|
||||
|
||||
// 调用公共函数
|
||||
let result = save_session_to_file(
|
||||
session_id,
|
||||
filepath,
|
||||
&*self.store,
|
||||
&self.provider_config,
|
||||
)
|
||||
.await;
|
||||
|
||||
// 结果通过返回 Ok(()) 表示成功
|
||||
// 实际输出由调用者通过消息总线发送
|
||||
match result {
|
||||
Ok(output_path) => {
|
||||
tracing::info!("Session saved to: {}", output_path.display());
|
||||
Ok(())
|
||||
}
|
||||
Err(error) => {
|
||||
tracing::error!("Failed to save session: {}", error);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@ -99,8 +99,15 @@ impl GatewayState {
|
||||
let semaphore = Arc::new(Semaphore::new(max_concurrent));
|
||||
|
||||
// Spawn inbound processor with semaphore-controlled concurrency
|
||||
let provider_config = match self.config.get_provider_config("default") {
|
||||
Ok(config) => config,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to get provider config");
|
||||
return;
|
||||
}
|
||||
};
|
||||
let inbound_processor =
|
||||
InboundProcessor::new(self.bus.clone(), self.session_manager.clone(), semaphore);
|
||||
InboundProcessor::new(self.bus.clone(), self.session_manager.clone(), semaphore, provider_config);
|
||||
tokio::spawn(inbound_processor.run());
|
||||
|
||||
// Spawn outbound dispatcher
|
||||
|
||||
@ -4,6 +4,9 @@ use tokio::sync::Semaphore;
|
||||
|
||||
use crate::agent::AgentError;
|
||||
use crate::bus::{InboundMessage, MessageBus, OutboundMessage};
|
||||
use crate::command::handler::{InChatCommandHandler, InChatCommandRouter};
|
||||
use crate::command::Command;
|
||||
use crate::config::LLMProviderConfig;
|
||||
|
||||
use super::session::{BusToolCallEmitter, SessionManager};
|
||||
|
||||
@ -12,6 +15,8 @@ pub struct InboundProcessor {
|
||||
bus: Arc<MessageBus>,
|
||||
session_manager: SessionManager,
|
||||
semaphore: Arc<Semaphore>,
|
||||
provider_config: LLMProviderConfig,
|
||||
command_router: Arc<InChatCommandRouter>,
|
||||
}
|
||||
|
||||
impl InboundProcessor {
|
||||
@ -19,11 +24,24 @@ impl InboundProcessor {
|
||||
bus: Arc<MessageBus>,
|
||||
session_manager: SessionManager,
|
||||
semaphore: Arc<Semaphore>,
|
||||
provider_config: LLMProviderConfig,
|
||||
) -> Self {
|
||||
// 创建命令路由器并注册处理器
|
||||
let mut command_router = InChatCommandRouter::new();
|
||||
|
||||
// 注册 save_session 处理器
|
||||
let store = session_manager.store();
|
||||
command_router.register(Box::new(crate::command::handlers::save_session::SaveSessionInChatHandler::new(
|
||||
store,
|
||||
provider_config.clone(),
|
||||
)));
|
||||
|
||||
Self {
|
||||
bus,
|
||||
session_manager,
|
||||
semaphore,
|
||||
provider_config,
|
||||
command_router: Arc::new(command_router),
|
||||
}
|
||||
}
|
||||
|
||||
@ -73,6 +91,16 @@ impl InboundProcessor {
|
||||
}
|
||||
|
||||
async fn process_one(&self, inbound: InboundMessage) -> Result<(), AgentError> {
|
||||
// 尝试解析为命令
|
||||
if let Some(cmd) = parse_in_chat_command(&inbound.content) {
|
||||
// 使用命令路由器处理
|
||||
let handled = self.command_router.dispatch(cmd, &inbound).await?;
|
||||
if handled {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
// 普通消息进入 AgentLoop
|
||||
let live_emitter = Arc::new(BusToolCallEmitter::new(
|
||||
self.bus.clone(),
|
||||
inbound.channel.clone(),
|
||||
@ -124,3 +152,26 @@ impl InboundProcessor {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// 解析聊天中的命令
|
||||
///
|
||||
/// 支持格式:
|
||||
/// - `/save [filepath]` - 保存会话
|
||||
///
|
||||
/// 返回 Some(Command) 如果是命令
|
||||
/// 返回 None 如果不是命令
|
||||
fn parse_in_chat_command(content: &str) -> Option<Command> {
|
||||
let trimmed = content.trim();
|
||||
|
||||
if trimmed.starts_with("/save") {
|
||||
let path = trimmed[5..].trim();
|
||||
let filepath = if path.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(path.to_string())
|
||||
};
|
||||
Some(Command::SaveSession { filepath })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user