PicoBot/src/tools/registry.rs

141 lines
3.7 KiB
Rust

use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use crate::providers::{Tool, ToolFunction};
use super::traits::Tool as ToolTrait;
pub struct ToolRegistry {
tools: Mutex<HashMap<String, Arc<dyn ToolTrait>>>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: Mutex::new(HashMap::new()),
}
}
pub fn register<T: ToolTrait + 'static>(&self, tool: T) {
self.tools
.lock()
.unwrap()
.insert(tool.name().to_string(), Arc::new(tool));
}
/// Register an existing Arc-wrapped tool by name
pub fn register_raw(&self, name: String, tool: Arc<dyn ToolTrait>) {
self.tools.lock().unwrap().insert(name, tool);
}
pub fn get(&self, name: &str) -> Option<Arc<dyn ToolTrait>> {
self.tools.lock().unwrap().get(name).cloned()
}
/// Get all registered tools.
/// Used for concurrent tool execution when we need to look up tools by name.
pub fn get_all(&self) -> Vec<Arc<dyn ToolTrait>> {
self.tools.lock().unwrap().values().cloned().collect()
}
pub fn get_definitions(&self) -> Vec<Tool> {
let mut defs: Vec<Tool> = self
.tools
.lock()
.unwrap()
.values()
.map(|tool| Tool {
tool_type: "function".to_string(),
function: ToolFunction {
name: tool.name().to_string(),
description: tool.description().to_string(),
parameters: tool.parameters_schema(),
},
})
.collect();
defs.sort_by(|a, b| a.function.name.cmp(&b.function.name));
defs
}
pub fn has_tools(&self) -> bool {
!self.tools.lock().unwrap().is_empty()
}
pub fn tool_names(&self) -> Vec<String> {
self.tools.lock().unwrap().keys().cloned().collect()
}
pub fn iter(&self) -> Vec<(String, Arc<dyn ToolTrait>)> {
self.tools
.lock()
.unwrap()
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect()
}
/// 生成工具列表描述,用于子 Agent 系统提示词
pub fn describe_for_prompt(&self) -> String {
let mut entries: Vec<String> = self
.iter()
.into_iter()
.map(|(name, tool)| format!("- {}: {}", name, tool.description()))
.collect();
entries.sort();
entries.join("\n")
}
}
impl Default for ToolRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tools::traits::ToolResult;
use async_trait::async_trait;
use serde_json::json;
struct TestTool(&'static str);
#[async_trait]
impl ToolTrait for TestTool {
fn name(&self) -> &str {
self.0
}
fn description(&self) -> &str {
self.0
}
fn parameters_schema(&self) -> serde_json::Value {
json!({})
}
async fn execute(&self, _args: serde_json::Value) -> anyhow::Result<ToolResult> {
Ok(ToolResult {
success: true,
output: "ok".to_string(),
error: None,
})
}
}
#[test]
fn test_get_definitions_sorted_by_name() {
let registry = ToolRegistry::new();
registry.register(TestTool("zeta"));
registry.register(TestTool("alpha"));
registry.register(TestTool("beta"));
let defs = registry.get_definitions();
let names: Vec<_> = defs.into_iter().map(|tool| tool.function.name).collect();
assert_eq!(names, vec!["alpha", "beta", "zeta"]);
}
}