From 08172dcf9c31e071788e081b2e3eb690c38f989d Mon Sep 17 00:00:00 2001 From: ooodc <549496103@qq.com> Date: Wed, 13 May 2026 22:40:41 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0=20InChat=20=E5=91=BD?= =?UTF-8?q?=E4=BB=A4=E5=A4=84=E7=90=86=E5=99=A8=E5=92=8C=E8=B7=AF=E7=94=B1?= =?UTF-8?q?=E5=99=A8=EF=BC=8C=E6=94=AF=E6=8C=81=E8=81=8A=E5=A4=A9=E4=B8=AD?= =?UTF-8?q?=E7=9B=B4=E6=8E=A5=E8=BE=93=E5=85=A5=E7=9A=84=E5=91=BD=E4=BB=A4?= =?UTF-8?q?=EF=BC=9B=E6=B7=BB=E5=8A=A0=E4=BF=9D=E5=AD=98=E4=BC=9A=E8=AF=9D?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/command/handler.rs | 84 +++++++++++++ src/command/handlers/save_session.rs | 181 ++++++++++++++++++++------- src/gateway/mod.rs | 9 +- src/gateway/processor.rs | 51 ++++++++ 4 files changed, 278 insertions(+), 47 deletions(-) diff --git a/src/command/handler.rs b/src/command/handler.rs index 2c568c4..75c48d2 100644 --- a/src/command/handler.rs +++ b/src/command/handler.rs @@ -1,6 +1,8 @@ +use crate::bus::InboundMessage; use crate::command::context::CommandContext; use crate::command::response::{CommandError, CommandResponse}; use crate::command::Command; +use crate::agent::AgentError; use async_trait::async_trait; /// 命令处理器 trait @@ -28,6 +30,31 @@ pub trait CommandHandler: Send + Sync { ) -> Result; } +/// InChat 命令处理器 trait +/// +/// 用于处理在聊天中直接输入的命令(如 Feishu/WeChat 等通道) +/// 接收 InboundMessage 而不是 CommandContext +#[async_trait] +pub trait InChatCommandHandler: Send + Sync { + /// 是否可以处理此命令 + fn can_handle(&self, cmd: &Command) -> bool; + + /// 执行命令 + /// + /// # Arguments + /// * `cmd` - 要执行的命令 + /// * `inbound` - 入站消息(包含通道信息) + /// + /// # Returns + /// * `Ok(())` - 命令执行成功 + /// * `Err(AgentError)` - 命令执行失败 + async fn handle( + &self, + cmd: Command, + inbound: &InboundMessage, + ) -> Result<(), AgentError>; +} + /// 命令路由器 /// /// 负责将命令分发到合适的处理器 @@ -102,6 +129,63 @@ impl Default for CommandRouter { } } +/// InChat 命令路由器 +/// +/// 负责将在聊天中输入的命令分发到合适的处理器 +pub struct InChatCommandRouter { + handlers: Vec>, +} + +impl InChatCommandRouter { + /// 创建新的 InChat 命令路由器 + pub fn new() -> Self { + Self { + handlers: Vec::new(), + } + } + + /// 注册 InChat 命令处理器 + /// + /// # Arguments + /// * `handler` - 要注册的处理器 + pub fn register(&mut self, handler: Box) { + self.handlers.push(handler); + } + + /// 分发命令到合适的处理器 + /// + /// # Arguments + /// * `cmd` - 要执行的命令 + /// * `inbound` - 入站消息 + /// + /// # Returns + /// * `Ok(true)` - 命令被处理 + /// * `Ok(false)` - 没有合适的处理器 + /// * `Err(AgentError)` - 执行失败 + pub async fn dispatch( + &self, + cmd: Command, + inbound: &InboundMessage, + ) -> Result { + // 查找能处理此命令的处理器 + for handler in &self.handlers { + if handler.can_handle(&cmd) { + handler.handle(cmd, inbound).await?; + return Ok(true); + } + } + + // 没有找到合适的处理器 + Ok(false) + } +} + +impl Default for InChatCommandRouter { + fn default() -> Self { + Self::new() + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/command/handlers/save_session.rs b/src/command/handlers/save_session.rs index 5d038a8..f56f1f9 100644 --- a/src/command/handlers/save_session.rs +++ b/src/command/handlers/save_session.rs @@ -1,16 +1,73 @@ use crate::agent::{SystemPrompt, SystemPromptContext, SystemPromptProvider}; +use crate::bus::InboundMessage; use crate::command::context::CommandContext; -use crate::command::handler::CommandHandler; +use crate::command::handler::{CommandHandler, InChatCommandHandler}; use crate::command::response::{CommandError, CommandResponse, MessageKind}; use crate::command::Command; use crate::config::LLMProviderConfig; use crate::gateway::agent_prompt_provider::SimpleAgentPromptProvider; use crate::storage::{SessionRecord, SessionStore}; +use crate::agent::AgentError; +use crate::bus::OutboundMessage; use async_trait::async_trait; use chrono::{Local, TimeZone}; use std::path::PathBuf; use std::sync::Arc; +/// 保存会话到文件(公共函数,可被命令处理器和其他模块复用) +/// +/// # Arguments +/// * `session_id` - 会话ID +/// * `filepath` - 可选的文件路径 +/// * `store` - 会话存储 +/// * `provider_config` - LLM提供者配置 +/// +/// # Returns +/// 返回保存的文件路径 +pub async fn save_session_to_file( + session_id: &str, + filepath: Option, + store: &SessionStore, + provider_config: &LLMProviderConfig, +) -> Result { + // 获取会话记录 + let record = store + .get_session(session_id) + .map_err(|e| format!("Failed to get session: {}", e))? + .ok_or_else(|| "Session not found".to_string())?; + + // 获取所有消息(包括历史) + let messages = store + .load_all_messages(session_id) + .map_err(|e| format!("Failed to load messages: {}", e))?; + + // 计算用户消息数(用于系统提示词构建) + let user_message_count = messages.iter().filter(|m| m.role == "user").count(); + + // 构建系统提示词 + let system_prompt = build_system_prompt(provider_config, &record, user_message_count); + + // 生成 Markdown 内容 + let markdown = generate_markdown(&record, &system_prompt, &messages); + + // 确定输出路径 + let output_path = resolve_filepath(filepath, &record); + + // 创建父目录 + if let Some(parent) = output_path.parent() { + if !parent.as_os_str().is_empty() && !parent.exists() { + std::fs::create_dir_all(parent) + .map_err(|e| format!("Failed to create directory: {}", e))?; + } + } + + // 写入文件 + std::fs::write(&output_path, markdown) + .map_err(|e| format!("Failed to write file: {}", e))?; + + Ok(output_path) +} + /// 保存会话命令处理器 /// /// 将当前会话内容(系统提示词和消息历史)保存到 Markdown 文件 @@ -70,50 +127,22 @@ async fn handle_save_session( .as_deref() .ok_or_else(|| CommandError::new("NO_SESSION", "No active session".to_string()))?; - // 获取会话记录 - let record = handler - .store - .get_session(session_id) - .map_err(|e| CommandError::new("SESSION_ERROR", e.to_string()))? - .ok_or_else(|| CommandError::new("SESSION_NOT_FOUND", "Session not found".to_string()))?; + // 调用公共函数 + let output_path = save_session_to_file( + session_id, + filepath, + &*handler.store, + &handler.provider_config, + ) + .await + .map_err(|e| CommandError::new("SAVE_ERROR", e))?; - // 获取所有消息(包括历史) - let messages = handler + // 获取消息数量用于返回 + let message_count = handler .store .load_all_messages(session_id) - .map_err(|e| CommandError::new("LOAD_MESSAGES_ERROR", e.to_string()))?; - - // 计算用户消息数(用于系统提示词构建) - let user_message_count = messages.iter().filter(|m| m.role == "user").count(); - - // 构建系统提示词 - let system_prompt = build_system_prompt(&handler.provider_config, &record, user_message_count); - - // 生成 Markdown 内容 - let markdown = generate_markdown(&record, &system_prompt, &messages); - - // 确定输出路径 - let output_path = resolve_filepath(filepath, &record); - - // 创建父目录 - if let Some(parent) = output_path.parent() { - if !parent.as_os_str().is_empty() && !parent.exists() { - std::fs::create_dir_all(parent).map_err(|e| { - CommandError::new( - "CREATE_DIR_ERROR", - format!("Failed to create directory: {}", e), - ) - })?; - } - } - - // 写入文件 - std::fs::write(&output_path, markdown).map_err(|e| { - CommandError::new( - "WRITE_FILE_ERROR", - format!("Failed to write file: {}", e), - ) - })?; + .map_err(|e| CommandError::new("LOAD_MESSAGES_ERROR", e.to_string()))? + .len(); Ok(CommandResponse::success(ctx.request_id) .with_message( @@ -121,7 +150,7 @@ async fn handle_save_session( &format!("Session saved to: {}", output_path.display()), ) .with_metadata("filepath", output_path.to_string_lossy().as_ref()) - .with_metadata("message_count", &messages.len().to_string())) + .with_metadata("message_count", &message_count.to_string())) } /// 构建系统提示词 @@ -140,7 +169,7 @@ fn build_system_prompt( } /// 生成 Markdown 内容 -fn generate_markdown( +pub fn generate_markdown( record: &SessionRecord, system_prompt: &Option, messages: &[crate::bus::ChatMessage], @@ -273,7 +302,7 @@ fn escape_yaml_string(s: &str) -> String { } /// 格式化时间戳 -fn format_timestamp(ts: i64) -> String { +pub fn format_timestamp(ts: i64) -> String { Local .timestamp_millis_opt(ts) .single() @@ -284,7 +313,7 @@ fn format_timestamp(ts: i64) -> String { /// 解析文件路径 /// /// 如果未提供路径,自动生成基于会话标题和时间戳的文件名 -fn resolve_filepath(filepath: Option, record: &SessionRecord) -> PathBuf { +pub fn resolve_filepath(filepath: Option, record: &SessionRecord) -> PathBuf { match filepath { Some(path) => PathBuf::from(path), None => { @@ -318,6 +347,66 @@ fn resolve_filepath(filepath: Option, record: &SessionRecord) -> PathBuf } } +/// InChat 保存会话命令处理器 +/// +/// 用于处理 Feishu/WeChat 等通道中直接输入的 /save 命令 +pub struct SaveSessionInChatHandler { + store: Arc, + provider_config: LLMProviderConfig, +} + +impl SaveSessionInChatHandler { + /// 创建新的 InChat 保存会话命令处理器 + pub fn new(store: Arc, provider_config: LLMProviderConfig) -> Self { + Self { + store, + provider_config, + } + } +} + +#[async_trait] +impl InChatCommandHandler for SaveSessionInChatHandler { + fn can_handle(&self, cmd: &Command) -> bool { + matches!(cmd, Command::SaveSession { .. }) + } + + async fn handle( + &self, + cmd: Command, + inbound: &InboundMessage, + ) -> Result<(), AgentError> { + let Command::SaveSession { filepath } = cmd else { + return Ok(()); + }; + + // 使用聊天ID作为会话ID + let session_id = &inbound.chat_id; + + // 调用公共函数 + let result = save_session_to_file( + session_id, + filepath, + &*self.store, + &self.provider_config, + ) + .await; + + // 结果通过返回 Ok(()) 表示成功 + // 实际输出由调用者通过消息总线发送 + match result { + Ok(output_path) => { + tracing::info!("Session saved to: {}", output_path.display()); + Ok(()) + } + Err(error) => { + tracing::error!("Failed to save session: {}", error); + Ok(()) + } + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 99a13f5..be137de 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -99,8 +99,15 @@ impl GatewayState { let semaphore = Arc::new(Semaphore::new(max_concurrent)); // Spawn inbound processor with semaphore-controlled concurrency + let provider_config = match self.config.get_provider_config("default") { + Ok(config) => config, + Err(e) => { + tracing::error!(error = %e, "Failed to get provider config"); + return; + } + }; let inbound_processor = - InboundProcessor::new(self.bus.clone(), self.session_manager.clone(), semaphore); + InboundProcessor::new(self.bus.clone(), self.session_manager.clone(), semaphore, provider_config); tokio::spawn(inbound_processor.run()); // Spawn outbound dispatcher diff --git a/src/gateway/processor.rs b/src/gateway/processor.rs index b468194..c2df86d 100644 --- a/src/gateway/processor.rs +++ b/src/gateway/processor.rs @@ -4,6 +4,9 @@ use tokio::sync::Semaphore; use crate::agent::AgentError; use crate::bus::{InboundMessage, MessageBus, OutboundMessage}; +use crate::command::handler::{InChatCommandHandler, InChatCommandRouter}; +use crate::command::Command; +use crate::config::LLMProviderConfig; use super::session::{BusToolCallEmitter, SessionManager}; @@ -12,6 +15,8 @@ pub struct InboundProcessor { bus: Arc, session_manager: SessionManager, semaphore: Arc, + provider_config: LLMProviderConfig, + command_router: Arc, } impl InboundProcessor { @@ -19,11 +24,24 @@ impl InboundProcessor { bus: Arc, session_manager: SessionManager, semaphore: Arc, + provider_config: LLMProviderConfig, ) -> Self { + // 创建命令路由器并注册处理器 + let mut command_router = InChatCommandRouter::new(); + + // 注册 save_session 处理器 + let store = session_manager.store(); + command_router.register(Box::new(crate::command::handlers::save_session::SaveSessionInChatHandler::new( + store, + provider_config.clone(), + ))); + Self { bus, session_manager, semaphore, + provider_config, + command_router: Arc::new(command_router), } } @@ -73,6 +91,16 @@ impl InboundProcessor { } async fn process_one(&self, inbound: InboundMessage) -> Result<(), AgentError> { + // 尝试解析为命令 + if let Some(cmd) = parse_in_chat_command(&inbound.content) { + // 使用命令路由器处理 + let handled = self.command_router.dispatch(cmd, &inbound).await?; + if handled { + return Ok(()); + } + } + + // 普通消息进入 AgentLoop let live_emitter = Arc::new(BusToolCallEmitter::new( self.bus.clone(), inbound.channel.clone(), @@ -124,3 +152,26 @@ impl InboundProcessor { Ok(()) } } + +/// 解析聊天中的命令 +/// +/// 支持格式: +/// - `/save [filepath]` - 保存会话 +/// +/// 返回 Some(Command) 如果是命令 +/// 返回 None 如果不是命令 +fn parse_in_chat_command(content: &str) -> Option { + let trimmed = content.trim(); + + if trimmed.starts_with("/save") { + let path = trimmed[5..].trim(); + let filepath = if path.is_empty() { + None + } else { + Some(path.to_string()) + }; + Some(Command::SaveSession { filepath }) + } else { + None + } +}