feat: 重构会话管理逻辑,添加多个服务以优化会话和任务调度
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
acc8f63da0
commit
e5e2b37246
@ -27,19 +27,6 @@ pub(crate) fn compose_scheduled_task_system_prompt(system_prompt: Option<&str>)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn select_provider_config(
|
||||
default_provider_config: &LLMProviderConfig,
|
||||
provider_configs: &HashMap<String, LLMProviderConfig>,
|
||||
agent_name: Option<&str>,
|
||||
) -> Result<LLMProviderConfig, AgentError> {
|
||||
match agent_name.map(str::trim).filter(|value| !value.is_empty()) {
|
||||
None | Some("default") => Ok(default_provider_config.clone()),
|
||||
Some(agent_name) => provider_configs.get(agent_name).cloned().ok_or_else(|| {
|
||||
AgentError::Other(format!("Scheduled agent '{}' not found", agent_name))
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct AgentExecutionService {
|
||||
show_tool_results: bool,
|
||||
}
|
||||
@ -325,50 +312,6 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::bus::ChatMessage;
|
||||
|
||||
fn test_provider_config_named(name: &str, model_id: &str) -> LLMProviderConfig {
|
||||
LLMProviderConfig {
|
||||
provider_type: "openai".to_string(),
|
||||
name: name.to_string(),
|
||||
base_url: "http://localhost".to_string(),
|
||||
api_key: "test-key".to_string(),
|
||||
extra_headers: HashMap::new(),
|
||||
llm_timeout_secs: 120,
|
||||
model_id: model_id.to_string(),
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(32),
|
||||
context_window_tokens: None,
|
||||
model_extra: HashMap::new(),
|
||||
max_tool_iterations: 1,
|
||||
tool_result_max_chars: 20_000,
|
||||
context_tool_result_trim_chars: 20_000,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_provider_config_uses_named_agent_override() {
|
||||
let default_provider = test_provider_config_named("default-provider", "default-model");
|
||||
let provider_configs = HashMap::from([(
|
||||
"planner".to_string(),
|
||||
test_provider_config_named("planner-provider", "planner-model"),
|
||||
)]);
|
||||
|
||||
let selected =
|
||||
select_provider_config(&default_provider, &provider_configs, Some("planner")).unwrap();
|
||||
assert_eq!(selected.name, "planner-provider");
|
||||
assert_eq!(selected.model_id, "planner-model");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_provider_config_falls_back_to_default() {
|
||||
let default_provider = test_provider_config_named("default-provider", "default-model");
|
||||
let provider_configs = HashMap::new();
|
||||
|
||||
let selected =
|
||||
select_provider_config(&default_provider, &provider_configs, Some("default")).unwrap();
|
||||
assert_eq!(selected.name, "default-provider");
|
||||
assert_eq!(selected.model_id, "default-model");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compose_scheduled_task_system_prompt_appends_task_specific_prompt() {
|
||||
let prompt = compose_scheduled_task_system_prompt(Some(" 只汇报异常 "));
|
||||
|
||||
46
src/gateway/memory_maintenance_coordinator.rs
Normal file
46
src/gateway/memory_maintenance_coordinator.rs
Normal file
@ -0,0 +1,46 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::agent::AgentError;
|
||||
use crate::storage::SessionStore;
|
||||
|
||||
use super::memory_maintenance::{
|
||||
MemoryMaintenanceModelOutput, MemoryMaintenanceScopeResult, MemoryMaintenanceService,
|
||||
};
|
||||
use super::provider_config_service::ProviderConfigService;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct MemoryMaintenanceCoordinator {
|
||||
store: Arc<SessionStore>,
|
||||
provider_configs: ProviderConfigService,
|
||||
}
|
||||
|
||||
impl MemoryMaintenanceCoordinator {
|
||||
pub(crate) fn new(store: Arc<SessionStore>, provider_configs: ProviderConfigService) -> Self {
|
||||
Self {
|
||||
store,
|
||||
provider_configs,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(not(test), allow(dead_code))]
|
||||
pub(crate) async fn summarize_for_scope(
|
||||
&self,
|
||||
scope_key: &str,
|
||||
) -> Result<Option<MemoryMaintenanceModelOutput>, AgentError> {
|
||||
self.service()?.summarize_for_scope(scope_key).await
|
||||
}
|
||||
|
||||
pub(crate) async fn run_for_all_scopes(
|
||||
&self,
|
||||
updated_since: Option<i64>,
|
||||
) -> Result<Vec<MemoryMaintenanceScopeResult>, AgentError> {
|
||||
self.service()?.run_for_all_scopes(updated_since).await
|
||||
}
|
||||
|
||||
fn service(&self) -> Result<MemoryMaintenanceService, AgentError> {
|
||||
Ok(MemoryMaintenanceService::new(
|
||||
self.store.clone(),
|
||||
self.provider_configs.default_provider_config(),
|
||||
))
|
||||
}
|
||||
}
|
||||
@ -6,13 +6,17 @@ pub mod compaction;
|
||||
pub mod execution;
|
||||
pub mod http;
|
||||
pub mod memory_maintenance;
|
||||
pub mod memory_maintenance_coordinator;
|
||||
pub mod message_prepare;
|
||||
pub mod processor;
|
||||
pub mod prompt;
|
||||
pub mod prompt_injector;
|
||||
pub mod provider_config_service;
|
||||
pub mod scheduled_agent_task_service;
|
||||
pub mod session;
|
||||
pub mod session_factory;
|
||||
pub mod session_lifecycle;
|
||||
pub mod session_message_service;
|
||||
pub mod session_pool;
|
||||
pub mod tool_registry_factory;
|
||||
pub mod ws;
|
||||
|
||||
90
src/gateway/provider_config_service.rs
Normal file
90
src/gateway/provider_config_service.rs
Normal file
@ -0,0 +1,90 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::agent::AgentError;
|
||||
use crate::config::LLMProviderConfig;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ProviderConfigService {
|
||||
default_provider_config: LLMProviderConfig,
|
||||
provider_configs: Arc<HashMap<String, LLMProviderConfig>>,
|
||||
}
|
||||
|
||||
impl ProviderConfigService {
|
||||
pub(crate) fn new(
|
||||
default_provider_config: LLMProviderConfig,
|
||||
provider_configs: HashMap<String, LLMProviderConfig>,
|
||||
) -> Self {
|
||||
Self {
|
||||
default_provider_config,
|
||||
provider_configs: Arc::new(provider_configs),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn select(&self, agent_name: Option<&str>) -> Result<LLMProviderConfig, AgentError> {
|
||||
match agent_name.map(str::trim).filter(|value| !value.is_empty()) {
|
||||
None | Some("default") => Ok(self.default_provider_config.clone()),
|
||||
Some(agent_name) => self
|
||||
.provider_configs
|
||||
.get(agent_name)
|
||||
.cloned()
|
||||
.ok_or_else(|| {
|
||||
AgentError::Other(format!("Scheduled agent '{}' not found", agent_name))
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn default_provider_config(&self) -> LLMProviderConfig {
|
||||
self.default_provider_config.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_provider_config_named(name: &str, model_id: &str) -> LLMProviderConfig {
|
||||
LLMProviderConfig {
|
||||
provider_type: "openai".to_string(),
|
||||
name: name.to_string(),
|
||||
base_url: "http://localhost".to_string(),
|
||||
api_key: "test-key".to_string(),
|
||||
extra_headers: HashMap::new(),
|
||||
llm_timeout_secs: 120,
|
||||
model_id: model_id.to_string(),
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(32),
|
||||
context_window_tokens: None,
|
||||
model_extra: HashMap::new(),
|
||||
max_tool_iterations: 1,
|
||||
tool_result_max_chars: 20_000,
|
||||
context_tool_result_trim_chars: 20_000,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_uses_named_agent_override() {
|
||||
let default_provider = test_provider_config_named("default-provider", "default-model");
|
||||
let service = ProviderConfigService::new(
|
||||
default_provider,
|
||||
HashMap::from([(
|
||||
"planner".to_string(),
|
||||
test_provider_config_named("planner-provider", "planner-model"),
|
||||
)]),
|
||||
);
|
||||
|
||||
let selected = service.select(Some("planner")).unwrap();
|
||||
assert_eq!(selected.name, "planner-provider");
|
||||
assert_eq!(selected.model_id, "planner-model");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_falls_back_to_default() {
|
||||
let default_provider = test_provider_config_named("default-provider", "default-model");
|
||||
let service = ProviderConfigService::new(default_provider, HashMap::new());
|
||||
|
||||
let selected = service.select(Some("default")).unwrap();
|
||||
assert_eq!(selected.name, "default-provider");
|
||||
assert_eq!(selected.model_id, "default-model");
|
||||
}
|
||||
}
|
||||
57
src/gateway/scheduled_agent_task_service.rs
Normal file
57
src/gateway/scheduled_agent_task_service.rs
Normal file
@ -0,0 +1,57 @@
|
||||
use crate::agent::AgentError;
|
||||
use crate::bus::OutboundMessage;
|
||||
use crate::scheduler::ScheduledAgentTaskOptions;
|
||||
|
||||
use super::execution::{AgentExecutionService, ScheduledExecutionRequest};
|
||||
use super::provider_config_service::ProviderConfigService;
|
||||
use super::session_lifecycle::SessionLifecycleService;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ScheduledAgentTaskService {
|
||||
lifecycle: SessionLifecycleService,
|
||||
provider_configs: ProviderConfigService,
|
||||
show_tool_results: bool,
|
||||
}
|
||||
|
||||
impl ScheduledAgentTaskService {
|
||||
pub(crate) fn new(
|
||||
lifecycle: SessionLifecycleService,
|
||||
provider_configs: ProviderConfigService,
|
||||
show_tool_results: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
lifecycle,
|
||||
provider_configs,
|
||||
show_tool_results,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn run(
|
||||
&self,
|
||||
channel_name: &str,
|
||||
chat_id: &str,
|
||||
prompt: &str,
|
||||
options: ScheduledAgentTaskOptions,
|
||||
) -> Result<Vec<OutboundMessage>, AgentError> {
|
||||
let session = self.lifecycle.active_session(channel_name).await?;
|
||||
let sender_id = options
|
||||
.sender_id
|
||||
.clone()
|
||||
.unwrap_or_else(|| "scheduler".to_string());
|
||||
let provider_config = self.provider_configs.select(options.agent.as_deref())?;
|
||||
|
||||
AgentExecutionService::new(self.show_tool_results)
|
||||
.prepare_and_execute_scheduled_task(ScheduledExecutionRequest {
|
||||
session,
|
||||
channel_name,
|
||||
chat_id,
|
||||
prompt,
|
||||
sender_id: &sender_id,
|
||||
provider_config,
|
||||
fresh_session: options.fresh_session,
|
||||
system_prompt: options.system_prompt.as_deref(),
|
||||
metadata: &options.metadata,
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
@ -16,22 +16,21 @@ use uuid::Uuid;
|
||||
|
||||
use super::agent_factory::{AgentBuildRequest, AgentFactory};
|
||||
use super::cli_session::CliSessionService;
|
||||
use super::execution::{
|
||||
AgentExecutionService, MessageExecutionRequest, ScheduledExecutionRequest,
|
||||
select_provider_config, should_display_message_to_user,
|
||||
};
|
||||
use super::execution::should_display_message_to_user;
|
||||
#[cfg(test)]
|
||||
use super::memory_maintenance::{
|
||||
MemoryMaintenanceMerge, apply_memory_maintenance_output, build_memory_maintenance_plan,
|
||||
combine_managed_memory_markdown, extract_json_object, is_recoverable_maintenance_llm_error,
|
||||
strip_json_code_fence,
|
||||
};
|
||||
use super::memory_maintenance::{
|
||||
MemoryMaintenanceModelOutput, MemoryMaintenanceScopeResult, MemoryMaintenanceService,
|
||||
};
|
||||
use super::memory_maintenance::{MemoryMaintenanceModelOutput, MemoryMaintenanceScopeResult};
|
||||
use super::memory_maintenance_coordinator::MemoryMaintenanceCoordinator;
|
||||
use super::prompt_injector::PromptInjector;
|
||||
use super::provider_config_service::ProviderConfigService;
|
||||
use super::scheduled_agent_task_service::ScheduledAgentTaskService;
|
||||
use super::session_factory::SessionFactory;
|
||||
use super::session_lifecycle::SessionLifecycleService;
|
||||
use super::session_message_service::SessionMessageService;
|
||||
use super::tool_registry_factory::ToolRegistryFactory;
|
||||
|
||||
fn preview_text(content: &str, max_chars: usize) -> String {
|
||||
@ -449,14 +448,15 @@ impl Session {
|
||||
/// SessionManager 管理所有 Session,按 channel_name 路由
|
||||
#[derive(Clone)]
|
||||
pub struct SessionManager {
|
||||
provider_config: LLMProviderConfig,
|
||||
provider_configs: Arc<HashMap<String, LLMProviderConfig>>,
|
||||
tools: Arc<ToolRegistry>,
|
||||
skills: Arc<SkillRuntime>,
|
||||
store: Arc<SessionStore>,
|
||||
show_tool_results: bool,
|
||||
lifecycle: SessionLifecycleService,
|
||||
cli_sessions: CliSessionService,
|
||||
messages: SessionMessageService,
|
||||
scheduled_tasks: ScheduledAgentTaskService,
|
||||
memory_maintenance: MemoryMaintenanceCoordinator,
|
||||
}
|
||||
|
||||
impl SessionManager {
|
||||
@ -474,6 +474,8 @@ impl SessionManager {
|
||||
.map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?,
|
||||
);
|
||||
let known_agents = provider_configs.keys().cloned().collect::<HashSet<_>>();
|
||||
let provider_configs =
|
||||
ProviderConfigService::new(provider_config.clone(), provider_configs);
|
||||
|
||||
if let Err(err) =
|
||||
store.append_skill_event(None, "discovered", None, &skills.discovery_event_payload())
|
||||
@ -501,16 +503,25 @@ impl SessionManager {
|
||||
);
|
||||
let lifecycle = SessionLifecycleService::new(session_ttl_hours, session_factory);
|
||||
let cli_sessions = CliSessionService::new(store.clone());
|
||||
let messages = SessionMessageService::new(lifecycle.clone(), show_tool_results);
|
||||
let scheduled_tasks = ScheduledAgentTaskService::new(
|
||||
lifecycle.clone(),
|
||||
provider_configs.clone(),
|
||||
show_tool_results,
|
||||
);
|
||||
let memory_maintenance =
|
||||
MemoryMaintenanceCoordinator::new(store.clone(), provider_configs.clone());
|
||||
|
||||
Ok(Self {
|
||||
provider_config,
|
||||
provider_configs: Arc::new(provider_configs),
|
||||
tools,
|
||||
skills,
|
||||
store,
|
||||
show_tool_results,
|
||||
lifecycle,
|
||||
cli_sessions,
|
||||
messages,
|
||||
scheduled_tasks,
|
||||
memory_maintenance,
|
||||
})
|
||||
}
|
||||
|
||||
@ -539,34 +550,18 @@ impl SessionManager {
|
||||
&self,
|
||||
scope_key: &str,
|
||||
) -> Result<Option<MemoryMaintenanceModelOutput>, AgentError> {
|
||||
self.memory_maintenance_service()?
|
||||
.summarize_for_scope(scope_key)
|
||||
.await
|
||||
self.memory_maintenance.summarize_for_scope(scope_key).await
|
||||
}
|
||||
|
||||
pub(crate) async fn run_memory_maintenance_for_all_scopes(
|
||||
&self,
|
||||
updated_since: Option<i64>,
|
||||
) -> Result<Vec<MemoryMaintenanceScopeResult>, AgentError> {
|
||||
self.memory_maintenance_service()?
|
||||
self.memory_maintenance
|
||||
.run_for_all_scopes(updated_since)
|
||||
.await
|
||||
}
|
||||
|
||||
fn memory_maintenance_service(&self) -> Result<MemoryMaintenanceService, AgentError> {
|
||||
Ok(MemoryMaintenanceService::new(
|
||||
self.store.clone(),
|
||||
self.provider_config_for_agent(None)?,
|
||||
))
|
||||
}
|
||||
|
||||
pub fn provider_config_for_agent(
|
||||
&self,
|
||||
agent_name: Option<&str>,
|
||||
) -> Result<LLMProviderConfig, AgentError> {
|
||||
select_provider_config(&self.provider_config, &self.provider_configs, agent_name)
|
||||
}
|
||||
|
||||
/// 确保 session 存在且未超时,超时则重建
|
||||
pub async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> {
|
||||
self.lifecycle.ensure_session(channel_name).await
|
||||
@ -596,43 +591,16 @@ impl SessionManager {
|
||||
media: Vec<crate::bus::MediaItem>,
|
||||
live_emitter: Option<Arc<dyn EmittedMessageHandler>>,
|
||||
) -> Result<Vec<OutboundMessage>, AgentError> {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
tracing::debug!(
|
||||
channel = %channel_name,
|
||||
chat_id = %chat_id,
|
||||
content_len = content.len(),
|
||||
media_count = %media.len(),
|
||||
"Routing message to agent"
|
||||
);
|
||||
for (i, m) in media.iter().enumerate() {
|
||||
tracing::debug!(media_index = i, media_type = %m.media_type, path = %m.path, "Media in handle_message");
|
||||
}
|
||||
}
|
||||
|
||||
let session = self.lifecycle.active_session(channel_name).await?;
|
||||
|
||||
let outbound_messages = AgentExecutionService::new(self.show_tool_results)
|
||||
.prepare_and_execute_message(MessageExecutionRequest {
|
||||
session: session.clone(),
|
||||
self.messages
|
||||
.handle_message(
|
||||
channel_name,
|
||||
sender_id,
|
||||
chat_id,
|
||||
content,
|
||||
media,
|
||||
live_emitter,
|
||||
})
|
||||
.await?;
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(
|
||||
channel = %channel_name,
|
||||
chat_id = %chat_id,
|
||||
outbound_count = outbound_messages.len(),
|
||||
"Agent response sequence received"
|
||||
);
|
||||
|
||||
Ok(outbound_messages)
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn run_scheduled_agent_task(
|
||||
@ -642,26 +610,8 @@ impl SessionManager {
|
||||
prompt: &str,
|
||||
options: ScheduledAgentTaskOptions,
|
||||
) -> Result<Vec<OutboundMessage>, AgentError> {
|
||||
let session = self.lifecycle.active_session(channel_name).await?;
|
||||
|
||||
let sender_id = options
|
||||
.sender_id
|
||||
.clone()
|
||||
.unwrap_or_else(|| "scheduler".to_string());
|
||||
let provider_config = self.provider_config_for_agent(options.agent.as_deref())?;
|
||||
|
||||
AgentExecutionService::new(self.show_tool_results)
|
||||
.prepare_and_execute_scheduled_task(ScheduledExecutionRequest {
|
||||
session: session.clone(),
|
||||
channel_name,
|
||||
chat_id,
|
||||
prompt,
|
||||
sender_id: &sender_id,
|
||||
provider_config,
|
||||
fresh_session: options.fresh_session,
|
||||
system_prompt: options.system_prompt.as_deref(),
|
||||
metadata: &options.metadata,
|
||||
})
|
||||
self.scheduled_tasks
|
||||
.run(channel_name, chat_id, prompt, options)
|
||||
.await
|
||||
}
|
||||
|
||||
|
||||
69
src/gateway/session_message_service.rs
Normal file
69
src/gateway/session_message_service.rs
Normal file
@ -0,0 +1,69 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::agent::{AgentError, EmittedMessageHandler};
|
||||
use crate::bus::{MediaItem, OutboundMessage};
|
||||
|
||||
use super::execution::{AgentExecutionService, MessageExecutionRequest};
|
||||
use super::session_lifecycle::SessionLifecycleService;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct SessionMessageService {
|
||||
lifecycle: SessionLifecycleService,
|
||||
show_tool_results: bool,
|
||||
}
|
||||
|
||||
impl SessionMessageService {
|
||||
pub(crate) fn new(lifecycle: SessionLifecycleService, show_tool_results: bool) -> Self {
|
||||
Self {
|
||||
lifecycle,
|
||||
show_tool_results,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_message(
|
||||
&self,
|
||||
channel_name: &str,
|
||||
sender_id: &str,
|
||||
chat_id: &str,
|
||||
content: &str,
|
||||
media: Vec<MediaItem>,
|
||||
live_emitter: Option<Arc<dyn EmittedMessageHandler>>,
|
||||
) -> Result<Vec<OutboundMessage>, AgentError> {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
tracing::debug!(
|
||||
channel = %channel_name,
|
||||
chat_id = %chat_id,
|
||||
content_len = content.len(),
|
||||
media_count = %media.len(),
|
||||
"Routing message to agent"
|
||||
);
|
||||
for (i, m) in media.iter().enumerate() {
|
||||
tracing::debug!(media_index = i, media_type = %m.media_type, path = %m.path, "Media in handle_message");
|
||||
}
|
||||
}
|
||||
|
||||
let session = self.lifecycle.active_session(channel_name).await?;
|
||||
let outbound_messages = AgentExecutionService::new(self.show_tool_results)
|
||||
.prepare_and_execute_message(MessageExecutionRequest {
|
||||
session,
|
||||
channel_name,
|
||||
sender_id,
|
||||
chat_id,
|
||||
content,
|
||||
media,
|
||||
live_emitter,
|
||||
})
|
||||
.await?;
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(
|
||||
channel = %channel_name,
|
||||
chat_id = %chat_id,
|
||||
outbound_count = outbound_messages.len(),
|
||||
"Agent response sequence received"
|
||||
);
|
||||
|
||||
Ok(outbound_messages)
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user