refactor: remove unused functions and improve path resolution in tools
This commit is contained in:
parent
5aec8cefb9
commit
5ef89cd667
@ -10,7 +10,6 @@
|
|||||||
//! - USER.md — user preferences and profile
|
//! - USER.md — user preferences and profile
|
||||||
|
|
||||||
use crate::tools::ToolRegistry;
|
use crate::tools::ToolRegistry;
|
||||||
use std::fmt::Write;
|
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
/// Maximum characters per injected workspace file.
|
/// Maximum characters per injected workspace file.
|
||||||
@ -101,27 +100,6 @@ impl PromptSection for ToolHonestySection {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// List of available tools.
|
|
||||||
pub struct ToolsSection;
|
|
||||||
|
|
||||||
impl PromptSection for ToolsSection {
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
"tools"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build(&self, ctx: &PromptContext<'_>) -> String {
|
|
||||||
if !ctx.tools.has_tools() {
|
|
||||||
return String::new();
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut output = String::from("## 工具\n\n你可以使用以下工具:\n\n");
|
|
||||||
for (name, tool) in ctx.tools.iter() {
|
|
||||||
let _ = writeln!(output, "- **{}**: {}", name, tool.description());
|
|
||||||
}
|
|
||||||
output
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Instructions for the task.
|
/// Instructions for the task.
|
||||||
pub struct YourTaskSection;
|
pub struct YourTaskSection;
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use tokio::sync::{mpsc, Mutex};
|
use tokio::sync::{mpsc, Mutex};
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
use crate::bus::{ControlMessage, InboundMessage, MessageBus, OutboundMessage};
|
use crate::bus::{ControlMessage, InboundMessage, MessageBus, OutboundMessage};
|
||||||
use crate::session::{SessionCommand, SessionEvent, UnifiedSessionId};
|
use crate::session::{SessionCommand, SessionEvent, UnifiedSessionId};
|
||||||
@ -9,11 +8,6 @@ use crate::protocol::{parse_inbound, WsInbound, WsOutbound, SlashCommandInfo};
|
|||||||
|
|
||||||
use super::base::{Channel, ChannelError};
|
use super::base::{Channel, ChannelError};
|
||||||
|
|
||||||
/// Generate a short ID (8 characters) from a UUID
|
|
||||||
fn short_id() -> String {
|
|
||||||
Uuid::new_v4().to_string()[..8].to_string()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// Client - Connected CLI client
|
// Client - Connected CLI client
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
@ -49,7 +43,7 @@ impl CliChatChannel {
|
|||||||
/// Register a new client connection, returns (session_id, client)
|
/// Register a new client connection, returns (session_id, client)
|
||||||
pub(crate) async fn register_client(&self, sender: mpsc::Sender<WsOutbound>) -> (String, Arc<Client>) {
|
pub(crate) async fn register_client(&self, sender: mpsc::Sender<WsOutbound>) -> (String, Arc<Client>) {
|
||||||
// Generate connection ID (used as chat_id) - use short ID
|
// Generate connection ID (used as chat_id) - use short ID
|
||||||
let connection_id = short_id();
|
let connection_id = crate::util::short_id();
|
||||||
|
|
||||||
let client = Arc::new(Client {
|
let client = Arc::new(Client {
|
||||||
sender,
|
sender,
|
||||||
@ -122,7 +116,7 @@ impl CliChatChannel {
|
|||||||
let msg = InboundMessage {
|
let msg = InboundMessage {
|
||||||
channel: self.name().to_string(),
|
channel: self.name().to_string(),
|
||||||
sender_id: "cli".to_string(),
|
sender_id: "cli".to_string(),
|
||||||
chat_id: chat_id.unwrap_or_else(short_id),
|
chat_id: chat_id.unwrap_or_else(crate::util::short_id),
|
||||||
content,
|
content,
|
||||||
timestamp: crate::bus::message::current_timestamp(),
|
timestamp: crate::bus::message::current_timestamp(),
|
||||||
media: Vec::new(),
|
media: Vec::new(),
|
||||||
@ -166,7 +160,7 @@ impl CliChatChannel {
|
|||||||
WsInbound::CreateSession { title } => {
|
WsInbound::CreateSession { title } => {
|
||||||
// Use current session's chat_id if available, otherwise generate new one
|
// Use current session's chat_id if available, otherwise generate new one
|
||||||
let chat_id = current_session_guard.clone()
|
let chat_id = current_session_guard.clone()
|
||||||
.unwrap_or_else(short_id);
|
.unwrap_or_else(crate::util::short_id);
|
||||||
let new_id = self.create_session_via_control(&chat_id, title.as_deref()).await?;
|
let new_id = self.create_session_via_control(&chat_id, title.as_deref()).await?;
|
||||||
*current_session_guard = Some(new_id.clone());
|
*current_session_guard = Some(new_id.clone());
|
||||||
let _ = client
|
let _ = client
|
||||||
@ -491,7 +485,7 @@ impl Channel for CliChatChannel {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
WsOutbound::AssistantResponse {
|
WsOutbound::AssistantResponse {
|
||||||
id: short_id(),
|
id: crate::util::short_id(),
|
||||||
content: msg.content.clone(),
|
content: msg.content.clone(),
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
}
|
}
|
||||||
|
|||||||
@ -14,3 +14,4 @@ pub mod scheduler;
|
|||||||
pub mod skills;
|
pub mod skills;
|
||||||
pub mod storage;
|
pub mod storage;
|
||||||
pub mod tools;
|
pub mod tools;
|
||||||
|
pub mod util;
|
||||||
|
|||||||
@ -133,8 +133,6 @@ struct OpenAIMessage {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
content: Option<String>,
|
content: Option<String>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
name: Option<String>,
|
|
||||||
#[serde(default)]
|
|
||||||
tool_calls: Vec<OpenAIToolCall>,
|
tool_calls: Vec<OpenAIToolCall>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -143,8 +141,6 @@ struct OpenAIToolCall {
|
|||||||
id: String,
|
id: String,
|
||||||
#[serde(rename = "function")]
|
#[serde(rename = "function")]
|
||||||
function: OAIFunction,
|
function: OAIFunction,
|
||||||
#[serde(default)]
|
|
||||||
index: Option<u32>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
@ -219,11 +215,13 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
if !status.is_success() {
|
if !status.is_success() {
|
||||||
let error = format!("API error {}: {}", status, text);
|
let error = format!("API error {}: {}", status, text);
|
||||||
if let Some(ref storage) = self.storage {
|
if let Some(ref storage) = self.storage {
|
||||||
let _ = storage.append_llm_call(
|
if let Err(e) = storage.append_llm_call(
|
||||||
&self.name, &self.model_id, &req_body_str,
|
&self.name, &self.model_id, &req_body_str,
|
||||||
Some(&text), Some(&error),
|
Some(&text), Some(&error),
|
||||||
start.elapsed().as_millis() as u64,
|
start.elapsed().as_millis() as u64,
|
||||||
).await;
|
).await {
|
||||||
|
tracing::warn!("failed to persist LLM call: {}", e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return Err(error.into());
|
return Err(error.into());
|
||||||
}
|
}
|
||||||
@ -240,20 +238,25 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
let err = err_msg.clone();
|
let err = err_msg.clone();
|
||||||
let s = storage.clone();
|
let s = storage.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let _ = s.append_llm_call(&name, &model, &req, Some(&resp), Some(&err), dur).await;
|
if let Err(e) = s.append_llm_call(&name, &model, &req, Some(&resp), Some(&err), dur).await {
|
||||||
|
tracing::warn!("failed to persist LLM call (decode error): {}", e);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
err_msg
|
err_msg
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let content = openai_resp.choices[0]
|
let first_choice = openai_resp.choices.into_iter().next()
|
||||||
|
.ok_or("no choices in response")?;
|
||||||
|
|
||||||
|
let content = first_choice
|
||||||
.message
|
.message
|
||||||
.content
|
.content
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.unwrap_or(&String::new())
|
.unwrap_or(&String::new())
|
||||||
.clone();
|
.clone();
|
||||||
|
|
||||||
let tool_calls: Vec<ToolCall> = openai_resp.choices[0]
|
let tool_calls: Vec<ToolCall> = first_choice
|
||||||
.message
|
.message
|
||||||
.tool_calls
|
.tool_calls
|
||||||
.iter()
|
.iter()
|
||||||
@ -277,11 +280,13 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
};
|
};
|
||||||
|
|
||||||
if let Some(ref storage) = self.storage {
|
if let Some(ref storage) = self.storage {
|
||||||
let _ = storage.append_llm_call(
|
if let Err(e) = storage.append_llm_call(
|
||||||
&self.name, &self.model_id, &req_body_str,
|
&self.name, &self.model_id, &req_body_str,
|
||||||
Some(&text), None,
|
Some(&text), None,
|
||||||
start.elapsed().as_millis() as u64,
|
start.elapsed().as_millis() as u64,
|
||||||
).await;
|
).await {
|
||||||
|
tracing::warn!("failed to persist LLM call: {}", e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(response)
|
Ok(response)
|
||||||
|
|||||||
@ -144,7 +144,9 @@ impl Scheduler {
|
|||||||
media: vec![],
|
media: vec![],
|
||||||
metadata: std::collections::HashMap::new(),
|
metadata: std::collections::HashMap::new(),
|
||||||
};
|
};
|
||||||
let _ = self.bus.publish_outbound(outbound).await;
|
if let Err(e) = self.bus.publish_outbound(outbound).await {
|
||||||
|
tracing::warn!(job_id = %job.id, "scheduler: failed to publish outbound: {}", e);
|
||||||
|
}
|
||||||
|
|
||||||
let output_truncated = if output.len() > 8000 {
|
let output_truncated = if output.len() > 8000 {
|
||||||
format!("{}...[truncated]", &output[..output.ceil_char_boundary(8000)])
|
format!("{}...[truncated]", &output[..output.ceil_char_boundary(8000)])
|
||||||
@ -186,7 +188,9 @@ impl Scheduler {
|
|||||||
media: vec![],
|
media: vec![],
|
||||||
metadata: std::collections::HashMap::new(),
|
metadata: std::collections::HashMap::new(),
|
||||||
};
|
};
|
||||||
let _ = self.bus.publish_outbound(outbound).await;
|
if let Err(e) = self.bus.publish_outbound(outbound).await {
|
||||||
|
tracing::warn!(job_id = %job.id, "scheduler: failed to publish outbound: {}", e);
|
||||||
|
}
|
||||||
|
|
||||||
let run = JobRun {
|
let run = JobRun {
|
||||||
id: 0,
|
id: 0,
|
||||||
|
|||||||
@ -2,7 +2,6 @@ use std::collections::HashMap;
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
use crate::bus::{ChatMessage, MediaItem, MessageSource, OutboundMessage, SourceKind};
|
use crate::bus::{ChatMessage, MediaItem, MessageSource, OutboundMessage, SourceKind};
|
||||||
use crate::storage::{Storage, StorageError};
|
use crate::storage::{Storage, StorageError};
|
||||||
@ -41,11 +40,6 @@ use crate::bus::MessageBus;
|
|||||||
use crate::tools::OutboundMessenger;
|
use crate::tools::OutboundMessenger;
|
||||||
use crate::tools::SendMessageTool;
|
use crate::tools::SendMessageTool;
|
||||||
|
|
||||||
/// Generate a short ID (8 characters) from a UUID
|
|
||||||
fn short_id() -> String {
|
|
||||||
Uuid::new_v4().to_string()[..8].to_string()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Session = 一个 dialog
|
/// Session = 一个 dialog
|
||||||
/// 每个 Session 对应一个 UnifiedSessionId,有独立的 messages history
|
/// 每个 Session 对应一个 UnifiedSessionId,有独立的 messages history
|
||||||
pub struct Session {
|
pub struct Session {
|
||||||
@ -988,7 +982,7 @@ impl SessionManager {
|
|||||||
title: Option<&str>,
|
title: Option<&str>,
|
||||||
routing_info: String,
|
routing_info: String,
|
||||||
) -> Result<(UnifiedSessionId, String), AgentError> {
|
) -> Result<(UnifiedSessionId, String), AgentError> {
|
||||||
let dialog_id = short_id();
|
let dialog_id = crate::util::short_id();
|
||||||
let unified_id = UnifiedSessionId::new(channel, chat_id, &dialog_id);
|
let unified_id = UnifiedSessionId::new(channel, chat_id, &dialog_id);
|
||||||
let session_id_str = unified_id.to_string();
|
let session_id_str = unified_id.to_string();
|
||||||
|
|
||||||
|
|||||||
@ -233,55 +233,6 @@ impl SkillsLoader {
|
|||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build XML summary of all skills (for progressive disclosure) (checks for changes first)
|
|
||||||
pub fn build_skills_summary(&self) -> String {
|
|
||||||
self.reload_if_changed();
|
|
||||||
let state = self.state.lock().unwrap();
|
|
||||||
|
|
||||||
if state.loaded_skills.is_empty() {
|
|
||||||
return String::new();
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut lines = vec!["<skills>".to_string()];
|
|
||||||
|
|
||||||
for skill in &state.loaded_skills {
|
|
||||||
if skill.always {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
lines.push(" <skill>".to_string());
|
|
||||||
lines.push(format!(" <name>{}</name>", escape_xml(&skill.name)));
|
|
||||||
lines.push(format!(
|
|
||||||
" <description>{}</description>",
|
|
||||||
escape_xml(&skill.description)
|
|
||||||
));
|
|
||||||
if let Some(path) = &skill.path {
|
|
||||||
lines.push(format!(" <path>{}</path>", escape_xml(&path.to_string_lossy())));
|
|
||||||
}
|
|
||||||
lines.push(" </skill>".to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
lines.push("</skills>".to_string());
|
|
||||||
lines.join("\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Build prompt for always-injected skills (checks for changes first)
|
|
||||||
pub fn build_always_skills_prompt(&self) -> String {
|
|
||||||
self.reload_if_changed();
|
|
||||||
let state = self.state.lock().unwrap();
|
|
||||||
|
|
||||||
let always_skills: Vec<_> = state.loaded_skills.iter().filter(|s| s.always).collect();
|
|
||||||
if always_skills.is_empty() {
|
|
||||||
return String::new();
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut parts = Vec::new();
|
|
||||||
for skill in always_skills {
|
|
||||||
parts.push(format!("## Skill: {}\n\n{}", skill.name, skill.content));
|
|
||||||
}
|
|
||||||
|
|
||||||
parts.join("\n\n---\n\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Build full skills prompt: directory conventions, always-skill summary, always-skill content
|
/// Build full skills prompt: directory conventions, always-skill summary, always-skill content
|
||||||
pub fn build_skills_prompt(&self) -> String {
|
pub fn build_skills_prompt(&self) -> String {
|
||||||
self.reload_if_changed();
|
self.reload_if_changed();
|
||||||
@ -474,22 +425,6 @@ fn extract_description(content: &str) -> String {
|
|||||||
.unwrap_or_else(|| "No description".to_string())
|
.unwrap_or_else(|| "No description".to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Escape XML special characters
|
|
||||||
fn escape_xml(s: &str) -> String {
|
|
||||||
let mut result = String::with_capacity(s.len());
|
|
||||||
for c in s.chars() {
|
|
||||||
match c {
|
|
||||||
'&' => result.push_str("&"),
|
|
||||||
'<' => result.push_str("<"),
|
|
||||||
'>' => result.push_str(">"),
|
|
||||||
'"' => result.push_str("""),
|
|
||||||
'\'' => result.push_str("'"),
|
|
||||||
_ => result.push(c),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
result
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@ -526,13 +461,6 @@ This is the content.
|
|||||||
assert!(body.contains("Test Skill"));
|
assert!(body.contains("Test Skill"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_escape_xml() {
|
|
||||||
assert_eq!(escape_xml("a & b"), "a & b");
|
|
||||||
assert_eq!(escape_xml("<tag>"), "<tag>");
|
|
||||||
assert_eq!(escape_xml("\"quote\""), ""quote"");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_extract_description() {
|
fn test_extract_description() {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
|||||||
@ -28,12 +28,12 @@ impl ContentSearchTool {
|
|||||||
|
|
||||||
fn truncate_output(&self, lines: &[String]) -> String {
|
fn truncate_output(&self, lines: &[String]) -> String {
|
||||||
let mut output = String::new();
|
let mut output = String::new();
|
||||||
for line in lines {
|
for (i, line) in lines.iter().enumerate() {
|
||||||
if output.len() + line.len() + 1 > MAX_OUTPUT_CHARS {
|
if output.len() + line.len() + 1 > MAX_OUTPUT_CHARS {
|
||||||
|
let omitted = lines.len() - i;
|
||||||
output.push_str(&format!(
|
output.push_str(&format!(
|
||||||
"\n... ({} chars truncated, {} matches omitted) ...",
|
"\n... ({} matches omitted) ...",
|
||||||
output.len(),
|
omitted
|
||||||
lines.len()
|
|
||||||
));
|
));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,8 +1,7 @@
|
|||||||
use std::path::Path;
|
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
|
||||||
|
use crate::tools::path_utils;
|
||||||
use crate::tools::traits::{Tool, ToolResult};
|
use crate::tools::traits::{Tool, ToolResult};
|
||||||
|
|
||||||
pub struct FileEditTool {
|
pub struct FileEditTool {
|
||||||
@ -20,30 +19,6 @@ impl FileEditTool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn resolve_path(&self, path: &str) -> Result<std::path::PathBuf, String> {
|
|
||||||
let p = Path::new(path);
|
|
||||||
let resolved = if p.is_absolute() {
|
|
||||||
p.to_path_buf()
|
|
||||||
} else {
|
|
||||||
std::env::current_dir()
|
|
||||||
.map_err(|e| format!("Failed to get current directory: {}", e))?
|
|
||||||
.join(p)
|
|
||||||
};
|
|
||||||
|
|
||||||
// Check directory restriction
|
|
||||||
if let Some(ref allowed) = self.allowed_dir {
|
|
||||||
let allowed_path = Path::new(allowed);
|
|
||||||
if !resolved.starts_with(allowed_path) {
|
|
||||||
return Err(format!(
|
|
||||||
"Path '{}' is outside allowed directory '{}'",
|
|
||||||
path, allowed
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(resolved)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn find_match(&self, content: &str, old_text: &str) -> Option<(String, usize)> {
|
fn find_match(&self, content: &str, old_text: &str) -> Option<(String, usize)> {
|
||||||
// Try exact match first
|
// Try exact match first
|
||||||
if content.contains(old_text) {
|
if content.contains(old_text) {
|
||||||
@ -155,7 +130,7 @@ impl Tool for FileEditTool {
|
|||||||
.and_then(|v| v.as_bool())
|
.and_then(|v| v.as_bool())
|
||||||
.unwrap_or(false);
|
.unwrap_or(false);
|
||||||
|
|
||||||
let resolved = match self.resolve_path(path) {
|
let resolved = match path_utils::resolve_path(path, self.allowed_dir.as_deref()) {
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
return Ok(ToolResult {
|
return Ok(ToolResult {
|
||||||
|
|||||||
@ -1,8 +1,7 @@
|
|||||||
use std::path::Path;
|
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
|
||||||
|
use crate::tools::path_utils;
|
||||||
use crate::tools::traits::{Tool, ToolResult};
|
use crate::tools::traits::{Tool, ToolResult};
|
||||||
|
|
||||||
const MAX_CHARS: usize = 128_000;
|
const MAX_CHARS: usize = 128_000;
|
||||||
@ -22,30 +21,6 @@ impl FileReadTool {
|
|||||||
allowed_dir: Some(dir),
|
allowed_dir: Some(dir),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn resolve_path(&self, path: &str) -> Result<std::path::PathBuf, String> {
|
|
||||||
let p = Path::new(path);
|
|
||||||
let resolved = if p.is_absolute() {
|
|
||||||
p.to_path_buf()
|
|
||||||
} else {
|
|
||||||
std::env::current_dir()
|
|
||||||
.map_err(|e| format!("Failed to get current directory: {}", e))?
|
|
||||||
.join(p)
|
|
||||||
};
|
|
||||||
|
|
||||||
// Check directory restriction
|
|
||||||
if let Some(ref allowed) = self.allowed_dir {
|
|
||||||
let allowed_path = Path::new(allowed);
|
|
||||||
if !resolved.starts_with(allowed_path) {
|
|
||||||
return Err(format!(
|
|
||||||
"Path '{}' is outside allowed directory '{}'",
|
|
||||||
path, allowed
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(resolved)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for FileReadTool {
|
impl Default for FileReadTool {
|
||||||
@ -115,7 +90,7 @@ impl Tool for FileReadTool {
|
|||||||
.map(|v| v as usize)
|
.map(|v| v as usize)
|
||||||
.unwrap_or(DEFAULT_LIMIT);
|
.unwrap_or(DEFAULT_LIMIT);
|
||||||
|
|
||||||
let resolved = match self.resolve_path(path) {
|
let resolved = match path_utils::resolve_path(path, self.allowed_dir.as_deref()) {
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
return Ok(ToolResult {
|
return Ok(ToolResult {
|
||||||
@ -179,6 +154,7 @@ impl Tool for FileReadTool {
|
|||||||
|
|
||||||
// Truncate if too long
|
// Truncate if too long
|
||||||
if result.len() > MAX_CHARS {
|
if result.len() > MAX_CHARS {
|
||||||
|
let original_len = result.len();
|
||||||
let mut truncated_chars = 0;
|
let mut truncated_chars = 0;
|
||||||
let mut end_idx = 0;
|
let mut end_idx = 0;
|
||||||
for (i, line) in lines.iter().enumerate() {
|
for (i, line) in lines.iter().enumerate() {
|
||||||
@ -190,9 +166,10 @@ impl Tool for FileReadTool {
|
|||||||
end_idx = i + 1;
|
end_idx = i + 1;
|
||||||
}
|
}
|
||||||
result = lines[..end_idx].join("\n");
|
result = lines[..end_idx].join("\n");
|
||||||
|
let truncated = original_len - result.len();
|
||||||
result.push_str(&format!(
|
result.push_str(&format!(
|
||||||
"\n\n... ({} chars truncated) ...",
|
"\n\n... ({} chars truncated) ...",
|
||||||
result.len() - MAX_CHARS
|
truncated
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -28,9 +28,10 @@ impl FileSearchTool {
|
|||||||
|
|
||||||
fn truncate_output(&self, lines: &[String]) -> String {
|
fn truncate_output(&self, lines: &[String]) -> String {
|
||||||
let mut output = String::new();
|
let mut output = String::new();
|
||||||
for line in lines {
|
for (i, line) in lines.iter().enumerate() {
|
||||||
if output.len() + line.len() + 1 > MAX_OUTPUT_CHARS {
|
if output.len() + line.len() + 1 > MAX_OUTPUT_CHARS {
|
||||||
output.push_str(&format!("\n... ({} chars truncated) ...", output.len()));
|
let omitted = lines.len() - i;
|
||||||
|
output.push_str(&format!("\n... ({} files omitted) ...", omitted));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if !output.is_empty() {
|
if !output.is_empty() {
|
||||||
@ -195,15 +196,15 @@ impl FileSearchTool {
|
|||||||
dir: &str,
|
dir: &str,
|
||||||
max_results: usize,
|
max_results: usize,
|
||||||
) -> anyhow::Result<Vec<String>> {
|
) -> anyhow::Result<Vec<String>> {
|
||||||
let limit_str = max_results.to_string();
|
let mut cmd = Command::new("find");
|
||||||
let mut cmd = Command::new("sh");
|
cmd.arg(dir)
|
||||||
cmd.arg("-c")
|
.arg("-name")
|
||||||
.arg(format!(
|
.arg(pattern)
|
||||||
"find '{}' -name '{}' -not -path '*/.*' 2>/dev/null | head -n {}",
|
.arg("-not")
|
||||||
dir, pattern, limit_str
|
.arg("-path")
|
||||||
))
|
.arg("*/.*")
|
||||||
.stdout(Stdio::piped())
|
.stdout(Stdio::piped())
|
||||||
.stderr(Stdio::piped());
|
.stderr(Stdio::null());
|
||||||
|
|
||||||
let output = timeout(
|
let output = timeout(
|
||||||
std::time::Duration::from_secs(TIMEOUT_SECS),
|
std::time::Duration::from_secs(TIMEOUT_SECS),
|
||||||
@ -213,13 +214,16 @@ impl FileSearchTool {
|
|||||||
.map_err(|_| anyhow::anyhow!("find timed out after {}s", TIMEOUT_SECS))??;
|
.map_err(|_| anyhow::anyhow!("find timed out after {}s", TIMEOUT_SECS))??;
|
||||||
|
|
||||||
let text = String::from_utf8_lossy(&output.stdout);
|
let text = String::from_utf8_lossy(&output.stdout);
|
||||||
let lines: Vec<String> = text.lines()
|
let mut lines: Vec<String> = text.lines()
|
||||||
.filter(|l| !l.is_empty())
|
.filter(|l| !l.is_empty())
|
||||||
.map(|l| {
|
.map(|l| {
|
||||||
let p = Path::new(l);
|
let p = Path::new(l);
|
||||||
p.to_string_lossy().to_string()
|
p.to_string_lossy().to_string()
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
if lines.len() > max_results {
|
||||||
|
lines.truncate(max_results);
|
||||||
|
}
|
||||||
Ok(lines)
|
Ok(lines)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,8 +1,7 @@
|
|||||||
use std::path::Path;
|
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
|
||||||
|
use crate::tools::path_utils;
|
||||||
use crate::tools::traits::{Tool, ToolResult};
|
use crate::tools::traits::{Tool, ToolResult};
|
||||||
|
|
||||||
pub struct FileWriteTool {
|
pub struct FileWriteTool {
|
||||||
@ -19,30 +18,6 @@ impl FileWriteTool {
|
|||||||
allowed_dir: Some(dir),
|
allowed_dir: Some(dir),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn resolve_path(&self, path: &str) -> Result<std::path::PathBuf, String> {
|
|
||||||
let p = Path::new(path);
|
|
||||||
let resolved = if p.is_absolute() {
|
|
||||||
p.to_path_buf()
|
|
||||||
} else {
|
|
||||||
std::env::current_dir()
|
|
||||||
.map_err(|e| format!("Failed to get current directory: {}", e))?
|
|
||||||
.join(p)
|
|
||||||
};
|
|
||||||
|
|
||||||
// Check directory restriction
|
|
||||||
if let Some(ref allowed) = self.allowed_dir {
|
|
||||||
let allowed_path = Path::new(allowed);
|
|
||||||
if !resolved.starts_with(allowed_path) {
|
|
||||||
return Err(format!(
|
|
||||||
"Path '{}' is outside allowed directory '{}'",
|
|
||||||
path, allowed
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(resolved)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for FileWriteTool {
|
impl Default for FileWriteTool {
|
||||||
@ -101,7 +76,7 @@ impl Tool for FileWriteTool {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let resolved = match self.resolve_path(path) {
|
let resolved = match path_utils::resolve_path(path, self.allowed_dir.as_deref()) {
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
return Ok(ToolResult {
|
return Ok(ToolResult {
|
||||||
|
|||||||
@ -10,6 +10,7 @@ pub mod file_write;
|
|||||||
pub mod get_skill;
|
pub mod get_skill;
|
||||||
pub mod http_request;
|
pub mod http_request;
|
||||||
pub mod memory;
|
pub mod memory;
|
||||||
|
pub mod path_utils;
|
||||||
pub mod registry;
|
pub mod registry;
|
||||||
pub mod schema;
|
pub mod schema;
|
||||||
pub mod send_message;
|
pub mod send_message;
|
||||||
@ -28,7 +29,6 @@ pub use get_skill::GetSkillTool;
|
|||||||
pub use http_request::HttpRequestTool;
|
pub use http_request::HttpRequestTool;
|
||||||
pub use memory::{MemoryForgetTool, MemoryRecallTool, MemoryStoreTool, TimelineRecallTool};
|
pub use memory::{MemoryForgetTool, MemoryRecallTool, MemoryStoreTool, TimelineRecallTool};
|
||||||
pub use registry::ToolRegistry;
|
pub use registry::ToolRegistry;
|
||||||
pub use schema::{CleaningStrategy, SchemaCleanr};
|
|
||||||
pub use send_message::SendMessageTool;
|
pub use send_message::SendMessageTool;
|
||||||
pub use traits::{OutboundMessenger, Tool, ToolResult};
|
pub use traits::{OutboundMessenger, Tool, ToolResult};
|
||||||
pub use web_fetch::WebFetchTool;
|
pub use web_fetch::WebFetchTool;
|
||||||
|
|||||||
24
src/tools/path_utils.rs
Normal file
24
src/tools/path_utils.rs
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
|
pub fn resolve_path(path: &str, allowed_dir: Option<&str>) -> Result<PathBuf, String> {
|
||||||
|
let p = Path::new(path);
|
||||||
|
let resolved = if p.is_absolute() {
|
||||||
|
p.to_path_buf()
|
||||||
|
} else {
|
||||||
|
std::env::current_dir()
|
||||||
|
.map_err(|e| format!("Failed to get current directory: {}", e))?
|
||||||
|
.join(p)
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(allowed) = allowed_dir {
|
||||||
|
let allowed_path = Path::new(allowed);
|
||||||
|
if !resolved.starts_with(allowed_path) {
|
||||||
|
return Err(format!(
|
||||||
|
"Path '{}' is outside allowed directory '{}'",
|
||||||
|
path, allowed
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(resolved)
|
||||||
|
}
|
||||||
@ -205,37 +205,6 @@ fn strip_all_tags(s: &str) -> String {
|
|||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
fn extract_html_entity(s: &str) -> Option<(char, usize)> {
|
|
||||||
let s_lower = s.to_lowercase();
|
|
||||||
|
|
||||||
let entities = [
|
|
||||||
(" ", ' '),
|
|
||||||
("<", '<'),
|
|
||||||
(">", '>'),
|
|
||||||
("&", '&'),
|
|
||||||
(""", '"'),
|
|
||||||
("'", '\''),
|
|
||||||
("—", '—'),
|
|
||||||
("–", '–'),
|
|
||||||
("©", '©'),
|
|
||||||
("®", '®'),
|
|
||||||
("™", '™'),
|
|
||||||
];
|
|
||||||
|
|
||||||
for (entity, replacement) in entities {
|
|
||||||
if s_lower.starts_with(&entity.to_lowercase()) {
|
|
||||||
return Some((replacement, entity.len()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle numeric entities
|
|
||||||
if s_lower.starts_with("&#x") || s_lower.starts_with("&#") {
|
|
||||||
// Skip for now
|
|
||||||
}
|
|
||||||
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_host(url: &str) -> Result<String, String> {
|
fn extract_host(url: &str) -> Result<String, String> {
|
||||||
let rest = url
|
let rest = url
|
||||||
.strip_prefix("http://")
|
.strip_prefix("http://")
|
||||||
|
|||||||
5
src/util.rs
Normal file
5
src/util.rs
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
pub fn short_id() -> String {
|
||||||
|
Uuid::new_v4().to_string()[..8].to_string()
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user