系统提示词精简,拆分固定部分与可变部分,固化工具skill顺序,提升缓存命中率

This commit is contained in:
xiaoxixi 2026-06-17 23:25:28 +08:00
parent f08bf85b37
commit fdd7f47305
9 changed files with 432 additions and 235 deletions

View File

@ -427,14 +427,8 @@ impl AgentLoop {
// Build and inject system prompt if not present
let has_system = messages.first().is_some_and(|m| m.role == "system");
if !has_system {
let system_prompt = build_system_prompt(
&self.workspace_dir,
&self.model_name,
&self.tools,
None,
None,
false,
);
let system_prompt =
build_system_prompt(&self.workspace_dir, &self.model_name, &self.tools);
#[cfg(debug_assertions)]
tracing::debug!("System prompt injected:\n{}", system_prompt);
messages.insert(0, ChatMessage::system(system_prompt));

View File

@ -719,6 +719,9 @@ mod tests {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
cached_tokens: None,
cache_read_input_tokens: None,
cache_creation_input_tokens: None,
},
})
}

View File

@ -16,11 +16,6 @@ pub struct PromptContext<'a> {
pub workspace_dir: &'a Path,
pub model_name: &'a str,
pub tools: &'a ToolRegistry,
pub session_id: Option<&'a str>,
/// Pre-fetched memory context string to inject.
pub memory_context: Option<&'a str>,
/// Whether this session has compressed history available via timeline_recall.
pub has_compressed_history: bool,
}
/// Trait for system prompt sections.
@ -43,14 +38,14 @@ impl SystemPromptBuilder {
Box::new(AgentProfileSection),
Box::new(UserProfileSection),
Box::new(RuntimeSection),
Box::new(DateTimeSection),
Box::new(WorkspaceSection),
Box::new(YourTaskSection),
Box::new(DecisionOrderSection),
Box::new(ToolHonestySection),
Box::new(ToolUsageSection),
Box::new(SafetySection),
Box::new(CrossChannelSection),
Box::new(MemorySection),
Box::new(HistorySection),
Box::new(DelegationSection),
],
}
@ -72,7 +67,6 @@ impl SystemPromptBuilder {
Box::new(SafetySection),
Box::new(SubAgentToolsSection { http_get_only }),
Box::new(WorkspaceSection),
Box::new(DateTimeSection),
];
if let Some(sp) = skills_prompt {
sections.push(Box::new(SubAgentSkillsSection { skills_prompt: sp }));
@ -114,9 +108,27 @@ impl PromptSection for ToolHonestySection {
fn build(&self, _ctx: &PromptContext<'_>) -> String {
"## 关键规则:工具诚实性
- \"没有找到结果\"
-
- "
-
- \"没有找到结果\";如果工具失败,直接报告错误。
- "
.to_string()
}
}
/// Tool calls should stay invisible to the user.
pub struct ToolUsageSection;
impl PromptSection for ToolUsageSection {
fn name(&self) -> &str {
"tool_usage"
}
fn build(&self, _ctx: &PromptContext<'_>) -> String {
"## 工具使用方式
-
- 使
- "
.to_string()
}
}
@ -132,10 +144,29 @@ impl PromptSection for YourTaskSection {
fn build(&self, _ctx: &PromptContext<'_>) -> String {
"## 你的任务
使skill来完成目标
使
"
使 skill
-
-
- "
.to_string()
}
}
/// Explicit decision order for real user scenarios.
pub struct DecisionOrderSection;
impl PromptSection for DecisionOrderSection {
fn name(&self) -> &str {
"decision_order"
}
fn build(&self, _ctx: &PromptContext<'_>) -> String {
"## 决策顺序
1.
2. 使
3. //"
.to_string()
}
}
@ -229,24 +260,6 @@ impl PromptSection for AgentProfileSection {
}
}
/// Current date and time.
pub struct DateTimeSection;
impl PromptSection for DateTimeSection {
fn name(&self) -> &str {
"datetime"
}
fn build(&self, _ctx: &PromptContext<'_>) -> String {
let now = chrono::Local::now();
format!(
"## 当前日期与时间\n\n{} ({})",
now.format("%Y-%m-%d %H:%M:%S"),
now.format("%Z")
)
}
}
/// Cross-channel messaging and system notification guidance for LLM.
pub struct CrossChannelSection;
@ -255,49 +268,14 @@ impl PromptSection for CrossChannelSection {
"cross_channel"
}
fn build(&self, ctx: &PromptContext<'_>) -> String {
let session_line = if let Some(id) = ctx.session_id {
format!("当前会话的 ID 是 `{}`。\n", id)
} else {
String::new()
};
fn build(&self, _ctx: &PromptContext<'_>) -> String {
"## 关于会话和跨渠道消息
format!(
r#"## 关于会话和跨渠道消息
### ID
session ID<channel>:<chat_id>:<dialog_id>
- channel: "cli_chat""feishu"
- chat_id: /
- dialog_id: chat dialog
{}###
`[message from X]` assistant
send_message
- X: ID "unknown"
### send_message
- target_chat_id: <channel>:<chat_id> <channel>:<chat_id>:<dialog_id>
- content:
### chat_manager
- action = "list_sessions" offset/count
- action = "list_channels"
- action = "list_messages" session
- session_id (): ID
- count (): 20 100
- offset (): N 0
- before_time (): Unix
- after_time (): Unix
使 list_messages offset "#,
session_line
)
- `[message from X]`
- 使 `send_message``target_chat_id` `<channel>:<chat_id>` `<channel>:<chat_id>:<dialog_id>`
- 使 `chat_manager`
- `chat_manager` `list_messages` "
.to_string()
}
}
@ -310,13 +288,8 @@ impl PromptSection for RuntimeSection {
}
fn build(&self, ctx: &PromptContext<'_>) -> String {
let host = hostname::get()
.map(|h| h.to_string_lossy().to_string())
.unwrap_or_else(|_| "unknown".to_string());
format!(
"## 运行环境\n\n主机: {} | 操作系统: {} | 模型: {}",
host,
std::env::consts::OS,
"## 运行环境\n\n使用的模型是 `{}`。所有文件操作都应默认针对当前工作目录。",
ctx.model_name
)
}
@ -330,47 +303,16 @@ impl PromptSection for MemorySection {
"memory"
}
fn build(&self, ctx: &PromptContext<'_>) -> String {
fn build(&self, _ctx: &PromptContext<'_>) -> String {
let guide = r#"## 记忆系统
###
- **Knowledge**
- **Timeline线** timeline_recall
###
- **memory_recall** query
- **timeline_recall** session_id
###
使 `memory_store`
-
- 使
-
"#;
match ctx.memory_context {
Some(context) if !context.is_empty() => {
format!("{}\n\n### 记忆上下文\n\n{}", guide, context)
}
_ => guide.to_string(),
}
}
}
/// Prompt agent to use timeline_recall if compressed history exists.
pub struct HistorySection;
impl PromptSection for HistorySection {
fn name(&self) -> &str {
"history"
}
fn build(&self, ctx: &PromptContext<'_>) -> String {
if ctx.has_compressed_history {
"## 历史会话\n之前的对话摘要已归档。如需回顾历史上下文,使用 `timeline_recall` 工具搜索。".to_string()
} else {
String::new()
}
- **Knowledge**
- **Timeline线** `timeline_recall`
- **memory_recall**
- **timeline_recall**
-
- "#;
guide.to_string()
}
}
@ -384,37 +326,12 @@ impl PromptSection for DelegationSection {
fn build(&self, _ctx: &PromptContext<'_>) -> String {
"## 子 Agent 委托原则\n\n\
使 delegate Agent\n\
\n\
### \n\
- 使 mode=\"parallel\"\n\
- 使 mode=\"background\"\n\
- \n\
\n\
### \n\
- **** Agent \n\
- **** file_readfile_searchweb_fetch bashfile_writefile_edit\n\
- **** delegate Agent\n\
- **** Agent \n\
\n\
### Skill \n\
- skill allowed_tools get_skill\n\
- prompt Agent 使 get_skill \n\
- \"使用 get_skill action='get' skill_name='pdf' 加载 PDF 处理技能后完成任务\"\n\
\n\
### \n\
- prompt \n\
- prompt \"跳过 .tmp 文件\"\n\
- \n\
\n\
### \n\
- 使 mode=\"parallel\",任务定义在 tasks 数组中\n\
- \n\
- 5 \n\
\n\
### \n\
- 30s 使 mode=\"background\"\n\
- ".to_string()
- \n\
- Agent \n\
- delegate Agent\n\
- prompt \n\
- background"
.to_string()
}
}
@ -434,16 +351,16 @@ impl PromptSection for SubAgentIdentitySection {
fn build(&self, _ctx: &PromptContext<'_>) -> String {
format!(
"## 子 Agent\n\n\
Agent Agent Agent\n\
Agent\n\
\n\
## \n\n\
{}\n\
\n\
## \n\
- \n\
- \n\
- 使\n\
- 使 delegate \n\
- \n\
- 使 delegate \n\
- \n\
- \n\
- {}",
self.task, self.timeout,
@ -525,25 +442,41 @@ fn load_file_from_dir(dir: &Path, filename: &str, max_chars: usize) -> Option<St
}
/// Build a complete system prompt with default configuration.
pub fn build_system_prompt(
workspace_dir: &Path,
model_name: &str,
tools: &ToolRegistry,
session_id: Option<&str>,
memory_context: Option<&str>,
has_compressed_history: bool,
) -> String {
pub fn build_system_prompt(workspace_dir: &Path, model_name: &str, tools: &ToolRegistry) -> String {
let ctx = PromptContext {
workspace_dir,
model_name,
tools,
session_id,
memory_context,
has_compressed_history,
};
SystemPromptBuilder::with_defaults().build(&ctx)
}
/// Build a runtime context tail that should be appended to the latest user message.
pub fn build_runtime_context(session_id: Option<&str>, memory_context: Option<&str>) -> String {
let mut sections = Vec::new();
let now = chrono::Local::now();
sections.push(format!(
"## 运行时上下文\n\n- 当前日期与时间: {} ({})",
now.format("%Y-%m-%d %H:%M:%S"),
now.format("%Z")
));
if let Some(id) = session_id {
sections.push(format!("- 会话 ID: `{}`", id));
}
if let Some(context) = memory_context.filter(|s| !s.trim().is_empty()) {
sections.push(format!("### 记忆上下文\n\n{}", context));
}
if sections.is_empty() {
String::new()
} else {
sections.join("\n")
}
}
/// Build a system prompt for a sub-agent with all relevant operational sections.
pub fn build_sub_agent_system_prompt(
task: &str,
@ -558,9 +491,6 @@ pub fn build_sub_agent_system_prompt(
workspace_dir,
model_name,
tools,
session_id: None,
memory_context: None,
has_compressed_history: false,
};
SystemPromptBuilder::with_sub_agent_defaults(task, timeout_human, skills_prompt, http_get_only)
.build(&ctx)
@ -579,9 +509,6 @@ mod tests {
workspace_dir: &temp_dir,
model_name: "test-model",
tools: &tools,
session_id: None,
memory_context: None,
has_compressed_history: false,
};
let prompt = SystemPromptBuilder::with_defaults().build(&ctx);
@ -589,7 +516,6 @@ mod tests {
assert!(prompt.contains("## 关键规则:工具诚实性"));
assert!(prompt.contains("## 安全规则"));
assert!(prompt.contains("## 工作目录"));
assert!(prompt.contains("## 当前日期与时间"));
assert!(prompt.contains("## 运行环境"));
}
@ -611,46 +537,58 @@ mod tests {
let temp_dir = std::env::temp_dir();
let tools = ToolRegistry::new();
let prompt = build_system_prompt(&temp_dir, "test-model", &tools, None, None, false);
let prompt = build_system_prompt(&temp_dir, "test-model", &tools);
assert!(!prompt.is_empty());
assert!(prompt.contains("test-model"));
}
#[test]
fn test_memory_section_with_context() {
fn test_prompt_contains_decision_order_section() {
let temp_dir = std::env::temp_dir();
let tools = ToolRegistry::new();
let ctx = PromptContext {
workspace_dir: &temp_dir,
model_name: "test",
tools: &tools,
session_id: None,
memory_context: Some("- user_pref: Prefers Rust"),
has_compressed_history: false,
};
let prompt = build_system_prompt(&temp_dir, "test-model", &tools);
let prompt = SystemPromptBuilder::with_defaults().build(&ctx);
assert!(prompt.contains("## 记忆上下文"));
assert!(prompt.contains("## 决策顺序"));
assert!(prompt.contains("直接回答"));
assert!(prompt.contains("使用工具"));
assert!(prompt.contains("追问用户"));
}
#[test]
fn test_build_system_prompt_is_stable_across_calls() {
let temp_dir = std::env::temp_dir();
let tools = ToolRegistry::new();
let prompt_a = build_system_prompt(&temp_dir, "test-model", &tools);
let prompt_b = build_system_prompt(&temp_dir, "test-model", &tools);
assert_eq!(prompt_a, prompt_b);
}
#[test]
fn test_runtime_context_with_memory() {
let temp_dir = std::env::temp_dir();
let tools = ToolRegistry::new();
let _ = (temp_dir, tools);
let prompt = build_runtime_context(Some("session-123"), Some("- user_pref: Prefers Rust"));
assert!(prompt.contains("## 运行时上下文"));
assert!(prompt.contains("session-123"));
assert!(prompt.contains("Prefers Rust"));
}
#[test]
fn test_memory_section_without_context() {
fn test_runtime_context_without_memory() {
let temp_dir = std::env::temp_dir();
let tools = ToolRegistry::new();
let ctx = PromptContext {
workspace_dir: &temp_dir,
model_name: "test",
tools: &tools,
session_id: None,
memory_context: None,
has_compressed_history: false,
};
let _ = (temp_dir, tools);
let prompt = SystemPromptBuilder::with_defaults().build(&ctx);
assert!(!prompt.contains("## 记忆上下文"));
let prompt = build_runtime_context(None, None);
assert!(prompt.contains("## 运行时上下文"));
assert!(prompt.contains("当前日期与时间"));
}
}

View File

@ -12,12 +12,34 @@ use std::sync::Arc;
const LLM_REQUEST_TIMEOUT_SECS: u64 = 300;
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> {
#[derive(Serialize)]
struct CacheControl {
#[serde(rename = "type")]
cache_type: String,
}
impl CacheControl {
fn ephemeral() -> Self {
Self {
cache_type: "ephemeral".to_string(),
}
}
}
fn convert_content_blocks(blocks: &[ContentBlock], cacheable: bool) -> Vec<serde_json::Value> {
blocks
.iter()
.map(|b| match b {
ContentBlock::Text { text } => {
serde_json::json!({ "type": "text", "text": text })
if cacheable {
serde_json::json!({
"type": "text",
"text": text,
"cache_control": CacheControl::ephemeral(),
})
} else {
serde_json::json!({ "type": "text", "text": text })
}
}
ContentBlock::ImageUrl { image_url } => convert_image_url_to_anthropic(&image_url.url),
})
@ -120,6 +142,8 @@ struct AnthropicTool {
name: String,
description: String,
input_schema: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
}
#[derive(Deserialize)]
@ -161,6 +185,10 @@ struct AnthropicUsage {
input_tokens: u32,
#[serde(default)]
output_tokens: u32,
#[serde(default)]
cache_read_input_tokens: Option<u32>,
#[serde(default)]
cache_creation_input_tokens: Option<u32>,
}
#[async_trait]
@ -180,6 +208,7 @@ impl LLMProvider for AnthropicProvider {
name: t.function.name.clone(),
description: t.function.description.clone(),
input_schema: t.function.parameters.clone(),
cache_control: Some(CacheControl::ephemeral()),
})
.collect()
});
@ -213,7 +242,7 @@ impl LLMProvider for AnthropicProvider {
"content": output,
})]
} else {
let mut blocks = convert_content_blocks(&m.content);
let mut blocks = convert_content_blocks(&m.content, m.role == "system");
// Append tool_use blocks from assistant messages with tool calls
if let Some(tool_calls) = m.tool_calls.as_ref().filter(|c| !c.is_empty()) {
for tc in tool_calls {
@ -369,6 +398,18 @@ impl LLMProvider for AnthropicProvider {
.as_ref()
.map(|u| u.input_tokens + u.output_tokens)
.unwrap_or(0),
cached_tokens: anthropic_resp
.usage
.as_ref()
.and_then(|u| u.cache_read_input_tokens),
cache_read_input_tokens: anthropic_resp
.usage
.as_ref()
.and_then(|u| u.cache_read_input_tokens),
cache_creation_input_tokens: anthropic_resp
.usage
.as_ref()
.and_then(|u| u.cache_creation_input_tokens),
},
};
@ -400,3 +441,39 @@ impl LLMProvider for AnthropicProvider {
&self.model_id
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_convert_content_blocks_adds_cache_control_for_system_text() {
let blocks = vec![ContentBlock::text("hello")];
let serialized = convert_content_blocks(&blocks, true);
assert_eq!(serialized[0]["type"], "text");
assert_eq!(serialized[0]["cache_control"]["type"], "ephemeral");
}
#[test]
fn test_convert_content_blocks_leaves_user_text_uncached() {
let blocks = vec![ContentBlock::text("hello")];
let serialized = convert_content_blocks(&blocks, false);
assert!(serialized[0].get("cache_control").is_none());
}
#[test]
fn test_anthropic_tool_serializes_cache_control() {
let tool = AnthropicTool {
name: "alpha".to_string(),
description: "desc".to_string(),
input_schema: json!({}),
cache_control: Some(CacheControl::ephemeral()),
};
let value = serde_json::to_value(tool).unwrap();
assert_eq!(value["cache_control"]["type"], "ephemeral");
}
}

View File

@ -188,6 +188,16 @@ struct OpenAIUsage {
completion_tokens: u32,
#[serde(default)]
total_tokens: u32,
#[serde(default)]
cached_tokens: Option<u32>,
#[serde(default)]
prompt_tokens_details: Option<OpenAIPromptTokensDetails>,
}
#[derive(Deserialize, Default)]
struct OpenAIPromptTokensDetails {
#[serde(default)]
cached_tokens: Option<u32>,
}
#[async_trait]
@ -332,6 +342,12 @@ impl LLMProvider for OpenAIProvider {
})
.collect();
let usage = openai_resp.usage;
let nested_cached_tokens = usage
.prompt_tokens_details
.as_ref()
.and_then(|d| d.cached_tokens);
let cached_tokens = nested_cached_tokens.or(usage.cached_tokens);
let response = ChatCompletionResponse {
id: openai_resp.id,
model: openai_resp.model,
@ -339,9 +355,12 @@ impl LLMProvider for OpenAIProvider {
reasoning_content: first_choice.message.reasoning_content,
tool_calls,
usage: Usage {
prompt_tokens: openai_resp.usage.prompt_tokens,
completion_tokens: openai_resp.usage.completion_tokens,
total_tokens: openai_resp.usage.total_tokens,
prompt_tokens: usage.prompt_tokens,
completion_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
cached_tokens: cached_tokens,
cache_read_input_tokens: None,
cache_creation_input_tokens: None,
},
};
@ -463,4 +482,39 @@ mod tests {
assert!(message.tool_calls.is_empty());
assert_eq!(response.usage.total_tokens, 11806);
}
#[test]
fn test_decode_response_exposes_cached_tokens() {
let text = r#"{
"id": "d21abaa6552741949e2aba76bde59359",
"choices": [{
"finish_reason": "stop",
"index": 0,
"message": {
"content": "你好!",
"role": "assistant",
"tool_calls": null
}
}],
"created": 1781622889,
"model": "mimo-v2.5",
"object": "chat.completion",
"usage": {
"completion_tokens": 65,
"prompt_tokens": 11741,
"total_tokens": 11806,
"prompt_tokens_details": {"cached_tokens": 1200}
}
}"#;
let response: OpenAIResponse = serde_json::from_str(text).unwrap();
assert_eq!(
response
.usage
.prompt_tokens_details
.as_ref()
.and_then(|d| d.cached_tokens),
Some(1200)
);
}
}

View File

@ -121,6 +121,12 @@ pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub cached_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_read_input_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_creation_input_tokens: Option<u32>,
}
#[async_trait]

View File

@ -29,7 +29,7 @@ pub enum HandleResult {
AgentProcessing,
}
use crate::agent::context_compressor::ContextCompressionConfig;
use crate::agent::system_prompt::build_system_prompt;
use crate::agent::system_prompt::{build_runtime_context, build_system_prompt};
use crate::agent::{AgentError, AgentLoop, ContextCompressor};
use crate::channels::slash_command::parse_slash_command;
use crate::config::BrowserConfig;
@ -472,6 +472,18 @@ impl Session {
}
}
fn append_runtime_context_to_user_message(message: &mut ChatMessage, runtime_context: &str) {
if runtime_context.trim().is_empty() {
return;
}
if message.content.trim().is_empty() {
message.content = runtime_context.to_string();
} else {
message.content = format!("{}\n\n{}", message.content, runtime_context);
}
}
pub fn create_user_message_with_source(
&self,
content: &str,
@ -615,14 +627,11 @@ impl Session {
}
/// 构建系统提示词(包含 AgentLoop 的基础提示词 + skills + memory
pub fn build_system_prompt(&self, skills_prompt: &str, memory_context: Option<&str>) -> String {
pub fn build_system_prompt(&self, skills_prompt: &str) -> String {
let base_prompt = build_system_prompt(
&self.provider_config.workspace_dir,
&self.provider_config.model_id,
&self.tools,
Some(&self.id.to_string()),
memory_context,
self.last_compressed_message_at.is_some(),
);
if skills_prompt.trim().is_empty() {
@ -1266,7 +1275,7 @@ impl SessionManager {
// Build the same system prompt that would be injected to the model
let skills_prompt = self.skills_loader.build_skills_prompt();
let system_prompt = session_guard.build_system_prompt(&skills_prompt, None);
let system_prompt = session_guard.build_system_prompt(&skills_prompt);
let filepath = session_guard
.dump_to_file(&system_prompt)
@ -1989,8 +1998,7 @@ fn spawn_agent_worker(
let media_refs: Vec<MediaRef> =
task.media.iter().map(|m| m.to_media_ref()).collect();
let user_message =
guard.create_user_message(&task.content, media_refs);
let user_message = guard.create_user_message(&task.content, media_refs);
let user_persist = guard.add_message_in_memory(user_message, true);
drop(guard);
if let Err(e) = persist_added_message(user_persist).await {
@ -2071,12 +2079,15 @@ fn spawn_agent_worker(
_ => None,
};
let runtime_context =
build_runtime_context(Some(unified_str.as_str()), memory_context.as_deref());
let system_prompt_out = {
let guard = session.lock().await;
if guard.worker_generation != worker_gen {
return;
}
guard.build_system_prompt(&skills_prompt, memory_context.as_deref())
guard.build_system_prompt(&skills_prompt)
};
let compression_result = compressor.compress_if_needed(history_raw).await;
@ -2119,6 +2130,9 @@ fn spawn_agent_worker(
}
};
history_out.insert(0, ChatMessage::system(system_prompt_out.clone()));
if let Some(last_msg) = history_out.iter_mut().rev().find(|m| m.role == "user") {
Session::append_runtime_context_to_user_message(last_msg, &runtime_context);
}
// Phase 2 + 3: LLM call with cancellation
let session2 = session.clone();
@ -2205,6 +2219,13 @@ fn spawn_agent_worker(
0,
ChatMessage::system(system_prompt_out.clone()),
);
if let Some(last_msg) = retry.iter_mut().rev().find(|m| m.role == "user")
{
Session::append_runtime_context_to_user_message(
last_msg,
&runtime_context,
);
}
retry
};
@ -2312,9 +2333,6 @@ impl SessionManager {
&self.provider_config.workspace_dir,
&self.provider_config.model_id,
&self.tools,
Some(&format!("cron:{}:{}", job_name, job_id)),
None,
false,
);
let cron_context = format!(
"## 定时任务执行\n\n\

View File

@ -252,19 +252,21 @@ impl SkillsLoader {
pub fn get_loaded_skills(&self) -> Vec<Skill> {
self.reload_if_changed();
let state = self.state.lock().unwrap();
state.loaded_skills.clone()
Self::sort_skills(state.loaded_skills.clone())
}
/// Get skills marked as always (checks for changes first)
pub fn get_always_skills(&self) -> Vec<Skill> {
self.reload_if_changed();
let state = self.state.lock().unwrap();
state
.loaded_skills
.iter()
.filter(|s| s.always)
.cloned()
.collect()
Self::sort_skills(
state
.loaded_skills
.iter()
.filter(|s| s.always)
.cloned()
.collect(),
)
}
/// Get a specific skill by name (checks for changes first)
@ -278,9 +280,8 @@ impl SkillsLoader {
pub fn list_skills(&self) -> Vec<(String, String)> {
self.reload_if_changed();
let state = self.state.lock().unwrap();
state
.loaded_skills
.iter()
Self::sort_skills(state.loaded_skills.clone())
.into_iter()
.map(|s| (s.name.clone(), s.description.clone()))
.collect()
}
@ -294,6 +295,7 @@ impl SkillsLoader {
return String::new();
}
let loaded_skills = Self::sort_skills(state.loaded_skills.clone());
let mut prompt = String::from("## Skills\n\n");
// Directory conventions
@ -308,7 +310,7 @@ impl SkillsLoader {
);
// Always skills summary
let always_skills: Vec<_> = state.loaded_skills.iter().filter(|s| s.always).collect();
let always_skills: Vec<_> = loaded_skills.iter().filter(|s| s.always).collect();
if !always_skills.is_empty() {
prompt.push_str("### 常用技能\n\n");
for skill in &always_skills {
@ -348,6 +350,22 @@ impl SkillsLoader {
prompt
}
fn sort_skills(mut skills: Vec<Skill>) -> Vec<Skill> {
skills.sort_by(|a, b| {
b.always
.cmp(&a.always)
.then_with(|| a.name.cmp(&b.name))
.then_with(|| a.description.cmp(&b.description))
.then_with(|| {
a.path
.as_ref()
.map(|p| p.to_string_lossy().to_string())
.cmp(&b.path.as_ref().map(|p| p.to_string_lossy().to_string()))
})
});
skills
}
/// Load skills from a specific directory
fn load_skills_from_dir(&self, dir: &Path) -> Vec<Skill> {
let mut skills = Vec::new();
@ -529,4 +547,43 @@ This is the content.
);
assert_eq!(extract_description("# Title"), "No description");
}
#[test]
fn test_build_skills_prompt_is_sorted() {
let loader = SkillsLoader::new_for_testing(
PathBuf::from("/tmp/picobot"),
PathBuf::from("/tmp/agents"),
);
{
let mut state = loader.state.lock().unwrap();
state.loaded_skills = vec![
Skill {
name: "zeta".to_string(),
description: "Z".to_string(),
content: "Z".to_string(),
always: false,
path: None,
},
Skill {
name: "beta".to_string(),
description: "B".to_string(),
content: "B".to_string(),
always: true,
path: None,
},
Skill {
name: "alpha".to_string(),
description: "A".to_string(),
content: "A".to_string(),
always: true,
path: None,
},
];
}
let prompt = loader.build_skills_prompt();
let alpha_pos = prompt.find("**alpha**").unwrap();
let beta_pos = prompt.find("**beta**").unwrap();
assert!(alpha_pos < beta_pos);
}
}

View File

@ -39,7 +39,8 @@ impl ToolRegistry {
}
pub fn get_definitions(&self) -> Vec<Tool> {
self.tools
let mut defs: Vec<Tool> = self
.tools
.lock()
.unwrap()
.values()
@ -51,7 +52,10 @@ impl ToolRegistry {
parameters: tool.parameters_schema(),
},
})
.collect()
.collect();
defs.sort_by(|a, b| a.function.name.cmp(&b.function.name));
defs
}
pub fn has_tools(&self) -> bool {
@ -88,3 +92,49 @@ impl Default for ToolRegistry {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tools::traits::ToolResult;
use async_trait::async_trait;
use serde_json::json;
struct TestTool(&'static str);
#[async_trait]
impl ToolTrait for TestTool {
fn name(&self) -> &str {
self.0
}
fn description(&self) -> &str {
self.0
}
fn parameters_schema(&self) -> serde_json::Value {
json!({})
}
async fn execute(&self, _args: serde_json::Value) -> anyhow::Result<ToolResult> {
Ok(ToolResult {
success: true,
output: "ok".to_string(),
error: None,
})
}
}
#[test]
fn test_get_definitions_sorted_by_name() {
let registry = ToolRegistry::new();
registry.register(TestTool("zeta"));
registry.register(TestTool("alpha"));
registry.register(TestTool("beta"));
let defs = registry.get_definitions();
let names: Vec<_> = defs.into_iter().map(|tool| tool.function.name).collect();
assert_eq!(names, vec!["alpha", "beta", "zeta"]);
}
}