PicoBot/src/gateway/runtime.rs

219 lines
7.7 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use crate::agent::AgentError;
use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig, TaskConfig};
use crate::gateway::tool_registry_factory::ToolRegistryFactory;
use crate::mcp::{McpClientManager, McpConfig};
use crate::skills::SkillRuntime;
use crate::storage::{
ConversationRepository, MemoryRepository, PromptInjectionRepository, SchedulerJobRepository,
SessionStore, SkillEventRepository,
};
use crate::tools::{
DefaultSubAgentRuntime, InMemoryTaskRepository, NoopSessionMessageSender,
SessionMessageSender, SubAgentRuntimeConfig, ToolRegistry,
};
use crate::tools::task::repository::TaskRepository;
use super::agent_factory::AgentFactory;
use super::cli_session::CliSessionService;
use super::memory_maintenance_coordinator::MemoryMaintenanceCoordinator;
use super::provider_config_service::ProviderConfigService;
use super::scheduled_agent_task_service::ScheduledAgentTaskService;
use super::session::{SessionManager, SessionManagerServices};
use super::session_factory::SessionFactory;
use super::session_lifecycle::SessionLifecycleService;
use super::session_message_service::SessionMessageService;
pub(crate) fn build_session_manager(
agent_prompt_reinject_every: u64,
show_tool_results: bool,
default_timezone: String,
provider_config: LLMProviderConfig,
provider_configs: HashMap<String, LLMProviderConfig>,
skills: Arc<SkillRuntime>,
disabled_tools: HashSet<String>,
task_config: TaskConfig,
maintenance_config: MemoryMaintenanceConfig,
session_ttl_hours: Option<u64>,
mcp_config: McpConfig,
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
build_session_manager_with_sender(
agent_prompt_reinject_every,
show_tool_results,
default_timezone,
provider_config,
provider_configs,
skills,
Arc::new(NoopSessionMessageSender),
disabled_tools,
task_config,
maintenance_config,
session_ttl_hours,
mcp_config,
)
}
pub(crate) fn build_session_manager_with_sender(
agent_prompt_reinject_every: u64,
show_tool_results: bool,
default_timezone: String,
provider_config: LLMProviderConfig,
provider_configs: HashMap<String, LLMProviderConfig>,
skills: Arc<SkillRuntime>,
session_message_sender: Arc<dyn SessionMessageSender>,
disabled_tools: HashSet<String>,
task_config: TaskConfig,
maintenance_config: MemoryMaintenanceConfig,
session_ttl_hours: Option<u64>,
mcp_config: McpConfig,
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
let store = Arc::new(
SessionStore::new()
.map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?,
);
let known_agents = provider_configs.keys().cloned().collect::<HashSet<_>>();
let provider_configs = ProviderConfigService::new(
provider_config.clone(),
provider_configs,
maintenance_config,
);
if let Err(err) =
store.append_skill_event(None, "discovered", None, &skills.discovery_event_payload())
{
tracing::warn!(error = %err, "Failed to record skill discovery event");
}
let memories: Arc<dyn MemoryRepository> = store.clone();
let scheduler_jobs: Arc<dyn SchedulerJobRepository> = store.clone();
let skill_events: Arc<dyn SkillEventRepository> = store.clone();
let conversations: Arc<dyn ConversationRepository> = store.clone();
// 创建 ToolRegistryFactory
let factory = ToolRegistryFactory::new(
skills.clone(),
memories,
scheduler_jobs,
skill_events.clone(),
session_message_sender.clone(),
known_agents,
default_timezone,
disabled_tools,
task_config.clone(),
);
// 创建 MCP Client Manager如果启用
let mcp_manager = if mcp_config.has_enabled_servers() {
let manager = Arc::new(McpClientManager::new());
// 在 tokio runtime 中连接 MCP servers
// 使用 block_in_place 允许在同步上下文中执行异步代码
let servers = mcp_config.enabled_servers();
let servers_clone: Vec<_> = servers.into_iter().cloned().collect();
tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(async {
tracing::info!("Connecting to MCP servers...");
if let Err(e) = manager.connect_all(&servers_clone).await {
tracing::error!(error = %e, "Failed to connect to some MCP servers");
}
})
});
Some(manager)
} else {
None
};
// 将 MCP manager 添加到 factory
let factory = if let Some(ref manager) = mcp_manager {
factory.with_mcp_manager(manager.clone())
} else {
factory
};
// 创建 SubAgentRuntime如果 task 工具启用)
let (factory, task_repository): (_, Arc<dyn TaskRepository>) = if task_config.enabled {
let task_repository = Arc::new(InMemoryTaskRepository::new());
let subagent_tools = Arc::new(factory.build_subagent_tools());
let runtime_config = SubAgentRuntimeConfig {
allowed_tools: task_config.allowed_tools.iter().cloned().collect(),
max_execution_secs: task_config.max_execution_secs,
explore_max_execution_secs: task_config.explore_max_execution_secs,
explore_max_tool_calls: 20,
ttl_hours: task_config.ttl_hours,
skills_index: skills.system_index_prompt(),
};
let subagent_runtime = Arc::new(DefaultSubAgentRuntime::new(
runtime_config,
task_repository.clone(),
conversations.clone(),
subagent_tools,
provider_config.clone(),
));
(factory.with_subagent_runtime(subagent_runtime), task_repository)
} else {
// 如果 task 工具未启用,创建一个空的内存仓库
(factory, Arc::new(InMemoryTaskRepository::new()))
};
let mut tools = factory.build();
// 注册 MCP tools如果有 MCP manager
if let Some(manager) = &mcp_manager {
tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(async {
if let Err(e) = crate::mcp::register_mcp_tools(manager.clone(), &mut tools).await {
tracing::error!(error = %e, "Failed to register MCP tools");
}
})
});
}
let tools = Arc::new(tools);
let prompt_repository: Arc<dyn PromptInjectionRepository> = store.clone();
let agent_factory = AgentFactory::new(
tools.clone(),
skills.clone(),
agent_prompt_reinject_every as usize,
prompt_repository.clone(),
);
let session_factory = SessionFactory::new(
provider_config.clone(),
skills.clone(),
agent_factory,
conversations,
skill_events,
store.clone(),
);
let lifecycle = SessionLifecycleService::new(session_factory, session_ttl_hours);
let cli_sessions = CliSessionService::new(store.clone());
let messages = SessionMessageService::new(lifecycle.clone(), show_tool_results);
let scheduled_tasks = ScheduledAgentTaskService::new(
lifecycle.clone(),
provider_configs.clone(),
show_tool_results,
);
let memory_maintenance =
MemoryMaintenanceCoordinator::new(store.clone(), provider_configs.clone());
Ok((SessionManager::from_services(SessionManagerServices {
tools: tools as Arc<ToolRegistry>,
skills,
store,
show_tool_results,
lifecycle,
cli_sessions,
messages,
scheduled_tasks,
memory_maintenance,
task_repository: task_repository.clone(),
}), task_repository))
}