feat: 添加 MCP (Model Context Protocol) 支持,包含客户端管理器和工具适配器
This commit is contained in:
parent
f68e915b04
commit
cbb384a4e6
@ -37,3 +37,11 @@ rusqlite = { version = "0.32", features = ["bundled"] }
|
|||||||
rustls = { version = "0.23", features = ["ring"] }
|
rustls = { version = "0.23", features = ["ring"] }
|
||||||
wechatbot = { path = "vendor/wechatbot" }
|
wechatbot = { path = "vendor/wechatbot" }
|
||||||
encoding_rs = "0.8"
|
encoding_rs = "0.8"
|
||||||
|
# MCP (Model Context Protocol) support
|
||||||
|
rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "main", features = [
|
||||||
|
"client",
|
||||||
|
"transport-child-process",
|
||||||
|
"transport-streamable-http-client-reqwest",
|
||||||
|
"reqwest",
|
||||||
|
] }
|
||||||
|
schemars = "1.0"
|
||||||
|
|||||||
@ -76,6 +76,7 @@ impl InitWizard {
|
|||||||
skills: crate::config::SkillsConfig::default(),
|
skills: crate::config::SkillsConfig::default(),
|
||||||
tools: crate::config::ToolsConfig::default(),
|
tools: crate::config::ToolsConfig::default(),
|
||||||
memory_maintenance: crate::config::MemoryMaintenanceConfig::default(),
|
memory_maintenance: crate::config::MemoryMaintenanceConfig::default(),
|
||||||
|
mcp: crate::mcp::McpConfig::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -826,6 +827,7 @@ impl InitWizard {
|
|||||||
skills: existing.skills.clone(),
|
skills: existing.skills.clone(),
|
||||||
tools: existing.tools.clone(),
|
tools: existing.tools.clone(),
|
||||||
memory_maintenance: existing.memory_maintenance.clone(),
|
memory_maintenance: existing.memory_maintenance.clone(),
|
||||||
|
mcp: existing.mcp.clone(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -29,6 +29,8 @@ pub struct Config {
|
|||||||
pub tools: ToolsConfig,
|
pub tools: ToolsConfig,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub memory_maintenance: MemoryMaintenanceConfig,
|
pub memory_maintenance: MemoryMaintenanceConfig,
|
||||||
|
#[serde(default)]
|
||||||
|
pub mcp: crate::mcp::McpConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
|||||||
@ -84,6 +84,7 @@ impl GatewayState {
|
|||||||
config.tools.task.clone(),
|
config.tools.task.clone(),
|
||||||
config.memory_maintenance.clone(),
|
config.memory_maintenance.clone(),
|
||||||
session_ttl_hours,
|
session_ttl_hours,
|
||||||
|
config.mcp.clone(),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
|||||||
@ -4,6 +4,7 @@ use std::sync::Arc;
|
|||||||
use crate::agent::AgentError;
|
use crate::agent::AgentError;
|
||||||
use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig, TaskConfig};
|
use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig, TaskConfig};
|
||||||
use crate::gateway::tool_registry_factory::ToolRegistryFactory;
|
use crate::gateway::tool_registry_factory::ToolRegistryFactory;
|
||||||
|
use crate::mcp::{McpClientManager, McpConfig};
|
||||||
use crate::skills::SkillRuntime;
|
use crate::skills::SkillRuntime;
|
||||||
use crate::storage::{
|
use crate::storage::{
|
||||||
ConversationRepository, MemoryRepository, PromptInjectionRepository, SchedulerJobRepository,
|
ConversationRepository, MemoryRepository, PromptInjectionRepository, SchedulerJobRepository,
|
||||||
@ -36,6 +37,7 @@ pub(crate) fn build_session_manager(
|
|||||||
task_config: TaskConfig,
|
task_config: TaskConfig,
|
||||||
maintenance_config: MemoryMaintenanceConfig,
|
maintenance_config: MemoryMaintenanceConfig,
|
||||||
session_ttl_hours: Option<u64>,
|
session_ttl_hours: Option<u64>,
|
||||||
|
mcp_config: McpConfig,
|
||||||
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
|
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
|
||||||
build_session_manager_with_sender(
|
build_session_manager_with_sender(
|
||||||
agent_prompt_reinject_every,
|
agent_prompt_reinject_every,
|
||||||
@ -49,6 +51,7 @@ pub(crate) fn build_session_manager(
|
|||||||
task_config,
|
task_config,
|
||||||
maintenance_config,
|
maintenance_config,
|
||||||
session_ttl_hours,
|
session_ttl_hours,
|
||||||
|
mcp_config,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -64,6 +67,7 @@ pub(crate) fn build_session_manager_with_sender(
|
|||||||
task_config: TaskConfig,
|
task_config: TaskConfig,
|
||||||
maintenance_config: MemoryMaintenanceConfig,
|
maintenance_config: MemoryMaintenanceConfig,
|
||||||
session_ttl_hours: Option<u64>,
|
session_ttl_hours: Option<u64>,
|
||||||
|
mcp_config: McpConfig,
|
||||||
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
|
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
|
||||||
let store = Arc::new(
|
let store = Arc::new(
|
||||||
SessionStore::new()
|
SessionStore::new()
|
||||||
@ -100,6 +104,36 @@ pub(crate) fn build_session_manager_with_sender(
|
|||||||
task_config.clone(),
|
task_config.clone(),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// 创建 MCP Client Manager(如果启用)
|
||||||
|
let mcp_manager = if mcp_config.has_enabled_servers() {
|
||||||
|
let manager = Arc::new(McpClientManager::new());
|
||||||
|
|
||||||
|
// 在 tokio runtime 中连接 MCP servers
|
||||||
|
// 使用 block_in_place 允许在同步上下文中执行异步代码
|
||||||
|
let servers = mcp_config.enabled_servers();
|
||||||
|
let servers_clone: Vec<_> = servers.into_iter().cloned().collect();
|
||||||
|
|
||||||
|
tokio::task::block_in_place(|| {
|
||||||
|
tokio::runtime::Handle::current().block_on(async {
|
||||||
|
tracing::info!("Connecting to MCP servers...");
|
||||||
|
if let Err(e) = manager.connect_all(&servers_clone).await {
|
||||||
|
tracing::error!(error = %e, "Failed to connect to some MCP servers");
|
||||||
|
}
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
Some(manager)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
// 将 MCP manager 添加到 factory
|
||||||
|
let factory = if let Some(ref manager) = mcp_manager {
|
||||||
|
factory.with_mcp_manager(manager.clone())
|
||||||
|
} else {
|
||||||
|
factory
|
||||||
|
};
|
||||||
|
|
||||||
// 创建 SubAgentRuntime(如果 task 工具启用)
|
// 创建 SubAgentRuntime(如果 task 工具启用)
|
||||||
let (factory, task_repository): (_, Arc<dyn TaskRepository>) = if task_config.enabled {
|
let (factory, task_repository): (_, Arc<dyn TaskRepository>) = if task_config.enabled {
|
||||||
let task_repository = Arc::new(InMemoryTaskRepository::new());
|
let task_repository = Arc::new(InMemoryTaskRepository::new());
|
||||||
@ -128,7 +162,20 @@ pub(crate) fn build_session_manager_with_sender(
|
|||||||
(factory, Arc::new(InMemoryTaskRepository::new()))
|
(factory, Arc::new(InMemoryTaskRepository::new()))
|
||||||
};
|
};
|
||||||
|
|
||||||
let tools = Arc::new(factory.build());
|
let mut tools = factory.build();
|
||||||
|
|
||||||
|
// 注册 MCP tools(如果有 MCP manager)
|
||||||
|
if let Some(manager) = &mcp_manager {
|
||||||
|
tokio::task::block_in_place(|| {
|
||||||
|
tokio::runtime::Handle::current().block_on(async {
|
||||||
|
if let Err(e) = crate::mcp::register_mcp_tools(manager.clone(), &mut tools).await {
|
||||||
|
tracing::error!(error = %e, "Failed to register MCP tools");
|
||||||
|
}
|
||||||
|
})
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let tools = Arc::new(tools);
|
||||||
|
|
||||||
let prompt_repository: Arc<dyn PromptInjectionRepository> = store.clone();
|
let prompt_repository: Arc<dyn PromptInjectionRepository> = store.clone();
|
||||||
let agent_factory = AgentFactory::new(
|
let agent_factory = AgentFactory::new(
|
||||||
|
|||||||
@ -495,6 +495,7 @@ impl SessionManager {
|
|||||||
task_config: crate::config::TaskConfig,
|
task_config: crate::config::TaskConfig,
|
||||||
maintenance_config: crate::config::MemoryMaintenanceConfig,
|
maintenance_config: crate::config::MemoryMaintenanceConfig,
|
||||||
session_ttl_hours: Option<u64>,
|
session_ttl_hours: Option<u64>,
|
||||||
|
mcp_config: crate::mcp::McpConfig,
|
||||||
) -> Result<Self, AgentError> {
|
) -> Result<Self, AgentError> {
|
||||||
super::runtime::build_session_manager(
|
super::runtime::build_session_manager(
|
||||||
agent_prompt_reinject_every,
|
agent_prompt_reinject_every,
|
||||||
@ -507,6 +508,7 @@ impl SessionManager {
|
|||||||
task_config,
|
task_config,
|
||||||
maintenance_config,
|
maintenance_config,
|
||||||
session_ttl_hours,
|
session_ttl_hours,
|
||||||
|
mcp_config,
|
||||||
)
|
)
|
||||||
.map(|(session_manager, _)| session_manager)
|
.map(|(session_manager, _)| session_manager)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,6 +2,7 @@ use std::collections::HashSet;
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::config::TaskConfig;
|
use crate::config::TaskConfig;
|
||||||
|
use crate::mcp::McpClientManager;
|
||||||
use crate::skills::SkillRuntime;
|
use crate::skills::SkillRuntime;
|
||||||
use crate::storage::{MemoryRepository, SchedulerJobRepository, SkillEventRepository};
|
use crate::storage::{MemoryRepository, SchedulerJobRepository, SkillEventRepository};
|
||||||
use crate::tools::{
|
use crate::tools::{
|
||||||
@ -23,6 +24,7 @@ pub(crate) struct ToolRegistryFactory {
|
|||||||
disabled_tools: HashSet<String>,
|
disabled_tools: HashSet<String>,
|
||||||
task_config: TaskConfig,
|
task_config: TaskConfig,
|
||||||
subagent_runtime: Option<Arc<dyn SubAgentRuntime>>,
|
subagent_runtime: Option<Arc<dyn SubAgentRuntime>>,
|
||||||
|
mcp_manager: Option<Arc<McpClientManager>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToolRegistryFactory {
|
impl ToolRegistryFactory {
|
||||||
@ -48,6 +50,7 @@ impl ToolRegistryFactory {
|
|||||||
disabled_tools,
|
disabled_tools,
|
||||||
task_config,
|
task_config,
|
||||||
subagent_runtime: None,
|
subagent_runtime: None,
|
||||||
|
mcp_manager: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -59,6 +62,14 @@ impl ToolRegistryFactory {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn with_mcp_manager(
|
||||||
|
mut self,
|
||||||
|
manager: Arc<McpClientManager>,
|
||||||
|
) -> Self {
|
||||||
|
self.mcp_manager = Some(manager);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
fn is_enabled(&self, tool_name: &str) -> bool {
|
fn is_enabled(&self, tool_name: &str) -> bool {
|
||||||
!self.disabled_tools.contains(tool_name)
|
!self.disabled_tools.contains(tool_name)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -9,6 +9,7 @@ pub mod config;
|
|||||||
pub mod domain;
|
pub mod domain;
|
||||||
pub mod gateway;
|
pub mod gateway;
|
||||||
pub mod logging;
|
pub mod logging;
|
||||||
|
pub mod mcp;
|
||||||
pub mod observability;
|
pub mod observability;
|
||||||
pub mod platform;
|
pub mod platform;
|
||||||
pub mod protocol;
|
pub mod protocol;
|
||||||
|
|||||||
229
src/mcp/client.rs
Normal file
229
src/mcp/client.rs
Normal file
@ -0,0 +1,229 @@
|
|||||||
|
//! MCP Client Manager - manages connections to MCP servers
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
|
use rmcp::{
|
||||||
|
model::{CallToolRequestParams, CallToolResult, ServerInfo, Tool},
|
||||||
|
RoleClient, ServiceExt,
|
||||||
|
service::RunningService,
|
||||||
|
transport::TokioChildProcess,
|
||||||
|
};
|
||||||
|
use tokio::process::Command;
|
||||||
|
|
||||||
|
use crate::mcp::config::{McpServerConfig, McpTransportConfig};
|
||||||
|
|
||||||
|
/// Type alias for the MCP client service
|
||||||
|
pub type McpClient = RunningService<RoleClient, ()>;
|
||||||
|
|
||||||
|
/// Information about a connected MCP server
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct McpServerInfo {
|
||||||
|
/// Server name
|
||||||
|
pub name: String,
|
||||||
|
/// Server information from MCP protocol
|
||||||
|
pub info: Option<ServerInfo>,
|
||||||
|
/// Available tools
|
||||||
|
pub tools: Vec<Tool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Manager for MCP client connections
|
||||||
|
pub struct McpClientManager {
|
||||||
|
/// Connected clients keyed by server name
|
||||||
|
clients: RwLock<HashMap<String, Arc<McpClient>>>,
|
||||||
|
/// Server information cache
|
||||||
|
server_info: RwLock<HashMap<String, McpServerInfo>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl McpClientManager {
|
||||||
|
/// Create a new manager
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
clients: RwLock::new(HashMap::new()),
|
||||||
|
server_info: RwLock::new(HashMap::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Connect to all configured servers
|
||||||
|
pub async fn connect_all(&self, servers: &[McpServerConfig]) -> anyhow::Result<()> {
|
||||||
|
for server in servers {
|
||||||
|
if !server.enabled {
|
||||||
|
tracing::info!(name = %server.name, "Skipping disabled MCP server");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
match self.connect_server(server).await {
|
||||||
|
Ok(info) => {
|
||||||
|
tracing::info!(
|
||||||
|
name = %server.name,
|
||||||
|
tools_count = info.tools.len(),
|
||||||
|
"Connected to MCP server"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!(
|
||||||
|
name = %server.name,
|
||||||
|
error = %e,
|
||||||
|
"Failed to connect to MCP server"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Connect to a single MCP server
|
||||||
|
pub async fn connect_server(&self, config: &McpServerConfig) -> anyhow::Result<McpServerInfo> {
|
||||||
|
tracing::info!(name = %config.name, "Connecting to MCP server");
|
||||||
|
|
||||||
|
let client = match &config.transport {
|
||||||
|
McpTransportConfig::Stdio { command, args, env } => {
|
||||||
|
self.connect_stdio(command, args, env).await?
|
||||||
|
}
|
||||||
|
McpTransportConfig::Http { url, headers: _ } => {
|
||||||
|
// HTTP transport requires additional setup
|
||||||
|
// For now, we'll return an error for HTTP transport
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"HTTP transport for MCP server '{}' is not yet implemented. URL: {}",
|
||||||
|
config.name,
|
||||||
|
url
|
||||||
|
));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Get server info (returns Option<&ServerInfo>)
|
||||||
|
let info = client.peer_info().cloned();
|
||||||
|
|
||||||
|
// List available tools
|
||||||
|
let tools = client.list_all_tools().await?;
|
||||||
|
|
||||||
|
let server_info = McpServerInfo {
|
||||||
|
name: config.name.clone(),
|
||||||
|
info,
|
||||||
|
tools,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Store the client and info
|
||||||
|
{
|
||||||
|
let mut clients = self.clients.write().await;
|
||||||
|
clients.insert(config.name.clone(), Arc::new(client));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
let mut info_map = self.server_info.write().await;
|
||||||
|
info_map.insert(config.name.clone(), server_info.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(server_info)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Connect via stdio transport
|
||||||
|
async fn connect_stdio(
|
||||||
|
&self,
|
||||||
|
command: &str,
|
||||||
|
args: &[String],
|
||||||
|
env: &HashMap<String, String>,
|
||||||
|
) -> anyhow::Result<McpClient> {
|
||||||
|
let mut cmd = Command::new(command);
|
||||||
|
cmd.args(args);
|
||||||
|
|
||||||
|
// Set environment variables
|
||||||
|
for (key, value) in env {
|
||||||
|
cmd.env(key, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
let transport = TokioChildProcess::new(cmd)?;
|
||||||
|
|
||||||
|
// Use default client handler (empty tuple)
|
||||||
|
let client = ().serve(transport).await?;
|
||||||
|
|
||||||
|
Ok(client)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a client by server name
|
||||||
|
pub async fn get_client(&self, name: &str) -> Option<Arc<McpClient>> {
|
||||||
|
let clients = self.clients.read().await;
|
||||||
|
clients.get(name).cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get server info by name
|
||||||
|
pub async fn get_server_info(&self, name: &str) -> Option<McpServerInfo> {
|
||||||
|
let info_map = self.server_info.read().await;
|
||||||
|
info_map.get(name).cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get all connected server names
|
||||||
|
pub async fn connected_servers(&self) -> Vec<String> {
|
||||||
|
let clients = self.clients.read().await;
|
||||||
|
clients.keys().cloned().collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get all tools from all connected servers
|
||||||
|
pub async fn all_tools(&self) -> Vec<(String, Tool)> {
|
||||||
|
let info_map = self.server_info.read().await;
|
||||||
|
info_map
|
||||||
|
.values()
|
||||||
|
.flat_map(|info| {
|
||||||
|
info.tools.iter().map(|tool| (info.name.clone(), tool.clone()))
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Call a tool on a specific server
|
||||||
|
pub async fn call_tool(
|
||||||
|
&self,
|
||||||
|
server_name: impl Into<String>,
|
||||||
|
tool_name: impl Into<String>,
|
||||||
|
args: serde_json::Value,
|
||||||
|
) -> anyhow::Result<CallToolResult> {
|
||||||
|
let server_name = server_name.into();
|
||||||
|
let tool_name = tool_name.into();
|
||||||
|
|
||||||
|
let client = self
|
||||||
|
.get_client(&server_name)
|
||||||
|
.await
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("MCP server '{}' not connected", server_name))?;
|
||||||
|
|
||||||
|
// Convert Value to JsonObject if it's an object
|
||||||
|
let arguments = if args.is_object() {
|
||||||
|
args.as_object().unwrap().clone()
|
||||||
|
} else {
|
||||||
|
// If not an object, wrap it or use empty object
|
||||||
|
serde_json::Map::new()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create params with owned String (converted to Cow<'static, str>)
|
||||||
|
let params = CallToolRequestParams::new(tool_name).with_arguments(arguments);
|
||||||
|
|
||||||
|
let result = client.call_tool(params).await?;
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Disconnect from a server
|
||||||
|
pub async fn disconnect(&self, name: impl Into<String>) -> anyhow::Result<()> {
|
||||||
|
let name = name.into();
|
||||||
|
let mut clients = self.clients.write().await;
|
||||||
|
if clients.remove(&name).is_some() {
|
||||||
|
tracing::info!(name = %name, "Disconnected MCP server");
|
||||||
|
}
|
||||||
|
self.server_info.write().await.remove(&name);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Disconnect from all servers
|
||||||
|
pub async fn disconnect_all(&self) -> anyhow::Result<()> {
|
||||||
|
let mut clients = self.clients.write().await;
|
||||||
|
for (name, _client) in clients.drain() {
|
||||||
|
tracing::info!(name = %name, "Disconnected MCP server");
|
||||||
|
}
|
||||||
|
self.server_info.write().await.clear();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for McpClientManager {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
173
src/mcp/config.rs
Normal file
173
src/mcp/config.rs
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
//! MCP Server configuration structures
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
/// MCP integration configuration
|
||||||
|
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
|
||||||
|
pub struct McpConfig {
|
||||||
|
/// Whether MCP integration is enabled
|
||||||
|
#[serde(default)]
|
||||||
|
pub enabled: bool,
|
||||||
|
|
||||||
|
/// List of MCP servers to connect
|
||||||
|
#[serde(default)]
|
||||||
|
pub servers: Vec<McpServerConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Configuration for a single MCP server
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct McpServerConfig {
|
||||||
|
/// Unique name for this server (used in tool naming)
|
||||||
|
pub name: String,
|
||||||
|
|
||||||
|
/// Transport configuration
|
||||||
|
pub transport: McpTransportConfig,
|
||||||
|
|
||||||
|
/// Whether this server is enabled
|
||||||
|
#[serde(default = "default_server_enabled")]
|
||||||
|
pub enabled: bool,
|
||||||
|
|
||||||
|
/// Optional description for the server
|
||||||
|
#[serde(default)]
|
||||||
|
pub description: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_server_enabled() -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Transport configuration for connecting to MCP servers
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
pub enum McpTransportConfig {
|
||||||
|
/// Stdio transport: spawn a child process
|
||||||
|
Stdio {
|
||||||
|
/// Command to execute (e.g., "npx", "cargo")
|
||||||
|
command: String,
|
||||||
|
/// Arguments to pass to the command
|
||||||
|
#[serde(default)]
|
||||||
|
args: Vec<String>,
|
||||||
|
/// Optional environment variables to set
|
||||||
|
#[serde(default)]
|
||||||
|
env: HashMap<String, String>,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// HTTP transport: connect to a remote server
|
||||||
|
Http {
|
||||||
|
/// URL of the MCP server endpoint
|
||||||
|
url: String,
|
||||||
|
/// Optional headers to include in requests
|
||||||
|
#[serde(default)]
|
||||||
|
headers: HashMap<String, String>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl McpServerConfig {
|
||||||
|
/// Create a stdio server config
|
||||||
|
pub fn stdio(name: impl Into<String>, command: impl Into<String>, args: Vec<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
name: name.into(),
|
||||||
|
transport: McpTransportConfig::Stdio {
|
||||||
|
command: command.into(),
|
||||||
|
args,
|
||||||
|
env: HashMap::new(),
|
||||||
|
},
|
||||||
|
enabled: true,
|
||||||
|
description: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create an HTTP server config
|
||||||
|
pub fn http(name: impl Into<String>, url: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
name: name.into(),
|
||||||
|
transport: McpTransportConfig::Http {
|
||||||
|
url: url.into(),
|
||||||
|
headers: HashMap::new(),
|
||||||
|
},
|
||||||
|
enabled: true,
|
||||||
|
description: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl McpConfig {
|
||||||
|
/// Get enabled servers
|
||||||
|
pub fn enabled_servers(&self) -> Vec<&McpServerConfig> {
|
||||||
|
self.servers.iter().filter(|s| s.enabled).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if there are any enabled servers
|
||||||
|
pub fn has_enabled_servers(&self) -> bool {
|
||||||
|
self.enabled && self.servers.iter().any(|s| s.enabled)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_stdio_config_creation() {
|
||||||
|
let config = McpServerConfig::stdio(
|
||||||
|
"filesystem",
|
||||||
|
"npx",
|
||||||
|
vec!["-y", "@modelcontextprotocol/server-filesystem", "/tmp"],
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(config.name, "filesystem");
|
||||||
|
assert!(config.enabled);
|
||||||
|
assert!(matches!(config.transport, McpTransportConfig::Stdio { .. }));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_http_config_creation() {
|
||||||
|
let config = McpServerConfig::http("custom", "http://localhost:8000/mcp");
|
||||||
|
|
||||||
|
assert_eq!(config.name, "custom");
|
||||||
|
assert!(config.enabled);
|
||||||
|
assert!(matches!(config.transport, McpTransportConfig::Http { .. }));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_config_deserialization() {
|
||||||
|
let json = r#"{
|
||||||
|
"enabled": true,
|
||||||
|
"servers": [
|
||||||
|
{
|
||||||
|
"name": "filesystem",
|
||||||
|
"transport": {
|
||||||
|
"type": "stdio",
|
||||||
|
"command": "npx",
|
||||||
|
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "http-server",
|
||||||
|
"enabled": false,
|
||||||
|
"transport": {
|
||||||
|
"type": "http",
|
||||||
|
"url": "http://localhost:8000/mcp",
|
||||||
|
"headers": {
|
||||||
|
"Authorization": "Bearer token"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}"#;
|
||||||
|
|
||||||
|
let config: McpConfig = serde_json::from_str(json).unwrap();
|
||||||
|
assert!(config.enabled);
|
||||||
|
assert_eq!(config.servers.len(), 2);
|
||||||
|
assert_eq!(config.enabled_servers().len(), 1);
|
||||||
|
|
||||||
|
let fs_server = &config.servers[0];
|
||||||
|
assert_eq!(fs_server.name, "filesystem");
|
||||||
|
assert!(fs_server.enabled);
|
||||||
|
|
||||||
|
let http_server = &config.servers[1];
|
||||||
|
assert_eq!(http_server.name, "http-server");
|
||||||
|
assert!(!http_server.enabled);
|
||||||
|
}
|
||||||
|
}
|
||||||
12
src/mcp/mod.rs
Normal file
12
src/mcp/mod.rs
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
//! MCP (Model Context Protocol) integration module
|
||||||
|
//!
|
||||||
|
//! This module provides MCP client functionality to connect to external MCP servers
|
||||||
|
//! and expose their tools through PicoBot's Tool system.
|
||||||
|
|
||||||
|
pub mod config;
|
||||||
|
pub mod client;
|
||||||
|
pub mod tool_adapter;
|
||||||
|
|
||||||
|
pub use config::{McpConfig, McpServerConfig, McpTransportConfig};
|
||||||
|
pub use client::{McpClientManager, McpClient, McpServerInfo};
|
||||||
|
pub use tool_adapter::{McpToolWrapper, register_mcp_tools};
|
||||||
186
src/mcp/tool_adapter.rs
Normal file
186
src/mcp/tool_adapter.rs
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
//! MCP Tool Adapter - wraps MCP tools as PicoBot tools
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use rmcp::model::Tool;
|
||||||
|
|
||||||
|
use crate::mcp::client::McpClientManager;
|
||||||
|
use crate::tools::traits::{Tool as PicoBotTool, ToolResult};
|
||||||
|
|
||||||
|
/// Wrapper that adapts an MCP tool to PicoBot's Tool trait
|
||||||
|
pub struct McpToolWrapper {
|
||||||
|
/// The MCP client manager
|
||||||
|
manager: Arc<McpClientManager>,
|
||||||
|
/// The server name this tool belongs to
|
||||||
|
server_name: String,
|
||||||
|
/// The original tool name on the MCP server
|
||||||
|
tool_name: String,
|
||||||
|
/// The full tool name with namespace (mcp_{server}_{tool})
|
||||||
|
full_name: String,
|
||||||
|
/// Tool information from MCP server
|
||||||
|
tool_info: Tool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl McpToolWrapper {
|
||||||
|
/// Create a new tool wrapper
|
||||||
|
pub fn new(
|
||||||
|
manager: Arc<McpClientManager>,
|
||||||
|
server_name: String,
|
||||||
|
tool_info: Tool,
|
||||||
|
) -> Self {
|
||||||
|
let tool_name = tool_info.name.clone().into_owned();
|
||||||
|
let full_name = format!("mcp_{}_{}", server_name, tool_name);
|
||||||
|
Self {
|
||||||
|
manager,
|
||||||
|
server_name,
|
||||||
|
tool_name,
|
||||||
|
full_name,
|
||||||
|
tool_info,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the server name
|
||||||
|
pub fn server_name(&self) -> &str {
|
||||||
|
&self.server_name
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the original tool name
|
||||||
|
pub fn original_name(&self) -> &str {
|
||||||
|
&self.tool_name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl PicoBotTool for McpToolWrapper {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
&self.full_name
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
self.tool_info.description.as_deref().unwrap_or("MCP tool")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
|
// Convert Arc<JsonObject> to serde_json::Value
|
||||||
|
let schema = (*self.tool_info.input_schema).clone();
|
||||||
|
serde_json::Value::Object(schema)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||||
|
tracing::debug!(
|
||||||
|
server = %self.server_name,
|
||||||
|
tool = %self.tool_name,
|
||||||
|
"Calling MCP tool"
|
||||||
|
);
|
||||||
|
|
||||||
|
let result = self
|
||||||
|
.manager
|
||||||
|
.call_tool(&self.server_name, &self.tool_name, args)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Convert MCP CallToolResult to PicoBot ToolResult
|
||||||
|
let output = extract_text_content(&result);
|
||||||
|
let is_error = result.is_error.unwrap_or(false);
|
||||||
|
|
||||||
|
Ok(ToolResult {
|
||||||
|
success: !is_error,
|
||||||
|
output,
|
||||||
|
error: if is_error {
|
||||||
|
Some("MCP tool returned error".to_string())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read_only(&self) -> bool {
|
||||||
|
// MCP tools may or may not be read-only; we default to false
|
||||||
|
// This could be enhanced if MCP servers provide this info via annotations
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract text content from MCP CallToolResult
|
||||||
|
fn extract_text_content(result: &rmcp::model::CallToolResult) -> String {
|
||||||
|
let mut text_parts = Vec::new();
|
||||||
|
|
||||||
|
for content in &result.content {
|
||||||
|
if let Some(text) = content.as_text() {
|
||||||
|
text_parts.push(text.text.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if text_parts.is_empty() {
|
||||||
|
// No text content found, try to serialize the whole result
|
||||||
|
serde_json::to_string_pretty(&result).unwrap_or_else(|_| "Empty result".to_string())
|
||||||
|
} else {
|
||||||
|
text_parts.join("\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register all MCP tools from connected servers into a tool registry
|
||||||
|
pub async fn register_mcp_tools(
|
||||||
|
manager: Arc<McpClientManager>,
|
||||||
|
registry: &mut crate::tools::registry::ToolRegistry,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
let all_tools = manager.all_tools().await;
|
||||||
|
|
||||||
|
for (server_name, tool_info) in all_tools {
|
||||||
|
let wrapper = McpToolWrapper::new(
|
||||||
|
manager.clone(),
|
||||||
|
server_name.clone(),
|
||||||
|
tool_info,
|
||||||
|
);
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
name = %wrapper.name(),
|
||||||
|
server = %server_name,
|
||||||
|
"Registering MCP tool"
|
||||||
|
);
|
||||||
|
|
||||||
|
registry.register(wrapper);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use rmcp::model::{CallToolResult, Content};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_extract_text_content_from_text() {
|
||||||
|
let result = CallToolResult::success(vec![
|
||||||
|
Content::text("Hello"),
|
||||||
|
Content::text("World"),
|
||||||
|
]);
|
||||||
|
|
||||||
|
let text = extract_text_content(&result);
|
||||||
|
assert_eq!(text, "Hello\nWorld");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_extract_text_content_empty() {
|
||||||
|
let result = CallToolResult::success(vec![]);
|
||||||
|
let text = extract_text_content(&result);
|
||||||
|
assert!(text.contains("Empty result"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_mcp_tool_wrapper_name() {
|
||||||
|
let manager = Arc::new(McpClientManager::new());
|
||||||
|
let tool_info = Tool {
|
||||||
|
name: "echo".into(),
|
||||||
|
description: Some("Echo tool".into()),
|
||||||
|
input_schema: serde_json::json!({"type": "object"}).as_object().unwrap().clone(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let wrapper = McpToolWrapper::new(manager, "filesystem".to_string(), tool_info);
|
||||||
|
assert_eq!(wrapper.name(), "mcp_filesystem_echo");
|
||||||
|
assert_eq!(wrapper.original_name(), "echo");
|
||||||
|
assert_eq!(wrapper.server_name(), "filesystem");
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user