PicoBot/src/gateway/runtime.rs

176 lines
6.3 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use crate::agent::AgentError;
use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig, TaskConfig};
use crate::gateway::tool_registry_factory::ToolRegistryFactory;
use crate::skills::SkillRuntime;
use crate::storage::{
ConversationRepository, MemoryRepository, PromptInjectionRepository, SchedulerJobRepository,
SessionStore, SkillEventRepository,
};
use crate::tools::{
DefaultSubAgentRuntime, InMemoryTaskRepository, NoopSessionMessageSender,
SessionMessageSender, SubAgentRuntimeConfig, ToolRegistry,
};
use crate::tools::task::repository::TaskRepository;
use super::agent_factory::AgentFactory;
use super::cli_session::CliSessionService;
use super::memory_maintenance_coordinator::MemoryMaintenanceCoordinator;
use super::provider_config_service::ProviderConfigService;
use super::scheduled_agent_task_service::ScheduledAgentTaskService;
use super::session::{SessionManager, SessionManagerServices};
use super::session_factory::SessionFactory;
use super::session_lifecycle::SessionLifecycleService;
use super::session_message_service::SessionMessageService;
pub(crate) fn build_session_manager(
agent_prompt_reinject_every: u64,
show_tool_results: bool,
default_timezone: String,
provider_config: LLMProviderConfig,
provider_configs: HashMap<String, LLMProviderConfig>,
skills: Arc<SkillRuntime>,
disabled_tools: HashSet<String>,
task_config: TaskConfig,
maintenance_config: MemoryMaintenanceConfig,
chat_history_ttl_hours: Option<u64>,
session_ttl_hours: Option<u64>,
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
build_session_manager_with_sender(
agent_prompt_reinject_every,
show_tool_results,
default_timezone,
provider_config,
provider_configs,
skills,
Arc::new(NoopSessionMessageSender),
disabled_tools,
task_config,
maintenance_config,
chat_history_ttl_hours,
session_ttl_hours,
)
}
pub(crate) fn build_session_manager_with_sender(
agent_prompt_reinject_every: u64,
show_tool_results: bool,
default_timezone: String,
provider_config: LLMProviderConfig,
provider_configs: HashMap<String, LLMProviderConfig>,
skills: Arc<SkillRuntime>,
session_message_sender: Arc<dyn SessionMessageSender>,
disabled_tools: HashSet<String>,
task_config: TaskConfig,
maintenance_config: MemoryMaintenanceConfig,
chat_history_ttl_hours: Option<u64>,
session_ttl_hours: Option<u64>,
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
let store = Arc::new(
SessionStore::new()
.map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?,
);
let known_agents = provider_configs.keys().cloned().collect::<HashSet<_>>();
let provider_configs = ProviderConfigService::new(
provider_config.clone(),
provider_configs,
maintenance_config,
);
if let Err(err) =
store.append_skill_event(None, "discovered", None, &skills.discovery_event_payload())
{
tracing::warn!(error = %err, "Failed to record skill discovery event");
}
let memories: Arc<dyn MemoryRepository> = store.clone();
let scheduler_jobs: Arc<dyn SchedulerJobRepository> = store.clone();
let skill_events: Arc<dyn SkillEventRepository> = store.clone();
let conversations: Arc<dyn ConversationRepository> = store.clone();
// 创建 ToolRegistryFactory
let factory = ToolRegistryFactory::new(
skills.clone(),
memories,
scheduler_jobs,
skill_events.clone(),
session_message_sender.clone(),
known_agents,
default_timezone,
disabled_tools,
task_config.clone(),
);
// 创建 SubAgentRuntime如果 task 工具启用)
let (factory, task_repository): (_, Arc<dyn TaskRepository>) = if task_config.enabled {
let task_repository = Arc::new(InMemoryTaskRepository::new());
let subagent_tools = Arc::new(factory.build_subagent_tools());
let runtime_config = SubAgentRuntimeConfig {
allowed_tools: task_config.allowed_tools.iter().cloned().collect(),
max_execution_secs: task_config.max_execution_secs,
explore_max_execution_secs: task_config.explore_max_execution_secs,
explore_max_tool_calls: 20,
ttl_hours: task_config.ttl_hours,
skills_index: skills.system_index_prompt(),
};
let subagent_runtime = Arc::new(DefaultSubAgentRuntime::new(
runtime_config,
task_repository.clone(),
conversations.clone(),
subagent_tools,
provider_config.clone(),
));
(factory.with_subagent_runtime(subagent_runtime), task_repository)
} else {
// 如果 task 工具未启用,创建一个空的内存仓库
(factory, Arc::new(InMemoryTaskRepository::new()))
};
let tools = Arc::new(factory.build());
let prompt_repository: Arc<dyn PromptInjectionRepository> = store.clone();
let agent_factory = AgentFactory::new(
tools.clone(),
skills.clone(),
agent_prompt_reinject_every as usize,
prompt_repository.clone(),
);
let session_factory = SessionFactory::new(
provider_config.clone(),
skills.clone(),
agent_factory,
conversations,
skill_events,
store.clone(),
chat_history_ttl_hours,
);
let lifecycle = SessionLifecycleService::new(session_factory, session_ttl_hours);
let cli_sessions = CliSessionService::new(store.clone());
let messages = SessionMessageService::new(lifecycle.clone(), show_tool_results);
let scheduled_tasks = ScheduledAgentTaskService::new(
lifecycle.clone(),
provider_configs.clone(),
show_tool_results,
);
let memory_maintenance =
MemoryMaintenanceCoordinator::new(store.clone(), provider_configs.clone());
Ok((SessionManager::from_services(SessionManagerServices {
tools: tools as Arc<ToolRegistry>,
skills,
store,
show_tool_results,
lifecycle,
cli_sessions,
messages,
scheduled_tasks,
memory_maintenance,
task_repository: task_repository.clone(),
}), task_repository))
}