From 23b7497b12b1a1c5e52b8c6cfc297b3d895dd077 Mon Sep 17 00:00:00 2001 From: oudecheng <13802883547@139.com> Date: Thu, 14 May 2026 16:07:49 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=9B=B4=E6=96=B0=20save=5Fsession=20?= =?UTF-8?q?=E5=A4=84=E7=90=86=E5=99=A8=EF=BC=8C=E4=BD=BF=E7=94=A8=E7=BB=84?= =?UTF-8?q?=E5=90=88=E7=B3=BB=E7=BB=9F=E6=8F=90=E7=A4=BA=E8=AF=8D=E6=8F=90?= =?UTF-8?q?=E4=BE=9B=E8=80=85=EF=BC=9B=E7=A7=BB=E9=99=A4=20LLMProviderConf?= =?UTF-8?q?ig=20=E4=BE=9D=E8=B5=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/command/handlers/save_session.rs | 65 ++++++++++++---------------- src/gateway/processor.rs | 16 ++++++- src/gateway/ws.rs | 18 +++++++- 3 files changed, 57 insertions(+), 42 deletions(-) diff --git a/src/command/handlers/save_session.rs b/src/command/handlers/save_session.rs index fcd4b19..75c2c88 100644 --- a/src/command/handlers/save_session.rs +++ b/src/command/handlers/save_session.rs @@ -4,8 +4,6 @@ use crate::command::context::CommandContext; use crate::command::handler::{CommandHandler, InChatCommandHandler}; use crate::command::response::{CommandError, CommandResponse, MessageKind}; use crate::command::Command; -use crate::config::LLMProviderConfig; -use crate::gateway::agent_prompt_provider::SimpleAgentPromptProvider; use crate::storage::{SessionRecord, SessionStore}; use crate::agent::AgentError; use async_trait::async_trait; @@ -29,7 +27,7 @@ pub async fn save_session_to_file( filepath: Option, include_all: bool, store: &SessionStore, - provider_config: &LLMProviderConfig, + system_prompt_provider: &dyn SystemPromptProvider, ) -> Result { // 获取会话记录 let record = store @@ -51,8 +49,8 @@ pub async fn save_session_to_file( // 计算用户消息数(用于系统提示词构建) let user_message_count = messages.iter().filter(|m| m.role == "user").count(); - // 构建系统提示词 - let system_prompt = build_system_prompt(provider_config, &record, user_message_count); + // 构建系统提示词(使用外部传入的提供者) + let system_prompt = build_system_prompt(system_prompt_provider, &record, user_message_count); // 生成 Markdown 内容 let markdown = generate_markdown(&record, &system_prompt, &messages); @@ -80,7 +78,7 @@ pub async fn save_session_to_file( /// 将当前会话内容(系统提示词和消息历史)保存到 Markdown 文件 pub struct SaveSessionCommandHandler { store: Arc, - provider_config: LLMProviderConfig, + system_prompt_provider: Arc, } impl SaveSessionCommandHandler { @@ -88,11 +86,11 @@ impl SaveSessionCommandHandler { /// /// # Arguments /// * `store` - 会话存储 - /// * `provider_config` - LLM 提供者配置(用于构建系统提示词) - pub fn new(store: Arc, provider_config: LLMProviderConfig) -> Self { + /// * `system_prompt_provider` - 系统提示词提供者(负责构建完整的系统提示词) + pub fn new(store: Arc, system_prompt_provider: Arc) -> Self { Self { store, - provider_config, + system_prompt_provider, } } @@ -141,7 +139,7 @@ async fn handle_save_session( filepath, include_all, &*handler.store, - &handler.provider_config, + &*handler.system_prompt_provider, ) .await .map_err(|e| CommandError::new("SAVE_ERROR", e))?; @@ -170,11 +168,10 @@ async fn handle_save_session( /// 构建系统提示词 fn build_system_prompt( - provider_config: &LLMProviderConfig, + provider: &dyn SystemPromptProvider, record: &SessionRecord, user_message_count: usize, ) -> Option { - let provider = SimpleAgentPromptProvider::new(provider_config.clone()); let context = SystemPromptContext { session_id: Some(record.id.clone()), chat_id: record.chat_id.clone(), @@ -373,15 +370,15 @@ pub fn resolve_filepath(filepath: Option, record: &SessionRecord) -> Pat /// 用于处理 Feishu/WeChat 等通道中直接输入的 /save 命令 pub struct SaveSessionInChatHandler { store: Arc, - provider_config: LLMProviderConfig, + system_prompt_provider: Arc, } impl SaveSessionInChatHandler { /// 创建新的 InChat 保存会话命令处理器 - pub fn new(store: Arc, provider_config: LLMProviderConfig) -> Self { + pub fn new(store: Arc, system_prompt_provider: Arc) -> Self { Self { store, - provider_config, + system_prompt_provider, } } } @@ -420,7 +417,7 @@ impl InChatCommandHandler for SaveSessionInChatHandler { filepath, include_all, &*self.store, - &self.provider_config, + &*self.system_prompt_provider, ) .await; @@ -444,27 +441,6 @@ impl InChatCommandHandler for SaveSessionInChatHandler { mod tests { use super::*; use crate::storage::{SessionRecord, SessionStore}; - use std::collections::HashMap; - - fn test_config() -> LLMProviderConfig { - LLMProviderConfig { - provider_type: "openai".to_string(), - name: "test".to_string(), - base_url: "http://localhost".to_string(), - api_key: "test-key".to_string(), - extra_headers: HashMap::new(), - llm_timeout_secs: 120, - memory_maintenance_timeout_secs: 600, - model_id: "test-model".to_string(), - temperature: Some(0.0), - max_tokens: Some(32), - context_window_tokens: None, - tool_result_max_chars: 20_000, - context_tool_result_trim_chars: 20_000, - model_extra: HashMap::new(), - max_tool_iterations: 1, - } - } fn create_test_record(id: &str, title: &str) -> SessionRecord { SessionRecord { @@ -547,10 +523,23 @@ mod tests { #[test] fn test_can_handle() { let store = Arc::new(SessionStore::in_memory().unwrap()); - let handler = SaveSessionCommandHandler::new(store, test_config()); + let provider = Arc::new(TestSystemPromptProvider); + let handler = SaveSessionCommandHandler::new(store, provider); assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: false })); assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: true })); assert!(!handler.can_handle(&Command::CreateSession { title: None })); } + + /// 测试用的系统提示词提供者 + struct TestSystemPromptProvider; + + impl SystemPromptProvider for TestSystemPromptProvider { + fn build(&self, _context: &SystemPromptContext) -> Option { + Some(SystemPrompt { + content: "Test system prompt".to_string(), + context: Some("test".to_string()), + }) + } + } } diff --git a/src/gateway/processor.rs b/src/gateway/processor.rs index 84cc7ba..5c21705 100644 --- a/src/gateway/processor.rs +++ b/src/gateway/processor.rs @@ -2,11 +2,13 @@ use std::sync::Arc; use tokio::sync::Semaphore; -use crate::agent::AgentError; +use crate::agent::{AgentError, CompositeSystemPromptProvider}; use crate::bus::{InboundMessage, MessageBus, OutboundMessage}; use crate::command::handler::InChatCommandRouter; use crate::command::Command; use crate::config::LLMProviderConfig; +use crate::gateway::agent_prompt_provider::AgentPromptProvider; +use crate::skills::SkillPromptProvider; use super::session::{BusToolCallEmitter, SessionManager}; @@ -31,9 +33,19 @@ impl InboundProcessor { // 注册 save_session 处理器 let store = session_manager.store(); + let skills = session_manager.skills(); + let prompt_repository = session_manager.store().clone(); + let system_prompt_provider: Arc = Arc::new(CompositeSystemPromptProvider::new(vec![ + Box::new(AgentPromptProvider::new( + 0, // save_session 不需要 reinject 逻辑 + provider_config.clone(), + prompt_repository, + )), + Box::new(SkillPromptProvider::new(skills)), + ])); command_router.register(Box::new(crate::command::handlers::save_session::SaveSessionInChatHandler::new( store, - provider_config.clone(), + system_prompt_provider, ))); Self { diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index 08b07bd..ee7249b 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -1,5 +1,5 @@ use super::GatewayState; -use crate::agent::AgentError; +use crate::agent::{AgentError, CompositeSystemPromptProvider}; use crate::bus::InboundMessage; use crate::command::adapter::OutputAdapter; use crate::command::adapters::websocket::{WebSocketInputAdapter, WebSocketOutputAdapter}; @@ -7,7 +7,9 @@ use crate::command::context::CommandContext; use crate::command::handler::CommandRouter; use crate::command::handlers::save_session::SaveSessionCommandHandler; use crate::command::handlers::session::SessionCommandHandler; +use crate::gateway::agent_prompt_provider::AgentPromptProvider; use crate::protocol::{SessionSummary, WsInbound, WsOutbound, parse_inbound, serialize_outbound}; +use crate::skills::SkillPromptProvider; use axum::extract::State; use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade}; use axum::response::Response; @@ -354,11 +356,23 @@ async fn handle_inbound( // 获取所需依赖 let store = state.session_manager.store(); + let skills = state.session_manager.skills(); let provider_config = state.config.get_provider_config("default") .map_err(|e| AgentError::Other(e.to_string()))?; + let prompt_repository = state.session_manager.store().clone(); + + // 构建组合系统提示词提供者(与运行时一致) + let system_prompt_provider: Arc = Arc::new(CompositeSystemPromptProvider::new(vec![ + Box::new(AgentPromptProvider::new( + 0, // save_session 不需要 reinject 逻辑 + provider_config.clone(), + prompt_repository, + )), + Box::new(SkillPromptProvider::new(skills)), + ])); // 构建处理器 - let handler = SaveSessionCommandHandler::new(store, provider_config); + let handler = SaveSessionCommandHandler::new(store, system_prompt_provider); let router = { let mut r = CommandRouter::new(); r.register(Box::new(handler));