diff --git a/src/session/session.rs b/src/session/session.rs index d47faaa..b201e16 100644 --- a/src/session/session.rs +++ b/src/session/session.rs @@ -1,20 +1,23 @@ use std::collections::HashMap; use std::sync::Arc; use std::time::{Duration, Instant}; + use tokio::sync::{Mutex, mpsc}; use uuid::Uuid; + use crate::bus::ChatMessage; use crate::config::LLMProviderConfig; use crate::agent::{AgentLoop, AgentError, ContextCompressor}; +use crate::agent::context_compressor::ContextCompressionConfig; use crate::protocol::WsOutbound; use crate::providers::{create_provider, LLMProvider}; use crate::session::session_id::{UnifiedSessionId, DEFAULT_DIALOG_ID}; use crate::session::events::DialogInfo; -use crate::skills::{Skill, SkillsLoader}; +use crate::skills::SkillsLoader; use crate::storage::{SessionRecord, SessionStore}; use crate::tools::{ BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, - HttpRequestTool, ToolRegistry, WebFetchTool, + GetSkillTool, HttpRequestTool, ToolRegistry, WebFetchTool, }; /// Generate a short ID (8 characters) from a UUID @@ -47,6 +50,11 @@ impl Session { .map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?; let provider: Arc = Arc::from(provider_box); + let compressor_config = ContextCompressionConfig { + protect_first_n: 2, + ..Default::default() + }; + Ok(Self { id, messages: Vec::new(), @@ -54,7 +62,7 @@ impl Session { provider_config: provider_config.clone(), provider: provider.clone(), tools, - compressor: ContextCompressor::new(provider.clone(), provider_config.token_limit), + compressor: ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config), store, }) } @@ -179,7 +187,7 @@ pub struct SessionManager { provider_config: LLMProviderConfig, tools: Arc, store: Arc, - skills: Vec, + skills_loader: Arc, } struct SessionManagerInner { @@ -189,7 +197,7 @@ struct SessionManagerInner { session_ttl: Duration, } -fn default_tools() -> ToolRegistry { +fn create_default_tools(skills_loader: Arc) -> ToolRegistry { let mut registry = ToolRegistry::new(); registry.register(CalculatorTool::new()); registry.register(FileReadTool::new()); @@ -197,12 +205,13 @@ fn default_tools() -> ToolRegistry { registry.register(FileEditTool::new()); registry.register(BashTool::new()); registry.register(HttpRequestTool::new( - vec!["*".to_string()], // 允许所有域名,实际使用时建议限制 - 1_000_000, // max_response_size - 30, // timeout_secs - false, // allow_private_hosts + vec!["*".to_string()], + 1_000_000, + 30, + false, )); - registry.register(WebFetchTool::new(50_000, 30)); // max_chars, timeout_secs + registry.register(WebFetchTool::new(50_000, 30)); + registry.register(GetSkillTool::new(skills_loader)); registry } @@ -241,9 +250,11 @@ impl SessionManager { .map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?, ); - // Load skills from standard locations let skills_loader = SkillsLoader::new(); - let skills = skills_loader.load_skills(); + skills_loader.load_skills(); + let skills_loader = Arc::new(skills_loader); + + let tools = Arc::new(create_default_tools(skills_loader.clone())); Ok(Self { inner: Arc::new(Mutex::new(SessionManagerInner { @@ -252,9 +263,9 @@ impl SessionManager { session_ttl: Duration::from_secs(session_ttl_hours * 3600), })), provider_config, - tools: Arc::new(default_tools()), + tools, store, - skills, + skills_loader, }) } @@ -276,7 +287,6 @@ impl SessionManager { chat_id: &str, current_session_id: Option<&UnifiedSessionId>, ) -> Result<(Option, String), AgentError> { - // 查找匹配的 command let cmd = SLASH_COMMANDS .iter() .find(|c| c.name == command) @@ -284,7 +294,6 @@ impl SessionManager { match cmd.name { "reset" => { - // Archive current session if exists if let Some(sid) = current_session_id { let unified_str = sid.to_string(); self.store @@ -292,7 +301,6 @@ impl SessionManager { .map_err(|e| AgentError::Other(format!("archive session error: {}", e)))?; } - // Create new dialog let (new_id, _title) = self.create_session(channel, chat_id, None).await?; Ok((Some(new_id), "Starting a fresh conversation...".to_string())) } @@ -343,20 +351,15 @@ impl SessionManager { pub fn clear_session_messages(&self, session_id: &str) -> Result<(), AgentError> { self.store .clear_messages(session_id) - .map_err(|err| AgentError::Other(format!("clear session error: {}", err))) + .map_err(|err| AgentError::Other(format!("clear session messages error: {}", err))) } pub fn load_session_messages(&self, session_id: &str) -> Result, AgentError> { self.store .load_messages(session_id) - .map_err(|err| AgentError::Other(format!("load messages error: {}", err))) + .map_err(|err| AgentError::Other(format!("load session messages error: {}", err))) } - // ========================================================================= - // Dialog management methods (UnifiedSessionId based) - // ========================================================================= - - /// Create a new session (dialog) and return (session_id, title) pub async fn create_session( &self, channel: &str, @@ -373,12 +376,10 @@ impl SessionManager { .map(ToOwned::to_owned) .unwrap_or_else(|| format!("Dialog {}", &dialog_id)); - // Ensure storage record exists self.store .ensure_channel_session(channel, chat_id, &dialog_id) .map_err(|err| AgentError::Other(format!("create session error: {}", err)))?; - // Create session instance let (user_tx, _rx) = mpsc::channel::(100); let session = Session::new( unified_id.clone(), @@ -386,8 +387,7 @@ impl SessionManager { user_tx, self.tools.clone(), self.store.clone(), - ) - .await?; + ).await?; let arc = Arc::new(Mutex::new(session)); let inner = &mut *self.inner.lock().await; @@ -397,21 +397,16 @@ impl SessionManager { Ok((unified_id, title)) } - /// Get or create a session by UnifiedSessionId pub async fn get_or_create_session(&self, unified_id: &UnifiedSessionId) -> Result>, AgentError> { let session_id_str = unified_id.to_string(); let inner = &mut *self.inner.lock().await; - // Check if session exists if let Some(session) = inner.sessions.get(&session_id_str) { - // Update timestamp inner.session_timestamps.insert(session_id_str, Instant::now()); return Ok(session.clone()); } - // Check if session exists in storage if let Ok(Some(_)) = self.store.get_session(&session_id_str) { - // Create session instance from storage let (user_tx, _rx) = mpsc::channel::(100); let session = Session::new( unified_id.clone(), @@ -419,8 +414,7 @@ impl SessionManager { user_tx, self.tools.clone(), self.store.clone(), - ) - .await?; + ).await?; let arc = Arc::new(Mutex::new(session)); inner.sessions.insert(session_id_str.clone(), arc.clone()); @@ -428,7 +422,6 @@ impl SessionManager { return Ok(arc); } - // Session doesn't exist - create new directly let (user_tx, _rx) = mpsc::channel::(100); let session = Session::new( unified_id.clone(), @@ -436,8 +429,7 @@ impl SessionManager { user_tx, self.tools.clone(), self.store.clone(), - ) - .await?; + ).await?; let arc = Arc::new(Mutex::new(session)); inner.sessions.insert(session_id_str.clone(), arc.clone()); @@ -445,7 +437,6 @@ impl SessionManager { Ok(arc) } - /// List all dialogs for a chat scope (internal) async fn list_dialogs_for_chat( &self, channel: &str, @@ -459,7 +450,6 @@ impl SessionManager { let dialogs: Vec = records .into_iter() .filter(|r| { - // Filter to only dialogs for this chat_id if let Some(sid) = UnifiedSessionId::parse(&r.id) { sid.chat_id == chat_id } else { @@ -482,7 +472,6 @@ impl SessionManager { Ok(dialogs) } - /// Get the most recent dialog for a chat scope (from storage) pub async fn get_most_recent_dialog( &self, channel: &str, @@ -506,14 +495,12 @@ impl SessionManager { Ok(most_recent.map(|r| UnifiedSessionId::parse(&r.id).unwrap())) } - /// Rename a dialog pub fn rename_dialog(&self, session_id: &UnifiedSessionId, title: &str) -> Result<(), AgentError> { self.store .rename_session(&session_id.to_string(), title) .map_err(|err| AgentError::Other(format!("rename dialog error: {}", err))) } - /// Create a new dialog (wrapper for create_session to match gateway interface) pub async fn create_dialog( &self, channel: &str, @@ -523,7 +510,6 @@ impl SessionManager { self.create_session(channel, chat_id, title).await } - /// Get current dialog for a chat (wrapper for get_most_recent_dialog) pub async fn get_current_dialog( &self, channel: &str, @@ -532,8 +518,6 @@ impl SessionManager { self.get_most_recent_dialog(channel, chat_id).await } - /// Switch to a different dialog - not applicable in new architecture - /// Each Session IS a dialog, so switching is just loading that session pub async fn switch_dialog( &self, _channel: &str, @@ -543,7 +527,6 @@ impl SessionManager { Err(AgentError::Other("switch_dialog not applicable in new architecture".to_string())) } - /// List all dialogs for a chat scope (returns tuple for gateway compatibility) pub async fn list_dialogs( &self, channel: &str, @@ -555,28 +538,24 @@ impl SessionManager { Ok((dialogs, current.map(|id| id.to_string()))) } - /// Archive a dialog pub fn archive_dialog(&self, session_id: &UnifiedSessionId) -> Result<(), AgentError> { self.store .archive_session(&session_id.to_string()) .map_err(|err| AgentError::Other(format!("archive dialog error: {}", err))) } - /// Delete a dialog pub fn delete_dialog(&self, session_id: &UnifiedSessionId) -> Result<(), AgentError> { self.store .delete_session(&session_id.to_string()) .map_err(|err| AgentError::Other(format!("delete dialog error: {}", err))) } - /// Clear dialog history pub fn clear_dialog_history(&self, session_id: &UnifiedSessionId) -> Result<(), AgentError> { self.store .clear_messages(&session_id.to_string()) .map_err(|err| AgentError::Other(format!("clear dialog history error: {}", err))) } - /// 处理消息:路由到对应 session 的 agent pub async fn handle_message( &self, channel: &str, @@ -586,21 +565,14 @@ impl SessionManager { content: &str, media: Vec, ) -> Result { - // 确定 dialog_id let dialog_id = dialog_id.unwrap_or(DEFAULT_DIALOG_ID); - - // 获取或创建 session let unified_id = UnifiedSessionId::new(channel, chat_id, dialog_id); let session = self.get_or_create_session(&unified_id).await?; - // 处理消息 let response: String = { let mut session_guard = session.lock().await; - - // 确保 session 持久化记录存在 session_guard.ensure_persistent_session()?; - // 添加用户消息到历史 let media_refs: Vec = media.iter().map(|m| m.path.clone()).collect(); #[cfg(debug_assertions)] if !media_refs.is_empty() { @@ -611,32 +583,24 @@ impl SessionManager { session_guard.add_message(user_message.clone()); session_guard.append_message(&user_message)?; - // 加载历史 session_guard.load_history()?; - // 构建历史消息 let mut history = session_guard.get_history().to_vec(); - // Prepend skills as a system message if skills are available - if !self.skills.is_empty() { - let skills_prompt = SkillsLoader::build_skills_prompt_from_skills(&self.skills); - if !skills_prompt.is_empty() { - let skills_message = ChatMessage::system(skills_prompt); - history.insert(0, skills_message); - tracing::debug!(skill_count = self.skills.len(), "Injected skills into context"); - } + let skills_prompt = self.skills_loader.build_skills_prompt(); + if !skills_prompt.is_empty() { + let skills_message = ChatMessage::system(skills_prompt); + history.insert(0, skills_message); + tracing::debug!("Injected skills into context"); } - // 压缩历史(如果需要) let history = session_guard.compressor .compress_if_needed(history) .await?; - // 创建 agent 并处理 let agent = session_guard.create_agent()?; let result = agent.process(history).await?; - // 持久化 assistant 消息 for msg in &result.emitted_messages { session_guard.append_message(msg)?; } @@ -655,7 +619,6 @@ impl SessionManager { Ok(response) } - /// 清除指定 session 的所有历史 pub async fn clear_session_history(&self, unified_id: &UnifiedSessionId) -> Result<(), AgentError> { let session = self.get_or_create_session(unified_id).await?; let mut session_guard = session.lock().await; diff --git a/src/skills/mod.rs b/src/skills/mod.rs index 6ffe844..c4030be 100644 --- a/src/skills/mod.rs +++ b/src/skills/mod.rs @@ -1,4 +1,6 @@ use std::path::{Path, PathBuf}; +use std::sync::{Arc, Mutex}; +use std::time::SystemTime; /// Skill definition #[derive(Debug, Clone)] @@ -6,17 +8,41 @@ pub struct Skill { pub name: String, pub description: String, pub content: String, + pub always: bool, + pub path: Option, } struct SkillMarkdownMeta { name: Option, description: Option, + always: Option, +} + +#[derive(Clone)] +struct SkillsState { + loaded_skills: Vec, + last_picobot_mtime: Option, + last_agent_mtime: Option, + last_load_time: SystemTime, +} + +impl Default for SkillsState { + fn default() -> Self { + Self { + loaded_skills: Vec::new(), + last_picobot_mtime: None, + last_agent_mtime: None, + last_load_time: SystemTime::now(), + } + } } /// Skills loader - loads skills from multiple directories +#[derive(Clone)] pub struct SkillsLoader { picobot_skills_dir: PathBuf, agent_skills_dir: PathBuf, + state: Arc>, } impl SkillsLoader { @@ -26,12 +52,23 @@ impl SkillsLoader { Self { picobot_skills_dir: home.join(".picobot/skills"), agent_skills_dir: home.join(".agent/skills"), + state: Arc::new(Mutex::new(SkillsState::default())), } } - /// Load all skills from both directories - pub fn load_skills(&self) -> Vec { - let mut skills = Vec::new(); + #[cfg(test)] + pub(crate) fn new_for_testing(picobot_dir: PathBuf, agent_dir: PathBuf) -> Self { + Self { + picobot_skills_dir: picobot_dir, + agent_skills_dir: agent_dir, + state: Arc::new(Mutex::new(SkillsState::default())), + } + } + + /// Load all skills from both directories and record modification times + pub fn load_skills(&self) { + let mut state = self.state.lock().unwrap(); + state.loaded_skills.clear(); // Load from ~/.picobot/skills if self.picobot_skills_dir.exists() { @@ -41,7 +78,8 @@ impl SkillsLoader { count = loaded.len(), "Loaded skills from picobot directory" ); - skills.extend(loaded); + state.loaded_skills.extend(loaded); + state.last_picobot_mtime = Self::get_dir_mtime(&self.picobot_skills_dir); } // Load from ~/.agent/skills @@ -52,16 +90,199 @@ impl SkillsLoader { count = loaded.len(), "Loaded skills from agent directory" ); - skills.extend(loaded); + state.loaded_skills.extend(loaded); + state.last_agent_mtime = Self::get_dir_mtime(&self.agent_skills_dir); } - if skills.is_empty() { + state.last_load_time = SystemTime::now(); + + if state.loaded_skills.is_empty() { tracing::debug!("No skills found in any skills directory"); } else { - tracing::info!(count = skills.len(), "Loaded {} skills total", skills.len()); + tracing::info!(count = state.loaded_skills.len(), "Loaded {} skills total", state.loaded_skills.len()); + } + } + + /// Check if skills directories have been modified since last load + fn has_changed(&self) -> bool { + let state = self.state.lock().unwrap(); + let picobot_changed = if self.picobot_skills_dir.exists() { + let current_mtime = Self::get_dir_mtime(&self.picobot_skills_dir); + current_mtime != state.last_picobot_mtime + } else { + false + }; + + let agent_changed = if self.agent_skills_dir.exists() { + let current_mtime = Self::get_dir_mtime(&self.agent_skills_dir); + current_mtime != state.last_agent_mtime + } else { + false + }; + + picobot_changed || agent_changed + } + + /// Reload skills if changes are detected + pub fn reload_if_changed(&self) -> bool { + if self.has_changed() { + tracing::info!("Skills directories changed, reloading..."); + self.load_skills(); + true + } else { + false + } + } + + /// Get the latest modification time of a directory or any of its children + fn get_dir_mtime(dir: &Path) -> Option { + let mut max_mtime = None; + + if let Ok(metadata) = std::fs::metadata(dir) { + if let Ok(mtime) = metadata.modified() { + max_mtime = Some(mtime); + } } - skills + if let Ok(entries) = std::fs::read_dir(dir) { + for entry in entries.flatten() { + let path = entry.path(); + if let Ok(metadata) = std::fs::metadata(&path) { + if let Ok(mtime) = metadata.modified() { + if max_mtime.map_or(true, |current| mtime > current) { + max_mtime = Some(mtime); + } + } + } + } + } + + max_mtime + } + + /// Get a copy of loaded skills (checks for changes first) + pub fn get_loaded_skills(&self) -> Vec { + self.reload_if_changed(); + let state = self.state.lock().unwrap(); + state.loaded_skills.clone() + } + + /// Get skills marked as always (checks for changes first) + pub fn get_always_skills(&self) -> Vec { + self.reload_if_changed(); + let state = self.state.lock().unwrap(); + state.loaded_skills.iter().filter(|s| s.always).cloned().collect() + } + + /// Get a specific skill by name (checks for changes first) + pub fn get_skill(&self, name: &str) -> Option { + self.reload_if_changed(); + let state = self.state.lock().unwrap(); + state.loaded_skills.iter().find(|s| s.name == name).cloned() + } + + /// List all skills (name + description) (checks for changes first) + pub fn list_skills(&self) -> Vec<(String, String)> { + self.reload_if_changed(); + let state = self.state.lock().unwrap(); + state.loaded_skills + .iter() + .map(|s| (s.name.clone(), s.description.clone())) + .collect() + } + + /// Build XML summary of all skills (for progressive disclosure) (checks for changes first) + pub fn build_skills_summary(&self) -> String { + self.reload_if_changed(); + let state = self.state.lock().unwrap(); + + if state.loaded_skills.is_empty() { + return String::new(); + } + + let mut lines = vec!["".to_string()]; + + for skill in &state.loaded_skills { + if skill.always { + continue; + } + lines.push(" ".to_string()); + lines.push(format!(" {}", escape_xml(&skill.name))); + lines.push(format!( + " {}", + escape_xml(&skill.description) + )); + if let Some(path) = &skill.path { + lines.push(format!(" {}", escape_xml(&path.to_string_lossy()))); + } + lines.push(" ".to_string()); + } + + lines.push("".to_string()); + lines.join("\n") + } + + /// Build prompt for always-injected skills (checks for changes first) + pub fn build_always_skills_prompt(&self) -> String { + self.reload_if_changed(); + let state = self.state.lock().unwrap(); + + let always_skills: Vec<_> = state.loaded_skills.iter().filter(|s| s.always).collect(); + if always_skills.is_empty() { + return String::new(); + } + + let mut parts = Vec::new(); + for skill in always_skills { + parts.push(format!("## Skill: {}\n\n{}", skill.name, skill.content)); + } + + parts.join("\n\n---\n\n") + } + + /// Build full skills prompt combining always skills and summary (checks for changes first) + pub fn build_skills_prompt(&self) -> String { + self.reload_if_changed(); + let state = self.state.lock().unwrap(); + + let mut prompt = String::new(); + + let always_skills: Vec<_> = state.loaded_skills.iter().filter(|s| s.always).collect(); + if !always_skills.is_empty() { + let mut parts = Vec::new(); + for skill in always_skills { + parts.push(format!("## Skill: {}\n\n{}", skill.name, skill.content)); + } + prompt.push_str(&parts.join("\n\n---\n\n")); + prompt.push_str("\n\n"); + } + + let has_other_skills = state.loaded_skills.iter().any(|s| !s.always); + if has_other_skills { + prompt.push_str("## Available Skills\n\n"); + prompt.push_str("Skills teach the agent how to use specific capabilities. Use the `get_skill` tool to load a skill's full content when needed.\n\n"); + + let mut lines = vec!["".to_string()]; + for skill in &state.loaded_skills { + if skill.always { + continue; + } + lines.push(" ".to_string()); + lines.push(format!(" {}", escape_xml(&skill.name))); + lines.push(format!( + " {}", + escape_xml(&skill.description) + )); + if let Some(path) = &skill.path { + lines.push(format!(" {}", escape_xml(&path.to_string_lossy()))); + } + lines.push(" ".to_string()); + } + lines.push("".to_string()); + prompt.push_str(&lines.join("\n")); + } + + prompt } /// Load skills from a specific directory @@ -91,6 +312,7 @@ impl SkillsLoader { tracing::debug!( skill = %skill.name, path = %skill_file.display(), + always = skill.always, "Loaded skill" ); skills.push(skill); @@ -116,78 +338,6 @@ impl SkillsLoader { skills } - /// List all skills (name + description) - pub fn list_skills(&self) -> Vec<(String, String)> { - self.load_skills() - .into_iter() - .map(|s| (s.name, s.description)) - .collect() - } - - /// Get a specific skill by name - pub fn get_skill(&self, name: &str) -> Option { - // Check picobot_skills first - let picobot_path = self.picobot_skills_dir.join(name).join("SKILL.md"); - if picobot_path.exists() { - if let Ok(content) = std::fs::read_to_string(&picobot_path) { - let dir = self.picobot_skills_dir.join(name); - return self.parse_skill(&dir, &content); - } - } - - // Check agent_skills - let agent_path = self.agent_skills_dir.join(name).join("SKILL.md"); - if agent_path.exists() { - if let Ok(content) = std::fs::read_to_string(&agent_path) { - let dir = self.agent_skills_dir.join(name); - return self.parse_skill(&dir, &content); - } - } - - None - } - - /// Build skills prompt for agent context (reloads from disk) - pub fn build_skills_prompt(&self) -> String { - let skills = self.load_skills(); - Self::format_skills_prompt(&skills) - } - - /// Build skills prompt from already-loaded skills (no disk I/O) - pub fn build_skills_prompt_from_skills(skills: &[Skill]) -> String { - Self::format_skills_prompt(skills) - } - - /// Format skills into a prompt string - fn format_skills_prompt(skills: &[Skill]) -> String { - if skills.is_empty() { - return String::new(); - } - - let mut prompt = String::from("## Available Skills\n\n"); - prompt.push_str("Skills teach the agent how to use specific capabilities.\n\n"); - prompt.push_str("\n"); - - for skill in skills { - prompt.push_str(" \n"); - prompt.push_str(&format!(" {}\n", escape_xml(&skill.name))); - prompt.push_str(&format!( - " {}\n", - escape_xml(&skill.description) - )); - prompt.push_str(" \n"); - prompt.push_str(&format!( - " {}\n", - escape_xml(&skill.content) - )); - prompt.push_str(" \n"); - prompt.push_str(" \n"); - } - - prompt.push_str("\n"); - prompt - } - /// Parse a skill from markdown content fn parse_skill(&self, dir: &Path, content: &str) -> Option { let (meta, body) = self.parse_skill_markdown(content); @@ -206,6 +356,8 @@ impl SkillsLoader { name, description, content: body, + always: meta.always.unwrap_or(false), + path: Some(dir.to_path_buf()), }) } @@ -242,6 +394,13 @@ impl SkillsLoader { match key { "name" => meta.name = Some(val.to_string()), "description" => meta.description = Some(val.to_string()), + "always" => { + meta.always = match val.to_lowercase().as_str() { + "true" | "1" | "yes" | "on" => Some(true), + "false" | "0" | "no" | "off" => Some(false), + _ => None, + }; + } _ => {} } } @@ -261,6 +420,7 @@ impl Default for SkillMarkdownMeta { Self { name: None, description: None, + always: None, } } } @@ -311,6 +471,7 @@ mod tests { let content = r#"--- name: test-skill description: A test skill +always: true --- # Test Skill @@ -321,6 +482,7 @@ This is the content. assert_eq!(meta.name, Some("test-skill".to_string())); assert_eq!(meta.description, Some("A test skill".to_string())); + assert_eq!(meta.always, Some(true)); assert!(body.contains("Test Skill")); } @@ -339,12 +501,4 @@ This is the content. ); assert_eq!(extract_description("# Title"), "No description"); } - - #[test] - fn test_load_skills_from_empty_dir() { - let loader = SkillsLoader::new(); - let temp_dir = tempfile::tempdir().unwrap(); - let skills = loader.load_skills_from_dir(temp_dir.path()); - assert!(skills.is_empty()); - } } diff --git a/src/tools/get_skill.rs b/src/tools/get_skill.rs new file mode 100644 index 0000000..8f071d3 --- /dev/null +++ b/src/tools/get_skill.rs @@ -0,0 +1,159 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use serde_json::json; + +use crate::skills::{Skill, SkillsLoader}; +use crate::tools::traits::{Tool, ToolResult}; + +pub struct GetSkillTool { + skills_loader: Arc, +} + +impl GetSkillTool { + pub fn new(skills_loader: Arc) -> Self { + Self { skills_loader } + } + + fn format_skill(&self, skill: &Skill) -> String { + let mut result = format!("# Skill: {}\n\n{}", skill.name, skill.description); + + if let Some(path) = &skill.path { + result.push_str(&format!( + "\n\n**Skill Root Directory:** `{}`\n\nAll files and references in this skill are relative to this directory.", + path.to_string_lossy() + )); + } + + result.push_str(&format!("\n\n---\n\n{}", skill.content)); + result + } +} + +#[async_trait] +impl Tool for GetSkillTool { + fn name(&self) -> &str { + "get_skill" + } + + fn description(&self) -> &str { + "Get complete content and guidance for a specified skill. Use this when you need detailed instructions for a specific type of task." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "skill_name": { + "type": "string", + "description": "Name of the skill to retrieve" + } + }, + "required": ["skill_name"] + }) + } + + fn read_only(&self) -> bool { + true + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let skill_name = match args.get("skill_name").and_then(|v| v.as_str()) { + Some(name) => name, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Missing required parameter: skill_name".to_string()), + }); + } + }; + + match self.skills_loader.get_skill(skill_name) { + Some(skill) => { + let formatted = self.format_skill(&skill); + Ok(ToolResult { + success: true, + output: formatted, + error: None, + }) + } + None => { + let available = self.skills_loader.list_skills(); + let available_str = if available.is_empty() { + "No skills available".to_string() + } else { + available + .iter() + .map(|(name, _)| name.as_str()) + .collect::>() + .join(", ") + }; + Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Skill '{}' not found. Available skills: {}", + skill_name, available_str + )), + }) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::tempdir; + use std::fs::File; + use std::io::Write; + use std::path::PathBuf; + + #[tokio::test] + async fn test_get_existing_skill() { + let temp_dir = tempdir().unwrap(); + + let skill_dir = temp_dir.path().join("test-skill"); + std::fs::create_dir(&skill_dir).unwrap(); + + let mut skill_file = File::create(skill_dir.join("SKILL.md")).unwrap(); + writeln!(skill_file, "---").unwrap(); + writeln!(skill_file, "name: test-skill").unwrap(); + writeln!(skill_file, "description: A test skill").unwrap(); + writeln!(skill_file, "---").unwrap(); + writeln!(skill_file, "# Test Skill").unwrap(); + writeln!(skill_file, "This is the test content.").unwrap(); + + let mut loader = SkillsLoader::new_for_testing( + temp_dir.path().to_path_buf(), + PathBuf::from("/nonexistent"), + ); + loader.load_skills(); + + let tool = GetSkillTool::new(Arc::new(loader)); + + let result = tool + .execute(json!({ "skill_name": "test-skill" })) + .await + .unwrap(); + + assert!(result.success); + assert!(result.output.contains("test-skill")); + assert!(result.output.contains("test content")); + } + + #[tokio::test] + async fn test_get_nonexistent_skill() { + let loader = SkillsLoader::new(); + let tool = GetSkillTool::new(Arc::new(loader)); + + let result = tool + .execute(json!({ "skill_name": "nonexistent" })) + .await + .unwrap(); + + assert!(!result.success); + assert!(result.error.is_some()); + } +} diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 035c8f7..ab1c361 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -3,6 +3,7 @@ pub mod calculator; pub mod file_edit; pub mod file_read; pub mod file_write; +pub mod get_skill; pub mod http_request; pub mod registry; pub mod schema; @@ -14,6 +15,7 @@ pub use calculator::CalculatorTool; pub use file_edit::FileEditTool; pub use file_read::FileReadTool; pub use file_write::FileWriteTool; +pub use get_skill::GetSkillTool; pub use http_request::HttpRequestTool; pub use registry::ToolRegistry; pub use schema::{CleaningStrategy, SchemaCleanr};