feat: 更新 save_session 处理器,使用组合系统提示词提供者;移除 LLMProviderConfig 依赖
This commit is contained in:
parent
bad36aa412
commit
23b7497b12
@ -4,8 +4,6 @@ use crate::command::context::CommandContext;
|
||||
use crate::command::handler::{CommandHandler, InChatCommandHandler};
|
||||
use crate::command::response::{CommandError, CommandResponse, MessageKind};
|
||||
use crate::command::Command;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::gateway::agent_prompt_provider::SimpleAgentPromptProvider;
|
||||
use crate::storage::{SessionRecord, SessionStore};
|
||||
use crate::agent::AgentError;
|
||||
use async_trait::async_trait;
|
||||
@ -29,7 +27,7 @@ pub async fn save_session_to_file(
|
||||
filepath: Option<String>,
|
||||
include_all: bool,
|
||||
store: &SessionStore,
|
||||
provider_config: &LLMProviderConfig,
|
||||
system_prompt_provider: &dyn SystemPromptProvider,
|
||||
) -> Result<PathBuf, String> {
|
||||
// 获取会话记录
|
||||
let record = store
|
||||
@ -51,8 +49,8 @@ pub async fn save_session_to_file(
|
||||
// 计算用户消息数(用于系统提示词构建)
|
||||
let user_message_count = messages.iter().filter(|m| m.role == "user").count();
|
||||
|
||||
// 构建系统提示词
|
||||
let system_prompt = build_system_prompt(provider_config, &record, user_message_count);
|
||||
// 构建系统提示词(使用外部传入的提供者)
|
||||
let system_prompt = build_system_prompt(system_prompt_provider, &record, user_message_count);
|
||||
|
||||
// 生成 Markdown 内容
|
||||
let markdown = generate_markdown(&record, &system_prompt, &messages);
|
||||
@ -80,7 +78,7 @@ pub async fn save_session_to_file(
|
||||
/// 将当前会话内容(系统提示词和消息历史)保存到 Markdown 文件
|
||||
pub struct SaveSessionCommandHandler {
|
||||
store: Arc<SessionStore>,
|
||||
provider_config: LLMProviderConfig,
|
||||
system_prompt_provider: Arc<dyn SystemPromptProvider>,
|
||||
}
|
||||
|
||||
impl SaveSessionCommandHandler {
|
||||
@ -88,11 +86,11 @@ impl SaveSessionCommandHandler {
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `store` - 会话存储
|
||||
/// * `provider_config` - LLM 提供者配置(用于构建系统提示词)
|
||||
pub fn new(store: Arc<SessionStore>, provider_config: LLMProviderConfig) -> Self {
|
||||
/// * `system_prompt_provider` - 系统提示词提供者(负责构建完整的系统提示词)
|
||||
pub fn new(store: Arc<SessionStore>, system_prompt_provider: Arc<dyn SystemPromptProvider>) -> Self {
|
||||
Self {
|
||||
store,
|
||||
provider_config,
|
||||
system_prompt_provider,
|
||||
}
|
||||
}
|
||||
|
||||
@ -141,7 +139,7 @@ async fn handle_save_session(
|
||||
filepath,
|
||||
include_all,
|
||||
&*handler.store,
|
||||
&handler.provider_config,
|
||||
&*handler.system_prompt_provider,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| CommandError::new("SAVE_ERROR", e))?;
|
||||
@ -170,11 +168,10 @@ async fn handle_save_session(
|
||||
|
||||
/// 构建系统提示词
|
||||
fn build_system_prompt(
|
||||
provider_config: &LLMProviderConfig,
|
||||
provider: &dyn SystemPromptProvider,
|
||||
record: &SessionRecord,
|
||||
user_message_count: usize,
|
||||
) -> Option<SystemPrompt> {
|
||||
let provider = SimpleAgentPromptProvider::new(provider_config.clone());
|
||||
let context = SystemPromptContext {
|
||||
session_id: Some(record.id.clone()),
|
||||
chat_id: record.chat_id.clone(),
|
||||
@ -373,15 +370,15 @@ pub fn resolve_filepath(filepath: Option<String>, record: &SessionRecord) -> Pat
|
||||
/// 用于处理 Feishu/WeChat 等通道中直接输入的 /save 命令
|
||||
pub struct SaveSessionInChatHandler {
|
||||
store: Arc<SessionStore>,
|
||||
provider_config: LLMProviderConfig,
|
||||
system_prompt_provider: Arc<dyn SystemPromptProvider>,
|
||||
}
|
||||
|
||||
impl SaveSessionInChatHandler {
|
||||
/// 创建新的 InChat 保存会话命令处理器
|
||||
pub fn new(store: Arc<SessionStore>, provider_config: LLMProviderConfig) -> Self {
|
||||
pub fn new(store: Arc<SessionStore>, system_prompt_provider: Arc<dyn SystemPromptProvider>) -> Self {
|
||||
Self {
|
||||
store,
|
||||
provider_config,
|
||||
system_prompt_provider,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -420,7 +417,7 @@ impl InChatCommandHandler for SaveSessionInChatHandler {
|
||||
filepath,
|
||||
include_all,
|
||||
&*self.store,
|
||||
&self.provider_config,
|
||||
&*self.system_prompt_provider,
|
||||
)
|
||||
.await;
|
||||
|
||||
@ -444,27 +441,6 @@ impl InChatCommandHandler for SaveSessionInChatHandler {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::storage::{SessionRecord, SessionStore};
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn test_config() -> LLMProviderConfig {
|
||||
LLMProviderConfig {
|
||||
provider_type: "openai".to_string(),
|
||||
name: "test".to_string(),
|
||||
base_url: "http://localhost".to_string(),
|
||||
api_key: "test-key".to_string(),
|
||||
extra_headers: HashMap::new(),
|
||||
llm_timeout_secs: 120,
|
||||
memory_maintenance_timeout_secs: 600,
|
||||
model_id: "test-model".to_string(),
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(32),
|
||||
context_window_tokens: None,
|
||||
tool_result_max_chars: 20_000,
|
||||
context_tool_result_trim_chars: 20_000,
|
||||
model_extra: HashMap::new(),
|
||||
max_tool_iterations: 1,
|
||||
}
|
||||
}
|
||||
|
||||
fn create_test_record(id: &str, title: &str) -> SessionRecord {
|
||||
SessionRecord {
|
||||
@ -547,10 +523,23 @@ mod tests {
|
||||
#[test]
|
||||
fn test_can_handle() {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let handler = SaveSessionCommandHandler::new(store, test_config());
|
||||
let provider = Arc::new(TestSystemPromptProvider);
|
||||
let handler = SaveSessionCommandHandler::new(store, provider);
|
||||
|
||||
assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: false }));
|
||||
assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: true }));
|
||||
assert!(!handler.can_handle(&Command::CreateSession { title: None }));
|
||||
}
|
||||
|
||||
/// 测试用的系统提示词提供者
|
||||
struct TestSystemPromptProvider;
|
||||
|
||||
impl SystemPromptProvider for TestSystemPromptProvider {
|
||||
fn build(&self, _context: &SystemPromptContext) -> Option<SystemPrompt> {
|
||||
Some(SystemPrompt {
|
||||
content: "Test system prompt".to_string(),
|
||||
context: Some("test".to_string()),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -2,11 +2,13 @@ use std::sync::Arc;
|
||||
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
use crate::agent::AgentError;
|
||||
use crate::agent::{AgentError, CompositeSystemPromptProvider};
|
||||
use crate::bus::{InboundMessage, MessageBus, OutboundMessage};
|
||||
use crate::command::handler::InChatCommandRouter;
|
||||
use crate::command::Command;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::gateway::agent_prompt_provider::AgentPromptProvider;
|
||||
use crate::skills::SkillPromptProvider;
|
||||
|
||||
use super::session::{BusToolCallEmitter, SessionManager};
|
||||
|
||||
@ -31,9 +33,19 @@ impl InboundProcessor {
|
||||
|
||||
// 注册 save_session 处理器
|
||||
let store = session_manager.store();
|
||||
let skills = session_manager.skills();
|
||||
let prompt_repository = session_manager.store().clone();
|
||||
let system_prompt_provider: Arc<dyn crate::agent::SystemPromptProvider> = Arc::new(CompositeSystemPromptProvider::new(vec![
|
||||
Box::new(AgentPromptProvider::new(
|
||||
0, // save_session 不需要 reinject 逻辑
|
||||
provider_config.clone(),
|
||||
prompt_repository,
|
||||
)),
|
||||
Box::new(SkillPromptProvider::new(skills)),
|
||||
]));
|
||||
command_router.register(Box::new(crate::command::handlers::save_session::SaveSessionInChatHandler::new(
|
||||
store,
|
||||
provider_config.clone(),
|
||||
system_prompt_provider,
|
||||
)));
|
||||
|
||||
Self {
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
use super::GatewayState;
|
||||
use crate::agent::AgentError;
|
||||
use crate::agent::{AgentError, CompositeSystemPromptProvider};
|
||||
use crate::bus::InboundMessage;
|
||||
use crate::command::adapter::OutputAdapter;
|
||||
use crate::command::adapters::websocket::{WebSocketInputAdapter, WebSocketOutputAdapter};
|
||||
@ -7,7 +7,9 @@ use crate::command::context::CommandContext;
|
||||
use crate::command::handler::CommandRouter;
|
||||
use crate::command::handlers::save_session::SaveSessionCommandHandler;
|
||||
use crate::command::handlers::session::SessionCommandHandler;
|
||||
use crate::gateway::agent_prompt_provider::AgentPromptProvider;
|
||||
use crate::protocol::{SessionSummary, WsInbound, WsOutbound, parse_inbound, serialize_outbound};
|
||||
use crate::skills::SkillPromptProvider;
|
||||
use axum::extract::State;
|
||||
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
|
||||
use axum::response::Response;
|
||||
@ -354,11 +356,23 @@ async fn handle_inbound(
|
||||
|
||||
// 获取所需依赖
|
||||
let store = state.session_manager.store();
|
||||
let skills = state.session_manager.skills();
|
||||
let provider_config = state.config.get_provider_config("default")
|
||||
.map_err(|e| AgentError::Other(e.to_string()))?;
|
||||
let prompt_repository = state.session_manager.store().clone();
|
||||
|
||||
// 构建组合系统提示词提供者(与运行时一致)
|
||||
let system_prompt_provider: Arc<dyn crate::agent::SystemPromptProvider> = Arc::new(CompositeSystemPromptProvider::new(vec![
|
||||
Box::new(AgentPromptProvider::new(
|
||||
0, // save_session 不需要 reinject 逻辑
|
||||
provider_config.clone(),
|
||||
prompt_repository,
|
||||
)),
|
||||
Box::new(SkillPromptProvider::new(skills)),
|
||||
]));
|
||||
|
||||
// 构建处理器
|
||||
let handler = SaveSessionCommandHandler::new(store, provider_config);
|
||||
let handler = SaveSessionCommandHandler::new(store, system_prompt_provider);
|
||||
let router = {
|
||||
let mut r = CommandRouter::new();
|
||||
r.register(Box::new(handler));
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user