diff --git a/src/agent/system_prompt.rs b/src/agent/system_prompt.rs index e6bd8b5..778b906 100644 --- a/src/agent/system_prompt.rs +++ b/src/agent/system_prompt.rs @@ -10,7 +10,6 @@ //! - USER.md — user preferences and profile use crate::tools::ToolRegistry; -use std::fmt::Write; use std::path::Path; /// Maximum characters per injected workspace file. @@ -101,27 +100,6 @@ impl PromptSection for ToolHonestySection { } } -/// List of available tools. -pub struct ToolsSection; - -impl PromptSection for ToolsSection { - fn name(&self) -> &str { - "tools" - } - - fn build(&self, ctx: &PromptContext<'_>) -> String { - if !ctx.tools.has_tools() { - return String::new(); - } - - let mut output = String::from("## 工具\n\n你可以使用以下工具:\n\n"); - for (name, tool) in ctx.tools.iter() { - let _ = writeln!(output, "- **{}**: {}", name, tool.description()); - } - output - } -} - /// Instructions for the task. pub struct YourTaskSection; diff --git a/src/channels/cli_chat.rs b/src/channels/cli_chat.rs index 80f7dd5..7f30ae6 100644 --- a/src/channels/cli_chat.rs +++ b/src/channels/cli_chat.rs @@ -1,7 +1,6 @@ use std::sync::Arc; use async_trait::async_trait; use tokio::sync::{mpsc, Mutex}; -use uuid::Uuid; use crate::bus::{ControlMessage, InboundMessage, MessageBus, OutboundMessage}; use crate::session::{SessionCommand, SessionEvent, UnifiedSessionId}; @@ -9,11 +8,6 @@ use crate::protocol::{parse_inbound, WsInbound, WsOutbound, SlashCommandInfo}; use super::base::{Channel, ChannelError}; -/// Generate a short ID (8 characters) from a UUID -fn short_id() -> String { - Uuid::new_v4().to_string()[..8].to_string() -} - // ============================================================================ // Client - Connected CLI client // ============================================================================ @@ -49,7 +43,7 @@ impl CliChatChannel { /// Register a new client connection, returns (session_id, client) pub(crate) async fn register_client(&self, sender: mpsc::Sender) -> (String, Arc) { // Generate connection ID (used as chat_id) - use short ID - let connection_id = short_id(); + let connection_id = crate::util::short_id(); let client = Arc::new(Client { sender, @@ -122,7 +116,7 @@ impl CliChatChannel { let msg = InboundMessage { channel: self.name().to_string(), sender_id: "cli".to_string(), - chat_id: chat_id.unwrap_or_else(short_id), + chat_id: chat_id.unwrap_or_else(crate::util::short_id), content, timestamp: crate::bus::message::current_timestamp(), media: Vec::new(), @@ -166,7 +160,7 @@ impl CliChatChannel { WsInbound::CreateSession { title } => { // Use current session's chat_id if available, otherwise generate new one let chat_id = current_session_guard.clone() - .unwrap_or_else(short_id); + .unwrap_or_else(crate::util::short_id); let new_id = self.create_session_via_control(&chat_id, title.as_deref()).await?; *current_session_guard = Some(new_id.clone()); let _ = client @@ -491,7 +485,7 @@ impl Channel for CliChatChannel { } } else { WsOutbound::AssistantResponse { - id: short_id(), + id: crate::util::short_id(), content: msg.content.clone(), role: "assistant".to_string(), } diff --git a/src/lib.rs b/src/lib.rs index 21c069a..cebf6b2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,3 +14,4 @@ pub mod scheduler; pub mod skills; pub mod storage; pub mod tools; +pub mod util; diff --git a/src/providers/openai.rs b/src/providers/openai.rs index d1160fe..882768f 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -133,8 +133,6 @@ struct OpenAIMessage { #[serde(default)] content: Option, #[serde(default)] - name: Option, - #[serde(default)] tool_calls: Vec, } @@ -143,8 +141,6 @@ struct OpenAIToolCall { id: String, #[serde(rename = "function")] function: OAIFunction, - #[serde(default)] - index: Option, } #[derive(Deserialize)] @@ -219,11 +215,13 @@ impl LLMProvider for OpenAIProvider { if !status.is_success() { let error = format!("API error {}: {}", status, text); if let Some(ref storage) = self.storage { - let _ = storage.append_llm_call( + if let Err(e) = storage.append_llm_call( &self.name, &self.model_id, &req_body_str, Some(&text), Some(&error), start.elapsed().as_millis() as u64, - ).await; + ).await { + tracing::warn!("failed to persist LLM call: {}", e); + } } return Err(error.into()); } @@ -240,20 +238,25 @@ impl LLMProvider for OpenAIProvider { let err = err_msg.clone(); let s = storage.clone(); tokio::spawn(async move { - let _ = s.append_llm_call(&name, &model, &req, Some(&resp), Some(&err), dur).await; + if let Err(e) = s.append_llm_call(&name, &model, &req, Some(&resp), Some(&err), dur).await { + tracing::warn!("failed to persist LLM call (decode error): {}", e); + } }); } err_msg })?; - let content = openai_resp.choices[0] + let first_choice = openai_resp.choices.into_iter().next() + .ok_or("no choices in response")?; + + let content = first_choice .message .content .as_ref() .unwrap_or(&String::new()) .clone(); - let tool_calls: Vec = openai_resp.choices[0] + let tool_calls: Vec = first_choice .message .tool_calls .iter() @@ -277,11 +280,13 @@ impl LLMProvider for OpenAIProvider { }; if let Some(ref storage) = self.storage { - let _ = storage.append_llm_call( + if let Err(e) = storage.append_llm_call( &self.name, &self.model_id, &req_body_str, Some(&text), None, start.elapsed().as_millis() as u64, - ).await; + ).await { + tracing::warn!("failed to persist LLM call: {}", e); + } } Ok(response) diff --git a/src/scheduler/mod.rs b/src/scheduler/mod.rs index 10602a7..f189641 100644 --- a/src/scheduler/mod.rs +++ b/src/scheduler/mod.rs @@ -144,7 +144,9 @@ impl Scheduler { media: vec![], metadata: std::collections::HashMap::new(), }; - let _ = self.bus.publish_outbound(outbound).await; + if let Err(e) = self.bus.publish_outbound(outbound).await { + tracing::warn!(job_id = %job.id, "scheduler: failed to publish outbound: {}", e); + } let output_truncated = if output.len() > 8000 { format!("{}...[truncated]", &output[..output.ceil_char_boundary(8000)]) @@ -186,7 +188,9 @@ impl Scheduler { media: vec![], metadata: std::collections::HashMap::new(), }; - let _ = self.bus.publish_outbound(outbound).await; + if let Err(e) = self.bus.publish_outbound(outbound).await { + tracing::warn!(job_id = %job.id, "scheduler: failed to publish outbound: {}", e); + } let run = JobRun { id: 0, diff --git a/src/session/session.rs b/src/session/session.rs index af9851f..4e10128 100644 --- a/src/session/session.rs +++ b/src/session/session.rs @@ -2,7 +2,6 @@ use std::collections::HashMap; use std::sync::Arc; use tokio::sync::Mutex; -use uuid::Uuid; use crate::bus::{ChatMessage, MediaItem, MessageSource, OutboundMessage, SourceKind}; use crate::storage::{Storage, StorageError}; @@ -41,11 +40,6 @@ use crate::bus::MessageBus; use crate::tools::OutboundMessenger; use crate::tools::SendMessageTool; -/// Generate a short ID (8 characters) from a UUID -fn short_id() -> String { - Uuid::new_v4().to_string()[..8].to_string() -} - /// Session = 一个 dialog /// 每个 Session 对应一个 UnifiedSessionId,有独立的 messages history pub struct Session { @@ -988,7 +982,7 @@ impl SessionManager { title: Option<&str>, routing_info: String, ) -> Result<(UnifiedSessionId, String), AgentError> { - let dialog_id = short_id(); + let dialog_id = crate::util::short_id(); let unified_id = UnifiedSessionId::new(channel, chat_id, &dialog_id); let session_id_str = unified_id.to_string(); diff --git a/src/skills/mod.rs b/src/skills/mod.rs index f8ebaa8..fafff88 100644 --- a/src/skills/mod.rs +++ b/src/skills/mod.rs @@ -233,55 +233,6 @@ impl SkillsLoader { .collect() } - /// Build XML summary of all skills (for progressive disclosure) (checks for changes first) - pub fn build_skills_summary(&self) -> String { - self.reload_if_changed(); - let state = self.state.lock().unwrap(); - - if state.loaded_skills.is_empty() { - return String::new(); - } - - let mut lines = vec!["".to_string()]; - - for skill in &state.loaded_skills { - if skill.always { - continue; - } - lines.push(" ".to_string()); - lines.push(format!(" {}", escape_xml(&skill.name))); - lines.push(format!( - " {}", - escape_xml(&skill.description) - )); - if let Some(path) = &skill.path { - lines.push(format!(" {}", escape_xml(&path.to_string_lossy()))); - } - lines.push(" ".to_string()); - } - - lines.push("".to_string()); - lines.join("\n") - } - - /// Build prompt for always-injected skills (checks for changes first) - pub fn build_always_skills_prompt(&self) -> String { - self.reload_if_changed(); - let state = self.state.lock().unwrap(); - - let always_skills: Vec<_> = state.loaded_skills.iter().filter(|s| s.always).collect(); - if always_skills.is_empty() { - return String::new(); - } - - let mut parts = Vec::new(); - for skill in always_skills { - parts.push(format!("## Skill: {}\n\n{}", skill.name, skill.content)); - } - - parts.join("\n\n---\n\n") - } - /// Build full skills prompt: directory conventions, always-skill summary, always-skill content pub fn build_skills_prompt(&self) -> String { self.reload_if_changed(); @@ -474,22 +425,6 @@ fn extract_description(content: &str) -> String { .unwrap_or_else(|| "No description".to_string()) } -/// Escape XML special characters -fn escape_xml(s: &str) -> String { - let mut result = String::with_capacity(s.len()); - for c in s.chars() { - match c { - '&' => result.push_str("&"), - '<' => result.push_str("<"), - '>' => result.push_str(">"), - '"' => result.push_str("""), - '\'' => result.push_str("'"), - _ => result.push(c), - } - } - result -} - #[cfg(test)] mod tests { use super::*; @@ -526,13 +461,6 @@ This is the content. assert!(body.contains("Test Skill")); } - #[test] - fn test_escape_xml() { - assert_eq!(escape_xml("a & b"), "a & b"); - assert_eq!(escape_xml(""), "<tag>"); - assert_eq!(escape_xml("\"quote\""), ""quote""); - } - #[test] fn test_extract_description() { assert_eq!( diff --git a/src/tools/content_search.rs b/src/tools/content_search.rs index 22b68e4..8ba81b9 100644 --- a/src/tools/content_search.rs +++ b/src/tools/content_search.rs @@ -28,12 +28,12 @@ impl ContentSearchTool { fn truncate_output(&self, lines: &[String]) -> String { let mut output = String::new(); - for line in lines { + for (i, line) in lines.iter().enumerate() { if output.len() + line.len() + 1 > MAX_OUTPUT_CHARS { + let omitted = lines.len() - i; output.push_str(&format!( - "\n... ({} chars truncated, {} matches omitted) ...", - output.len(), - lines.len() + "\n... ({} matches omitted) ...", + omitted )); break; } diff --git a/src/tools/file_edit.rs b/src/tools/file_edit.rs index 78ab2f4..fa6b4fc 100644 --- a/src/tools/file_edit.rs +++ b/src/tools/file_edit.rs @@ -1,8 +1,7 @@ -use std::path::Path; - use async_trait::async_trait; use serde_json::json; +use crate::tools::path_utils; use crate::tools::traits::{Tool, ToolResult}; pub struct FileEditTool { @@ -20,30 +19,6 @@ impl FileEditTool { } } - fn resolve_path(&self, path: &str) -> Result { - let p = Path::new(path); - let resolved = if p.is_absolute() { - p.to_path_buf() - } else { - std::env::current_dir() - .map_err(|e| format!("Failed to get current directory: {}", e))? - .join(p) - }; - - // Check directory restriction - if let Some(ref allowed) = self.allowed_dir { - let allowed_path = Path::new(allowed); - if !resolved.starts_with(allowed_path) { - return Err(format!( - "Path '{}' is outside allowed directory '{}'", - path, allowed - )); - } - } - - Ok(resolved) - } - fn find_match(&self, content: &str, old_text: &str) -> Option<(String, usize)> { // Try exact match first if content.contains(old_text) { @@ -155,7 +130,7 @@ impl Tool for FileEditTool { .and_then(|v| v.as_bool()) .unwrap_or(false); - let resolved = match self.resolve_path(path) { + let resolved = match path_utils::resolve_path(path, self.allowed_dir.as_deref()) { Ok(p) => p, Err(e) => { return Ok(ToolResult { diff --git a/src/tools/file_read.rs b/src/tools/file_read.rs index 10725f1..1553ad3 100644 --- a/src/tools/file_read.rs +++ b/src/tools/file_read.rs @@ -1,8 +1,7 @@ -use std::path::Path; - use async_trait::async_trait; use serde_json::json; +use crate::tools::path_utils; use crate::tools::traits::{Tool, ToolResult}; const MAX_CHARS: usize = 128_000; @@ -22,30 +21,6 @@ impl FileReadTool { allowed_dir: Some(dir), } } - - fn resolve_path(&self, path: &str) -> Result { - let p = Path::new(path); - let resolved = if p.is_absolute() { - p.to_path_buf() - } else { - std::env::current_dir() - .map_err(|e| format!("Failed to get current directory: {}", e))? - .join(p) - }; - - // Check directory restriction - if let Some(ref allowed) = self.allowed_dir { - let allowed_path = Path::new(allowed); - if !resolved.starts_with(allowed_path) { - return Err(format!( - "Path '{}' is outside allowed directory '{}'", - path, allowed - )); - } - } - - Ok(resolved) - } } impl Default for FileReadTool { @@ -115,7 +90,7 @@ impl Tool for FileReadTool { .map(|v| v as usize) .unwrap_or(DEFAULT_LIMIT); - let resolved = match self.resolve_path(path) { + let resolved = match path_utils::resolve_path(path, self.allowed_dir.as_deref()) { Ok(p) => p, Err(e) => { return Ok(ToolResult { @@ -179,6 +154,7 @@ impl Tool for FileReadTool { // Truncate if too long if result.len() > MAX_CHARS { + let original_len = result.len(); let mut truncated_chars = 0; let mut end_idx = 0; for (i, line) in lines.iter().enumerate() { @@ -190,9 +166,10 @@ impl Tool for FileReadTool { end_idx = i + 1; } result = lines[..end_idx].join("\n"); + let truncated = original_len - result.len(); result.push_str(&format!( "\n\n... ({} chars truncated) ...", - result.len() - MAX_CHARS + truncated )); } diff --git a/src/tools/file_search.rs b/src/tools/file_search.rs index 1e90bbe..f547b91 100644 --- a/src/tools/file_search.rs +++ b/src/tools/file_search.rs @@ -28,9 +28,10 @@ impl FileSearchTool { fn truncate_output(&self, lines: &[String]) -> String { let mut output = String::new(); - for line in lines { + for (i, line) in lines.iter().enumerate() { if output.len() + line.len() + 1 > MAX_OUTPUT_CHARS { - output.push_str(&format!("\n... ({} chars truncated) ...", output.len())); + let omitted = lines.len() - i; + output.push_str(&format!("\n... ({} files omitted) ...", omitted)); break; } if !output.is_empty() { @@ -195,15 +196,15 @@ impl FileSearchTool { dir: &str, max_results: usize, ) -> anyhow::Result> { - let limit_str = max_results.to_string(); - let mut cmd = Command::new("sh"); - cmd.arg("-c") - .arg(format!( - "find '{}' -name '{}' -not -path '*/.*' 2>/dev/null | head -n {}", - dir, pattern, limit_str - )) + let mut cmd = Command::new("find"); + cmd.arg(dir) + .arg("-name") + .arg(pattern) + .arg("-not") + .arg("-path") + .arg("*/.*") .stdout(Stdio::piped()) - .stderr(Stdio::piped()); + .stderr(Stdio::null()); let output = timeout( std::time::Duration::from_secs(TIMEOUT_SECS), @@ -213,13 +214,16 @@ impl FileSearchTool { .map_err(|_| anyhow::anyhow!("find timed out after {}s", TIMEOUT_SECS))??; let text = String::from_utf8_lossy(&output.stdout); - let lines: Vec = text.lines() + let mut lines: Vec = text.lines() .filter(|l| !l.is_empty()) .map(|l| { let p = Path::new(l); p.to_string_lossy().to_string() }) .collect(); + if lines.len() > max_results { + lines.truncate(max_results); + } Ok(lines) } diff --git a/src/tools/file_write.rs b/src/tools/file_write.rs index 0086421..55235c9 100644 --- a/src/tools/file_write.rs +++ b/src/tools/file_write.rs @@ -1,8 +1,7 @@ -use std::path::Path; - use async_trait::async_trait; use serde_json::json; +use crate::tools::path_utils; use crate::tools::traits::{Tool, ToolResult}; pub struct FileWriteTool { @@ -19,30 +18,6 @@ impl FileWriteTool { allowed_dir: Some(dir), } } - - fn resolve_path(&self, path: &str) -> Result { - let p = Path::new(path); - let resolved = if p.is_absolute() { - p.to_path_buf() - } else { - std::env::current_dir() - .map_err(|e| format!("Failed to get current directory: {}", e))? - .join(p) - }; - - // Check directory restriction - if let Some(ref allowed) = self.allowed_dir { - let allowed_path = Path::new(allowed); - if !resolved.starts_with(allowed_path) { - return Err(format!( - "Path '{}' is outside allowed directory '{}'", - path, allowed - )); - } - } - - Ok(resolved) - } } impl Default for FileWriteTool { @@ -101,7 +76,7 @@ impl Tool for FileWriteTool { } }; - let resolved = match self.resolve_path(path) { + let resolved = match path_utils::resolve_path(path, self.allowed_dir.as_deref()) { Ok(p) => p, Err(e) => { return Ok(ToolResult { diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 91c2225..03a572d 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -10,6 +10,7 @@ pub mod file_write; pub mod get_skill; pub mod http_request; pub mod memory; +pub mod path_utils; pub mod registry; pub mod schema; pub mod send_message; @@ -28,7 +29,6 @@ pub use get_skill::GetSkillTool; pub use http_request::HttpRequestTool; pub use memory::{MemoryForgetTool, MemoryRecallTool, MemoryStoreTool, TimelineRecallTool}; pub use registry::ToolRegistry; -pub use schema::{CleaningStrategy, SchemaCleanr}; pub use send_message::SendMessageTool; pub use traits::{OutboundMessenger, Tool, ToolResult}; pub use web_fetch::WebFetchTool; diff --git a/src/tools/path_utils.rs b/src/tools/path_utils.rs new file mode 100644 index 0000000..49c4523 --- /dev/null +++ b/src/tools/path_utils.rs @@ -0,0 +1,24 @@ +use std::path::{Path, PathBuf}; + +pub fn resolve_path(path: &str, allowed_dir: Option<&str>) -> Result { + let p = Path::new(path); + let resolved = if p.is_absolute() { + p.to_path_buf() + } else { + std::env::current_dir() + .map_err(|e| format!("Failed to get current directory: {}", e))? + .join(p) + }; + + if let Some(allowed) = allowed_dir { + let allowed_path = Path::new(allowed); + if !resolved.starts_with(allowed_path) { + return Err(format!( + "Path '{}' is outside allowed directory '{}'", + path, allowed + )); + } + } + + Ok(resolved) +} diff --git a/src/tools/web_fetch.rs b/src/tools/web_fetch.rs index aad0987..e4717b5 100644 --- a/src/tools/web_fetch.rs +++ b/src/tools/web_fetch.rs @@ -205,37 +205,6 @@ fn strip_all_tags(s: &str) -> String { result } -fn extract_html_entity(s: &str) -> Option<(char, usize)> { - let s_lower = s.to_lowercase(); - - let entities = [ - (" ", ' '), - ("<", '<'), - (">", '>'), - ("&", '&'), - (""", '"'), - ("'", '\''), - ("—", '—'), - ("–", '–'), - ("©", '©'), - ("®", '®'), - ("™", '™'), - ]; - - for (entity, replacement) in entities { - if s_lower.starts_with(&entity.to_lowercase()) { - return Some((replacement, entity.len())); - } - } - - // Handle numeric entities - if s_lower.starts_with("&#x") || s_lower.starts_with("&#") { - // Skip for now - } - - None -} - fn extract_host(url: &str) -> Result { let rest = url .strip_prefix("http://") diff --git a/src/util.rs b/src/util.rs new file mode 100644 index 0000000..f8294b2 --- /dev/null +++ b/src/util.rs @@ -0,0 +1,5 @@ +use uuid::Uuid; + +pub fn short_id() -> String { + Uuid::new_v4().to_string()[..8].to_string() +}