141 lines
3.7 KiB
Rust
141 lines
3.7 KiB
Rust
use std::collections::HashMap;
|
|
use std::sync::{Arc, Mutex};
|
|
|
|
use crate::providers::{Tool, ToolFunction};
|
|
|
|
use super::traits::Tool as ToolTrait;
|
|
|
|
pub struct ToolRegistry {
|
|
tools: Mutex<HashMap<String, Arc<dyn ToolTrait>>>,
|
|
}
|
|
|
|
impl ToolRegistry {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
tools: Mutex::new(HashMap::new()),
|
|
}
|
|
}
|
|
|
|
pub fn register<T: ToolTrait + 'static>(&self, tool: T) {
|
|
self.tools
|
|
.lock()
|
|
.unwrap()
|
|
.insert(tool.name().to_string(), Arc::new(tool));
|
|
}
|
|
|
|
/// Register an existing Arc-wrapped tool by name
|
|
pub fn register_raw(&self, name: String, tool: Arc<dyn ToolTrait>) {
|
|
self.tools.lock().unwrap().insert(name, tool);
|
|
}
|
|
|
|
pub fn get(&self, name: &str) -> Option<Arc<dyn ToolTrait>> {
|
|
self.tools.lock().unwrap().get(name).cloned()
|
|
}
|
|
|
|
/// Get all registered tools.
|
|
/// Used for concurrent tool execution when we need to look up tools by name.
|
|
pub fn get_all(&self) -> Vec<Arc<dyn ToolTrait>> {
|
|
self.tools.lock().unwrap().values().cloned().collect()
|
|
}
|
|
|
|
pub fn get_definitions(&self) -> Vec<Tool> {
|
|
let mut defs: Vec<Tool> = self
|
|
.tools
|
|
.lock()
|
|
.unwrap()
|
|
.values()
|
|
.map(|tool| Tool {
|
|
tool_type: "function".to_string(),
|
|
function: ToolFunction {
|
|
name: tool.name().to_string(),
|
|
description: tool.description().to_string(),
|
|
parameters: tool.parameters_schema(),
|
|
},
|
|
})
|
|
.collect();
|
|
|
|
defs.sort_by(|a, b| a.function.name.cmp(&b.function.name));
|
|
defs
|
|
}
|
|
|
|
pub fn has_tools(&self) -> bool {
|
|
!self.tools.lock().unwrap().is_empty()
|
|
}
|
|
|
|
pub fn tool_names(&self) -> Vec<String> {
|
|
self.tools.lock().unwrap().keys().cloned().collect()
|
|
}
|
|
|
|
pub fn iter(&self) -> Vec<(String, Arc<dyn ToolTrait>)> {
|
|
self.tools
|
|
.lock()
|
|
.unwrap()
|
|
.iter()
|
|
.map(|(k, v)| (k.clone(), v.clone()))
|
|
.collect()
|
|
}
|
|
|
|
/// 生成工具列表描述,用于子 Agent 系统提示词
|
|
pub fn describe_for_prompt(&self) -> String {
|
|
let mut entries: Vec<String> = self
|
|
.iter()
|
|
.into_iter()
|
|
.map(|(name, tool)| format!("- {}: {}", name, tool.description()))
|
|
.collect();
|
|
entries.sort();
|
|
entries.join("\n")
|
|
}
|
|
}
|
|
|
|
impl Default for ToolRegistry {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::tools::traits::ToolResult;
|
|
use async_trait::async_trait;
|
|
use serde_json::json;
|
|
|
|
struct TestTool(&'static str);
|
|
|
|
#[async_trait]
|
|
impl ToolTrait for TestTool {
|
|
fn name(&self) -> &str {
|
|
self.0
|
|
}
|
|
|
|
fn description(&self) -> &str {
|
|
self.0
|
|
}
|
|
|
|
fn parameters_schema(&self) -> serde_json::Value {
|
|
json!({})
|
|
}
|
|
|
|
async fn execute(&self, _args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
|
Ok(ToolResult {
|
|
success: true,
|
|
output: "ok".to_string(),
|
|
error: None,
|
|
})
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_get_definitions_sorted_by_name() {
|
|
let registry = ToolRegistry::new();
|
|
registry.register(TestTool("zeta"));
|
|
registry.register(TestTool("alpha"));
|
|
registry.register(TestTool("beta"));
|
|
|
|
let defs = registry.get_definitions();
|
|
let names: Vec<_> = defs.into_iter().map(|tool| tool.function.name).collect();
|
|
|
|
assert_eq!(names, vec!["alpha", "beta", "zeta"]);
|
|
}
|
|
}
|