重构: 添加技能加载和获取工具,优化技能管理
This commit is contained in:
parent
401a7b6473
commit
ac2333900a
@ -1,20 +1,23 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::agent::{AgentLoop, AgentError, ContextCompressor};
|
||||
use crate::agent::context_compressor::ContextCompressionConfig;
|
||||
use crate::protocol::WsOutbound;
|
||||
use crate::providers::{create_provider, LLMProvider};
|
||||
use crate::session::session_id::{UnifiedSessionId, DEFAULT_DIALOG_ID};
|
||||
use crate::session::events::DialogInfo;
|
||||
use crate::skills::{Skill, SkillsLoader};
|
||||
use crate::skills::SkillsLoader;
|
||||
use crate::storage::{SessionRecord, SessionStore};
|
||||
use crate::tools::{
|
||||
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
|
||||
HttpRequestTool, ToolRegistry, WebFetchTool,
|
||||
GetSkillTool, HttpRequestTool, ToolRegistry, WebFetchTool,
|
||||
};
|
||||
|
||||
/// Generate a short ID (8 characters) from a UUID
|
||||
@ -47,6 +50,11 @@ impl Session {
|
||||
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
|
||||
let provider: Arc<dyn LLMProvider> = Arc::from(provider_box);
|
||||
|
||||
let compressor_config = ContextCompressionConfig {
|
||||
protect_first_n: 2,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
id,
|
||||
messages: Vec::new(),
|
||||
@ -54,7 +62,7 @@ impl Session {
|
||||
provider_config: provider_config.clone(),
|
||||
provider: provider.clone(),
|
||||
tools,
|
||||
compressor: ContextCompressor::new(provider.clone(), provider_config.token_limit),
|
||||
compressor: ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config),
|
||||
store,
|
||||
})
|
||||
}
|
||||
@ -179,7 +187,7 @@ pub struct SessionManager {
|
||||
provider_config: LLMProviderConfig,
|
||||
tools: Arc<ToolRegistry>,
|
||||
store: Arc<SessionStore>,
|
||||
skills: Vec<Skill>,
|
||||
skills_loader: Arc<SkillsLoader>,
|
||||
}
|
||||
|
||||
struct SessionManagerInner {
|
||||
@ -189,7 +197,7 @@ struct SessionManagerInner {
|
||||
session_ttl: Duration,
|
||||
}
|
||||
|
||||
fn default_tools() -> ToolRegistry {
|
||||
fn create_default_tools(skills_loader: Arc<SkillsLoader>) -> ToolRegistry {
|
||||
let mut registry = ToolRegistry::new();
|
||||
registry.register(CalculatorTool::new());
|
||||
registry.register(FileReadTool::new());
|
||||
@ -197,12 +205,13 @@ fn default_tools() -> ToolRegistry {
|
||||
registry.register(FileEditTool::new());
|
||||
registry.register(BashTool::new());
|
||||
registry.register(HttpRequestTool::new(
|
||||
vec!["*".to_string()], // 允许所有域名,实际使用时建议限制
|
||||
1_000_000, // max_response_size
|
||||
30, // timeout_secs
|
||||
false, // allow_private_hosts
|
||||
vec!["*".to_string()],
|
||||
1_000_000,
|
||||
30,
|
||||
false,
|
||||
));
|
||||
registry.register(WebFetchTool::new(50_000, 30)); // max_chars, timeout_secs
|
||||
registry.register(WebFetchTool::new(50_000, 30));
|
||||
registry.register(GetSkillTool::new(skills_loader));
|
||||
registry
|
||||
}
|
||||
|
||||
@ -241,9 +250,11 @@ impl SessionManager {
|
||||
.map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?,
|
||||
);
|
||||
|
||||
// Load skills from standard locations
|
||||
let skills_loader = SkillsLoader::new();
|
||||
let skills = skills_loader.load_skills();
|
||||
skills_loader.load_skills();
|
||||
let skills_loader = Arc::new(skills_loader);
|
||||
|
||||
let tools = Arc::new(create_default_tools(skills_loader.clone()));
|
||||
|
||||
Ok(Self {
|
||||
inner: Arc::new(Mutex::new(SessionManagerInner {
|
||||
@ -252,9 +263,9 @@ impl SessionManager {
|
||||
session_ttl: Duration::from_secs(session_ttl_hours * 3600),
|
||||
})),
|
||||
provider_config,
|
||||
tools: Arc::new(default_tools()),
|
||||
tools,
|
||||
store,
|
||||
skills,
|
||||
skills_loader,
|
||||
})
|
||||
}
|
||||
|
||||
@ -276,7 +287,6 @@ impl SessionManager {
|
||||
chat_id: &str,
|
||||
current_session_id: Option<&UnifiedSessionId>,
|
||||
) -> Result<(Option<UnifiedSessionId>, String), AgentError> {
|
||||
// 查找匹配的 command
|
||||
let cmd = SLASH_COMMANDS
|
||||
.iter()
|
||||
.find(|c| c.name == command)
|
||||
@ -284,7 +294,6 @@ impl SessionManager {
|
||||
|
||||
match cmd.name {
|
||||
"reset" => {
|
||||
// Archive current session if exists
|
||||
if let Some(sid) = current_session_id {
|
||||
let unified_str = sid.to_string();
|
||||
self.store
|
||||
@ -292,7 +301,6 @@ impl SessionManager {
|
||||
.map_err(|e| AgentError::Other(format!("archive session error: {}", e)))?;
|
||||
}
|
||||
|
||||
// Create new dialog
|
||||
let (new_id, _title) = self.create_session(channel, chat_id, None).await?;
|
||||
Ok((Some(new_id), "Starting a fresh conversation...".to_string()))
|
||||
}
|
||||
@ -343,20 +351,15 @@ impl SessionManager {
|
||||
pub fn clear_session_messages(&self, session_id: &str) -> Result<(), AgentError> {
|
||||
self.store
|
||||
.clear_messages(session_id)
|
||||
.map_err(|err| AgentError::Other(format!("clear session error: {}", err)))
|
||||
.map_err(|err| AgentError::Other(format!("clear session messages error: {}", err)))
|
||||
}
|
||||
|
||||
pub fn load_session_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, AgentError> {
|
||||
self.store
|
||||
.load_messages(session_id)
|
||||
.map_err(|err| AgentError::Other(format!("load messages error: {}", err)))
|
||||
.map_err(|err| AgentError::Other(format!("load session messages error: {}", err)))
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Dialog management methods (UnifiedSessionId based)
|
||||
// =========================================================================
|
||||
|
||||
/// Create a new session (dialog) and return (session_id, title)
|
||||
pub async fn create_session(
|
||||
&self,
|
||||
channel: &str,
|
||||
@ -373,12 +376,10 @@ impl SessionManager {
|
||||
.map(ToOwned::to_owned)
|
||||
.unwrap_or_else(|| format!("Dialog {}", &dialog_id));
|
||||
|
||||
// Ensure storage record exists
|
||||
self.store
|
||||
.ensure_channel_session(channel, chat_id, &dialog_id)
|
||||
.map_err(|err| AgentError::Other(format!("create session error: {}", err)))?;
|
||||
|
||||
// Create session instance
|
||||
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
|
||||
let session = Session::new(
|
||||
unified_id.clone(),
|
||||
@ -386,8 +387,7 @@ impl SessionManager {
|
||||
user_tx,
|
||||
self.tools.clone(),
|
||||
self.store.clone(),
|
||||
)
|
||||
.await?;
|
||||
).await?;
|
||||
|
||||
let arc = Arc::new(Mutex::new(session));
|
||||
let inner = &mut *self.inner.lock().await;
|
||||
@ -397,21 +397,16 @@ impl SessionManager {
|
||||
Ok((unified_id, title))
|
||||
}
|
||||
|
||||
/// Get or create a session by UnifiedSessionId
|
||||
pub async fn get_or_create_session(&self, unified_id: &UnifiedSessionId) -> Result<Arc<Mutex<Session>>, AgentError> {
|
||||
let session_id_str = unified_id.to_string();
|
||||
let inner = &mut *self.inner.lock().await;
|
||||
|
||||
// Check if session exists
|
||||
if let Some(session) = inner.sessions.get(&session_id_str) {
|
||||
// Update timestamp
|
||||
inner.session_timestamps.insert(session_id_str, Instant::now());
|
||||
return Ok(session.clone());
|
||||
}
|
||||
|
||||
// Check if session exists in storage
|
||||
if let Ok(Some(_)) = self.store.get_session(&session_id_str) {
|
||||
// Create session instance from storage
|
||||
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
|
||||
let session = Session::new(
|
||||
unified_id.clone(),
|
||||
@ -419,8 +414,7 @@ impl SessionManager {
|
||||
user_tx,
|
||||
self.tools.clone(),
|
||||
self.store.clone(),
|
||||
)
|
||||
.await?;
|
||||
).await?;
|
||||
|
||||
let arc = Arc::new(Mutex::new(session));
|
||||
inner.sessions.insert(session_id_str.clone(), arc.clone());
|
||||
@ -428,7 +422,6 @@ impl SessionManager {
|
||||
return Ok(arc);
|
||||
}
|
||||
|
||||
// Session doesn't exist - create new directly
|
||||
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
|
||||
let session = Session::new(
|
||||
unified_id.clone(),
|
||||
@ -436,8 +429,7 @@ impl SessionManager {
|
||||
user_tx,
|
||||
self.tools.clone(),
|
||||
self.store.clone(),
|
||||
)
|
||||
.await?;
|
||||
).await?;
|
||||
|
||||
let arc = Arc::new(Mutex::new(session));
|
||||
inner.sessions.insert(session_id_str.clone(), arc.clone());
|
||||
@ -445,7 +437,6 @@ impl SessionManager {
|
||||
Ok(arc)
|
||||
}
|
||||
|
||||
/// List all dialogs for a chat scope (internal)
|
||||
async fn list_dialogs_for_chat(
|
||||
&self,
|
||||
channel: &str,
|
||||
@ -459,7 +450,6 @@ impl SessionManager {
|
||||
let dialogs: Vec<DialogInfo> = records
|
||||
.into_iter()
|
||||
.filter(|r| {
|
||||
// Filter to only dialogs for this chat_id
|
||||
if let Some(sid) = UnifiedSessionId::parse(&r.id) {
|
||||
sid.chat_id == chat_id
|
||||
} else {
|
||||
@ -482,7 +472,6 @@ impl SessionManager {
|
||||
Ok(dialogs)
|
||||
}
|
||||
|
||||
/// Get the most recent dialog for a chat scope (from storage)
|
||||
pub async fn get_most_recent_dialog(
|
||||
&self,
|
||||
channel: &str,
|
||||
@ -506,14 +495,12 @@ impl SessionManager {
|
||||
Ok(most_recent.map(|r| UnifiedSessionId::parse(&r.id).unwrap()))
|
||||
}
|
||||
|
||||
/// Rename a dialog
|
||||
pub fn rename_dialog(&self, session_id: &UnifiedSessionId, title: &str) -> Result<(), AgentError> {
|
||||
self.store
|
||||
.rename_session(&session_id.to_string(), title)
|
||||
.map_err(|err| AgentError::Other(format!("rename dialog error: {}", err)))
|
||||
}
|
||||
|
||||
/// Create a new dialog (wrapper for create_session to match gateway interface)
|
||||
pub async fn create_dialog(
|
||||
&self,
|
||||
channel: &str,
|
||||
@ -523,7 +510,6 @@ impl SessionManager {
|
||||
self.create_session(channel, chat_id, title).await
|
||||
}
|
||||
|
||||
/// Get current dialog for a chat (wrapper for get_most_recent_dialog)
|
||||
pub async fn get_current_dialog(
|
||||
&self,
|
||||
channel: &str,
|
||||
@ -532,8 +518,6 @@ impl SessionManager {
|
||||
self.get_most_recent_dialog(channel, chat_id).await
|
||||
}
|
||||
|
||||
/// Switch to a different dialog - not applicable in new architecture
|
||||
/// Each Session IS a dialog, so switching is just loading that session
|
||||
pub async fn switch_dialog(
|
||||
&self,
|
||||
_channel: &str,
|
||||
@ -543,7 +527,6 @@ impl SessionManager {
|
||||
Err(AgentError::Other("switch_dialog not applicable in new architecture".to_string()))
|
||||
}
|
||||
|
||||
/// List all dialogs for a chat scope (returns tuple for gateway compatibility)
|
||||
pub async fn list_dialogs(
|
||||
&self,
|
||||
channel: &str,
|
||||
@ -555,28 +538,24 @@ impl SessionManager {
|
||||
Ok((dialogs, current.map(|id| id.to_string())))
|
||||
}
|
||||
|
||||
/// Archive a dialog
|
||||
pub fn archive_dialog(&self, session_id: &UnifiedSessionId) -> Result<(), AgentError> {
|
||||
self.store
|
||||
.archive_session(&session_id.to_string())
|
||||
.map_err(|err| AgentError::Other(format!("archive dialog error: {}", err)))
|
||||
}
|
||||
|
||||
/// Delete a dialog
|
||||
pub fn delete_dialog(&self, session_id: &UnifiedSessionId) -> Result<(), AgentError> {
|
||||
self.store
|
||||
.delete_session(&session_id.to_string())
|
||||
.map_err(|err| AgentError::Other(format!("delete dialog error: {}", err)))
|
||||
}
|
||||
|
||||
/// Clear dialog history
|
||||
pub fn clear_dialog_history(&self, session_id: &UnifiedSessionId) -> Result<(), AgentError> {
|
||||
self.store
|
||||
.clear_messages(&session_id.to_string())
|
||||
.map_err(|err| AgentError::Other(format!("clear dialog history error: {}", err)))
|
||||
}
|
||||
|
||||
/// 处理消息:路由到对应 session 的 agent
|
||||
pub async fn handle_message(
|
||||
&self,
|
||||
channel: &str,
|
||||
@ -586,21 +565,14 @@ impl SessionManager {
|
||||
content: &str,
|
||||
media: Vec<crate::bus::MediaItem>,
|
||||
) -> Result<String, AgentError> {
|
||||
// 确定 dialog_id
|
||||
let dialog_id = dialog_id.unwrap_or(DEFAULT_DIALOG_ID);
|
||||
|
||||
// 获取或创建 session
|
||||
let unified_id = UnifiedSessionId::new(channel, chat_id, dialog_id);
|
||||
let session = self.get_or_create_session(&unified_id).await?;
|
||||
|
||||
// 处理消息
|
||||
let response: String = {
|
||||
let mut session_guard = session.lock().await;
|
||||
|
||||
// 确保 session 持久化记录存在
|
||||
session_guard.ensure_persistent_session()?;
|
||||
|
||||
// 添加用户消息到历史
|
||||
let media_refs: Vec<String> = media.iter().map(|m| m.path.clone()).collect();
|
||||
#[cfg(debug_assertions)]
|
||||
if !media_refs.is_empty() {
|
||||
@ -611,32 +583,24 @@ impl SessionManager {
|
||||
session_guard.add_message(user_message.clone());
|
||||
session_guard.append_message(&user_message)?;
|
||||
|
||||
// 加载历史
|
||||
session_guard.load_history()?;
|
||||
|
||||
// 构建历史消息
|
||||
let mut history = session_guard.get_history().to_vec();
|
||||
|
||||
// Prepend skills as a system message if skills are available
|
||||
if !self.skills.is_empty() {
|
||||
let skills_prompt = SkillsLoader::build_skills_prompt_from_skills(&self.skills);
|
||||
let skills_prompt = self.skills_loader.build_skills_prompt();
|
||||
if !skills_prompt.is_empty() {
|
||||
let skills_message = ChatMessage::system(skills_prompt);
|
||||
history.insert(0, skills_message);
|
||||
tracing::debug!(skill_count = self.skills.len(), "Injected skills into context");
|
||||
}
|
||||
tracing::debug!("Injected skills into context");
|
||||
}
|
||||
|
||||
// 压缩历史(如果需要)
|
||||
let history = session_guard.compressor
|
||||
.compress_if_needed(history)
|
||||
.await?;
|
||||
|
||||
// 创建 agent 并处理
|
||||
let agent = session_guard.create_agent()?;
|
||||
let result = agent.process(history).await?;
|
||||
|
||||
// 持久化 assistant 消息
|
||||
for msg in &result.emitted_messages {
|
||||
session_guard.append_message(msg)?;
|
||||
}
|
||||
@ -655,7 +619,6 @@ impl SessionManager {
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// 清除指定 session 的所有历史
|
||||
pub async fn clear_session_history(&self, unified_id: &UnifiedSessionId) -> Result<(), AgentError> {
|
||||
let session = self.get_or_create_session(unified_id).await?;
|
||||
let mut session_guard = session.lock().await;
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::SystemTime;
|
||||
|
||||
/// Skill definition
|
||||
#[derive(Debug, Clone)]
|
||||
@ -6,17 +8,41 @@ pub struct Skill {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub content: String,
|
||||
pub always: bool,
|
||||
pub path: Option<PathBuf>,
|
||||
}
|
||||
|
||||
struct SkillMarkdownMeta {
|
||||
name: Option<String>,
|
||||
description: Option<String>,
|
||||
always: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct SkillsState {
|
||||
loaded_skills: Vec<Skill>,
|
||||
last_picobot_mtime: Option<SystemTime>,
|
||||
last_agent_mtime: Option<SystemTime>,
|
||||
last_load_time: SystemTime,
|
||||
}
|
||||
|
||||
impl Default for SkillsState {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
loaded_skills: Vec::new(),
|
||||
last_picobot_mtime: None,
|
||||
last_agent_mtime: None,
|
||||
last_load_time: SystemTime::now(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Skills loader - loads skills from multiple directories
|
||||
#[derive(Clone)]
|
||||
pub struct SkillsLoader {
|
||||
picobot_skills_dir: PathBuf,
|
||||
agent_skills_dir: PathBuf,
|
||||
state: Arc<Mutex<SkillsState>>,
|
||||
}
|
||||
|
||||
impl SkillsLoader {
|
||||
@ -26,12 +52,23 @@ impl SkillsLoader {
|
||||
Self {
|
||||
picobot_skills_dir: home.join(".picobot/skills"),
|
||||
agent_skills_dir: home.join(".agent/skills"),
|
||||
state: Arc::new(Mutex::new(SkillsState::default())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Load all skills from both directories
|
||||
pub fn load_skills(&self) -> Vec<Skill> {
|
||||
let mut skills = Vec::new();
|
||||
#[cfg(test)]
|
||||
pub(crate) fn new_for_testing(picobot_dir: PathBuf, agent_dir: PathBuf) -> Self {
|
||||
Self {
|
||||
picobot_skills_dir: picobot_dir,
|
||||
agent_skills_dir: agent_dir,
|
||||
state: Arc::new(Mutex::new(SkillsState::default())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Load all skills from both directories and record modification times
|
||||
pub fn load_skills(&self) {
|
||||
let mut state = self.state.lock().unwrap();
|
||||
state.loaded_skills.clear();
|
||||
|
||||
// Load from ~/.picobot/skills
|
||||
if self.picobot_skills_dir.exists() {
|
||||
@ -41,7 +78,8 @@ impl SkillsLoader {
|
||||
count = loaded.len(),
|
||||
"Loaded skills from picobot directory"
|
||||
);
|
||||
skills.extend(loaded);
|
||||
state.loaded_skills.extend(loaded);
|
||||
state.last_picobot_mtime = Self::get_dir_mtime(&self.picobot_skills_dir);
|
||||
}
|
||||
|
||||
// Load from ~/.agent/skills
|
||||
@ -52,16 +90,199 @@ impl SkillsLoader {
|
||||
count = loaded.len(),
|
||||
"Loaded skills from agent directory"
|
||||
);
|
||||
skills.extend(loaded);
|
||||
state.loaded_skills.extend(loaded);
|
||||
state.last_agent_mtime = Self::get_dir_mtime(&self.agent_skills_dir);
|
||||
}
|
||||
|
||||
if skills.is_empty() {
|
||||
state.last_load_time = SystemTime::now();
|
||||
|
||||
if state.loaded_skills.is_empty() {
|
||||
tracing::debug!("No skills found in any skills directory");
|
||||
} else {
|
||||
tracing::info!(count = skills.len(), "Loaded {} skills total", skills.len());
|
||||
tracing::info!(count = state.loaded_skills.len(), "Loaded {} skills total", state.loaded_skills.len());
|
||||
}
|
||||
}
|
||||
|
||||
skills
|
||||
/// Check if skills directories have been modified since last load
|
||||
fn has_changed(&self) -> bool {
|
||||
let state = self.state.lock().unwrap();
|
||||
let picobot_changed = if self.picobot_skills_dir.exists() {
|
||||
let current_mtime = Self::get_dir_mtime(&self.picobot_skills_dir);
|
||||
current_mtime != state.last_picobot_mtime
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
let agent_changed = if self.agent_skills_dir.exists() {
|
||||
let current_mtime = Self::get_dir_mtime(&self.agent_skills_dir);
|
||||
current_mtime != state.last_agent_mtime
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
picobot_changed || agent_changed
|
||||
}
|
||||
|
||||
/// Reload skills if changes are detected
|
||||
pub fn reload_if_changed(&self) -> bool {
|
||||
if self.has_changed() {
|
||||
tracing::info!("Skills directories changed, reloading...");
|
||||
self.load_skills();
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the latest modification time of a directory or any of its children
|
||||
fn get_dir_mtime(dir: &Path) -> Option<SystemTime> {
|
||||
let mut max_mtime = None;
|
||||
|
||||
if let Ok(metadata) = std::fs::metadata(dir) {
|
||||
if let Ok(mtime) = metadata.modified() {
|
||||
max_mtime = Some(mtime);
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(entries) = std::fs::read_dir(dir) {
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
if let Ok(metadata) = std::fs::metadata(&path) {
|
||||
if let Ok(mtime) = metadata.modified() {
|
||||
if max_mtime.map_or(true, |current| mtime > current) {
|
||||
max_mtime = Some(mtime);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
max_mtime
|
||||
}
|
||||
|
||||
/// Get a copy of loaded skills (checks for changes first)
|
||||
pub fn get_loaded_skills(&self) -> Vec<Skill> {
|
||||
self.reload_if_changed();
|
||||
let state = self.state.lock().unwrap();
|
||||
state.loaded_skills.clone()
|
||||
}
|
||||
|
||||
/// Get skills marked as always (checks for changes first)
|
||||
pub fn get_always_skills(&self) -> Vec<Skill> {
|
||||
self.reload_if_changed();
|
||||
let state = self.state.lock().unwrap();
|
||||
state.loaded_skills.iter().filter(|s| s.always).cloned().collect()
|
||||
}
|
||||
|
||||
/// Get a specific skill by name (checks for changes first)
|
||||
pub fn get_skill(&self, name: &str) -> Option<Skill> {
|
||||
self.reload_if_changed();
|
||||
let state = self.state.lock().unwrap();
|
||||
state.loaded_skills.iter().find(|s| s.name == name).cloned()
|
||||
}
|
||||
|
||||
/// List all skills (name + description) (checks for changes first)
|
||||
pub fn list_skills(&self) -> Vec<(String, String)> {
|
||||
self.reload_if_changed();
|
||||
let state = self.state.lock().unwrap();
|
||||
state.loaded_skills
|
||||
.iter()
|
||||
.map(|s| (s.name.clone(), s.description.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Build XML summary of all skills (for progressive disclosure) (checks for changes first)
|
||||
pub fn build_skills_summary(&self) -> String {
|
||||
self.reload_if_changed();
|
||||
let state = self.state.lock().unwrap();
|
||||
|
||||
if state.loaded_skills.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
let mut lines = vec!["<skills>".to_string()];
|
||||
|
||||
for skill in &state.loaded_skills {
|
||||
if skill.always {
|
||||
continue;
|
||||
}
|
||||
lines.push(" <skill>".to_string());
|
||||
lines.push(format!(" <name>{}</name>", escape_xml(&skill.name)));
|
||||
lines.push(format!(
|
||||
" <description>{}</description>",
|
||||
escape_xml(&skill.description)
|
||||
));
|
||||
if let Some(path) = &skill.path {
|
||||
lines.push(format!(" <path>{}</path>", escape_xml(&path.to_string_lossy())));
|
||||
}
|
||||
lines.push(" </skill>".to_string());
|
||||
}
|
||||
|
||||
lines.push("</skills>".to_string());
|
||||
lines.join("\n")
|
||||
}
|
||||
|
||||
/// Build prompt for always-injected skills (checks for changes first)
|
||||
pub fn build_always_skills_prompt(&self) -> String {
|
||||
self.reload_if_changed();
|
||||
let state = self.state.lock().unwrap();
|
||||
|
||||
let always_skills: Vec<_> = state.loaded_skills.iter().filter(|s| s.always).collect();
|
||||
if always_skills.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
let mut parts = Vec::new();
|
||||
for skill in always_skills {
|
||||
parts.push(format!("## Skill: {}\n\n{}", skill.name, skill.content));
|
||||
}
|
||||
|
||||
parts.join("\n\n---\n\n")
|
||||
}
|
||||
|
||||
/// Build full skills prompt combining always skills and summary (checks for changes first)
|
||||
pub fn build_skills_prompt(&self) -> String {
|
||||
self.reload_if_changed();
|
||||
let state = self.state.lock().unwrap();
|
||||
|
||||
let mut prompt = String::new();
|
||||
|
||||
let always_skills: Vec<_> = state.loaded_skills.iter().filter(|s| s.always).collect();
|
||||
if !always_skills.is_empty() {
|
||||
let mut parts = Vec::new();
|
||||
for skill in always_skills {
|
||||
parts.push(format!("## Skill: {}\n\n{}", skill.name, skill.content));
|
||||
}
|
||||
prompt.push_str(&parts.join("\n\n---\n\n"));
|
||||
prompt.push_str("\n\n");
|
||||
}
|
||||
|
||||
let has_other_skills = state.loaded_skills.iter().any(|s| !s.always);
|
||||
if has_other_skills {
|
||||
prompt.push_str("## Available Skills\n\n");
|
||||
prompt.push_str("Skills teach the agent how to use specific capabilities. Use the `get_skill` tool to load a skill's full content when needed.\n\n");
|
||||
|
||||
let mut lines = vec!["<skills>".to_string()];
|
||||
for skill in &state.loaded_skills {
|
||||
if skill.always {
|
||||
continue;
|
||||
}
|
||||
lines.push(" <skill>".to_string());
|
||||
lines.push(format!(" <name>{}</name>", escape_xml(&skill.name)));
|
||||
lines.push(format!(
|
||||
" <description>{}</description>",
|
||||
escape_xml(&skill.description)
|
||||
));
|
||||
if let Some(path) = &skill.path {
|
||||
lines.push(format!(" <path>{}</path>", escape_xml(&path.to_string_lossy())));
|
||||
}
|
||||
lines.push(" </skill>".to_string());
|
||||
}
|
||||
lines.push("</skills>".to_string());
|
||||
prompt.push_str(&lines.join("\n"));
|
||||
}
|
||||
|
||||
prompt
|
||||
}
|
||||
|
||||
/// Load skills from a specific directory
|
||||
@ -91,6 +312,7 @@ impl SkillsLoader {
|
||||
tracing::debug!(
|
||||
skill = %skill.name,
|
||||
path = %skill_file.display(),
|
||||
always = skill.always,
|
||||
"Loaded skill"
|
||||
);
|
||||
skills.push(skill);
|
||||
@ -116,78 +338,6 @@ impl SkillsLoader {
|
||||
skills
|
||||
}
|
||||
|
||||
/// List all skills (name + description)
|
||||
pub fn list_skills(&self) -> Vec<(String, String)> {
|
||||
self.load_skills()
|
||||
.into_iter()
|
||||
.map(|s| (s.name, s.description))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get a specific skill by name
|
||||
pub fn get_skill(&self, name: &str) -> Option<Skill> {
|
||||
// Check picobot_skills first
|
||||
let picobot_path = self.picobot_skills_dir.join(name).join("SKILL.md");
|
||||
if picobot_path.exists() {
|
||||
if let Ok(content) = std::fs::read_to_string(&picobot_path) {
|
||||
let dir = self.picobot_skills_dir.join(name);
|
||||
return self.parse_skill(&dir, &content);
|
||||
}
|
||||
}
|
||||
|
||||
// Check agent_skills
|
||||
let agent_path = self.agent_skills_dir.join(name).join("SKILL.md");
|
||||
if agent_path.exists() {
|
||||
if let Ok(content) = std::fs::read_to_string(&agent_path) {
|
||||
let dir = self.agent_skills_dir.join(name);
|
||||
return self.parse_skill(&dir, &content);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Build skills prompt for agent context (reloads from disk)
|
||||
pub fn build_skills_prompt(&self) -> String {
|
||||
let skills = self.load_skills();
|
||||
Self::format_skills_prompt(&skills)
|
||||
}
|
||||
|
||||
/// Build skills prompt from already-loaded skills (no disk I/O)
|
||||
pub fn build_skills_prompt_from_skills(skills: &[Skill]) -> String {
|
||||
Self::format_skills_prompt(skills)
|
||||
}
|
||||
|
||||
/// Format skills into a prompt string
|
||||
fn format_skills_prompt(skills: &[Skill]) -> String {
|
||||
if skills.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
let mut prompt = String::from("## Available Skills\n\n");
|
||||
prompt.push_str("Skills teach the agent how to use specific capabilities.\n\n");
|
||||
prompt.push_str("<skills>\n");
|
||||
|
||||
for skill in skills {
|
||||
prompt.push_str(" <skill>\n");
|
||||
prompt.push_str(&format!(" <name>{}</name>\n", escape_xml(&skill.name)));
|
||||
prompt.push_str(&format!(
|
||||
" <description>{}</description>\n",
|
||||
escape_xml(&skill.description)
|
||||
));
|
||||
prompt.push_str(" <instructions>\n");
|
||||
prompt.push_str(&format!(
|
||||
" <instruction>{}</instruction>\n",
|
||||
escape_xml(&skill.content)
|
||||
));
|
||||
prompt.push_str(" </instructions>\n");
|
||||
prompt.push_str(" </skill>\n");
|
||||
}
|
||||
|
||||
prompt.push_str("</skills>\n");
|
||||
prompt
|
||||
}
|
||||
|
||||
/// Parse a skill from markdown content
|
||||
fn parse_skill(&self, dir: &Path, content: &str) -> Option<Skill> {
|
||||
let (meta, body) = self.parse_skill_markdown(content);
|
||||
@ -206,6 +356,8 @@ impl SkillsLoader {
|
||||
name,
|
||||
description,
|
||||
content: body,
|
||||
always: meta.always.unwrap_or(false),
|
||||
path: Some(dir.to_path_buf()),
|
||||
})
|
||||
}
|
||||
|
||||
@ -242,6 +394,13 @@ impl SkillsLoader {
|
||||
match key {
|
||||
"name" => meta.name = Some(val.to_string()),
|
||||
"description" => meta.description = Some(val.to_string()),
|
||||
"always" => {
|
||||
meta.always = match val.to_lowercase().as_str() {
|
||||
"true" | "1" | "yes" | "on" => Some(true),
|
||||
"false" | "0" | "no" | "off" => Some(false),
|
||||
_ => None,
|
||||
};
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
@ -261,6 +420,7 @@ impl Default for SkillMarkdownMeta {
|
||||
Self {
|
||||
name: None,
|
||||
description: None,
|
||||
always: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -311,6 +471,7 @@ mod tests {
|
||||
let content = r#"---
|
||||
name: test-skill
|
||||
description: A test skill
|
||||
always: true
|
||||
---
|
||||
# Test Skill
|
||||
|
||||
@ -321,6 +482,7 @@ This is the content.
|
||||
|
||||
assert_eq!(meta.name, Some("test-skill".to_string()));
|
||||
assert_eq!(meta.description, Some("A test skill".to_string()));
|
||||
assert_eq!(meta.always, Some(true));
|
||||
assert!(body.contains("Test Skill"));
|
||||
}
|
||||
|
||||
@ -339,12 +501,4 @@ This is the content.
|
||||
);
|
||||
assert_eq!(extract_description("# Title"), "No description");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_skills_from_empty_dir() {
|
||||
let loader = SkillsLoader::new();
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
let skills = loader.load_skills_from_dir(temp_dir.path());
|
||||
assert!(skills.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
159
src/tools/get_skill.rs
Normal file
159
src/tools/get_skill.rs
Normal file
@ -0,0 +1,159 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::skills::{Skill, SkillsLoader};
|
||||
use crate::tools::traits::{Tool, ToolResult};
|
||||
|
||||
pub struct GetSkillTool {
|
||||
skills_loader: Arc<SkillsLoader>,
|
||||
}
|
||||
|
||||
impl GetSkillTool {
|
||||
pub fn new(skills_loader: Arc<SkillsLoader>) -> Self {
|
||||
Self { skills_loader }
|
||||
}
|
||||
|
||||
fn format_skill(&self, skill: &Skill) -> String {
|
||||
let mut result = format!("# Skill: {}\n\n{}", skill.name, skill.description);
|
||||
|
||||
if let Some(path) = &skill.path {
|
||||
result.push_str(&format!(
|
||||
"\n\n**Skill Root Directory:** `{}`\n\nAll files and references in this skill are relative to this directory.",
|
||||
path.to_string_lossy()
|
||||
));
|
||||
}
|
||||
|
||||
result.push_str(&format!("\n\n---\n\n{}", skill.content));
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for GetSkillTool {
|
||||
fn name(&self) -> &str {
|
||||
"get_skill"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Get complete content and guidance for a specified skill. Use this when you need detailed instructions for a specific type of task."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"skill_name": {
|
||||
"type": "string",
|
||||
"description": "Name of the skill to retrieve"
|
||||
}
|
||||
},
|
||||
"required": ["skill_name"]
|
||||
})
|
||||
}
|
||||
|
||||
fn read_only(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let skill_name = match args.get("skill_name").and_then(|v| v.as_str()) {
|
||||
Some(name) => name,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Missing required parameter: skill_name".to_string()),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
match self.skills_loader.get_skill(skill_name) {
|
||||
Some(skill) => {
|
||||
let formatted = self.format_skill(&skill);
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: formatted,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
None => {
|
||||
let available = self.skills_loader.list_skills();
|
||||
let available_str = if available.is_empty() {
|
||||
"No skills available".to_string()
|
||||
} else {
|
||||
available
|
||||
.iter()
|
||||
.map(|(name, _)| name.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
};
|
||||
Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Skill '{}' not found. Available skills: {}",
|
||||
skill_name, available_str
|
||||
)),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::tempdir;
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_existing_skill() {
|
||||
let temp_dir = tempdir().unwrap();
|
||||
|
||||
let skill_dir = temp_dir.path().join("test-skill");
|
||||
std::fs::create_dir(&skill_dir).unwrap();
|
||||
|
||||
let mut skill_file = File::create(skill_dir.join("SKILL.md")).unwrap();
|
||||
writeln!(skill_file, "---").unwrap();
|
||||
writeln!(skill_file, "name: test-skill").unwrap();
|
||||
writeln!(skill_file, "description: A test skill").unwrap();
|
||||
writeln!(skill_file, "---").unwrap();
|
||||
writeln!(skill_file, "# Test Skill").unwrap();
|
||||
writeln!(skill_file, "This is the test content.").unwrap();
|
||||
|
||||
let mut loader = SkillsLoader::new_for_testing(
|
||||
temp_dir.path().to_path_buf(),
|
||||
PathBuf::from("/nonexistent"),
|
||||
);
|
||||
loader.load_skills();
|
||||
|
||||
let tool = GetSkillTool::new(Arc::new(loader));
|
||||
|
||||
let result = tool
|
||||
.execute(json!({ "skill_name": "test-skill" }))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("test-skill"));
|
||||
assert!(result.output.contains("test content"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_nonexistent_skill() {
|
||||
let loader = SkillsLoader::new();
|
||||
let tool = GetSkillTool::new(Arc::new(loader));
|
||||
|
||||
let result = tool
|
||||
.execute(json!({ "skill_name": "nonexistent" }))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.is_some());
|
||||
}
|
||||
}
|
||||
@ -3,6 +3,7 @@ pub mod calculator;
|
||||
pub mod file_edit;
|
||||
pub mod file_read;
|
||||
pub mod file_write;
|
||||
pub mod get_skill;
|
||||
pub mod http_request;
|
||||
pub mod registry;
|
||||
pub mod schema;
|
||||
@ -14,6 +15,7 @@ pub use calculator::CalculatorTool;
|
||||
pub use file_edit::FileEditTool;
|
||||
pub use file_read::FileReadTool;
|
||||
pub use file_write::FileWriteTool;
|
||||
pub use get_skill::GetSkillTool;
|
||||
pub use http_request::HttpRequestTool;
|
||||
pub use registry::ToolRegistry;
|
||||
pub use schema::{CleaningStrategy, SchemaCleanr};
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user