feat: 添加 MCP (Model Context Protocol) 支持,包含客户端管理器和工具适配器

This commit is contained in:
ooodc 2026-05-23 22:52:36 +08:00
parent f68e915b04
commit cbb384a4e6
12 changed files with 675 additions and 1 deletions

View File

@ -37,3 +37,11 @@ rusqlite = { version = "0.32", features = ["bundled"] }
rustls = { version = "0.23", features = ["ring"] }
wechatbot = { path = "vendor/wechatbot" }
encoding_rs = "0.8"
# MCP (Model Context Protocol) support
rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "main", features = [
"client",
"transport-child-process",
"transport-streamable-http-client-reqwest",
"reqwest",
] }
schemars = "1.0"

View File

@ -76,6 +76,7 @@ impl InitWizard {
skills: crate::config::SkillsConfig::default(),
tools: crate::config::ToolsConfig::default(),
memory_maintenance: crate::config::MemoryMaintenanceConfig::default(),
mcp: crate::mcp::McpConfig::default(),
}
}
@ -826,6 +827,7 @@ impl InitWizard {
skills: existing.skills.clone(),
tools: existing.tools.clone(),
memory_maintenance: existing.memory_maintenance.clone(),
mcp: existing.mcp.clone(),
}
}

View File

@ -29,6 +29,8 @@ pub struct Config {
pub tools: ToolsConfig,
#[serde(default)]
pub memory_maintenance: MemoryMaintenanceConfig,
#[serde(default)]
pub mcp: crate::mcp::McpConfig,
}
#[derive(Debug, Clone, Deserialize, Serialize)]

View File

@ -84,6 +84,7 @@ impl GatewayState {
config.tools.task.clone(),
config.memory_maintenance.clone(),
session_ttl_hours,
config.mcp.clone(),
)?;
Ok(Self {

View File

@ -4,6 +4,7 @@ use std::sync::Arc;
use crate::agent::AgentError;
use crate::config::{LLMProviderConfig, MemoryMaintenanceConfig, TaskConfig};
use crate::gateway::tool_registry_factory::ToolRegistryFactory;
use crate::mcp::{McpClientManager, McpConfig};
use crate::skills::SkillRuntime;
use crate::storage::{
ConversationRepository, MemoryRepository, PromptInjectionRepository, SchedulerJobRepository,
@ -36,6 +37,7 @@ pub(crate) fn build_session_manager(
task_config: TaskConfig,
maintenance_config: MemoryMaintenanceConfig,
session_ttl_hours: Option<u64>,
mcp_config: McpConfig,
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
build_session_manager_with_sender(
agent_prompt_reinject_every,
@ -49,6 +51,7 @@ pub(crate) fn build_session_manager(
task_config,
maintenance_config,
session_ttl_hours,
mcp_config,
)
}
@ -64,6 +67,7 @@ pub(crate) fn build_session_manager_with_sender(
task_config: TaskConfig,
maintenance_config: MemoryMaintenanceConfig,
session_ttl_hours: Option<u64>,
mcp_config: McpConfig,
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
let store = Arc::new(
SessionStore::new()
@ -100,6 +104,36 @@ pub(crate) fn build_session_manager_with_sender(
task_config.clone(),
);
// 创建 MCP Client Manager如果启用
let mcp_manager = if mcp_config.has_enabled_servers() {
let manager = Arc::new(McpClientManager::new());
// 在 tokio runtime 中连接 MCP servers
// 使用 block_in_place 允许在同步上下文中执行异步代码
let servers = mcp_config.enabled_servers();
let servers_clone: Vec<_> = servers.into_iter().cloned().collect();
tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(async {
tracing::info!("Connecting to MCP servers...");
if let Err(e) = manager.connect_all(&servers_clone).await {
tracing::error!(error = %e, "Failed to connect to some MCP servers");
}
})
});
Some(manager)
} else {
None
};
// 将 MCP manager 添加到 factory
let factory = if let Some(ref manager) = mcp_manager {
factory.with_mcp_manager(manager.clone())
} else {
factory
};
// 创建 SubAgentRuntime如果 task 工具启用)
let (factory, task_repository): (_, Arc<dyn TaskRepository>) = if task_config.enabled {
let task_repository = Arc::new(InMemoryTaskRepository::new());
@ -128,7 +162,20 @@ pub(crate) fn build_session_manager_with_sender(
(factory, Arc::new(InMemoryTaskRepository::new()))
};
let tools = Arc::new(factory.build());
let mut tools = factory.build();
// 注册 MCP tools如果有 MCP manager
if let Some(manager) = &mcp_manager {
tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(async {
if let Err(e) = crate::mcp::register_mcp_tools(manager.clone(), &mut tools).await {
tracing::error!(error = %e, "Failed to register MCP tools");
}
})
});
}
let tools = Arc::new(tools);
let prompt_repository: Arc<dyn PromptInjectionRepository> = store.clone();
let agent_factory = AgentFactory::new(

View File

@ -495,6 +495,7 @@ impl SessionManager {
task_config: crate::config::TaskConfig,
maintenance_config: crate::config::MemoryMaintenanceConfig,
session_ttl_hours: Option<u64>,
mcp_config: crate::mcp::McpConfig,
) -> Result<Self, AgentError> {
super::runtime::build_session_manager(
agent_prompt_reinject_every,
@ -507,6 +508,7 @@ impl SessionManager {
task_config,
maintenance_config,
session_ttl_hours,
mcp_config,
)
.map(|(session_manager, _)| session_manager)
}

View File

@ -2,6 +2,7 @@ use std::collections::HashSet;
use std::sync::Arc;
use crate::config::TaskConfig;
use crate::mcp::McpClientManager;
use crate::skills::SkillRuntime;
use crate::storage::{MemoryRepository, SchedulerJobRepository, SkillEventRepository};
use crate::tools::{
@ -23,6 +24,7 @@ pub(crate) struct ToolRegistryFactory {
disabled_tools: HashSet<String>,
task_config: TaskConfig,
subagent_runtime: Option<Arc<dyn SubAgentRuntime>>,
mcp_manager: Option<Arc<McpClientManager>>,
}
impl ToolRegistryFactory {
@ -48,6 +50,7 @@ impl ToolRegistryFactory {
disabled_tools,
task_config,
subagent_runtime: None,
mcp_manager: None,
}
}
@ -59,6 +62,14 @@ impl ToolRegistryFactory {
self
}
pub(crate) fn with_mcp_manager(
mut self,
manager: Arc<McpClientManager>,
) -> Self {
self.mcp_manager = Some(manager);
self
}
fn is_enabled(&self, tool_name: &str) -> bool {
!self.disabled_tools.contains(tool_name)
}

View File

@ -9,6 +9,7 @@ pub mod config;
pub mod domain;
pub mod gateway;
pub mod logging;
pub mod mcp;
pub mod observability;
pub mod platform;
pub mod protocol;

229
src/mcp/client.rs Normal file
View File

@ -0,0 +1,229 @@
//! MCP Client Manager - manages connections to MCP servers
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use rmcp::{
model::{CallToolRequestParams, CallToolResult, ServerInfo, Tool},
RoleClient, ServiceExt,
service::RunningService,
transport::TokioChildProcess,
};
use tokio::process::Command;
use crate::mcp::config::{McpServerConfig, McpTransportConfig};
/// Type alias for the MCP client service
pub type McpClient = RunningService<RoleClient, ()>;
/// Information about a connected MCP server
#[derive(Debug, Clone)]
pub struct McpServerInfo {
/// Server name
pub name: String,
/// Server information from MCP protocol
pub info: Option<ServerInfo>,
/// Available tools
pub tools: Vec<Tool>,
}
/// Manager for MCP client connections
pub struct McpClientManager {
/// Connected clients keyed by server name
clients: RwLock<HashMap<String, Arc<McpClient>>>,
/// Server information cache
server_info: RwLock<HashMap<String, McpServerInfo>>,
}
impl McpClientManager {
/// Create a new manager
pub fn new() -> Self {
Self {
clients: RwLock::new(HashMap::new()),
server_info: RwLock::new(HashMap::new()),
}
}
/// Connect to all configured servers
pub async fn connect_all(&self, servers: &[McpServerConfig]) -> anyhow::Result<()> {
for server in servers {
if !server.enabled {
tracing::info!(name = %server.name, "Skipping disabled MCP server");
continue;
}
match self.connect_server(server).await {
Ok(info) => {
tracing::info!(
name = %server.name,
tools_count = info.tools.len(),
"Connected to MCP server"
);
}
Err(e) => {
tracing::error!(
name = %server.name,
error = %e,
"Failed to connect to MCP server"
);
}
}
}
Ok(())
}
/// Connect to a single MCP server
pub async fn connect_server(&self, config: &McpServerConfig) -> anyhow::Result<McpServerInfo> {
tracing::info!(name = %config.name, "Connecting to MCP server");
let client = match &config.transport {
McpTransportConfig::Stdio { command, args, env } => {
self.connect_stdio(command, args, env).await?
}
McpTransportConfig::Http { url, headers: _ } => {
// HTTP transport requires additional setup
// For now, we'll return an error for HTTP transport
return Err(anyhow::anyhow!(
"HTTP transport for MCP server '{}' is not yet implemented. URL: {}",
config.name,
url
));
}
};
// Get server info (returns Option<&ServerInfo>)
let info = client.peer_info().cloned();
// List available tools
let tools = client.list_all_tools().await?;
let server_info = McpServerInfo {
name: config.name.clone(),
info,
tools,
};
// Store the client and info
{
let mut clients = self.clients.write().await;
clients.insert(config.name.clone(), Arc::new(client));
}
{
let mut info_map = self.server_info.write().await;
info_map.insert(config.name.clone(), server_info.clone());
}
Ok(server_info)
}
/// Connect via stdio transport
async fn connect_stdio(
&self,
command: &str,
args: &[String],
env: &HashMap<String, String>,
) -> anyhow::Result<McpClient> {
let mut cmd = Command::new(command);
cmd.args(args);
// Set environment variables
for (key, value) in env {
cmd.env(key, value);
}
let transport = TokioChildProcess::new(cmd)?;
// Use default client handler (empty tuple)
let client = ().serve(transport).await?;
Ok(client)
}
/// Get a client by server name
pub async fn get_client(&self, name: &str) -> Option<Arc<McpClient>> {
let clients = self.clients.read().await;
clients.get(name).cloned()
}
/// Get server info by name
pub async fn get_server_info(&self, name: &str) -> Option<McpServerInfo> {
let info_map = self.server_info.read().await;
info_map.get(name).cloned()
}
/// Get all connected server names
pub async fn connected_servers(&self) -> Vec<String> {
let clients = self.clients.read().await;
clients.keys().cloned().collect()
}
/// Get all tools from all connected servers
pub async fn all_tools(&self) -> Vec<(String, Tool)> {
let info_map = self.server_info.read().await;
info_map
.values()
.flat_map(|info| {
info.tools.iter().map(|tool| (info.name.clone(), tool.clone()))
})
.collect()
}
/// Call a tool on a specific server
pub async fn call_tool(
&self,
server_name: impl Into<String>,
tool_name: impl Into<String>,
args: serde_json::Value,
) -> anyhow::Result<CallToolResult> {
let server_name = server_name.into();
let tool_name = tool_name.into();
let client = self
.get_client(&server_name)
.await
.ok_or_else(|| anyhow::anyhow!("MCP server '{}' not connected", server_name))?;
// Convert Value to JsonObject if it's an object
let arguments = if args.is_object() {
args.as_object().unwrap().clone()
} else {
// If not an object, wrap it or use empty object
serde_json::Map::new()
};
// Create params with owned String (converted to Cow<'static, str>)
let params = CallToolRequestParams::new(tool_name).with_arguments(arguments);
let result = client.call_tool(params).await?;
Ok(result)
}
/// Disconnect from a server
pub async fn disconnect(&self, name: impl Into<String>) -> anyhow::Result<()> {
let name = name.into();
let mut clients = self.clients.write().await;
if clients.remove(&name).is_some() {
tracing::info!(name = %name, "Disconnected MCP server");
}
self.server_info.write().await.remove(&name);
Ok(())
}
/// Disconnect from all servers
pub async fn disconnect_all(&self) -> anyhow::Result<()> {
let mut clients = self.clients.write().await;
for (name, _client) in clients.drain() {
tracing::info!(name = %name, "Disconnected MCP server");
}
self.server_info.write().await.clear();
Ok(())
}
}
impl Default for McpClientManager {
fn default() -> Self {
Self::new()
}
}

173
src/mcp/config.rs Normal file
View File

@ -0,0 +1,173 @@
//! MCP Server configuration structures
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// MCP integration configuration
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct McpConfig {
/// Whether MCP integration is enabled
#[serde(default)]
pub enabled: bool,
/// List of MCP servers to connect
#[serde(default)]
pub servers: Vec<McpServerConfig>,
}
/// Configuration for a single MCP server
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpServerConfig {
/// Unique name for this server (used in tool naming)
pub name: String,
/// Transport configuration
pub transport: McpTransportConfig,
/// Whether this server is enabled
#[serde(default = "default_server_enabled")]
pub enabled: bool,
/// Optional description for the server
#[serde(default)]
pub description: Option<String>,
}
fn default_server_enabled() -> bool {
true
}
/// Transport configuration for connecting to MCP servers
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum McpTransportConfig {
/// Stdio transport: spawn a child process
Stdio {
/// Command to execute (e.g., "npx", "cargo")
command: String,
/// Arguments to pass to the command
#[serde(default)]
args: Vec<String>,
/// Optional environment variables to set
#[serde(default)]
env: HashMap<String, String>,
},
/// HTTP transport: connect to a remote server
Http {
/// URL of the MCP server endpoint
url: String,
/// Optional headers to include in requests
#[serde(default)]
headers: HashMap<String, String>,
},
}
impl McpServerConfig {
/// Create a stdio server config
pub fn stdio(name: impl Into<String>, command: impl Into<String>, args: Vec<String>) -> Self {
Self {
name: name.into(),
transport: McpTransportConfig::Stdio {
command: command.into(),
args,
env: HashMap::new(),
},
enabled: true,
description: None,
}
}
/// Create an HTTP server config
pub fn http(name: impl Into<String>, url: impl Into<String>) -> Self {
Self {
name: name.into(),
transport: McpTransportConfig::Http {
url: url.into(),
headers: HashMap::new(),
},
enabled: true,
description: None,
}
}
}
impl McpConfig {
/// Get enabled servers
pub fn enabled_servers(&self) -> Vec<&McpServerConfig> {
self.servers.iter().filter(|s| s.enabled).collect()
}
/// Check if there are any enabled servers
pub fn has_enabled_servers(&self) -> bool {
self.enabled && self.servers.iter().any(|s| s.enabled)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stdio_config_creation() {
let config = McpServerConfig::stdio(
"filesystem",
"npx",
vec!["-y", "@modelcontextprotocol/server-filesystem", "/tmp"],
);
assert_eq!(config.name, "filesystem");
assert!(config.enabled);
assert!(matches!(config.transport, McpTransportConfig::Stdio { .. }));
}
#[test]
fn test_http_config_creation() {
let config = McpServerConfig::http("custom", "http://localhost:8000/mcp");
assert_eq!(config.name, "custom");
assert!(config.enabled);
assert!(matches!(config.transport, McpTransportConfig::Http { .. }));
}
#[test]
fn test_config_deserialization() {
let json = r#"{
"enabled": true,
"servers": [
{
"name": "filesystem",
"transport": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]
}
},
{
"name": "http-server",
"enabled": false,
"transport": {
"type": "http",
"url": "http://localhost:8000/mcp",
"headers": {
"Authorization": "Bearer token"
}
}
}
]
}"#;
let config: McpConfig = serde_json::from_str(json).unwrap();
assert!(config.enabled);
assert_eq!(config.servers.len(), 2);
assert_eq!(config.enabled_servers().len(), 1);
let fs_server = &config.servers[0];
assert_eq!(fs_server.name, "filesystem");
assert!(fs_server.enabled);
let http_server = &config.servers[1];
assert_eq!(http_server.name, "http-server");
assert!(!http_server.enabled);
}
}

12
src/mcp/mod.rs Normal file
View File

@ -0,0 +1,12 @@
//! MCP (Model Context Protocol) integration module
//!
//! This module provides MCP client functionality to connect to external MCP servers
//! and expose their tools through PicoBot's Tool system.
pub mod config;
pub mod client;
pub mod tool_adapter;
pub use config::{McpConfig, McpServerConfig, McpTransportConfig};
pub use client::{McpClientManager, McpClient, McpServerInfo};
pub use tool_adapter::{McpToolWrapper, register_mcp_tools};

186
src/mcp/tool_adapter.rs Normal file
View File

@ -0,0 +1,186 @@
//! MCP Tool Adapter - wraps MCP tools as PicoBot tools
use async_trait::async_trait;
use std::sync::Arc;
use rmcp::model::Tool;
use crate::mcp::client::McpClientManager;
use crate::tools::traits::{Tool as PicoBotTool, ToolResult};
/// Wrapper that adapts an MCP tool to PicoBot's Tool trait
pub struct McpToolWrapper {
/// The MCP client manager
manager: Arc<McpClientManager>,
/// The server name this tool belongs to
server_name: String,
/// The original tool name on the MCP server
tool_name: String,
/// The full tool name with namespace (mcp_{server}_{tool})
full_name: String,
/// Tool information from MCP server
tool_info: Tool,
}
impl McpToolWrapper {
/// Create a new tool wrapper
pub fn new(
manager: Arc<McpClientManager>,
server_name: String,
tool_info: Tool,
) -> Self {
let tool_name = tool_info.name.clone().into_owned();
let full_name = format!("mcp_{}_{}", server_name, tool_name);
Self {
manager,
server_name,
tool_name,
full_name,
tool_info,
}
}
/// Get the server name
pub fn server_name(&self) -> &str {
&self.server_name
}
/// Get the original tool name
pub fn original_name(&self) -> &str {
&self.tool_name
}
}
#[async_trait]
impl PicoBotTool for McpToolWrapper {
fn name(&self) -> &str {
&self.full_name
}
fn description(&self) -> &str {
self.tool_info.description.as_deref().unwrap_or("MCP tool")
}
fn parameters_schema(&self) -> serde_json::Value {
// Convert Arc<JsonObject> to serde_json::Value
let schema = (*self.tool_info.input_schema).clone();
serde_json::Value::Object(schema)
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
tracing::debug!(
server = %self.server_name,
tool = %self.tool_name,
"Calling MCP tool"
);
let result = self
.manager
.call_tool(&self.server_name, &self.tool_name, args)
.await?;
// Convert MCP CallToolResult to PicoBot ToolResult
let output = extract_text_content(&result);
let is_error = result.is_error.unwrap_or(false);
Ok(ToolResult {
success: !is_error,
output,
error: if is_error {
Some("MCP tool returned error".to_string())
} else {
None
},
})
}
fn read_only(&self) -> bool {
// MCP tools may or may not be read-only; we default to false
// This could be enhanced if MCP servers provide this info via annotations
false
}
}
/// Extract text content from MCP CallToolResult
fn extract_text_content(result: &rmcp::model::CallToolResult) -> String {
let mut text_parts = Vec::new();
for content in &result.content {
if let Some(text) = content.as_text() {
text_parts.push(text.text.clone());
}
}
if text_parts.is_empty() {
// No text content found, try to serialize the whole result
serde_json::to_string_pretty(&result).unwrap_or_else(|_| "Empty result".to_string())
} else {
text_parts.join("\n")
}
}
/// Register all MCP tools from connected servers into a tool registry
pub async fn register_mcp_tools(
manager: Arc<McpClientManager>,
registry: &mut crate::tools::registry::ToolRegistry,
) -> anyhow::Result<()> {
let all_tools = manager.all_tools().await;
for (server_name, tool_info) in all_tools {
let wrapper = McpToolWrapper::new(
manager.clone(),
server_name.clone(),
tool_info,
);
tracing::info!(
name = %wrapper.name(),
server = %server_name,
"Registering MCP tool"
);
registry.register(wrapper);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use rmcp::model::{CallToolResult, Content};
#[test]
fn test_extract_text_content_from_text() {
let result = CallToolResult::success(vec![
Content::text("Hello"),
Content::text("World"),
]);
let text = extract_text_content(&result);
assert_eq!(text, "Hello\nWorld");
}
#[test]
fn test_extract_text_content_empty() {
let result = CallToolResult::success(vec![]);
let text = extract_text_content(&result);
assert!(text.contains("Empty result"));
}
#[test]
fn test_mcp_tool_wrapper_name() {
let manager = Arc::new(McpClientManager::new());
let tool_info = Tool {
name: "echo".into(),
description: Some("Echo tool".into()),
input_schema: serde_json::json!({"type": "object"}).as_object().unwrap().clone(),
..Default::default()
};
let wrapper = McpToolWrapper::new(manager, "filesystem".to_string(), tool_info);
assert_eq!(wrapper.name(), "mcp_filesystem_echo");
assert_eq!(wrapper.original_name(), "echo");
assert_eq!(wrapper.server_name(), "filesystem");
}
}