feat: 重构会话管理逻辑,添加多个服务以优化会话和任务调度

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
ooodc 2026-04-28 14:43:46 +08:00
parent acc8f63da0
commit e5e2b37246
7 changed files with 296 additions and 137 deletions

View File

@ -27,19 +27,6 @@ pub(crate) fn compose_scheduled_task_system_prompt(system_prompt: Option<&str>)
}
}
pub(crate) fn select_provider_config(
default_provider_config: &LLMProviderConfig,
provider_configs: &HashMap<String, LLMProviderConfig>,
agent_name: Option<&str>,
) -> Result<LLMProviderConfig, AgentError> {
match agent_name.map(str::trim).filter(|value| !value.is_empty()) {
None | Some("default") => Ok(default_provider_config.clone()),
Some(agent_name) => provider_configs.get(agent_name).cloned().ok_or_else(|| {
AgentError::Other(format!("Scheduled agent '{}' not found", agent_name))
}),
}
}
pub(crate) struct AgentExecutionService {
show_tool_results: bool,
}
@ -325,50 +312,6 @@ mod tests {
use super::*;
use crate::bus::ChatMessage;
fn test_provider_config_named(name: &str, model_id: &str) -> LLMProviderConfig {
LLMProviderConfig {
provider_type: "openai".to_string(),
name: name.to_string(),
base_url: "http://localhost".to_string(),
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
llm_timeout_secs: 120,
model_id: model_id.to_string(),
temperature: Some(0.0),
max_tokens: Some(32),
context_window_tokens: None,
model_extra: HashMap::new(),
max_tool_iterations: 1,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
}
}
#[test]
fn test_select_provider_config_uses_named_agent_override() {
let default_provider = test_provider_config_named("default-provider", "default-model");
let provider_configs = HashMap::from([(
"planner".to_string(),
test_provider_config_named("planner-provider", "planner-model"),
)]);
let selected =
select_provider_config(&default_provider, &provider_configs, Some("planner")).unwrap();
assert_eq!(selected.name, "planner-provider");
assert_eq!(selected.model_id, "planner-model");
}
#[test]
fn test_select_provider_config_falls_back_to_default() {
let default_provider = test_provider_config_named("default-provider", "default-model");
let provider_configs = HashMap::new();
let selected =
select_provider_config(&default_provider, &provider_configs, Some("default")).unwrap();
assert_eq!(selected.name, "default-provider");
assert_eq!(selected.model_id, "default-model");
}
#[test]
fn test_compose_scheduled_task_system_prompt_appends_task_specific_prompt() {
let prompt = compose_scheduled_task_system_prompt(Some(" 只汇报异常 "));

View File

@ -0,0 +1,46 @@
use std::sync::Arc;
use crate::agent::AgentError;
use crate::storage::SessionStore;
use super::memory_maintenance::{
MemoryMaintenanceModelOutput, MemoryMaintenanceScopeResult, MemoryMaintenanceService,
};
use super::provider_config_service::ProviderConfigService;
#[derive(Clone)]
pub(crate) struct MemoryMaintenanceCoordinator {
store: Arc<SessionStore>,
provider_configs: ProviderConfigService,
}
impl MemoryMaintenanceCoordinator {
pub(crate) fn new(store: Arc<SessionStore>, provider_configs: ProviderConfigService) -> Self {
Self {
store,
provider_configs,
}
}
#[cfg_attr(not(test), allow(dead_code))]
pub(crate) async fn summarize_for_scope(
&self,
scope_key: &str,
) -> Result<Option<MemoryMaintenanceModelOutput>, AgentError> {
self.service()?.summarize_for_scope(scope_key).await
}
pub(crate) async fn run_for_all_scopes(
&self,
updated_since: Option<i64>,
) -> Result<Vec<MemoryMaintenanceScopeResult>, AgentError> {
self.service()?.run_for_all_scopes(updated_since).await
}
fn service(&self) -> Result<MemoryMaintenanceService, AgentError> {
Ok(MemoryMaintenanceService::new(
self.store.clone(),
self.provider_configs.default_provider_config(),
))
}
}

View File

@ -6,13 +6,17 @@ pub mod compaction;
pub mod execution;
pub mod http;
pub mod memory_maintenance;
pub mod memory_maintenance_coordinator;
pub mod message_prepare;
pub mod processor;
pub mod prompt;
pub mod prompt_injector;
pub mod provider_config_service;
pub mod scheduled_agent_task_service;
pub mod session;
pub mod session_factory;
pub mod session_lifecycle;
pub mod session_message_service;
pub mod session_pool;
pub mod tool_registry_factory;
pub mod ws;

View File

@ -0,0 +1,90 @@
use std::collections::HashMap;
use std::sync::Arc;
use crate::agent::AgentError;
use crate::config::LLMProviderConfig;
#[derive(Clone)]
pub(crate) struct ProviderConfigService {
default_provider_config: LLMProviderConfig,
provider_configs: Arc<HashMap<String, LLMProviderConfig>>,
}
impl ProviderConfigService {
pub(crate) fn new(
default_provider_config: LLMProviderConfig,
provider_configs: HashMap<String, LLMProviderConfig>,
) -> Self {
Self {
default_provider_config,
provider_configs: Arc::new(provider_configs),
}
}
pub(crate) fn select(&self, agent_name: Option<&str>) -> Result<LLMProviderConfig, AgentError> {
match agent_name.map(str::trim).filter(|value| !value.is_empty()) {
None | Some("default") => Ok(self.default_provider_config.clone()),
Some(agent_name) => self
.provider_configs
.get(agent_name)
.cloned()
.ok_or_else(|| {
AgentError::Other(format!("Scheduled agent '{}' not found", agent_name))
}),
}
}
pub(crate) fn default_provider_config(&self) -> LLMProviderConfig {
self.default_provider_config.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_provider_config_named(name: &str, model_id: &str) -> LLMProviderConfig {
LLMProviderConfig {
provider_type: "openai".to_string(),
name: name.to_string(),
base_url: "http://localhost".to_string(),
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
llm_timeout_secs: 120,
model_id: model_id.to_string(),
temperature: Some(0.0),
max_tokens: Some(32),
context_window_tokens: None,
model_extra: HashMap::new(),
max_tool_iterations: 1,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
}
}
#[test]
fn test_select_uses_named_agent_override() {
let default_provider = test_provider_config_named("default-provider", "default-model");
let service = ProviderConfigService::new(
default_provider,
HashMap::from([(
"planner".to_string(),
test_provider_config_named("planner-provider", "planner-model"),
)]),
);
let selected = service.select(Some("planner")).unwrap();
assert_eq!(selected.name, "planner-provider");
assert_eq!(selected.model_id, "planner-model");
}
#[test]
fn test_select_falls_back_to_default() {
let default_provider = test_provider_config_named("default-provider", "default-model");
let service = ProviderConfigService::new(default_provider, HashMap::new());
let selected = service.select(Some("default")).unwrap();
assert_eq!(selected.name, "default-provider");
assert_eq!(selected.model_id, "default-model");
}
}

View File

@ -0,0 +1,57 @@
use crate::agent::AgentError;
use crate::bus::OutboundMessage;
use crate::scheduler::ScheduledAgentTaskOptions;
use super::execution::{AgentExecutionService, ScheduledExecutionRequest};
use super::provider_config_service::ProviderConfigService;
use super::session_lifecycle::SessionLifecycleService;
#[derive(Clone)]
pub(crate) struct ScheduledAgentTaskService {
lifecycle: SessionLifecycleService,
provider_configs: ProviderConfigService,
show_tool_results: bool,
}
impl ScheduledAgentTaskService {
pub(crate) fn new(
lifecycle: SessionLifecycleService,
provider_configs: ProviderConfigService,
show_tool_results: bool,
) -> Self {
Self {
lifecycle,
provider_configs,
show_tool_results,
}
}
pub(crate) async fn run(
&self,
channel_name: &str,
chat_id: &str,
prompt: &str,
options: ScheduledAgentTaskOptions,
) -> Result<Vec<OutboundMessage>, AgentError> {
let session = self.lifecycle.active_session(channel_name).await?;
let sender_id = options
.sender_id
.clone()
.unwrap_or_else(|| "scheduler".to_string());
let provider_config = self.provider_configs.select(options.agent.as_deref())?;
AgentExecutionService::new(self.show_tool_results)
.prepare_and_execute_scheduled_task(ScheduledExecutionRequest {
session,
channel_name,
chat_id,
prompt,
sender_id: &sender_id,
provider_config,
fresh_session: options.fresh_session,
system_prompt: options.system_prompt.as_deref(),
metadata: &options.metadata,
})
.await
}
}

View File

@ -16,22 +16,21 @@ use uuid::Uuid;
use super::agent_factory::{AgentBuildRequest, AgentFactory};
use super::cli_session::CliSessionService;
use super::execution::{
AgentExecutionService, MessageExecutionRequest, ScheduledExecutionRequest,
select_provider_config, should_display_message_to_user,
};
use super::execution::should_display_message_to_user;
#[cfg(test)]
use super::memory_maintenance::{
MemoryMaintenanceMerge, apply_memory_maintenance_output, build_memory_maintenance_plan,
combine_managed_memory_markdown, extract_json_object, is_recoverable_maintenance_llm_error,
strip_json_code_fence,
};
use super::memory_maintenance::{
MemoryMaintenanceModelOutput, MemoryMaintenanceScopeResult, MemoryMaintenanceService,
};
use super::memory_maintenance::{MemoryMaintenanceModelOutput, MemoryMaintenanceScopeResult};
use super::memory_maintenance_coordinator::MemoryMaintenanceCoordinator;
use super::prompt_injector::PromptInjector;
use super::provider_config_service::ProviderConfigService;
use super::scheduled_agent_task_service::ScheduledAgentTaskService;
use super::session_factory::SessionFactory;
use super::session_lifecycle::SessionLifecycleService;
use super::session_message_service::SessionMessageService;
use super::tool_registry_factory::ToolRegistryFactory;
fn preview_text(content: &str, max_chars: usize) -> String {
@ -449,14 +448,15 @@ impl Session {
/// SessionManager 管理所有 Session按 channel_name 路由
#[derive(Clone)]
pub struct SessionManager {
provider_config: LLMProviderConfig,
provider_configs: Arc<HashMap<String, LLMProviderConfig>>,
tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>,
store: Arc<SessionStore>,
show_tool_results: bool,
lifecycle: SessionLifecycleService,
cli_sessions: CliSessionService,
messages: SessionMessageService,
scheduled_tasks: ScheduledAgentTaskService,
memory_maintenance: MemoryMaintenanceCoordinator,
}
impl SessionManager {
@ -474,6 +474,8 @@ impl SessionManager {
.map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?,
);
let known_agents = provider_configs.keys().cloned().collect::<HashSet<_>>();
let provider_configs =
ProviderConfigService::new(provider_config.clone(), provider_configs);
if let Err(err) =
store.append_skill_event(None, "discovered", None, &skills.discovery_event_payload())
@ -501,16 +503,25 @@ impl SessionManager {
);
let lifecycle = SessionLifecycleService::new(session_ttl_hours, session_factory);
let cli_sessions = CliSessionService::new(store.clone());
let messages = SessionMessageService::new(lifecycle.clone(), show_tool_results);
let scheduled_tasks = ScheduledAgentTaskService::new(
lifecycle.clone(),
provider_configs.clone(),
show_tool_results,
);
let memory_maintenance =
MemoryMaintenanceCoordinator::new(store.clone(), provider_configs.clone());
Ok(Self {
provider_config,
provider_configs: Arc::new(provider_configs),
tools,
skills,
store,
show_tool_results,
lifecycle,
cli_sessions,
messages,
scheduled_tasks,
memory_maintenance,
})
}
@ -539,34 +550,18 @@ impl SessionManager {
&self,
scope_key: &str,
) -> Result<Option<MemoryMaintenanceModelOutput>, AgentError> {
self.memory_maintenance_service()?
.summarize_for_scope(scope_key)
.await
self.memory_maintenance.summarize_for_scope(scope_key).await
}
pub(crate) async fn run_memory_maintenance_for_all_scopes(
&self,
updated_since: Option<i64>,
) -> Result<Vec<MemoryMaintenanceScopeResult>, AgentError> {
self.memory_maintenance_service()?
self.memory_maintenance
.run_for_all_scopes(updated_since)
.await
}
fn memory_maintenance_service(&self) -> Result<MemoryMaintenanceService, AgentError> {
Ok(MemoryMaintenanceService::new(
self.store.clone(),
self.provider_config_for_agent(None)?,
))
}
pub fn provider_config_for_agent(
&self,
agent_name: Option<&str>,
) -> Result<LLMProviderConfig, AgentError> {
select_provider_config(&self.provider_config, &self.provider_configs, agent_name)
}
/// 确保 session 存在且未超时,超时则重建
pub async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> {
self.lifecycle.ensure_session(channel_name).await
@ -596,43 +591,16 @@ impl SessionManager {
media: Vec<crate::bus::MediaItem>,
live_emitter: Option<Arc<dyn EmittedMessageHandler>>,
) -> Result<Vec<OutboundMessage>, AgentError> {
#[cfg(debug_assertions)]
{
tracing::debug!(
channel = %channel_name,
chat_id = %chat_id,
content_len = content.len(),
media_count = %media.len(),
"Routing message to agent"
);
for (i, m) in media.iter().enumerate() {
tracing::debug!(media_index = i, media_type = %m.media_type, path = %m.path, "Media in handle_message");
}
}
let session = self.lifecycle.active_session(channel_name).await?;
let outbound_messages = AgentExecutionService::new(self.show_tool_results)
.prepare_and_execute_message(MessageExecutionRequest {
session: session.clone(),
self.messages
.handle_message(
channel_name,
sender_id,
chat_id,
content,
media,
live_emitter,
})
.await?;
#[cfg(debug_assertions)]
tracing::debug!(
channel = %channel_name,
chat_id = %chat_id,
outbound_count = outbound_messages.len(),
"Agent response sequence received"
);
Ok(outbound_messages)
)
.await
}
pub async fn run_scheduled_agent_task(
@ -642,26 +610,8 @@ impl SessionManager {
prompt: &str,
options: ScheduledAgentTaskOptions,
) -> Result<Vec<OutboundMessage>, AgentError> {
let session = self.lifecycle.active_session(channel_name).await?;
let sender_id = options
.sender_id
.clone()
.unwrap_or_else(|| "scheduler".to_string());
let provider_config = self.provider_config_for_agent(options.agent.as_deref())?;
AgentExecutionService::new(self.show_tool_results)
.prepare_and_execute_scheduled_task(ScheduledExecutionRequest {
session: session.clone(),
channel_name,
chat_id,
prompt,
sender_id: &sender_id,
provider_config,
fresh_session: options.fresh_session,
system_prompt: options.system_prompt.as_deref(),
metadata: &options.metadata,
})
self.scheduled_tasks
.run(channel_name, chat_id, prompt, options)
.await
}

View File

@ -0,0 +1,69 @@
use std::sync::Arc;
use crate::agent::{AgentError, EmittedMessageHandler};
use crate::bus::{MediaItem, OutboundMessage};
use super::execution::{AgentExecutionService, MessageExecutionRequest};
use super::session_lifecycle::SessionLifecycleService;
#[derive(Clone)]
pub(crate) struct SessionMessageService {
lifecycle: SessionLifecycleService,
show_tool_results: bool,
}
impl SessionMessageService {
pub(crate) fn new(lifecycle: SessionLifecycleService, show_tool_results: bool) -> Self {
Self {
lifecycle,
show_tool_results,
}
}
pub(crate) async fn handle_message(
&self,
channel_name: &str,
sender_id: &str,
chat_id: &str,
content: &str,
media: Vec<MediaItem>,
live_emitter: Option<Arc<dyn EmittedMessageHandler>>,
) -> Result<Vec<OutboundMessage>, AgentError> {
#[cfg(debug_assertions)]
{
tracing::debug!(
channel = %channel_name,
chat_id = %chat_id,
content_len = content.len(),
media_count = %media.len(),
"Routing message to agent"
);
for (i, m) in media.iter().enumerate() {
tracing::debug!(media_index = i, media_type = %m.media_type, path = %m.path, "Media in handle_message");
}
}
let session = self.lifecycle.active_session(channel_name).await?;
let outbound_messages = AgentExecutionService::new(self.show_tool_results)
.prepare_and_execute_message(MessageExecutionRequest {
session,
channel_name,
sender_id,
chat_id,
content,
media,
live_emitter,
})
.await?;
#[cfg(debug_assertions)]
tracing::debug!(
channel = %channel_name,
chat_id = %chat_id,
outbound_count = outbound_messages.len(),
"Agent response sequence received"
);
Ok(outbound_messages)
}
}