Compare commits
No commits in common. "4605c2dad38266872c81edf1e3f47c83a2839a80" and "ef7e899584e0af0109524509439443aa519ed9f3" have entirely different histories.
4605c2dad3
...
ef7e899584
75
README.md
75
README.md
@ -534,73 +534,58 @@ PicoBot 的 Agent 是围绕工具调用构建的。当前默认注册的工具
|
||||
|
||||
### 8.1 MCP 工具集成
|
||||
|
||||
PicoBot 支持通过 MCP (Model Context Protocol) 扩展工具能力,可以连接外部 MCP servers 并自动发现其提供的工具。配置格式兼容 Claude Desktop / Cursor。
|
||||
PicoBot 支持通过 MCP (Model Context Protocol) 扩展工具能力,可以连接外部 MCP servers 并自动发现其提供的工具。
|
||||
|
||||
**支持的 Transport 类型:**
|
||||
|
||||
| Transport | type 值 | 说明 | 适用场景 |
|
||||
|-----------|---------|------|----------|
|
||||
| **Stdio** | `stdio` | 启动子进程,通过 stdin/stdout 通信 | 本地 MCP servers(如 npm 包) |
|
||||
| **HTTP** | `streamableHttp` 或 `http` | 通过 HTTP/SSE 连接远程服务器 | 远程 MCP servers、云服务 |
|
||||
| Transport | 说明 | 适用场景 |
|
||||
|-----------|------|----------|
|
||||
| **Stdio** | 启动子进程,通过 stdin/stdout 通信 | 本地 MCP servers(如 npm 包) |
|
||||
| **HTTP** | 通过 HTTP/SSE 连接远程服务器 | 远程 MCP servers、云服务 |
|
||||
|
||||
**配置示例:**
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"filesystem": {
|
||||
"mcp": {
|
||||
"enabled": true,
|
||||
"servers": [
|
||||
{
|
||||
"name": "filesystem",
|
||||
"enabled": true,
|
||||
"description": "本地文件系统操作",
|
||||
"transport": {
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/home/user"],
|
||||
"isActive": true
|
||||
"env": {}
|
||||
}
|
||||
},
|
||||
"WebSearch": {
|
||||
"type": "streamableHttp",
|
||||
"baseUrl": "https://dashscope.aliyuncs.com/api/v1/mcps/WebSearch/mcp",
|
||||
"headers": {
|
||||
"Authorization": "Bearer ${DASHSCOPE_API_KEY}"
|
||||
},
|
||||
"isActive": true,
|
||||
"name": "AliyunBailianMCP_WebSearch"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**配置字段说明:**
|
||||
|
||||
| 字段 | 说明 | 必填 |
|
||||
|------|------|------|
|
||||
| `type` | Transport 类型:`stdio` 或 `streamableHttp`/`http` | 是 |
|
||||
| `isActive` | 是否启用此 server(默认 `true`) | 否 |
|
||||
| `name` | Server 显示名称(默认使用 map key) | 否 |
|
||||
| `command` | Stdio: 要执行的命令 | Stdio 必填 |
|
||||
| `args` | Stdio: 命令参数 | 否 |
|
||||
| `env` | Stdio: 环境变量 | 否 |
|
||||
| `baseUrl` | HTTP: MCP server URL | HTTP 必填 |
|
||||
| `headers` | HTTP: 自定义请求头(支持 `${ENV_VAR}` 占位符) | 否 |
|
||||
|
||||
**环境变量占位符:**
|
||||
|
||||
配置中支持两种占位符语法:
|
||||
- `${ENV_VAR}` - Claude Desktop 风格,推荐用于 MCP headers
|
||||
- `<ENV_VAR>` - PicoBot 原有风格,用于其他配置项
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "remote-tools",
|
||||
"enabled": true,
|
||||
"description": "远程 MCP server",
|
||||
"transport": {
|
||||
"type": "http",
|
||||
"url": "http://api.example.com/mcp",
|
||||
"headers": {
|
||||
"Authorization": "Bearer ${API_KEY}"
|
||||
"Authorization": "Bearer your-token"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**工具命名规则:**
|
||||
|
||||
MCP 工具会自动注册到 ToolRegistry,命名格式为 `mcp_{server_key}_{tool_name}`:
|
||||
MCP 工具会自动注册到 ToolRegistry,命名格式为 `mcp_{server_name}_{tool_name}`:
|
||||
|
||||
- `mcp_filesystem_read_file` - server key 为 "filesystem"
|
||||
- `mcp_filesystem_read_file`
|
||||
- `mcp_filesystem_write_file`
|
||||
- `mcp_WebSearch_search` - server key 为 "WebSearch"
|
||||
- `mcp_filesystem_list_directory`
|
||||
- `mcp_remote-tools_custom_query`
|
||||
|
||||
**架构特点:**
|
||||
|
||||
|
||||
@ -76,7 +76,7 @@ impl InitWizard {
|
||||
skills: crate::config::SkillsConfig::default(),
|
||||
tools: crate::config::ToolsConfig::default(),
|
||||
memory_maintenance: crate::config::MemoryMaintenanceConfig::default(),
|
||||
mcp_servers: HashMap::new(),
|
||||
mcp: crate::mcp::McpConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -827,7 +827,7 @@ impl InitWizard {
|
||||
skills: existing.skills.clone(),
|
||||
tools: existing.tools.clone(),
|
||||
memory_maintenance: existing.memory_maintenance.clone(),
|
||||
mcp_servers: existing.mcp_servers.clone(),
|
||||
mcp: existing.mcp.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -114,21 +114,10 @@ async fn handle_get_current_session(
|
||||
let last_active = format_time_ago(topic.last_active_at);
|
||||
let created_at = format_time_ago(topic.created_at);
|
||||
|
||||
let description_line = if let Some(ref desc) = topic.description {
|
||||
if !desc.is_empty() {
|
||||
format!("\n Description: {}", desc)
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let message = format!(
|
||||
"Current Topic:\n\n Topic ID: {}\n Title: {}{}\n Messages: {}\n Tokens: ~{} (系统提示词: ~{}, 用户消息: ~{})\n Created: {}\n Last Active: {}",
|
||||
"Current Topic:\n\n Topic ID: {}\n Title: {}\n Messages: {}\n Tokens: ~{} (系统提示词: ~{}, 用户消息: ~{})\n Created: {}\n Last Active: {}",
|
||||
topic.id,
|
||||
topic.title,
|
||||
description_line,
|
||||
actual_message_count,
|
||||
total_tokens,
|
||||
system_prompt_tokens,
|
||||
|
||||
@ -80,13 +80,6 @@ async fn handle_list_sessions(
|
||||
"{}. {}{} ({})",
|
||||
num, topic.title, marker, msg_count
|
||||
));
|
||||
|
||||
// 显示描述(如果有)
|
||||
if let Some(ref desc) = topic.description {
|
||||
if !desc.is_empty() {
|
||||
lines.push(format!(" {}", desc));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
lines.push(String::new());
|
||||
|
||||
@ -29,8 +29,8 @@ pub struct Config {
|
||||
pub tools: ToolsConfig,
|
||||
#[serde(default)]
|
||||
pub memory_maintenance: MemoryMaintenanceConfig,
|
||||
#[serde(default, rename = "mcpServers")]
|
||||
pub mcp_servers: HashMap<String, crate::mcp::McpServerConfig>,
|
||||
#[serde(default)]
|
||||
pub mcp: crate::mcp::McpConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
@ -813,15 +813,6 @@ impl Config {
|
||||
let content = resolve_env_placeholders(&content);
|
||||
let config: Config = serde_json::from_str(&content)?;
|
||||
config.time.parse_timezone()?;
|
||||
|
||||
// Log MCP servers count if any
|
||||
if !config.mcp_servers.is_empty() {
|
||||
tracing::info!(
|
||||
mcp_servers = config.mcp_servers.len(),
|
||||
"MCP servers loaded from config"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
@ -916,16 +907,8 @@ fn load_env_file() -> Result<(), Box<dyn std::error::Error>> {
|
||||
}
|
||||
|
||||
fn resolve_env_placeholders(content: &str) -> String {
|
||||
// Support both ${ENV_VAR} (Claude Desktop style) and <ENV_VAR> (legacy style)
|
||||
let re_braces = Regex::new(r"\$\{([A-Z_][A-Z0-9_]*)\}").expect("invalid regex");
|
||||
let re_angle = Regex::new(r"<([A-Z_]+)>").expect("invalid regex");
|
||||
|
||||
let content = re_braces.replace_all(content, |caps: ®ex::Captures| {
|
||||
let var_name = &caps[1];
|
||||
env::var(var_name).unwrap_or_else(|_| caps[0].to_string())
|
||||
});
|
||||
|
||||
re_angle.replace_all(&content, |caps: ®ex::Captures| {
|
||||
let re = Regex::new(r"<([A-Z_]+)>").expect("invalid regex");
|
||||
re.replace_all(content, |caps: ®ex::Captures| {
|
||||
let var_name = &caps[1];
|
||||
env::var(var_name).unwrap_or_else(|_| caps[0].to_string())
|
||||
})
|
||||
@ -1976,145 +1959,4 @@ mod tests {
|
||||
.is_err()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_env_placeholders_brace_syntax() {
|
||||
// Test ${ENV_VAR} syntax (Claude Desktop style)
|
||||
unsafe { env::set_var("TEST_API_KEY", "my-secret-key") };
|
||||
|
||||
let content = r#"{"api_key": "${TEST_API_KEY}", "other": "${MISSING_VAR}"}"#;
|
||||
let resolved = resolve_env_placeholders(content);
|
||||
|
||||
assert!(resolved.contains("my-secret-key"));
|
||||
assert!(resolved.contains("${MISSING_VAR}")); // Unresolved stays as-is
|
||||
|
||||
// Clean up
|
||||
unsafe { env::remove_var("TEST_API_KEY") };
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_env_placeholders_angle_syntax() {
|
||||
// Test <ENV_VAR> syntax (legacy style)
|
||||
unsafe { env::set_var("LEGACY_KEY", "legacy-value") };
|
||||
|
||||
let content = r#"{"api_key": "<LEGACY_KEY>", "other": "<MISSING>"}"#;
|
||||
let resolved = resolve_env_placeholders(content);
|
||||
|
||||
assert!(resolved.contains("legacy-value"));
|
||||
assert!(resolved.contains("<MISSING>")); // Unresolved stays as-is
|
||||
|
||||
// Clean up
|
||||
unsafe { env::remove_var("LEGACY_KEY") };
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_env_placeholders_mixed_syntax() {
|
||||
// Test both syntaxes in the same content
|
||||
unsafe { env::set_var("BRACE_VAR", "brace-value") };
|
||||
unsafe { env::set_var("ANGLE_VAR", "angle-value") };
|
||||
|
||||
let content = r#"{"brace": "${BRACE_VAR}", "angle": "<ANGLE_VAR>"}"#;
|
||||
let resolved = resolve_env_placeholders(content);
|
||||
|
||||
assert!(resolved.contains("brace-value"));
|
||||
assert!(resolved.contains("angle-value"));
|
||||
|
||||
// Clean up
|
||||
unsafe { env::remove_var("BRACE_VAR") };
|
||||
unsafe { env::remove_var("ANGLE_VAR") };
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_root_level_mcp_servers_merging() {
|
||||
// Test that mcpServers at root level is loaded correctly
|
||||
let file = tempfile::NamedTempFile::new().unwrap();
|
||||
std::fs::write(
|
||||
file.path(),
|
||||
r#"{
|
||||
"providers": {
|
||||
"aliyun": {
|
||||
"type": "openai",
|
||||
"base_url": "https://example.invalid/v1",
|
||||
"api_key": "test-key",
|
||||
"extra_headers": {}
|
||||
}
|
||||
},
|
||||
"models": {
|
||||
"qwen-plus": {
|
||||
"model_id": "qwen-plus"
|
||||
}
|
||||
},
|
||||
"agents": {
|
||||
"default": {
|
||||
"provider": "aliyun",
|
||||
"model": "qwen-plus"
|
||||
}
|
||||
},
|
||||
"mcpServers": {
|
||||
"WebSearch": {
|
||||
"type": "streamableHttp",
|
||||
"baseUrl": "https://api.example.com/mcp",
|
||||
"isActive": true
|
||||
},
|
||||
"filesystem": {
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"isActive": true
|
||||
}
|
||||
}
|
||||
}"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||
|
||||
// Should have 2 servers
|
||||
assert_eq!(config.mcp_servers.len(), 2);
|
||||
assert!(config.mcp_servers.contains_key("WebSearch"));
|
||||
assert!(config.mcp_servers.contains_key("filesystem"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_root_level_mcp_servers_only() {
|
||||
// Test that mcpServers at root level works
|
||||
let file = tempfile::NamedTempFile::new().unwrap();
|
||||
std::fs::write(
|
||||
file.path(),
|
||||
r#"{
|
||||
"providers": {
|
||||
"aliyun": {
|
||||
"type": "openai",
|
||||
"base_url": "https://example.invalid/v1",
|
||||
"api_key": "test-key",
|
||||
"extra_headers": {}
|
||||
}
|
||||
},
|
||||
"models": {
|
||||
"qwen-plus": {
|
||||
"model_id": "qwen-plus"
|
||||
}
|
||||
},
|
||||
"agents": {
|
||||
"default": {
|
||||
"provider": "aliyun",
|
||||
"model": "qwen-plus"
|
||||
}
|
||||
},
|
||||
"mcpServers": {
|
||||
"WebSearch": {
|
||||
"type": "streamableHttp",
|
||||
"baseUrl": "https://api.example.com/mcp",
|
||||
"isActive": true
|
||||
}
|
||||
}
|
||||
}"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||
|
||||
// Should have 1 server from root level
|
||||
assert_eq!(config.mcp_servers.len(), 1);
|
||||
assert!(config.mcp_servers.contains_key("WebSearch"));
|
||||
}
|
||||
}
|
||||
|
||||
@ -72,10 +72,6 @@ impl GatewayState {
|
||||
let channel_manager = ChannelManager::new();
|
||||
let bus = channel_manager.bus();
|
||||
|
||||
let mcp_config = crate::mcp::McpConfig {
|
||||
mcp_servers: config.mcp_servers.clone(),
|
||||
};
|
||||
|
||||
let (session_manager, task_repository) = build_session_manager_with_sender(
|
||||
agent_prompt_reinject_every,
|
||||
show_tool_results,
|
||||
@ -88,7 +84,7 @@ impl GatewayState {
|
||||
config.tools.task.clone(),
|
||||
config.memory_maintenance.clone(),
|
||||
session_ttl_hours,
|
||||
mcp_config,
|
||||
config.mcp.clone(),
|
||||
)?;
|
||||
|
||||
Ok(Self {
|
||||
|
||||
@ -17,10 +17,8 @@ use crate::command::handlers::session::SessionCommandHandler;
|
||||
use crate::command::handlers::switch_session::SwitchSessionCommandHandler;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::gateway::agent_prompt_provider::AgentPromptProvider;
|
||||
use crate::providers::{create_provider, ProviderRuntimeConfig};
|
||||
use crate::skills::SkillPromptProvider;
|
||||
use crate::storage::persistent_session_id;
|
||||
use crate::topic_description::generate_topic_description;
|
||||
|
||||
use super::session::{BusToolCallEmitter, SessionManager};
|
||||
|
||||
@ -29,7 +27,7 @@ pub struct InboundProcessor {
|
||||
bus: Arc<MessageBus>,
|
||||
session_manager: SessionManager,
|
||||
semaphore: Arc<Semaphore>,
|
||||
provider_config: LLMProviderConfig,
|
||||
_provider_config: LLMProviderConfig,
|
||||
command_router: Arc<CommandRouter>,
|
||||
}
|
||||
|
||||
@ -101,7 +99,7 @@ impl InboundProcessor {
|
||||
bus,
|
||||
session_manager,
|
||||
semaphore,
|
||||
provider_config,
|
||||
_provider_config: provider_config,
|
||||
command_router: Arc::new(command_router),
|
||||
}
|
||||
}
|
||||
@ -245,37 +243,6 @@ impl InboundProcessor {
|
||||
tracing::error!(error = %error, "Failed to publish outbound");
|
||||
}
|
||||
}
|
||||
|
||||
// 异步生成 topic 描述(仅第一条消息后触发一次)
|
||||
if let Some(ref topic_id) = current_topic {
|
||||
let store = self.session_manager.store();
|
||||
if let Ok(Some(topic)) = store.get_topic(topic_id) {
|
||||
if topic.description.is_none() || topic.description.as_ref().map(|d| d.is_empty()).unwrap_or(true) {
|
||||
let provider_config = self.provider_config.clone();
|
||||
let topic_id_clone = topic_id.clone();
|
||||
let first_message = inbound.content.clone();
|
||||
let store_clone = store.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let runtime_config: ProviderRuntimeConfig = provider_config.into();
|
||||
if let Ok(provider) = create_provider(runtime_config) {
|
||||
match generate_topic_description(provider.as_ref(), &first_message).await {
|
||||
Ok(description) => {
|
||||
if let Err(e) = store_clone.update_topic_description(&topic_id_clone, &description) {
|
||||
tracing::error!(error = %e, topic_id = %topic_id_clone, "Failed to update topic description");
|
||||
} else {
|
||||
tracing::info!(topic_id = %topic_id_clone, description = %description, "Topic description generated");
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, topic_id = %topic_id_clone, "Failed to generate topic description");
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(error) => {
|
||||
tracing::error!(error = %error, "Failed to handle message");
|
||||
|
||||
@ -112,7 +112,7 @@ pub(crate) fn build_session_manager_with_sender(
|
||||
|
||||
// Create MCP Initializer (async, non-blocking)
|
||||
// MCP servers connect in background task
|
||||
let mut mcp_initializer = McpInitializer::with_config(mcp_config);
|
||||
let mcp_initializer = McpInitializer::with_config(mcp_config);
|
||||
|
||||
// Add MCP manager to factory (if enabled)
|
||||
let factory = if let Some(manager) = mcp_initializer.manager() {
|
||||
|
||||
@ -966,7 +966,6 @@ mod tests {
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(24),
|
||||
crate::mcp::McpConfig::default(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@ -1019,7 +1018,6 @@ mod tests {
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(24),
|
||||
crate::mcp::McpConfig::default(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@ -1086,7 +1084,6 @@ mod tests {
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(24),
|
||||
crate::mcp::McpConfig::default(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@ -1171,7 +1168,6 @@ mod tests {
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(24),
|
||||
crate::mcp::McpConfig::default(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@ -1258,7 +1254,6 @@ mod tests {
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(24),
|
||||
crate::mcp::McpConfig::default(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@ -1344,7 +1339,6 @@ mod tests {
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(24),
|
||||
crate::mcp::McpConfig::default(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@ -1412,7 +1406,6 @@ mod tests {
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(24),
|
||||
crate::mcp::McpConfig::default(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@ -1489,7 +1482,6 @@ mod tests {
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(24),
|
||||
crate::mcp::McpConfig::default(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@ -1553,7 +1545,6 @@ mod tests {
|
||||
crate::config::TaskConfig::default(),
|
||||
crate::config::MemoryMaintenanceConfig::default(),
|
||||
Some(24),
|
||||
crate::mcp::McpConfig::default(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
|
||||
@ -18,5 +18,4 @@ pub mod scheduler;
|
||||
pub mod skills;
|
||||
pub mod storage;
|
||||
pub mod text;
|
||||
pub mod topic_description;
|
||||
pub mod tools;
|
||||
|
||||
@ -21,17 +21,6 @@ use http::{HeaderName, HeaderValue};
|
||||
use tokio::process::Command;
|
||||
|
||||
use crate::mcp::config::{McpServerConfig, McpTransportConfig};
|
||||
use std::env;
|
||||
|
||||
/// Resolve ${ENV_VAR} placeholders in a value string
|
||||
fn resolve_env_placeholders_in_value(value: &str) -> String {
|
||||
let re = regex::Regex::new(r"\$\{([A-Z_][A-Z0-9_]*)\}").expect("invalid regex");
|
||||
re.replace_all(value, |caps: ®ex::Captures| {
|
||||
let var_name = &caps[1];
|
||||
env::var(var_name).unwrap_or_else(|_| caps[0].to_string())
|
||||
})
|
||||
.to_string()
|
||||
}
|
||||
|
||||
/// Type alias for the MCP client service
|
||||
pub type McpClient = RunningService<RoleClient, ()>;
|
||||
@ -39,10 +28,8 @@ pub type McpClient = RunningService<RoleClient, ()>;
|
||||
/// Information about a connected MCP server
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct McpServerInfo {
|
||||
/// Server name (effective name from config)
|
||||
/// Server name
|
||||
pub name: String,
|
||||
/// Server key (the key in mcpServers map)
|
||||
pub key: String,
|
||||
/// Server information from MCP protocol
|
||||
pub info: Option<ServerInfo>,
|
||||
/// Available tools
|
||||
@ -57,9 +44,9 @@ pub struct McpServerInfo {
|
||||
/// - Calling tools on connected servers
|
||||
/// - Connection lifecycle management
|
||||
pub struct McpClientManager {
|
||||
/// Connected clients keyed by server key
|
||||
/// Connected clients keyed by server name
|
||||
clients: RwLock<HashMap<String, Arc<McpClient>>>,
|
||||
/// Server information cache keyed by server key
|
||||
/// Server information cache
|
||||
server_info: RwLock<HashMap<String, McpServerInfo>>,
|
||||
}
|
||||
|
||||
@ -76,19 +63,17 @@ impl McpClientManager {
|
||||
///
|
||||
/// This method is designed to be called asynchronously without
|
||||
/// blocking the main gateway startup flow.
|
||||
/// Takes a list of (key, config) pairs from the mcpServers map.
|
||||
pub async fn connect_all(&self, servers: Vec<(String, McpServerConfig)>) -> anyhow::Result<()> {
|
||||
for (key, config) in servers {
|
||||
if !config.is_active {
|
||||
tracing::info!(key = %key, "Skipping inactive MCP server");
|
||||
pub async fn connect_all(&self, servers: &[McpServerConfig]) -> anyhow::Result<()> {
|
||||
for server in servers {
|
||||
if !server.enabled {
|
||||
tracing::info!(name = %server.name, "Skipping disabled MCP server");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Each server connection is independent
|
||||
match self.connect_server(&key, &config).await {
|
||||
match self.connect_server(server).await {
|
||||
Ok(info) => {
|
||||
tracing::info!(
|
||||
key = %key,
|
||||
name = %info.name,
|
||||
tools_count = info.tools.len(),
|
||||
"Connected to MCP server"
|
||||
@ -97,7 +82,7 @@ impl McpClientManager {
|
||||
Err(e) => {
|
||||
// Log error but continue with other servers
|
||||
tracing::error!(
|
||||
key = %key,
|
||||
name = %server.name,
|
||||
error = %e,
|
||||
"Failed to connect to MCP server"
|
||||
);
|
||||
@ -108,17 +93,15 @@ impl McpClientManager {
|
||||
}
|
||||
|
||||
/// Connect to a single MCP server
|
||||
pub async fn connect_server(&self, key: &str, config: &McpServerConfig) -> anyhow::Result<McpServerInfo> {
|
||||
let effective_name = config.effective_name(key);
|
||||
tracing::info!(key = %key, name = %effective_name, transport_type = %config.transport_type, "Connecting to MCP server");
|
||||
pub async fn connect_server(&self, config: &McpServerConfig) -> anyhow::Result<McpServerInfo> {
|
||||
tracing::info!(name = %config.name, transport = ?config.transport, "Connecting to MCP server");
|
||||
|
||||
let transport = config.transport().map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
let client = match transport {
|
||||
let client = match &config.transport {
|
||||
McpTransportConfig::Stdio { command, args, env } => {
|
||||
self.connect_stdio(&command, &args, &env).await?
|
||||
self.connect_stdio(command, args, env).await?
|
||||
}
|
||||
McpTransportConfig::Http { url, headers } => {
|
||||
self.connect_http(&url, &headers).await?
|
||||
self.connect_http(url, headers).await?
|
||||
}
|
||||
};
|
||||
|
||||
@ -129,8 +112,7 @@ impl McpClientManager {
|
||||
let tools = client.list_all_tools().await?;
|
||||
|
||||
let server_info = McpServerInfo {
|
||||
key: key.to_string(),
|
||||
name: effective_name,
|
||||
name: config.name.clone(),
|
||||
info,
|
||||
tools,
|
||||
};
|
||||
@ -138,11 +120,11 @@ impl McpClientManager {
|
||||
// Store the client and info
|
||||
{
|
||||
let mut clients = self.clients.write().await;
|
||||
clients.insert(key.to_string(), Arc::new(client));
|
||||
clients.insert(config.name.clone(), Arc::new(client));
|
||||
}
|
||||
{
|
||||
let mut info_map = self.server_info.write().await;
|
||||
info_map.insert(key.to_string(), server_info.clone());
|
||||
info_map.insert(config.name.clone(), server_info.clone());
|
||||
}
|
||||
|
||||
Ok(server_info)
|
||||
@ -177,22 +159,8 @@ impl McpClientManager {
|
||||
url: &str,
|
||||
headers: &HashMap<String, String>,
|
||||
) -> anyhow::Result<McpClient> {
|
||||
// Resolve env placeholders in headers
|
||||
let resolved_headers: HashMap<String, String> = headers
|
||||
.iter()
|
||||
.map(|(key, value)| {
|
||||
// Resolve ${ENV_VAR} placeholders
|
||||
let resolved = if value.contains("${") {
|
||||
resolve_env_placeholders_in_value(value)
|
||||
} else {
|
||||
value.clone()
|
||||
};
|
||||
(key.clone(), resolved)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Build custom headers
|
||||
let custom_headers: HashMap<HeaderName, HeaderValue> = resolved_headers
|
||||
let custom_headers: HashMap<HeaderName, HeaderValue> = headers
|
||||
.iter()
|
||||
.filter_map(|(key, value)| {
|
||||
// Try to parse header name and value
|
||||
@ -222,50 +190,49 @@ impl McpClientManager {
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
/// Get a client by server key
|
||||
pub async fn get_client(&self, key: &str) -> Option<Arc<McpClient>> {
|
||||
/// Get a client by server name
|
||||
pub async fn get_client(&self, name: &str) -> Option<Arc<McpClient>> {
|
||||
let clients = self.clients.read().await;
|
||||
clients.get(key).cloned()
|
||||
clients.get(name).cloned()
|
||||
}
|
||||
|
||||
/// Get server info by key
|
||||
pub async fn get_server_info(&self, key: &str) -> Option<McpServerInfo> {
|
||||
/// Get server info by name
|
||||
pub async fn get_server_info(&self, name: &str) -> Option<McpServerInfo> {
|
||||
let info_map = self.server_info.read().await;
|
||||
info_map.get(key).cloned()
|
||||
info_map.get(name).cloned()
|
||||
}
|
||||
|
||||
/// Get all connected server keys
|
||||
/// Get all connected server names
|
||||
pub async fn connected_servers(&self) -> Vec<String> {
|
||||
let clients = self.clients.read().await;
|
||||
clients.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Get all tools from all connected servers
|
||||
/// Returns (server_key, tool) pairs for tool registration
|
||||
pub async fn all_tools(&self) -> Vec<(String, Tool)> {
|
||||
let info_map = self.server_info.read().await;
|
||||
info_map
|
||||
.values()
|
||||
.flat_map(|info| {
|
||||
info.tools.iter().map(|tool| (info.key.clone(), tool.clone()))
|
||||
info.tools.iter().map(|tool| (info.name.clone(), tool.clone()))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Call a tool on a specific server by key
|
||||
/// Call a tool on a specific server
|
||||
pub async fn call_tool(
|
||||
&self,
|
||||
server_key: impl Into<String>,
|
||||
server_name: impl Into<String>,
|
||||
tool_name: impl Into<String>,
|
||||
args: serde_json::Value,
|
||||
) -> anyhow::Result<CallToolResult> {
|
||||
let server_key = server_key.into();
|
||||
let server_name = server_name.into();
|
||||
let tool_name = tool_name.into();
|
||||
|
||||
let client = self
|
||||
.get_client(&server_key)
|
||||
.get_client(&server_name)
|
||||
.await
|
||||
.ok_or_else(|| anyhow::anyhow!("MCP server '{}' not connected", server_key))?;
|
||||
.ok_or_else(|| anyhow::anyhow!("MCP server '{}' not connected", server_name))?;
|
||||
|
||||
// Convert Value to JsonObject if it's an object
|
||||
let arguments = if args.is_object() {
|
||||
@ -283,22 +250,22 @@ impl McpClientManager {
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Disconnect from a server by key
|
||||
pub async fn disconnect(&self, key: impl Into<String>) -> anyhow::Result<()> {
|
||||
let key = key.into();
|
||||
/// Disconnect from a server
|
||||
pub async fn disconnect(&self, name: impl Into<String>) -> anyhow::Result<()> {
|
||||
let name = name.into();
|
||||
let mut clients = self.clients.write().await;
|
||||
if clients.remove(&key).is_some() {
|
||||
tracing::info!(key = %key, "Disconnected MCP server");
|
||||
if clients.remove(&name).is_some() {
|
||||
tracing::info!(name = %name, "Disconnected MCP server");
|
||||
}
|
||||
self.server_info.write().await.remove(&key);
|
||||
self.server_info.write().await.remove(&name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Disconnect from all servers
|
||||
pub async fn disconnect_all(&self) -> anyhow::Result<()> {
|
||||
let mut clients = self.clients.write().await;
|
||||
for (key, _client) in clients.drain() {
|
||||
tracing::info!(key = %key, "Disconnected MCP server");
|
||||
for (name, _client) in clients.drain() {
|
||||
tracing::info!(name = %name, "Disconnected MCP server");
|
||||
}
|
||||
self.server_info.write().await.clear();
|
||||
Ok(())
|
||||
@ -343,18 +310,18 @@ impl McpInitializer {
|
||||
/// This spawns a background task to connect to MCP servers,
|
||||
/// allowing the gateway to start immediately.
|
||||
pub fn with_config(config: crate::mcp::McpConfig) -> Self {
|
||||
if !config.has_active_servers() {
|
||||
if !config.has_enabled_servers() {
|
||||
return Self::disabled();
|
||||
}
|
||||
|
||||
let manager = Arc::new(McpClientManager::new());
|
||||
let servers = config.active_servers();
|
||||
let servers: Vec<_> = config.enabled_servers().into_iter().cloned().collect();
|
||||
|
||||
// Spawn background connection task
|
||||
let manager_clone = manager.clone();
|
||||
let connection_task = tokio::spawn(async move {
|
||||
tracing::info!("Starting MCP connection task...");
|
||||
manager_clone.connect_all(servers).await
|
||||
manager_clone.connect_all(&servers).await
|
||||
});
|
||||
|
||||
Self {
|
||||
@ -388,14 +355,14 @@ impl McpInitializer {
|
||||
/// Register MCP tools to the tool registry
|
||||
///
|
||||
/// This should be called after the gateway is ready to accept tools.
|
||||
/// Waits for connections to complete before registering tools.
|
||||
pub async fn register_tools(&mut self, registry: &mut crate::tools::ToolRegistry) -> anyhow::Result<()> {
|
||||
if let Some(manager) = self.manager.clone() {
|
||||
// Wait for connections to complete first
|
||||
self.wait_for_connections().await?;
|
||||
/// The method handles the case where connections are still in progress.
|
||||
pub async fn register_tools(&self, registry: &mut crate::tools::ToolRegistry) -> anyhow::Result<()> {
|
||||
if let Some(manager) = &self.manager {
|
||||
// Give a small grace period for connections if still in progress
|
||||
// This allows tools to be registered even if connection task is running
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
|
||||
tracing::info!("Registering MCP tools after connections completed");
|
||||
crate::mcp::register_mcp_tools(manager, registry).await?;
|
||||
crate::mcp::register_mcp_tools(manager.clone(), registry).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1,131 +1,79 @@
|
||||
//! MCP Server configuration structures
|
||||
//!
|
||||
//! This module provides configuration compatible with Claude Desktop/Cursor format:
|
||||
//! - Uses `mcpServers` object (HashMap) instead of array
|
||||
//! - Uses `isActive` instead of `enabled`
|
||||
//! - Uses `baseUrl` instead of `url` for HTTP
|
||||
//! - Uses `streamableHttp` type (compatible with Claude Desktop)
|
||||
//! - Supports `${ENV_VAR}` placeholder syntax for headers
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// MCP integration configuration
|
||||
///
|
||||
/// Compatible with Claude Desktop format using `mcpServers` object.
|
||||
/// Example config:
|
||||
/// ```json
|
||||
/// {
|
||||
/// "mcpServers": {
|
||||
/// "WebSearch": {
|
||||
/// "type": "streamableHttp",
|
||||
/// "baseUrl": "https://api.example.com/mcp",
|
||||
/// "headers": { "Authorization": "Bearer ${API_KEY}" },
|
||||
/// "isActive": true,
|
||||
/// "name": "WebSearch"
|
||||
/// }
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
|
||||
pub struct McpConfig {
|
||||
/// MCP servers as a map (Claude Desktop compatible format)
|
||||
/// The key is used as the server identifier
|
||||
#[serde(default, rename = "mcpServers")]
|
||||
pub mcp_servers: HashMap<String, McpServerConfig>,
|
||||
/// Whether MCP integration is enabled
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
|
||||
/// List of MCP servers to connect
|
||||
#[serde(default)]
|
||||
pub servers: Vec<McpServerConfig>,
|
||||
}
|
||||
|
||||
/// Configuration for a single MCP server
|
||||
///
|
||||
/// Supports both stdio and HTTP (streamableHttp) transports.
|
||||
/// Configuration is flattened (no nested transport object).
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct McpServerConfig {
|
||||
/// Server name (optional, defaults to the key in mcpServers)
|
||||
#[serde(default)]
|
||||
pub name: Option<String>,
|
||||
/// Unique name for this server (used in tool naming)
|
||||
pub name: String,
|
||||
|
||||
/// Transport type: "stdio" or "streamableHttp" (or "http")
|
||||
#[serde(rename = "type")]
|
||||
pub transport_type: String,
|
||||
/// Transport configuration
|
||||
pub transport: McpTransportConfig,
|
||||
|
||||
/// Whether this server is active (Claude Desktop compatible)
|
||||
#[serde(default = "default_is_active", alias = "enabled", alias = "isActive")]
|
||||
pub is_active: bool,
|
||||
|
||||
// Stdio transport fields
|
||||
/// Command to execute for stdio transport (e.g., "npx", "cargo")
|
||||
#[serde(default)]
|
||||
pub command: Option<String>,
|
||||
/// Arguments for stdio transport
|
||||
#[serde(default)]
|
||||
pub args: Option<Vec<String>>,
|
||||
/// Environment variables for stdio transport
|
||||
#[serde(default)]
|
||||
pub env: Option<HashMap<String, String>>,
|
||||
|
||||
// HTTP transport fields
|
||||
/// Base URL for HTTP transport (Claude Desktop compatible naming)
|
||||
#[serde(default, alias = "url", alias = "baseUrl")]
|
||||
pub base_url: Option<String>,
|
||||
/// Headers for HTTP transport (supports ${ENV_VAR} placeholders)
|
||||
#[serde(default)]
|
||||
pub headers: Option<HashMap<String, String>>,
|
||||
/// Whether this server is enabled
|
||||
#[serde(default = "default_server_enabled")]
|
||||
pub enabled: bool,
|
||||
|
||||
/// Optional description for the server
|
||||
#[serde(default)]
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
fn default_is_active() -> bool {
|
||||
fn default_server_enabled() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// Transport configuration for connecting to MCP servers
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum McpTransportConfig {
|
||||
/// Stdio transport: spawn a child process
|
||||
Stdio {
|
||||
/// Command to execute (e.g., "npx", "cargo")
|
||||
command: String,
|
||||
/// Arguments to pass to the command
|
||||
#[serde(default)]
|
||||
args: Vec<String>,
|
||||
/// Optional environment variables to set
|
||||
#[serde(default)]
|
||||
env: HashMap<String, String>,
|
||||
},
|
||||
|
||||
/// HTTP transport: connect to a remote server
|
||||
Http {
|
||||
/// URL of the MCP server endpoint
|
||||
url: String,
|
||||
/// Optional headers to include in requests
|
||||
#[serde(default)]
|
||||
headers: HashMap<String, String>,
|
||||
},
|
||||
}
|
||||
|
||||
impl McpServerConfig {
|
||||
/// Get the effective server name (uses key if name not specified)
|
||||
pub fn effective_name(&self, key: &str) -> String {
|
||||
self.name.clone().unwrap_or_else(|| key.to_string())
|
||||
}
|
||||
|
||||
/// Parse transport type to internal enum
|
||||
pub fn transport(&self) -> Result<McpTransportConfig, String> {
|
||||
match self.transport_type.as_str() {
|
||||
"stdio" => {
|
||||
let command = self.command.clone().unwrap_or_default();
|
||||
if command.is_empty() {
|
||||
return Err("stdio transport requires 'command' field".to_string());
|
||||
}
|
||||
Ok(McpTransportConfig::Stdio {
|
||||
command,
|
||||
args: self.args.clone().unwrap_or_default(),
|
||||
env: self.env.clone().unwrap_or_default(),
|
||||
})
|
||||
}
|
||||
"http" | "streamableHttp" => {
|
||||
let url = self.base_url.clone().unwrap_or_default();
|
||||
if url.is_empty() {
|
||||
return Err("HTTP transport requires 'baseUrl' field".to_string());
|
||||
}
|
||||
Ok(McpTransportConfig::Http {
|
||||
url,
|
||||
headers: self.headers.clone().unwrap_or_default(),
|
||||
})
|
||||
}
|
||||
other => Err(format!("unknown transport type: {}", other)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a stdio server config
|
||||
pub fn stdio(name: impl Into<String>, command: impl Into<String>, args: Vec<String>) -> Self {
|
||||
Self {
|
||||
name: Some(name.into()),
|
||||
transport_type: "stdio".to_string(),
|
||||
is_active: true,
|
||||
command: Some(command.into()),
|
||||
args: Some(args),
|
||||
env: Some(HashMap::new()),
|
||||
base_url: None,
|
||||
headers: None,
|
||||
name: name.into(),
|
||||
transport: McpTransportConfig::Stdio {
|
||||
command: command.into(),
|
||||
args,
|
||||
env: HashMap::new(),
|
||||
},
|
||||
enabled: true,
|
||||
description: None,
|
||||
}
|
||||
}
|
||||
@ -133,57 +81,26 @@ impl McpServerConfig {
|
||||
/// Create an HTTP server config
|
||||
pub fn http(name: impl Into<String>, url: impl Into<String>) -> Self {
|
||||
Self {
|
||||
name: Some(name.into()),
|
||||
transport_type: "streamableHttp".to_string(),
|
||||
is_active: true,
|
||||
command: None,
|
||||
args: None,
|
||||
env: None,
|
||||
base_url: Some(url.into()),
|
||||
headers: Some(HashMap::new()),
|
||||
name: name.into(),
|
||||
transport: McpTransportConfig::Http {
|
||||
url: url.into(),
|
||||
headers: HashMap::new(),
|
||||
},
|
||||
enabled: true,
|
||||
description: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Transport configuration for connecting to MCP servers
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum McpTransportConfig {
|
||||
/// Stdio transport: spawn a child process
|
||||
Stdio {
|
||||
command: String,
|
||||
args: Vec<String>,
|
||||
env: HashMap<String, String>,
|
||||
},
|
||||
/// HTTP transport: connect to a remote server (Streamable HTTP)
|
||||
Http {
|
||||
url: String,
|
||||
headers: HashMap<String, String>,
|
||||
},
|
||||
}
|
||||
|
||||
impl McpConfig {
|
||||
/// Get active servers as a list
|
||||
pub fn active_servers(&self) -> Vec<(String, McpServerConfig)> {
|
||||
self.mcp_servers
|
||||
.iter()
|
||||
.filter(|(_, config)| config.is_active)
|
||||
.map(|(key, config)| (key.clone(), config.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Check if there are any active servers
|
||||
pub fn has_active_servers(&self) -> bool {
|
||||
self.mcp_servers.iter().any(|(_, config)| config.is_active)
|
||||
}
|
||||
|
||||
/// Get enabled servers with resolved transport
|
||||
/// Get enabled servers
|
||||
pub fn enabled_servers(&self) -> Vec<&McpServerConfig> {
|
||||
self.mcp_servers
|
||||
.iter()
|
||||
.filter(|(_, config)| config.is_active)
|
||||
.map(|(_, config)| config)
|
||||
.collect()
|
||||
self.servers.iter().filter(|s| s.enabled).collect()
|
||||
}
|
||||
|
||||
/// Check if there are any enabled servers
|
||||
pub fn has_enabled_servers(&self) -> bool {
|
||||
self.enabled && self.servers.iter().any(|s| s.enabled)
|
||||
}
|
||||
}
|
||||
|
||||
@ -196,193 +113,61 @@ mod tests {
|
||||
let config = McpServerConfig::stdio(
|
||||
"filesystem",
|
||||
"npx",
|
||||
vec!["-y".to_string(), "@modelcontextprotocol/server-filesystem".to_string(), "/tmp".to_string()],
|
||||
vec!["-y", "@modelcontextprotocol/server-filesystem", "/tmp"],
|
||||
);
|
||||
|
||||
assert_eq!(config.name, Some("filesystem".to_string()));
|
||||
assert!(config.is_active);
|
||||
assert_eq!(config.transport_type, "stdio");
|
||||
|
||||
let transport = config.transport().unwrap();
|
||||
assert!(matches!(transport, McpTransportConfig::Stdio { .. }));
|
||||
assert_eq!(config.name, "filesystem");
|
||||
assert!(config.enabled);
|
||||
assert!(matches!(config.transport, McpTransportConfig::Stdio { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_http_config_creation() {
|
||||
let config = McpServerConfig::http("custom", "http://localhost:8000/mcp");
|
||||
|
||||
assert_eq!(config.name, Some("custom".to_string()));
|
||||
assert!(config.is_active);
|
||||
assert_eq!(config.transport_type, "streamableHttp");
|
||||
|
||||
let transport = config.transport().unwrap();
|
||||
assert!(matches!(transport, McpTransportConfig::Http { .. }));
|
||||
assert_eq!(config.name, "custom");
|
||||
assert!(config.enabled);
|
||||
assert!(matches!(config.transport, McpTransportConfig::Http { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_claude_desktop_format_deserialization() {
|
||||
// Claude Desktop/Cursor compatible format
|
||||
fn test_config_deserialization() {
|
||||
let json = r#"{
|
||||
"mcpServers": {
|
||||
"filesystem": {
|
||||
"enabled": true,
|
||||
"servers": [
|
||||
{
|
||||
"name": "filesystem",
|
||||
"transport": {
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/home/user"],
|
||||
"isActive": true
|
||||
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]
|
||||
}
|
||||
},
|
||||
"WebSearch": {
|
||||
"type": "streamableHttp",
|
||||
"baseUrl": "https://dashscope.aliyuncs.com/api/v1/mcps/WebSearch/mcp",
|
||||
{
|
||||
"name": "http-server",
|
||||
"enabled": false,
|
||||
"transport": {
|
||||
"type": "http",
|
||||
"url": "http://localhost:8000/mcp",
|
||||
"headers": {
|
||||
"Authorization": "Bearer ${DASHSCOPE_API_KEY}"
|
||||
},
|
||||
"isActive": true,
|
||||
"name": "AliyunBailianMCP_WebSearch"
|
||||
},
|
||||
"disabled-server": {
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["-y", "some-server"],
|
||||
"isActive": false
|
||||
"Authorization": "Bearer token"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}"#;
|
||||
|
||||
let config: McpConfig = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(config.mcp_servers.len(), 3);
|
||||
assert_eq!(config.active_servers().len(), 2);
|
||||
assert!(config.enabled);
|
||||
assert_eq!(config.servers.len(), 2);
|
||||
assert_eq!(config.enabled_servers().len(), 1);
|
||||
|
||||
// Check filesystem server
|
||||
let fs = config.mcp_servers.get("filesystem").unwrap();
|
||||
assert_eq!(fs.transport_type, "stdio");
|
||||
assert!(fs.is_active);
|
||||
let transport = fs.transport().unwrap();
|
||||
match transport {
|
||||
McpTransportConfig::Stdio { command, args, .. } => {
|
||||
assert_eq!(command, "npx");
|
||||
assert_eq!(args, vec![
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-filesystem",
|
||||
"/home/user"
|
||||
]);
|
||||
}
|
||||
_ => panic!("Expected stdio transport"),
|
||||
}
|
||||
let fs_server = &config.servers[0];
|
||||
assert_eq!(fs_server.name, "filesystem");
|
||||
assert!(fs_server.enabled);
|
||||
|
||||
// Check WebSearch server (streamableHttp)
|
||||
let websearch = config.mcp_servers.get("WebSearch").unwrap();
|
||||
assert_eq!(websearch.transport_type, "streamableHttp");
|
||||
assert_eq!(websearch.name, Some("AliyunBailianMCP_WebSearch".to_string()));
|
||||
assert!(websearch.is_active);
|
||||
let transport = websearch.transport().unwrap();
|
||||
match transport {
|
||||
McpTransportConfig::Http { url, headers } => {
|
||||
assert_eq!(url, "https://dashscope.aliyuncs.com/api/v1/mcps/WebSearch/mcp");
|
||||
assert_eq!(
|
||||
headers.get("Authorization"),
|
||||
Some(&"Bearer ${DASHSCOPE_API_KEY}".to_string())
|
||||
);
|
||||
}
|
||||
_ => panic!("Expected HTTP transport"),
|
||||
}
|
||||
|
||||
// Check disabled server
|
||||
let disabled = config.mcp_servers.get("disabled-server").unwrap();
|
||||
assert!(!disabled.is_active);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_effective_name_uses_key_when_name_missing() {
|
||||
let config = McpServerConfig {
|
||||
name: None,
|
||||
transport_type: "stdio".to_string(),
|
||||
is_active: true,
|
||||
command: Some("npx".to_string()),
|
||||
args: Some(vec!["-y".to_string(), "server".to_string()]),
|
||||
env: None,
|
||||
base_url: None,
|
||||
headers: None,
|
||||
description: None,
|
||||
};
|
||||
|
||||
assert_eq!(config.effective_name("my-key"), "my-key");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_effective_name_uses_name_when_specified() {
|
||||
let config = McpServerConfig {
|
||||
name: Some("MyServer".to_string()),
|
||||
transport_type: "stdio".to_string(),
|
||||
is_active: true,
|
||||
command: Some("npx".to_string()),
|
||||
args: Some(vec!["-y".to_string(), "server".to_string()]),
|
||||
env: None,
|
||||
base_url: None,
|
||||
headers: None,
|
||||
description: None,
|
||||
};
|
||||
|
||||
assert_eq!(config.effective_name("my-key"), "MyServer");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transport_validation() {
|
||||
// Missing command for stdio
|
||||
let config = McpServerConfig {
|
||||
name: Some("test".to_string()),
|
||||
transport_type: "stdio".to_string(),
|
||||
is_active: true,
|
||||
command: None,
|
||||
args: None,
|
||||
env: None,
|
||||
base_url: None,
|
||||
headers: None,
|
||||
description: None,
|
||||
};
|
||||
assert!(config.transport().is_err());
|
||||
|
||||
// Missing baseUrl for HTTP
|
||||
let config = McpServerConfig {
|
||||
name: Some("test".to_string()),
|
||||
transport_type: "streamableHttp".to_string(),
|
||||
is_active: true,
|
||||
command: None,
|
||||
args: None,
|
||||
env: None,
|
||||
base_url: None,
|
||||
headers: None,
|
||||
description: None,
|
||||
};
|
||||
assert!(config.transport().is_err());
|
||||
|
||||
// Unknown transport type
|
||||
let config = McpServerConfig {
|
||||
name: Some("test".to_string()),
|
||||
transport_type: "unknown".to_string(),
|
||||
is_active: true,
|
||||
command: Some("cmd".to_string()),
|
||||
args: None,
|
||||
env: None,
|
||||
base_url: None,
|
||||
headers: None,
|
||||
description: None,
|
||||
};
|
||||
assert!(config.transport().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_http_type_alias() {
|
||||
// Both "http" and "streamableHttp" should work
|
||||
let json_http = r#"{"mcpServers": {"test": {"type": "http", "baseUrl": "http://localhost"}}}"#;
|
||||
let json_streamable = r#"{"mcpServers": {"test": {"type": "streamableHttp", "baseUrl": "http://localhost"}}}"#;
|
||||
|
||||
let config_http: McpConfig = serde_json::from_str(json_http).unwrap();
|
||||
let config_streamable: McpConfig = serde_json::from_str(json_streamable).unwrap();
|
||||
|
||||
let transport_http = config_http.mcp_servers.get("test").unwrap().transport().unwrap();
|
||||
let transport_streamable = config_streamable.mcp_servers.get("test").unwrap().transport().unwrap();
|
||||
|
||||
assert!(matches!(transport_http, McpTransportConfig::Http { .. }));
|
||||
assert!(matches!(transport_streamable, McpTransportConfig::Http { .. }));
|
||||
let http_server = &config.servers[1];
|
||||
assert_eq!(http_server.name, "http-server");
|
||||
assert!(!http_server.enabled);
|
||||
}
|
||||
}
|
||||
@ -12,11 +12,11 @@ use crate::tools::traits::{Tool as PicoBotTool, ToolResult};
|
||||
pub struct McpToolWrapper {
|
||||
/// The MCP client manager
|
||||
manager: Arc<McpClientManager>,
|
||||
/// The server key this tool belongs to (from mcpServers map key)
|
||||
server_key: String,
|
||||
/// The server name this tool belongs to
|
||||
server_name: String,
|
||||
/// The original tool name on the MCP server
|
||||
tool_name: String,
|
||||
/// The full tool name with namespace (mcp_{key}_{tool})
|
||||
/// The full tool name with namespace (mcp_{server}_{tool})
|
||||
full_name: String,
|
||||
/// Tool information from MCP server
|
||||
tool_info: Tool,
|
||||
@ -26,23 +26,23 @@ impl McpToolWrapper {
|
||||
/// Create a new tool wrapper
|
||||
pub fn new(
|
||||
manager: Arc<McpClientManager>,
|
||||
server_key: String,
|
||||
server_name: String,
|
||||
tool_info: Tool,
|
||||
) -> Self {
|
||||
let tool_name = tool_info.name.clone().into_owned();
|
||||
let full_name = format!("mcp_{}_{}", server_key, tool_name);
|
||||
let full_name = format!("mcp_{}_{}", server_name, tool_name);
|
||||
Self {
|
||||
manager,
|
||||
server_key,
|
||||
server_name,
|
||||
tool_name,
|
||||
full_name,
|
||||
tool_info,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the server key
|
||||
pub fn server_key(&self) -> &str {
|
||||
&self.server_key
|
||||
/// Get the server name
|
||||
pub fn server_name(&self) -> &str {
|
||||
&self.server_name
|
||||
}
|
||||
|
||||
/// Get the original tool name
|
||||
@ -69,14 +69,14 @@ impl PicoBotTool for McpToolWrapper {
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
tracing::debug!(
|
||||
server_key = %self.server_key,
|
||||
server = %self.server_name,
|
||||
tool = %self.tool_name,
|
||||
"Calling MCP tool"
|
||||
);
|
||||
|
||||
let result = self
|
||||
.manager
|
||||
.call_tool(&self.server_key, &self.tool_name, args)
|
||||
.call_tool(&self.server_name, &self.tool_name, args)
|
||||
.await?;
|
||||
|
||||
// Convert MCP CallToolResult to PicoBot ToolResult
|
||||
@ -126,16 +126,16 @@ pub async fn register_mcp_tools(
|
||||
) -> anyhow::Result<()> {
|
||||
let all_tools = manager.all_tools().await;
|
||||
|
||||
for (server_key, tool_info) in all_tools {
|
||||
for (server_name, tool_info) in all_tools {
|
||||
let wrapper = McpToolWrapper::new(
|
||||
manager.clone(),
|
||||
server_key.clone(),
|
||||
server_name.clone(),
|
||||
tool_info,
|
||||
);
|
||||
|
||||
tracing::info!(
|
||||
name = %wrapper.name(),
|
||||
server_key = %server_key,
|
||||
server = %server_name,
|
||||
"Registering MCP tool"
|
||||
);
|
||||
|
||||
@ -165,24 +165,22 @@ mod tests {
|
||||
fn test_extract_text_content_empty() {
|
||||
let result = CallToolResult::success(vec![]);
|
||||
let text = extract_text_content(&result);
|
||||
// When content is empty, the function serializes the result to JSON
|
||||
// which contains an empty content array
|
||||
assert!(text.contains("content") || text.contains("Empty result"));
|
||||
assert!(text.contains("Empty result"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mcp_tool_wrapper_name() {
|
||||
let manager = Arc::new(McpClientManager::new());
|
||||
// Create a minimal tool info using rmcp's Tool constructor
|
||||
let schema: serde_json::Map<String, serde_json::Value> = serde_json::json!({"type": "object"})
|
||||
.as_object()
|
||||
.unwrap()
|
||||
.clone();
|
||||
let tool_info = Tool::new("echo", "Echo tool", schema);
|
||||
let tool_info = Tool {
|
||||
name: "echo".into(),
|
||||
description: Some("Echo tool".into()),
|
||||
input_schema: serde_json::json!({"type": "object"}).as_object().unwrap().clone(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let wrapper = McpToolWrapper::new(manager, "filesystem".to_string(), tool_info);
|
||||
assert_eq!(wrapper.name(), "mcp_filesystem_echo");
|
||||
assert_eq!(wrapper.original_name(), "echo");
|
||||
assert_eq!(wrapper.server_key(), "filesystem");
|
||||
assert_eq!(wrapper.server_name(), "filesystem");
|
||||
}
|
||||
}
|
||||
@ -1,6 +1,5 @@
|
||||
use crate::domain::messages::{ContentBlock, ToolCall};
|
||||
use crate::domain::tools::Tool;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
@ -19,23 +18,6 @@ pub struct ProviderRuntimeConfig {
|
||||
pub model_extra: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
impl From<LLMProviderConfig> for ProviderRuntimeConfig {
|
||||
fn from(config: LLMProviderConfig) -> Self {
|
||||
Self {
|
||||
provider_type: config.provider_type,
|
||||
name: config.name,
|
||||
base_url: config.base_url,
|
||||
api_key: config.api_key,
|
||||
extra_headers: config.extra_headers,
|
||||
llm_timeout_secs: config.llm_timeout_secs,
|
||||
model_id: config.model_id,
|
||||
temperature: config.temperature,
|
||||
max_tokens: config.max_tokens,
|
||||
model_extra: config.model_extra,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: String,
|
||||
|
||||
@ -462,16 +462,6 @@ impl SessionStore {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn update_topic_description(&self, topic_id: &str, description: &str) -> Result<(), StorageError> {
|
||||
let now = current_timestamp();
|
||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||
conn.execute(
|
||||
"UPDATE topics SET description = ?2, updated_at = ?3 WHERE id = ?1",
|
||||
params![topic_id, description, now],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn delete_topic(&self, topic_id: &str) -> Result<(), StorageError> {
|
||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||
// Messages 的 topic_id 会被设为 NULL(ON DELETE SET NULL)
|
||||
|
||||
@ -1,27 +0,0 @@
|
||||
use crate::providers::{ChatCompletionRequest, LLMProvider, Message};
|
||||
|
||||
pub async fn generate_topic_description(
|
||||
provider: &dyn LLMProvider,
|
||||
first_user_message: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let prompt = format!(
|
||||
"请根据用户的第一句话,用简短的词语(不超过15字)描述这个对话的主题或意图。只输出描述内容,不要其他解释。\n\n用户消息:{}",
|
||||
first_user_message
|
||||
);
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
messages: vec![Message::user(prompt)],
|
||||
temperature: Some(0.3),
|
||||
max_tokens: Some(50),
|
||||
tools: None,
|
||||
};
|
||||
|
||||
let response = provider.chat(request).await?;
|
||||
let description = response.content.trim();
|
||||
|
||||
if description.len() > 50 {
|
||||
Ok(description.chars().take(50).collect())
|
||||
} else {
|
||||
Ok(description.to_string())
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user