diff --git a/Cargo.toml b/Cargo.toml index a435293..dbfe05a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,3 +38,10 @@ hostname = "0.3" sqlx = { version = "0.8", features = ["sqlite", "macros", "chrono", "runtime-tokio"] } jieba-rs = "0.9" which = "7" +rmcp = { version = "1.6", default-features = false, features = [ + "client", + "transport-child-process", + "transport-streamable-http-client-reqwest", + "which-command", +] } +http = "1" diff --git a/src/config/mod.rs b/src/config/mod.rs index a39e39f..35e34ca 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -53,6 +53,8 @@ pub struct Config { pub memory: MemoryConfig, #[serde(default = "default_workspace_dir")] pub workspace_dir: String, + #[serde(default)] + pub mcp: McpConfig, } fn default_workspace_dir() -> String { @@ -269,6 +271,59 @@ impl MemoryConfig { } } +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct McpConfig { + #[serde(default)] + pub servers: Vec, + #[serde(default = "default_mcp_tool_timeout_secs")] + pub tool_timeout_secs: u64, +} + +impl Default for McpConfig { + fn default() -> Self { + Self { + servers: Vec::new(), + tool_timeout_secs: default_mcp_tool_timeout_secs(), + } + } +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct McpServerConfig { + pub name: String, + #[serde(default = "default_mcp_transport")] + pub transport: McpTransport, + #[serde(default)] + pub command: Option, + #[serde(default)] + pub args: Vec, + #[serde(default)] + pub env: HashMap, + #[serde(default)] + pub url: Option, + #[serde(default)] + pub headers: HashMap, + #[serde(default)] + pub tool_timeout_secs: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "kebab-case")] +pub enum McpTransport { + Stdio, + Sse, + #[serde(alias = "streamable-http")] + StreamableHttp, +} + +fn default_mcp_transport() -> McpTransport { + McpTransport::Stdio +} + +fn default_mcp_tool_timeout_secs() -> u64 { + 180 +} + fn default_recall_limit() -> usize { 5 } fn default_idle_consolidation_minutes() -> u64 { 10 } fn default_timeline_retention_days() -> u64 { 90 } diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index a3ff44e..cf9354c 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -10,6 +10,7 @@ use crate::channels::{ChannelManager, CliChatChannel}; use crate::channels::base::{Channel, ChannelError}; use crate::config::{Config, expand_path, ensure_workspace_dir}; use crate::logging; +use crate::mcp; use crate::memory::MemoryManager; use crate::session::SessionManager; use crate::scheduler::Scheduler; @@ -102,6 +103,21 @@ impl GatewayState { crate::tools::ChatManagerTool::new(storage.clone(), valid_channels.clone()), ); + // Initialize MCP servers — connect and register discovered tools + if !config.mcp.servers.is_empty() { + let mcp_tools = mcp::connect_all(&config.mcp).await; + for tool_info in mcp_tools { + let wrapper = mcp::McpToolWrapper::new( + &tool_info.server_name, + tool_info.tool_name, + tool_info.description, + tool_info.schema, + tool_info.connection, + ); + session_manager.tools().register(wrapper); + } + } + // Initialize scheduler if enabled in config let scheduler_config = config.gateway.scheduler.clone().unwrap_or_default(); if scheduler_config.enabled { diff --git a/src/lib.rs b/src/lib.rs index cebf6b2..6e651ba 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,6 +8,7 @@ pub mod client; pub mod protocol; pub mod channels; pub mod logging; +pub mod mcp; pub mod memory; pub mod observability; pub mod scheduler; diff --git a/src/mcp/mod.rs b/src/mcp/mod.rs new file mode 100644 index 0000000..1d2df63 --- /dev/null +++ b/src/mcp/mod.rs @@ -0,0 +1,303 @@ +pub mod tool_wrapper; + +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; + +use anyhow::Context; +use http::{HeaderName, HeaderValue}; +use rmcp::model::{CallToolRequestParams, RawContent}; +use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig; +use rmcp::transport::{StreamableHttpClientTransport, TokioChildProcess}; +use rmcp::{Peer, RoleClient, ServiceExt}; +use tokio::process::Command; + +use crate::config::{McpConfig, McpServerConfig, McpTransport}; +use crate::tools::ToolResult; + +pub use tool_wrapper::McpToolWrapper; + +/// Status of a single MCP tool. +#[derive(Debug, Clone)] +pub struct McpToolStatus { + pub name: String, + pub description: String, +} + +/// Status of a single MCP server. +#[derive(Debug, Clone)] +pub struct McpServerStatus { + pub name: String, + pub transport: String, + pub connected: bool, + pub error: Option, + pub tools: Vec, +} + +static MCP_SERVER_STATUS: Mutex> = Mutex::new(Vec::new()); + +pub fn get_mcp_status() -> Vec { + MCP_SERVER_STATUS.lock().unwrap().clone() +} + +fn update_mcp_status(servers: Vec) { + let mut status = MCP_SERVER_STATUS.lock().unwrap(); + *status = servers; +} + +/// A connected MCP server. Holds a clonable Peer handle for tool calls, +/// and keeps the underlying service alive via a background task. +pub struct McpConnection { + #[allow(dead_code)] + pub name: String, + peer: Peer, + /// Keep the service alive. When dropped, the MCP connection is closed. + _service: Option>, +} + +impl McpConnection { + pub async fn call_tool( + &self, + tool_name: &str, + arguments: serde_json::Value, + ) -> anyhow::Result { + let result = self + .peer + .call_tool( + CallToolRequestParams::new(tool_name.to_string()) + .with_arguments(arguments.as_object().cloned().unwrap_or_default()), + ) + .await + .context("MCP tool call failed")?; + + let is_error = result.is_error.unwrap_or(false); + let output = extract_text(&result); + + Ok(ToolResult { + success: !is_error, + output, + error: if is_error { + Some("MCP server returned an error".to_string()) + } else { + None + }, + }) + } +} + +fn extract_text(result: &rmcp::model::CallToolResult) -> String { + let mut parts = Vec::new(); + for content in &result.content { + match &**content { + RawContent::Text(text) => { + parts.push(text.text.clone()); + } + RawContent::Image(image) => { + parts.push(format!( + "[image: {}]", + image.mime_type, + )); + } + RawContent::Resource(resource) => { + match &resource.resource { + rmcp::model::ResourceContents::TextResourceContents { text, .. } => { + parts.push(format!( + "[resource text: {}]", + text.chars().take(200).collect::(), + )); + } + rmcp::model::ResourceContents::BlobResourceContents { uri, .. } => { + parts.push(format!("[resource blob: {}]", uri)); + } + } + } + _ => { + parts.push("[unsupported content]".to_string()); + } + } + } + if parts.is_empty() { + String::new() + } else { + parts.join("\n") + } +} + +pub struct ToolInfo { + pub server_name: String, + pub tool_name: String, + pub description: String, + pub schema: serde_json::Value, + pub connection: Arc, +} + +pub async fn connect_all(config: &McpConfig) -> Vec { + let mut tools = Vec::new(); + let mut server_statuses = Vec::new(); + + for server_config in &config.servers { + let transport_str = match server_config.transport { + McpTransport::Stdio => "stdio", + McpTransport::Sse => "sse", + McpTransport::StreamableHttp => "streamable-http", + }; + + match connect_server(server_config).await { + Ok(connection) => { + let connection = Arc::new(connection); + match list_tools(&connection).await { + Ok(server_tools) => { + tracing::info!( + server = %server_config.name, + count = server_tools.len(), + "MCP server connected" + ); + let tool_statuses: Vec = server_tools + .iter() + .map(|(name, desc, _)| McpToolStatus { + name: name.clone(), + description: desc.clone(), + }) + .collect(); + server_statuses.push(McpServerStatus { + name: server_config.name.clone(), + transport: transport_str.to_string(), + connected: true, + error: None, + tools: tool_statuses, + }); + for (orig_name, desc, schema) in server_tools { + tools.push(ToolInfo { + server_name: server_config.name.clone(), + tool_name: orig_name, + description: desc, + schema, + connection: connection.clone(), + }); + } + } + Err(e) => { + tracing::error!( + server = %server_config.name, + error = %e, + "Failed to list MCP tools" + ); + server_statuses.push(McpServerStatus { + name: server_config.name.clone(), + transport: transport_str.to_string(), + connected: false, + error: Some(e.to_string()), + tools: Vec::new(), + }); + } + } + } + Err(e) => { + tracing::error!( + server = %server_config.name, + error = %e, + "Failed to connect to MCP server" + ); + server_statuses.push(McpServerStatus { + name: server_config.name.clone(), + transport: transport_str.to_string(), + connected: false, + error: Some(e.to_string()), + tools: Vec::new(), + }); + } + } + } + + update_mcp_status(server_statuses); + tools +} + +async fn connect_server(config: &McpServerConfig) -> anyhow::Result { + match config.transport { + McpTransport::Stdio => { + let command = config + .command + .as_ref() + .context("stdio transport requires 'command'")?; + let mut cmd = Command::new(command); + cmd.args(&config.args); + for (k, v) in &config.env { + cmd.env(k, v); + } + + let service = () + .serve( + TokioChildProcess::new(cmd).context("failed to create stdio MCP transport")?, + ) + .await + .context("failed to connect to stdio MCP server")?; + + let peer = service.peer().clone(); + + Ok(McpConnection { + name: config.name.clone(), + peer, + _service: Some(Box::new(service)), + }) + } + McpTransport::Sse | McpTransport::StreamableHttp => { + let url = config + .url + .as_ref() + .context("sse/streamable-http transport requires 'url'")?; + + let mut headers_map = HashMap::new(); + for (k, v) in &config.headers { + if let (Ok(name), Ok(value)) = ( + HeaderName::from_bytes(k.as_bytes()), + HeaderValue::from_str(v), + ) { + headers_map.insert(name, value); + } + } + + let transport = if headers_map.is_empty() { + StreamableHttpClientTransport::from_uri(url.to_string()) + } else { + StreamableHttpClientTransport::from_config( + StreamableHttpClientTransportConfig::with_uri(url.to_string()) + .custom_headers(headers_map) + ) + }; + + let service = () + .serve(transport) + .await + .context("failed to connect to HTTP/SSE MCP server")?; + + let peer = service.peer().clone(); + + Ok(McpConnection { + name: config.name.clone(), + peer, + _service: Some(Box::new(service)), + }) + } + } +} + +async fn list_tools( + connection: &McpConnection, +) -> anyhow::Result> { + let tools = connection + .peer + .list_all_tools() + .await + .context("failed to list MCP tools")?; + + Ok(tools + .into_iter() + .map(|tool| { + ( + tool.name.to_string(), + tool.description.map(|d| d.to_string()).unwrap_or_default(), + serde_json::Value::Object((*tool.input_schema).clone()), + ) + }) + .collect()) +} diff --git a/src/mcp/tool_wrapper.rs b/src/mcp/tool_wrapper.rs new file mode 100644 index 0000000..5b59127 --- /dev/null +++ b/src/mcp/tool_wrapper.rs @@ -0,0 +1,54 @@ +use std::sync::Arc; + +use async_trait::async_trait; + +use crate::tools::{Tool, ToolResult}; + +use super::McpConnection; + +pub struct McpToolWrapper { + full_name: String, + description: String, + parameters_schema: serde_json::Value, + original_tool_name: String, + connection: Arc, +} + +impl McpToolWrapper { + pub fn new( + server_name: &str, + original_tool_name: String, + description: String, + parameters_schema: serde_json::Value, + connection: Arc, + ) -> Self { + Self { + full_name: format!("{}__{}", server_name, original_tool_name), + description, + parameters_schema, + original_tool_name, + connection, + } + } +} + +#[async_trait] +impl Tool for McpToolWrapper { + fn name(&self) -> &str { + &self.full_name + } + + fn description(&self) -> &str { + &self.description + } + + fn parameters_schema(&self) -> serde_json::Value { + self.parameters_schema.clone() + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + self.connection + .call_tool(&self.original_tool_name, args) + .await + } +} diff --git a/src/session/session.rs b/src/session/session.rs index 9cd174c..0920060 100644 --- a/src/session/session.rs +++ b/src/session/session.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use tokio::sync::Mutex; use crate::bus::{ChatMessage, MediaItem, MessageSource, OutboundMessage, SourceKind}; +use crate::mcp::get_mcp_status; use crate::storage::{Storage, StorageError}; use std::sync::Arc as StdArc; @@ -799,6 +800,11 @@ pub static SLASH_COMMANDS: &[SlashCommand] = &[ description: "显示帮助", aliases: &["/?", "/help"], }, + SlashCommand { + name: "mcp", + description: "显示 MCP 服务状态和工具列表", + aliases: &["/mcp"], + }, ]; impl SessionManager { @@ -987,6 +993,34 @@ impl SessionManager { }).collect(); Ok((None, format!("可用命令:\n{}", lines.join("\n")))) } + "mcp" => { + let servers = get_mcp_status(); + if servers.is_empty() { + return Ok((None, "未配置 MCP 服务。".to_string())); + } + let lines: Vec = servers.iter().map(|s| { + let status = if s.connected { + format!("✅ 已连接 ({})", s.transport) + } else { + format!("❌ 连接失败: {}", s.error.as_deref().unwrap_or("未知错误")) + }; + let tool_lines: Vec = s.tools.iter().map(|t| { + let desc = if t.description.is_empty() { + "无描述".to_string() + } else { + t.description.chars().take(60).collect::() + }; + format!(" - {}: {}", t.name, desc) + }).collect(); + let tools_section = if tool_lines.is_empty() { + String::new() + } else { + format!("\n{}", tool_lines.join("\n")) + }; + format!("{} {}{}", s.name, status, tools_section) + }).collect(); + Ok((None, format!("MCP 服务:\n\n{}", lines.join("\n\n")))) + } _ => Err(AgentError::Other(format!("未知命令:/{}。输入 /? 获取帮助。", cmd.name))), } }