From 8f4ee79d8d6f880824df1770c001a6554fe9d719 Mon Sep 17 00:00:00 2001 From: xiaoski Date: Mon, 15 Jun 2026 23:47:24 +0800 Subject: [PATCH] Format codebase with rustfmt --- src/agent/agent_loop.rs | 113 +++-- src/agent/context_compressor.rs | 283 +++++++---- src/agent/media_handler.rs | 2 +- src/agent/mod.rs | 13 +- src/agent/sub_agent.rs | 78 +-- src/agent/system_prompt.rs | 29 +- src/bus/dispatcher.rs | 2 +- src/bus/message.rs | 23 +- src/bus/mod.rs | 10 +- src/channels/feishu.rs | 562 +++++++++++++++------- src/channels/manager.rs | 26 +- src/channels/mod.rs | 8 +- src/channels/slash_command.rs | 9 +- src/client/mod.rs | 9 +- src/client/tui/components/chat_history.rs | 2 +- src/client/tui/components/command_menu.rs | 2 +- src/client/tui/components/help_popup.rs | 2 +- src/client/tui/components/input_area.rs | 2 +- src/client/tui/components/session_list.rs | 6 +- src/client/tui/components/title_bar.rs | 7 +- src/client/tui/event.rs | 19 +- src/client/tui/ui.rs | 2 +- src/config/mod.rs | 30 +- src/gateway/ws.rs | 25 +- src/lib.rs | 14 +- src/logging.rs | 29 +- src/main.rs | 2 +- src/mcp/mod.rs | 41 +- src/memory/mod.rs | 25 +- src/observability/mod.rs | 11 +- src/providers/anthropic.rs | 143 +++--- src/providers/mod.rs | 11 +- src/providers/openai.rs | 162 ++++--- src/providers/traits.rs | 8 +- src/scheduler/mod.rs | 47 +- src/session/commands.rs | 34 +- src/session/events.rs | 27 +- src/session/mod.rs | 8 +- src/session/session_id.rs | 7 +- src/skills/builtin.rs | 5 +- src/skills/mod.rs | 132 +++-- src/storage/memory.rs | 15 +- src/storage/scheduler.rs | 156 ++++-- src/tools/bash.rs | 15 +- src/tools/browser.rs | 129 +++-- src/tools/calculator.rs | 5 +- src/tools/content_search.rs | 147 ++++-- src/tools/cron.rs | 25 +- src/tools/delegate.rs | 61 ++- src/tools/file_edit.rs | 2 +- src/tools/file_read.rs | 19 +- src/tools/file_search.rs | 91 ++-- src/tools/file_write.rs | 20 +- src/tools/get_skill.rs | 6 +- src/tools/http_request.rs | 33 +- src/tools/memory.rs | 38 +- src/tools/mod.rs | 2 +- src/tools/registry.rs | 5 +- src/tools/schema.rs | 48 +- src/tools/send_message.rs | 25 +- src/tools/traits.rs | 2 +- src/tools/web_fetch.rs | 10 +- tests/test_integration.rs | 14 +- tests/test_scheduler.rs | 8 +- tests/test_tool_calling.rs | 22 +- 65 files changed, 1807 insertions(+), 1061 deletions(-) diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index 61356b8..11c69c7 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -4,10 +4,8 @@ use crate::agent::system_prompt::build_system_prompt; use crate::bus::message::ContentBlock; use crate::bus::{ChatMessage, MediaRef}; use crate::config::LLMProviderConfig; -use crate::observability::{ - truncate_args, Observer, ObserverEvent, ToolExecutionOutcome, -}; -use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall}; +use crate::observability::{Observer, ObserverEvent, ToolExecutionOutcome, truncate_args}; +use crate::providers::{ChatCompletionRequest, LLMProvider, Message, ToolCall, create_provider}; use crate::tools::ToolRegistry; use std::collections::VecDeque; use std::hash::{Hash, Hasher}; @@ -256,7 +254,10 @@ impl AgentLoop { } /// Create a new AgentLoop with provider created from config and given tools. - pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc) -> Result { + pub fn with_tools( + provider_config: LLMProviderConfig, + tools: Arc, + ) -> Result { let max_iterations = provider_config.max_tool_iterations; let model_name = provider_config.model_id.clone(); let workspace_dir = provider_config.workspace_dir.clone(); @@ -279,7 +280,13 @@ impl AgentLoop { } /// Create a new AgentLoop with an existing shared provider. - pub fn with_provider(provider: Arc, max_iterations: usize, model_name: String, workspace_dir: PathBuf, input_types: Vec) -> Self { + pub fn with_provider( + provider: Arc, + max_iterations: usize, + model_name: String, + workspace_dir: PathBuf, + input_types: Vec, + ) -> Self { Self { provider, tools: Arc::new(ToolRegistry::new()), @@ -379,7 +386,12 @@ impl AgentLoop { let content = if m.media_refs.is_empty() { vec![ContentBlock::text(&m.content)] } else { - build_content_blocks(&m.content, &m.media_refs, &self.input_types, &self.media_registry) + build_content_blocks( + &m.content, + &m.media_refs, + &self.input_types, + &self.media_registry, + ) }; Message { @@ -399,14 +411,28 @@ impl AgentLoop { /// it loops back to the LLM with the tool results until either: /// - The LLM returns no more tool calls (final response) /// - Maximum iterations are reached - pub async fn process(&self, mut messages: Vec) -> Result { + pub async fn process( + &self, + mut messages: Vec, + ) -> Result { #[cfg(debug_assertions)] - tracing::debug!(history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process"); + tracing::debug!( + history_len = messages.len(), + max_iterations = self.max_iterations, + "Starting agent process" + ); // Build and inject system prompt if not present let has_system = messages.first().is_some_and(|m| m.role == "system"); if !has_system { - let system_prompt = build_system_prompt(&self.workspace_dir, &self.model_name, &self.tools, None, None, false); + let system_prompt = build_system_prompt( + &self.workspace_dir, + &self.model_name, + &self.tools, + None, + None, + false, + ); #[cfg(debug_assertions)] tracing::debug!("System prompt injected:\n{}", system_prompt); messages.insert(0, ChatMessage::system(system_prompt)); @@ -427,9 +453,7 @@ impl AgentLoop { let estimated = estimate_tokens(&messages); let danger = (self.context_window as f64 * 0.8) as usize; if estimated > danger { - let trimmed = self.preemptive_trim_old_tool_results( - &mut messages, 2000, 4, - ); + let trimmed = self.preemptive_trim_old_tool_results(&mut messages, 2000, 4); if trimmed > 0 { #[cfg(debug_assertions)] tracing::debug!( @@ -463,11 +487,10 @@ impl AgentLoop { }; // Call LLM - let response = (*self.provider).chat(request).await - .map_err(|e| { - tracing::error!(error = %e, "LLM request failed"); - AgentError::LlmError(e.to_string()) - })?; + let response = (*self.provider).chat(request).await.map_err(|e| { + tracing::error!(error = %e, "LLM request failed"); + AgentError::LlmError(e.to_string()) + })?; accumulated_tokens += response.usage.total_tokens; @@ -493,7 +516,9 @@ impl AgentLoop { // Execute tool calls — log and notify immediately { - let tools_info: Vec = response.tool_calls.iter() + let tools_info: Vec = response + .tool_calls + .iter() .map(|tc| { let args = serde_json::to_string(&tc.arguments).unwrap_or_default(); let s = format!("{}:{}", tc.name, args); @@ -522,7 +547,9 @@ impl AgentLoop { // Log function call with name and arguments let args_str = match &tool_call.arguments { serde_json::Value::Object(obj) if obj.is_empty() => "{}".to_string(), - other => serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()), + other => { + serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()) + } }; tracing::info!(tool = %tool_call.name, args = %args_str, "Calling tool"); @@ -562,7 +589,11 @@ impl AgentLoop { // Loop continues to next iteration with updated messages #[cfg(debug_assertions)] - tracing::debug!(iteration, message_count = messages.len(), "Tool execution complete, continuing to next iteration"); + tracing::debug!( + iteration, + message_count = messages.len(), + "Tool execution complete, continuing to next iteration" + ); } // Max iterations reached - ask LLM for a summary based on completed work @@ -571,7 +602,7 @@ impl AgentLoop { // Add a message asking for summary let summary_request = ChatMessage::user( "You have reached the maximum number of tool call iterations. \ - Please provide your best answer based on the work completed so far." + Please provide your best answer based on the work completed so far.", ); messages.push(summary_request); @@ -603,14 +634,19 @@ impl AgentLoop { Err(e) => { // Fallback if summary call fails tracing::error!(error = %e, "Failed to get summary from LLM"); - let final_message = ChatMessage::assistant( - format!("I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.", self.max_iterations) - ); + let final_message = ChatMessage::assistant(format!( + "I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.", + self.max_iterations + )); emitted_messages.push(final_message.clone()); Ok(AgentProcessResult { final_response: final_message, emitted_messages, - total_tokens: if accumulated_tokens > 0 { Some(accumulated_tokens) } else { None }, + total_tokens: if accumulated_tokens > 0 { + Some(accumulated_tokens) + } else { + None + }, }) } } @@ -698,10 +734,7 @@ impl AgentLoop { } // Apply duration - ToolExecutionOutcome { - duration, - ..result - } + ToolExecutionOutcome { duration, ..result } } /// Internal tool execution without event tracking. @@ -723,18 +756,12 @@ impl AgentLoop { ToolExecutionOutcome::success(result.output) } else { let error = result.error.unwrap_or_default(); - ToolExecutionOutcome::failure( - format!("Error: {}", error), - Some(error), - ) + ToolExecutionOutcome::failure(format!("Error: {}", error), Some(error)) } } Err(e) => { tracing::error!(tool = %tool_call.name, error = %e, "Tool execution failed"); - ToolExecutionOutcome::failure( - format!("Error: {}", e), - Some(e.to_string()), - ) + ToolExecutionOutcome::failure(format!("Error: {}", e), Some(e.to_string())) } } } @@ -822,8 +849,14 @@ mod tests { assert_eq!(provider_message.role, "assistant"); assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1); - assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].id, "call_1"); - assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].name, "calculator"); + assert_eq!( + provider_message.tool_calls.as_ref().unwrap()[0].id, + "call_1" + ); + assert_eq!( + provider_message.tool_calls.as_ref().unwrap()[0].name, + "calculator" + ); } } diff --git a/src/agent/context_compressor.rs b/src/agent/context_compressor.rs index 6c20b61..a1b47c5 100644 --- a/src/agent/context_compressor.rs +++ b/src/agent/context_compressor.rs @@ -234,27 +234,33 @@ impl ContextCompressor { } } else if messages[i].role == "tool" && let Some(ref tid) = messages[i].tool_call_id - && !declared.contains(tid.as_str()) { - messages.remove(i); - continue; - } + && !declared.contains(tid.as_str()) + { + messages.remove(i); + continue; + } i += 1; } - let broken: Vec = messages.iter().enumerate() + let broken: Vec = messages + .iter() + .enumerate() .filter_map(|(idx, msg)| { if msg.role == "assistant" && let Some(ref tcs) = msg.tool_calls - && !tcs.is_empty() { - let all_present = tcs.iter().all(|tc| { - messages.iter().any(|m| { - m.role == "tool" - && m.tool_call_id.as_deref() == Some(tc.id.as_str()) - }) - }); - if !all_present { Some(idx) } else { None } - } else { None } - }).collect(); + && !tcs.is_empty() + { + let all_present = tcs.iter().all(|tc| { + messages.iter().any(|m| { + m.role == "tool" && m.tool_call_id.as_deref() == Some(tc.id.as_str()) + }) + }); + if !all_present { Some(idx) } else { None } + } else { + None + } + }) + .collect(); for idx in broken { let msg = &mut messages[idx]; @@ -262,7 +268,8 @@ impl ContextCompressor { let names: Vec<&str> = tcs.iter().map(|tc| tc.name.as_str()).collect(); msg.content = format!( "{}\n\n[Tool calls ({}) — results are no longer available]", - msg.content, names.join(", ") + msg.content, + names.join(", ") ); } } @@ -275,7 +282,10 @@ impl ContextCompressor { // Check if compression is needed let tokens = self.token_estimate_with_history(&history); if tokens <= self.threshold() { - return Ok(CompressionResult { history, created_timelines: false }); + return Ok(CompressionResult { + history, + created_timelines: false, + }); } #[cfg(debug_assertions)] @@ -299,7 +309,10 @@ impl ContextCompressor { } if tokens_after <= self.threshold() { self.invalidate_token_cache(); - return Ok(CompressionResult { history, created_timelines: false }); + return Ok(CompressionResult { + history, + created_timelines: false, + }); } // LLM summarization pass @@ -312,11 +325,7 @@ impl ContextCompressor { } #[cfg(debug_assertions)] - tracing::debug!( - pass = pass + 1, - tokens = tokens, - "Compression pass" - ); + tracing::debug!(pass = pass + 1, tokens = tokens, "Compression pass"); match self.compress_once(¤t_history).await { Ok(Some(compressed)) => { @@ -352,18 +361,24 @@ impl ContextCompressor { let m = ¤t_history[scan]; if m.role == "assistant" { if let Some(tcs) = &m.tool_calls - && !tcs.is_empty() { - let has_post = current_history[scan + 1..] - .iter() - .filter(|r| r.role == "tool") - .any(|r| tcs.iter().any(|tc| r.tool_call_id.as_deref() == Some(tc.id.as_str()))); - if has_post { - tail_start = scan; - } + && !tcs.is_empty() + { + let has_post = current_history[scan + 1..] + .iter() + .filter(|r| r.role == "tool") + .any(|r| { + tcs.iter() + .any(|tc| r.tool_call_id.as_deref() == Some(tc.id.as_str())) + }); + if has_post { + tail_start = scan; } + } + break; + } + if scan == 0 { break; } - if scan == 0 { break; } scan -= 1; } } @@ -390,14 +405,16 @@ impl ContextCompressor { for msg in &mut truncated[..self.config.protect_first_n] { if msg.role == "assistant" { if let Some(ref tcs) = msg.tool_calls - && !tcs.is_empty() { - let names: Vec<&str> = tcs.iter().map(|tc| tc.name.as_str()).collect(); - msg.content = format!( - "{}\n\n[Tool calls ({}) — results dropped during truncation]", - msg.content, names.join(", ") - ); - msg.tool_calls = None; - } + && !tcs.is_empty() + { + let names: Vec<&str> = tcs.iter().map(|tc| tc.name.as_str()).collect(); + msg.content = format!( + "{}\n\n[Tool calls ({}) — results dropped during truncation]", + msg.content, + names.join(", ") + ); + msg.tool_calls = None; + } } } @@ -424,7 +441,10 @@ impl ContextCompressor { "Context compression completed" ); - Ok(CompressionResult { history: current_history, created_timelines }) + Ok(CompressionResult { + history: current_history, + created_timelines, + }) } /// Try to extract the actual context token limit from an LLM error message. @@ -447,20 +467,21 @@ impl ContextCompressor { // Look for a number in the vicinity (up to 10 chars after marker) if let Some(num_str) = find_number_nearby(after, 50) && let Ok(n) = num_str.parse::() - && (1024..=10_000_000).contains(&n) { - return Some(n); - } + && (1024..=10_000_000).contains(&n) + { + return Some(n); + } } } // Also try: "XXXX token context" or "XXXX limit" if let Some(num_str) = find_number_nearby(&lower, lower.len()) && let Ok(n) = num_str.parse::() - && (1024..=10_000_000).contains(&n) - && (lower.contains("token") || lower.contains("context") || lower.contains("limit")) - { - return Some(n); - } + && (1024..=10_000_000).contains(&n) + && (lower.contains("token") || lower.contains("context") || lower.contains("limit")) + { + return Some(n); + } None } @@ -509,19 +530,26 @@ impl ContextCompressor { // Persist compressed summary as timeline memory entry let ts = chrono::Utc::now().format("%Y-%m-%d %H:%M").to_string(); - let timeline_content = format!("[{}] Compressed {} conversation segments:\n{}", - ts, between.len(), summary); + let timeline_content = format!( + "[{}] Compressed {} conversation segments:\n{}", + ts, + between.len(), + summary + ); let key = format!("ctx_compressed_{}", uuid::Uuid::new_v4()); let mm = self.memory.clone(); let sid = self.session_id.clone(); tokio::spawn(async move { - if let Err(e) = mm.store( - &key, - &timeline_content, - crate::memory::MemoryCategory::Timeline, - sid.as_deref(), - Some(0.3), - ).await { + if let Err(e) = mm + .store( + &key, + &timeline_content, + crate::memory::MemoryCategory::Timeline, + sid.as_deref(), + Some(0.3), + ) + .await + { tracing::warn!(error = %e, "Failed to store compressed context as timeline"); } }); @@ -552,10 +580,7 @@ impl ContextCompressor { } /// Summarize a segment of messages using LLM. - async fn summarize_segment( - &self, - messages: &[ChatMessage], - ) -> Result { + async fn summarize_segment(&self, messages: &[ChatMessage]) -> Result { if messages.is_empty() { return Ok(String::new()); } @@ -569,7 +594,8 @@ impl ContextCompressor { "tool" => "Tool", _ => m.role.as_str(), }; - let name = m.tool_name + let name = m + .tool_name .as_ref() .map(|n| format!(" ({})", n)) .unwrap_or_default(); @@ -614,7 +640,10 @@ Be concise, aim for {} characters or less. ); let request = ChatCompletionRequest { - messages: vec![Message::system("You are a helpful assistant."), Message::user(&prompt)], + messages: vec![ + Message::system("You are a helpful assistant."), + Message::user(&prompt), + ], temperature: Some(0.3), max_tokens: Some(1000), tools: None, @@ -686,13 +715,23 @@ mod tests { content: "[summarized]".into(), reasoning_content: None, tool_calls: vec![], - usage: Usage { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }, + usage: Usage { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0, + }, }) } - fn ptype(&self) -> &str { "mock" } - fn name(&self) -> &str { "mock" } - fn model_id(&self) -> &str { "mock" } + fn ptype(&self) -> &str { + "mock" + } + fn name(&self) -> &str { + "mock" + } + fn model_id(&self) -> &str { + "mock" + } } fn mock_summarizer() -> Arc { @@ -704,11 +743,13 @@ mod tests { MM.get_or_init(|| { let rt = tokio::runtime::Runtime::new().unwrap(); rt.block_on(async { - let tmp = std::env::temp_dir().join(format!("picobot_ctx_test_{}.db", std::process::id())); + let tmp = std::env::temp_dir() + .join(format!("picobot_ctx_test_{}.db", std::process::id())); let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap()); Arc::new(MemoryManager::new(storage, "test".into(), "test".into())) }) - }).clone() + }) + .clone() } #[test] @@ -724,7 +765,11 @@ mod tests { // "Hi there!" (8) -> ceil(8/4)+4 = 2+4 = 6 // "How are you?" (11) -> ceil(11/4)+4 = 3+4 = 7 // raw = 19, with 1.2x = ~23 - assert!(tokens > 18 && tokens < 30, "Expected ~23 tokens, got {}", tokens); + assert!( + tokens > 18 && tokens < 30, + "Expected ~23 tokens, got {}", + tokens + ); } #[test] @@ -733,7 +778,8 @@ mod tests { tool_result_trim_chars: 50, ..Default::default() }; - let compressor = ContextCompressor::with_config(mock_provider(), 100_000, config, test_memory_manager()); + let compressor = + ContextCompressor::with_config(mock_provider(), 100_000, config, test_memory_manager()); let mut messages = vec![ ChatMessage::user("Hello"), @@ -774,7 +820,11 @@ mod tests { ChatMessage::tool("call1", "bash", &"x".repeat(3000)), ]; - let result = compressor.compress_if_needed(messages).await.unwrap().history; + let result = compressor + .compress_if_needed(messages) + .await + .unwrap() + .history; let tool_msg = result.iter().find(|m| m.role == "tool").unwrap(); assert!( @@ -798,13 +848,14 @@ mod tests { // - B2B (L275): last user message lost when it is the final history message // // context_window=200 → threshold=100. Large tool outputs force LLM summarization. - let tmp = std::env::temp_dir().join(format!("picobot_ctx_boundary_{}.db", std::process::id())); + let tmp = + std::env::temp_dir().join(format!("picobot_ctx_boundary_{}.db", std::process::id())); let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap()); let mm = Arc::new(MemoryManager::new(storage, "test".into(), "test".into())); let config = ContextCompressionConfig { tool_result_trim_chars: 2000, - protect_first_n: 1, // system/protected → B2A: first user (after skip) duplicated + protect_first_n: 1, // system/protected → B2A: first user (after skip) duplicated protect_last_n: 2, max_passes: 1, ..Default::default() @@ -818,25 +869,43 @@ mod tests { let big = "x".repeat(3000); let messages = vec![ ChatMessage::system("You are a helper."), // 0: protected - ChatMessage::user("Q1"), // 1: first user - ChatMessage::tool("t1", "bash", &big), // 2 - ChatMessage::user("Q2"), // 3 - ChatMessage::assistant("thinking"), // 4 - ChatMessage::tool("t2", "bash", &big), // 5 - ChatMessage::user("Q3"), // 6 - ChatMessage::assistant("thinking"), // 7 - ChatMessage::user("Q4"), // 8: LAST, is user → B2B triggers + ChatMessage::user("Q1"), // 1: first user + ChatMessage::tool("t1", "bash", &big), // 2 + ChatMessage::user("Q2"), // 3 + ChatMessage::assistant("thinking"), // 4 + ChatMessage::tool("t2", "bash", &big), // 5 + ChatMessage::user("Q3"), // 6 + ChatMessage::assistant("thinking"), // 7 + ChatMessage::user("Q4"), // 8: LAST, is user → B2B triggers ]; - let result = compressor.compress_if_needed(messages).await.unwrap().history; + let result = compressor + .compress_if_needed(messages) + .await + .unwrap() + .history; // B2A: "Q1" must appear exactly once - let q1_count = result.iter().filter(|m| m.role == "user" && m.content == "Q1").count(); - assert_eq!(q1_count, 1, "Q1 should appear exactly once, got {}", q1_count); + let q1_count = result + .iter() + .filter(|m| m.role == "user" && m.content == "Q1") + .count(); + assert_eq!( + q1_count, 1, + "Q1 should appear exactly once, got {}", + q1_count + ); // B2B: "Q4" must NOT be lost - let q4_count = result.iter().filter(|m| m.role == "user" && m.content == "Q4").count(); - assert_eq!(q4_count, 1, "Q4 should appear exactly once (not lost), got {}", q4_count); + let q4_count = result + .iter() + .filter(|m| m.role == "user" && m.content == "Q4") + .count(); + assert_eq!( + q4_count, 1, + "Q4 should appear exactly once (not lost), got {}", + q4_count + ); let _ = std::fs::remove_file(&tmp); } @@ -850,10 +919,10 @@ mod tests { let mm = Arc::new(MemoryManager::new(storage, "test".into(), "test".into())); let config = ContextCompressionConfig { - tool_result_trim_chars: 500, // trim reduces but not enough + tool_result_trim_chars: 500, // trim reduces but not enough protect_first_n: 1, protect_last_n: 2, - max_passes: 0, // no LLM summarization → will exceed danger + max_passes: 0, // no LLM summarization → will exceed danger ..Default::default() }; // context_window=100, danger_threshold=90. @@ -872,13 +941,23 @@ mod tests { ChatMessage::tool("t3", "bash", &big), ]; - let result = compressor.compress_if_needed(messages).await.unwrap().history; + let result = compressor + .compress_if_needed(messages) + .await + .unwrap() + .history; // After hard truncation: head (1) + trunc_note (1) + tail (2) = 4 messages - assert!(result.len() < 7, "expected truncation reduction, got {} messages", result.len()); + assert!( + result.len() < 7, + "expected truncation reduction, got {} messages", + result.len() + ); // Truncation notice should be present - let has_notice = result.iter().any(|m| m.content.contains("Context truncation")); + let has_notice = result + .iter() + .any(|m| m.content.contains("Context truncation")); assert!(has_notice, "hard truncation notice missing"); let _ = std::fs::remove_file(&tmp); @@ -893,9 +972,9 @@ mod tests { let mut messages = vec![ ChatMessage::user("Q1"), ChatMessage::user("[Context Summary]\n\nsummary of previous turn"), - ChatMessage::tool("tc1", "bash", "orphan result"), // orphan — tc1 never declared - ChatMessage::assistant("done"), // declares tc2 - ChatMessage::tool("tc2", "bash", "legitimate result"), // legit + ChatMessage::tool("tc1", "bash", "orphan result"), // orphan — tc1 never declared + ChatMessage::assistant("done"), // declares tc2 + ChatMessage::tool("tc2", "bash", "legitimate result"), // legit ]; // Set tool_call_id on tool messages and tool_calls on assistant messages[2].tool_call_id = Some("tc1".into()); @@ -910,8 +989,16 @@ mod tests { // orphan should be removed; legitimate should stay assert_eq!(messages.len(), 4); - assert!(messages.iter().all(|m| m.tool_call_id != Some("tc1".into()))); - assert!(messages.iter().any(|m| m.tool_call_id == Some("tc2".into()))); + assert!( + messages + .iter() + .all(|m| m.tool_call_id != Some("tc1".into())) + ); + assert!( + messages + .iter() + .any(|m| m.tool_call_id == Some("tc2".into())) + ); } #[test] diff --git a/src/agent/media_handler.rs b/src/agent/media_handler.rs index 6a93f8d..5b1b72c 100644 --- a/src/agent/media_handler.rs +++ b/src/agent/media_handler.rs @@ -49,7 +49,7 @@ impl MediaHandler for ImageHandler { } fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> { - use base64::{engine::general_purpose::STANDARD, Engine as _}; + use base64::{Engine as _, engine::general_purpose::STANDARD}; let mut file = std::fs::File::open(path)?; let mut buffer = Vec::new(); diff --git a/src/agent/mod.rs b/src/agent/mod.rs index c2b681a..64591be 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -4,10 +4,13 @@ pub mod media_handler; pub mod sub_agent; pub mod system_prompt; -pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult}; +pub use agent_loop::{AgentError, AgentLoop, AgentProcessResult}; pub use context_compressor::{ContextCompressor, estimate_tokens}; -pub use sub_agent::{DelegateContext, ExecutionMode, SubAgentConfig, SubAgentError, SubAgentManager, SubAgentResult, TaskNotification, TaskStatus}; -pub use system_prompt::{ - build_system_prompt, build_sub_agent_system_prompt, PromptContext, PromptSection, - SystemPromptBuilder, +pub use sub_agent::{ + DelegateContext, ExecutionMode, SubAgentConfig, SubAgentError, SubAgentManager, SubAgentResult, + TaskNotification, TaskStatus, +}; +pub use system_prompt::{ + PromptContext, PromptSection, SystemPromptBuilder, build_sub_agent_system_prompt, + build_system_prompt, }; diff --git a/src/agent/sub_agent.rs b/src/agent/sub_agent.rs index 51d5a91..bc7c0e3 100644 --- a/src/agent/sub_agent.rs +++ b/src/agent/sub_agent.rs @@ -6,12 +6,12 @@ use dashmap::DashMap; use tokio_util::sync::CancellationToken; use uuid::Uuid; -use crate::agent::system_prompt::build_sub_agent_system_prompt; -use crate::agent::AgentLoop; use crate::agent::AgentError; +use crate::agent::AgentLoop; +use crate::agent::system_prompt::build_sub_agent_system_prompt; use crate::bus::ChatMessage; use crate::config::LLMProviderConfig; -use crate::providers::{create_provider, LLMProvider}; +use crate::providers::{LLMProvider, create_provider}; use crate::skills::SkillsLoader; use crate::tools::ToolRegistry; @@ -21,7 +21,8 @@ tokio::task_local! { /// Read the delegate context from the current task. Returns an error if not set. pub fn get_delegate_context() -> Result { - DELEGATE_CONTEXT.try_with(|ctx| ctx.clone()) + DELEGATE_CONTEXT + .try_with(|ctx| ctx.clone()) .map_err(|_| "DELEGATE_CONTEXT not set".to_string()) } @@ -207,7 +208,10 @@ impl SubAgentManager { let timeout_secs = config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS); let timeout_human = format_duration(timeout_secs); let http_get_only = config.allowed_tools.is_none() - || config.allowed_tools.as_ref().is_some_and(|v| v.iter().any(|t| t == "http_request")); + || config + .allowed_tools + .as_ref() + .is_some_and(|v| v.iter().any(|t| t == "http_request")); let skills_prompt = self.get_skills_prompt(&tools); let system_prompt = build_sub_agent_system_prompt( &config.prompt, @@ -219,7 +223,8 @@ impl SubAgentManager { http_get_only, ); - let agent = self.build_sub_agent(&config, tools) + let agent = self + .build_sub_agent(&config, tools) .map_err(|e| SubAgentError::ProviderCreation(e.to_string()))?; let history = vec![ @@ -241,10 +246,14 @@ impl SubAgentManager { Ok(Ok(agent_result)) => { let (content, truncated) = truncate_sub_agent_result(&agent_result.final_response.content); - let tool_calls_count = agent_result.emitted_messages.iter() + let tool_calls_count = agent_result + .emitted_messages + .iter() .filter(|m| m.tool_calls.is_some()) .count(); - let iterations = agent_result.emitted_messages.iter() + let iterations = agent_result + .emitted_messages + .iter() .filter(|m| m.role == "assistant" && m.tool_calls.is_some()) .count(); Ok(SubAgentResult { @@ -343,7 +352,10 @@ impl SubAgentManager { let timeout_secs = config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS); let timeout_human = format_duration(timeout_secs); let http_get_only = config.allowed_tools.is_none() - || config.allowed_tools.as_ref().is_some_and(|v| v.iter().any(|t| t == "http_request")); + || config + .allowed_tools + .as_ref() + .is_some_and(|v| v.iter().any(|t| t == "http_request")); let skills_prompt = self.get_skills_prompt(&tools); let system_prompt = build_sub_agent_system_prompt( &config.prompt, @@ -372,8 +384,12 @@ impl SubAgentManager { if let Some(ref s) = storage { let _ = s .update_background_task_status( - &tid, "running", None, None, - Some(started_at), None, + &tid, + "running", + None, + None, + Some(started_at), + None, ) .await; } @@ -384,8 +400,7 @@ impl SubAgentManager { p.set_storage(s.clone()); } } - let provider_result: Option> = - provider.map(|p| Arc::from(p)); + let provider_result: Option> = provider.map(|p| Arc::from(p)); let result = match provider_result { Some(provider) => { @@ -474,9 +489,12 @@ impl SubAgentManager { if let Some(ref s) = storage { let _ = s .update_background_task_status( - &tid, &status_str, - Some(&result.content), error_val.as_deref(), - Some(started_at), Some(finished_at), + &tid, + &status_str, + Some(&result.content), + error_val.as_deref(), + Some(started_at), + Some(finished_at), ) .await; } @@ -514,15 +532,13 @@ impl SubAgentManager { Ok(true) } else if let Some(ref s) = self.storage { match s.get_background_task(task_id).await { - Ok(task) => { - match task.status.as_str() { - "pending" | "running" => { - tracing::warn!(task_id, "task in DB but not in active_tasks"); - Ok(false) - } - _ => Ok(false), + Ok(task) => match task.status.as_str() { + "pending" | "running" => { + tracing::warn!(task_id, "task in DB but not in active_tasks"); + Ok(false) } - } + _ => Ok(false), + }, Err(_) => Ok(false), } } else { @@ -530,10 +546,7 @@ impl SubAgentManager { } } - pub async fn check_task( - &self, - task_id: &str, - ) -> Option { + pub async fn check_task(&self, task_id: &str) -> Option { if let Some(ref s) = self.storage { s.get_background_task(task_id).await.ok() } else { @@ -541,12 +554,11 @@ impl SubAgentManager { } } - pub async fn list_tasks( - &self, - session_id: &str, - ) -> Vec { + pub async fn list_tasks(&self, session_id: &str) -> Vec { if let Some(ref s) = self.storage { - s.list_background_tasks(session_id).await.unwrap_or_default() + s.list_background_tasks(session_id) + .await + .unwrap_or_default() } else { vec![] } diff --git a/src/agent/system_prompt.rs b/src/agent/system_prompt.rs index b2855c4..373d332 100644 --- a/src/agent/system_prompt.rs +++ b/src/agent/system_prompt.rs @@ -196,10 +196,10 @@ impl PromptSection for UserProfileSection { if let Some(user_config_dir) = get_user_config_dir() && let Some(content) = load_file_from_dir(&user_config_dir, "USER.md", BOOTSTRAP_MAX_CHARS) - { - output.push_str(&content); - return output; - } + { + output.push_str(&content); + return output; + } // No USER.md found, return empty String::new() @@ -220,10 +220,10 @@ impl PromptSection for AgentProfileSection { if let Some(user_config_dir) = get_user_config_dir() && let Some(content) = load_file_from_dir(&user_config_dir, "AGENTS.md", BOOTSTRAP_MAX_CHARS) - { - output.push_str(&content); - return output; - } + { + output.push_str(&content); + return output; + } String::new() } @@ -465,7 +465,9 @@ impl PromptSection for SubAgentToolsSection { let mut s = String::from("## 可用工具\n\n"); s.push_str(&ctx.tools.describe_for_prompt()); if self.http_get_only { - s.push_str("\n\n**注意**:使用 http_request 时只允许 GET 方法,禁止 POST、PUT、DELETE 等。"); + s.push_str( + "\n\n**注意**:使用 http_request 时只允许 GET 方法,禁止 POST、PUT、DELETE 等。", + ); } s } @@ -560,13 +562,8 @@ pub fn build_sub_agent_system_prompt( memory_context: None, has_compressed_history: false, }; - SystemPromptBuilder::with_sub_agent_defaults( - task, - timeout_human, - skills_prompt, - http_get_only, - ) - .build(&ctx) + SystemPromptBuilder::with_sub_agent_defaults(task, timeout_human, skills_prompt, http_get_only) + .build(&ctx) } #[cfg(test)] diff --git a/src/bus/dispatcher.rs b/src/bus/dispatcher.rs index 27543b1..b10099f 100644 --- a/src/bus/dispatcher.rs +++ b/src/bus/dispatcher.rs @@ -1,8 +1,8 @@ use std::sync::Arc; use crate::bus::{MessageBus, OutboundMessage}; -use crate::channels::base::{Channel, ChannelError}; use crate::channels::ChannelManager; +use crate::channels::base::{Channel, ChannelError}; /// OutboundDispatcher consumes outbound messages from the MessageBus /// and dispatches them to the appropriate Channel diff --git a/src/bus/message.rs b/src/bus/message.rs index 2e937f6..93315b7 100644 --- a/src/bus/message.rs +++ b/src/bus/message.rs @@ -1,5 +1,5 @@ -use std::collections::HashMap; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use crate::providers::ToolCall; @@ -23,7 +23,9 @@ pub struct ImageUrlBlock { impl ContentBlock { pub fn text(content: impl Into) -> Self { - Self::Text { text: content.into() } + Self::Text { + text: content.into(), + } } pub fn image_url(url: impl Into) -> Self { @@ -49,10 +51,10 @@ pub struct MediaRef { #[derive(Debug, Clone)] pub struct MediaItem { - pub path: String, // Local file path - pub media_type: String, // "image", "audio", "file", "video" + pub path: String, // Local file path + pub media_type: String, // "image", "audio", "file", "video" pub mime_type: Option, - pub original_key: Option, // Feishu file_key for download + pub original_key: Option, // Feishu file_key for download } impl MediaItem { @@ -161,7 +163,10 @@ impl ChatMessage { } } - pub fn assistant_with_tool_calls(content: impl Into, tool_calls: Vec) -> Self { + pub fn assistant_with_tool_calls( + content: impl Into, + tool_calls: Vec, + ) -> Self { Self { id: uuid::Uuid::new_v4().to_string(), role: "assistant".to_string(), @@ -206,7 +211,11 @@ impl ChatMessage { } } - pub fn tool(tool_call_id: impl Into, tool_name: impl Into, content: impl Into) -> Self { + pub fn tool( + tool_call_id: impl Into, + tool_name: impl Into, + content: impl Into, + ) -> Self { Self { id: uuid::Uuid::new_v4().to_string(), role: "tool".to_string(), diff --git a/src/bus/mod.rs b/src/bus/mod.rs index 6c46feb..2f3011c 100644 --- a/src/bus/mod.rs +++ b/src/bus/mod.rs @@ -2,10 +2,13 @@ pub mod dispatcher; pub mod message; pub use dispatcher::OutboundDispatcher; -pub use message::{ChatMessage, ContentBlock, ControlMessage, InboundMessage, MediaItem, MediaRef, MessageSource, OutboundMessage, SourceKind}; +pub use message::{ + ChatMessage, ContentBlock, ControlMessage, InboundMessage, MediaItem, MediaRef, MessageSource, + OutboundMessage, SourceKind, +}; use std::sync::Arc; -use tokio::sync::{mpsc, Mutex}; +use tokio::sync::{Mutex, mpsc}; // ============================================================================ // MessageBus - Async message queue for Channel <-> Agent communication @@ -49,7 +52,8 @@ impl MessageBus { /// Consume an inbound message (Agent -> Bus) pub async fn consume_inbound(&self) -> InboundMessage { - let msg = self.inbound_rx + let msg = self + .inbound_rx .lock() .await .recv() diff --git a/src/channels/feishu.rs b/src/channels/feishu.rs index fa9b555..3e24b63 100644 --- a/src/channels/feishu.rs +++ b/src/channels/feishu.rs @@ -8,9 +8,9 @@ use futures_util::{SinkExt, StreamExt}; use prost::{Message as ProstMessage, bytes::Bytes}; use regex::Regex; use serde::Deserialize; -use tokio::sync::{broadcast, RwLock}; +use tokio::sync::{RwLock, broadcast}; -use crate::bus::{MessageBus, MediaItem, OutboundMessage}; +use crate::bus::{MediaItem, MessageBus, OutboundMessage}; use crate::channels::base::{Channel, ChannelError}; use crate::config::FeishuChannelConfig; @@ -192,7 +192,10 @@ impl FeishuChannel { } /// Get WebSocket endpoint URL from Feishu API - async fn get_ws_endpoint(&self, client: &reqwest::Client) -> Result<(String, WsClientConfig), ChannelError> { + async fn get_ws_endpoint( + &self, + client: &reqwest::Client, + ) -> Result<(String, WsClientConfig), ChannelError> { let resp = client .post(format!("{}/callback/ws/endpoint", FEISHU_WS_BASE)) .header("locale", "zh") @@ -204,10 +207,9 @@ impl FeishuChannel { .await .map_err(|e| ChannelError::ConnectionError(format!("HTTP error: {}", e)))?; - let endpoint_resp: WsEndpointResp = resp - .json() - .await - .map_err(|e| ChannelError::ConnectionError(format!("Failed to parse endpoint response: {}", e)))?; + let endpoint_resp: WsEndpointResp = resp.json().await.map_err(|e| { + ChannelError::ConnectionError(format!("Failed to parse endpoint response: {}", e)) + })?; if endpoint_resp.code != 0 { return Err(ChannelError::ConnectionError(format!( @@ -217,7 +219,8 @@ impl FeishuChannel { ))); } - let ep = endpoint_resp.data + let ep = endpoint_resp + .data .ok_or_else(|| ChannelError::ConnectionError("Empty endpoint data".to_string()))?; let client_config = ep.client_config.unwrap_or_default(); @@ -230,9 +233,10 @@ impl FeishuChannel { { let cached = self.tenant_token.read().await; if let Some(ref token) = *cached - && Instant::now() < token.refresh_after { - return Ok(token.value.clone()); - } + && Instant::now() < token.refresh_after + { + return Ok(token.value.clone()); + } } // 2. Fetch new token @@ -253,8 +257,12 @@ impl FeishuChannel { /// Fetch a new tenant access token from Feishu. async fn fetch_new_token(&self) -> Result<(String, Duration), ChannelError> { - let resp = self.http_client - .post(format!("{}/auth/v3/tenant_access_token/internal", FEISHU_API_BASE)) + let resp = self + .http_client + .post(format!( + "{}/auth/v3/tenant_access_token/internal", + FEISHU_API_BASE + )) .header("Content-Type", "application/json") .json(&serde_json::json!({ "app_id": self.config.app_id, @@ -280,10 +288,12 @@ impl FeishuChannel { return Err(ChannelError::Other("Auth failed".to_string())); } - let token = token_resp.tenant_access_token + let token = token_resp + .tenant_access_token .ok_or_else(|| ChannelError::Other("No token in response".to_string()))?; - let ttl = token_resp.expire + let ttl = token_resp + .expire .and_then(|v| u64::try_from(v).ok()) .map(Duration::from_secs) .unwrap_or(DEFAULT_TOKEN_TTL); @@ -314,12 +324,19 @@ impl FeishuChannel { message_id: &str, ) -> Result<(String, Option), ChannelError> { let media_dir = Path::new(&self.config.media_dir); - tokio::fs::create_dir_all(media_dir).await + tokio::fs::create_dir_all(media_dir) + .await .map_err(|e| ChannelError::Other(format!("Failed to create media dir: {}", e)))?; match msg_type { - "image" => self.download_image(content_json, message_id, media_dir).await, - "audio" | "file" | "media" => self.download_file(content_json, message_id, media_dir, msg_type).await, + "image" => { + self.download_image(content_json, message_id, media_dir) + .await + } + "audio" | "file" | "media" => { + self.download_file(content_json, message_id, media_dir, msg_type) + .await + } _ => Ok((format!("[unsupported media type: {}]", msg_type), None)), } } @@ -331,24 +348,31 @@ impl FeishuChannel { message_id: &str, media_dir: &Path, ) -> Result<(String, Option), ChannelError> { - let image_key = content_json.get("image_key") + let image_key = content_json + .get("image_key") .and_then(|v| v.as_str()) .ok_or_else(|| ChannelError::Other("No image_key in message".to_string()))?; let token = self.get_tenant_access_token().await?; // Use message resource API for downloading message images - let url = format!("{}/im/v1/messages/{}/resources/{}?type=image", FEISHU_API_BASE, message_id, image_key); + let url = format!( + "{}/im/v1/messages/{}/resources/{}?type=image", + FEISHU_API_BASE, message_id, image_key + ); #[cfg(debug_assertions)] tracing::debug!(url = %url, image_key = %image_key, message_id = %message_id, "Downloading image from Feishu via message resource API"); - let resp = self.http_client + let resp = self + .http_client .get(&url) .header("Authorization", format!("Bearer {}", token)) .send() .await - .map_err(|e| ChannelError::ConnectionError(format!("Download image HTTP error: {}", e)))?; + .map_err(|e| { + ChannelError::ConnectionError(format!("Download image HTTP error: {}", e)) + })?; let status = resp.status(); #[cfg(debug_assertions)] @@ -356,7 +380,10 @@ impl FeishuChannel { if !status.is_success() { let error_text = resp.text().await.unwrap_or_default(); - return Err(ChannelError::Other(format!("Image download failed {}: {}", status, error_text))); + return Err(ChannelError::Other(format!( + "Image download failed {}: {}", + status, error_text + ))); } let content_type = resp @@ -368,23 +395,28 @@ impl FeishuChannel { let ext = resolve_image_ext(&content_type); - let data = resp.bytes().await + let data = resp + .bytes() + .await .map_err(|e| ChannelError::Other(format!("Failed to read image data: {}", e)))? .to_vec(); #[cfg(debug_assertions)] tracing::debug!(data_len = %data.len(), content_type = %content_type, "Downloaded image data"); - let filename = format!("{}_{}.{}", message_id, &image_key[..8.min(image_key.len())], ext); + let filename = format!( + "{}_{}.{}", + message_id, + &image_key[..8.min(image_key.len())], + ext + ); let file_path = resolve_unique_path(media_dir, &filename).await; - tokio::fs::write(&file_path, &data).await + tokio::fs::write(&file_path, &data) + .await .map_err(|e| ChannelError::Other(format!("Failed to write image: {}", e)))?; - let media_item = MediaItem::new( - file_path.to_string_lossy().to_string(), - "image", - ); + let media_item = MediaItem::new(file_path.to_string_lossy().to_string(), "image"); tracing::info!(message_id = %message_id, filename = %filename, "Downloaded image"); @@ -399,32 +431,44 @@ impl FeishuChannel { media_dir: &Path, file_type: &str, ) -> Result<(String, Option), ChannelError> { - let file_key = content_json.get("file_key") + let file_key = content_json + .get("file_key") .and_then(|v| v.as_str()) .ok_or_else(|| ChannelError::Other("No file_key in message".to_string()))?; let token = self.get_tenant_access_token().await?; // Use message resource API for downloading message files - let url = format!("{}/im/v1/messages/{}/resources/{}?type=file", FEISHU_API_BASE, message_id, file_key); + let url = format!( + "{}/im/v1/messages/{}/resources/{}?type=file", + FEISHU_API_BASE, message_id, file_key + ); #[cfg(debug_assertions)] tracing::debug!(url = %url, file_key = %file_key, message_id = %message_id, "Downloading file from Feishu via message resource API"); - let resp = self.http_client + let resp = self + .http_client .get(&url) .header("Authorization", format!("Bearer {}", token)) .send() .await - .map_err(|e| ChannelError::ConnectionError(format!("Download file HTTP error: {}", e)))?; + .map_err(|e| { + ChannelError::ConnectionError(format!("Download file HTTP error: {}", e)) + })?; let status = resp.status(); if !status.is_success() { let error_text = resp.text().await.unwrap_or_default(); - return Err(ChannelError::Other(format!("File download failed {}: {}", status, error_text))); + return Err(ChannelError::Other(format!( + "File download failed {}: {}", + status, error_text + ))); } - let data = resp.bytes().await + let data = resp + .bytes() + .await .map_err(|e| ChannelError::Other(format!("Failed to read file data: {}", e)))? .to_vec(); @@ -437,18 +481,21 @@ impl FeishuChannel { if ext.is_empty() { format!("{}_{}", message_id, &file_key[..8.min(file_key.len())]) } else { - format!("{}_{}.{}", message_id, &file_key[..8.min(file_key.len())], ext) + format!( + "{}_{}.{}", + message_id, + &file_key[..8.min(file_key.len())], + ext + ) } }); let file_path = resolve_unique_path(media_dir, &filename).await; - tokio::fs::write(&file_path, &data).await + tokio::fs::write(&file_path, &data) + .await .map_err(|e| ChannelError::Other(format!("Failed to write file: {}", e)))?; - let media_item = MediaItem::new( - file_path.to_string_lossy().to_string(), - file_type, - ); + let media_item = MediaItem::new(file_path.to_string_lossy().to_string(), file_type); tracing::info!(message_id = %message_id, filename = %filename, file_type = %file_type, "Downloaded file"); @@ -468,7 +515,8 @@ impl FeishuChannel { .and_then(|n| n.to_str()) .unwrap_or("image.jpg"); - let file_data = tokio::fs::read(file_path).await + let file_data = tokio::fs::read(file_path) + .await .map_err(|e| ChannelError::Other(format!("Failed to read file: {}", e)))?; let part = reqwest::multipart::Part::bytes(file_data) @@ -480,16 +528,21 @@ impl FeishuChannel { .text("image_type", "message".to_string()) .part("image", part); - let resp = self.http_client + let resp = self + .http_client .post(format!("{}/im/v1/images", FEISHU_API_BASE)) .header("Authorization", format!("Bearer {}", token)) .multipart(form) .send() .await - .map_err(|e| ChannelError::ConnectionError(format!("Upload image HTTP error: {}", e)))?; + .map_err(|e| { + ChannelError::ConnectionError(format!("Upload image HTTP error: {}", e)) + })?; let status = resp.status(); - let body_text = resp.text().await + let body_text = resp + .text() + .await .map_err(|e| ChannelError::Other(format!("Failed to read upload response: {}", e)))?; tracing::debug!(status = %status, body = %body_text, "Feishu upload image"); @@ -504,8 +557,12 @@ impl FeishuChannel { image_key: String, } - let result: UploadResp = serde_json::from_str(&body_text) - .map_err(|e| ChannelError::Other(format!("Parse upload response error: {} | body: {}", e, &body_text)))?; + let result: UploadResp = serde_json::from_str(&body_text).map_err(|e| { + ChannelError::Other(format!( + "Parse upload response error: {} | body: {}", + e, &body_text + )) + })?; if result.code != 0 { return Err(ChannelError::Other(format!( @@ -515,7 +572,8 @@ impl FeishuChannel { ))); } - result.data + result + .data .map(|d| d.image_key) .ok_or_else(|| ChannelError::Other("No image_key in response".to_string())) } @@ -545,7 +603,8 @@ impl FeishuChannel { _ => "stream", }; - let file_data = tokio::fs::read(file_path).await + let file_data = tokio::fs::read(file_path) + .await .map_err(|e| ChannelError::Other(format!("Failed to read file: {}", e)))?; let part = reqwest::multipart::Part::bytes(file_data) @@ -558,7 +617,8 @@ impl FeishuChannel { .text("file_name", file_name.to_string()) .part("file", part); - let resp = self.http_client + let resp = self + .http_client .post(format!("{}/im/v1/files", FEISHU_API_BASE)) .header("Authorization", format!("Bearer {}", token)) .multipart(form) @@ -567,7 +627,9 @@ impl FeishuChannel { .map_err(|e| ChannelError::ConnectionError(format!("Upload file HTTP error: {}", e)))?; let status = resp.status(); - let body_text = resp.text().await + let body_text = resp + .text() + .await .map_err(|e| ChannelError::Other(format!("Failed to read upload response: {}", e)))?; tracing::debug!(status = %status, body = %body_text, "Feishu upload file"); @@ -582,8 +644,12 @@ impl FeishuChannel { file_key: String, } - let result: UploadResp = serde_json::from_str(&body_text) - .map_err(|e| ChannelError::Other(format!("Parse upload response error: {} | body: {}", e, &body_text)))?; + let result: UploadResp = serde_json::from_str(&body_text).map_err(|e| { + ChannelError::Other(format!( + "Parse upload response error: {} | body: {}", + e, &body_text + )) + })?; if result.code != 0 { return Err(ChannelError::Other(format!( @@ -593,7 +659,8 @@ impl FeishuChannel { ))); } - result.data + result + .data .map(|d| d.file_key) .ok_or_else(|| ChannelError::Other("No file_key in response".to_string())) } @@ -604,15 +671,21 @@ impl FeishuChannel { let emoji = self.config.reaction_emoji.as_str(); let token = self.get_tenant_access_token().await?; - let resp = self.http_client - .post(format!("{}/im/v1/messages/{}/reactions", FEISHU_API_BASE, message_id)) + let resp = self + .http_client + .post(format!( + "{}/im/v1/messages/{}/reactions", + FEISHU_API_BASE, message_id + )) .header("Authorization", format!("Bearer {}", token)) .json(&serde_json::json!({ "reaction_type": { "emoji_type": emoji } })) .send() .await - .map_err(|e| ChannelError::ConnectionError(format!("Add reaction HTTP error: {}", e)))?; + .map_err(|e| { + ChannelError::ConnectionError(format!("Add reaction HTTP error: {}", e)) + })?; #[derive(Deserialize)] struct ReactionResp { @@ -625,7 +698,9 @@ impl FeishuChannel { reaction_id: Option, } - let result: ReactionResp = resp.json().await + let result: ReactionResp = resp + .json() + .await .map_err(|e| ChannelError::Other(format!("Parse reaction response error: {}", e)))?; if result.code != 0 { @@ -644,7 +719,10 @@ impl FeishuChannel { /// Remove reaction using feishu metadata propagated through OutboundMessage. /// Reads feishu.message_id and feishu.reaction_id from metadata. - async fn remove_reaction_from_metadata(&self, metadata: &std::collections::HashMap) { + async fn remove_reaction_from_metadata( + &self, + metadata: &std::collections::HashMap, + ) { let (message_id, reaction_id) = match ( metadata.get("feishu.message_id"), metadata.get("feishu.reaction_id"), @@ -658,15 +736,25 @@ impl FeishuChannel { } /// Remove a reaction emoji from a message. - async fn remove_reaction(&self, message_id: &str, reaction_id: &str) -> Result<(), ChannelError> { + async fn remove_reaction( + &self, + message_id: &str, + reaction_id: &str, + ) -> Result<(), ChannelError> { let token = self.get_tenant_access_token().await?; - let resp = self.http_client - .delete(format!("{}/im/v1/messages/{}/reactions/{}", FEISHU_API_BASE, message_id, reaction_id)) + let resp = self + .http_client + .delete(format!( + "{}/im/v1/messages/{}/reactions/{}", + FEISHU_API_BASE, message_id, reaction_id + )) .header("Authorization", format!("Bearer {}", token)) .send() .await - .map_err(|e| ChannelError::ConnectionError(format!("Remove reaction HTTP error: {}", e)))?; + .map_err(|e| { + ChannelError::ConnectionError(format!("Remove reaction HTTP error: {}", e)) + })?; #[derive(Deserialize)] struct ReactionResp { @@ -674,8 +762,9 @@ impl FeishuChannel { msg: Option, } - let result: ReactionResp = resp.json().await - .map_err(|e| ChannelError::Other(format!("Parse remove reaction response error: {}", e)))?; + let result: ReactionResp = resp.json().await.map_err(|e| { + ChannelError::Other(format!("Parse remove reaction response error: {}", e)) + })?; if result.code != 0 { tracing::debug!( @@ -702,7 +791,8 @@ impl FeishuChannel { } }; - let resp = self.http_client + let resp = self + .http_client .get(format!("{}/im/v1/messages/{}", FEISHU_API_BASE, message_id)) .header("Authorization", format!("Bearer {}", token)) .send() @@ -754,16 +844,12 @@ impl FeishuChannel { let msg_type = msg_obj.msg_type.as_str(); let text = match msg_type { - "text" => { - serde_json::from_str::(raw_content) - .ok()? - .get("text")? - .as_str()? - .to_string() - } - "post" => { - parse_post_content(raw_content) - } + "text" => serde_json::from_str::(raw_content) + .ok()? + .get("text")? + .as_str()? + .to_string(), + "post" => parse_post_content(raw_content), _ => String::new(), }; @@ -782,11 +868,21 @@ impl FeishuChannel { /// Send a message to Feishu chat with specified message type and content. /// Content is passed as-is (already a JSON string for file/media, or plain text for fallback). - async fn send_message_to_feishu(&self, receive_id: &str, receive_id_type: &str, msg_type: &str, content: &str) -> Result<(), ChannelError> { + async fn send_message_to_feishu( + &self, + receive_id: &str, + receive_id_type: &str, + msg_type: &str, + content: &str, + ) -> Result<(), ChannelError> { let token = self.get_tenant_access_token().await?; - let resp = self.http_client - .post(format!("{}/im/v1/messages?receive_id_type={}", FEISHU_API_BASE, receive_id_type)) + let resp = self + .http_client + .post(format!( + "{}/im/v1/messages?receive_id_type={}", + FEISHU_API_BASE, receive_id_type + )) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", token)) .json(&serde_json::json!({ @@ -796,7 +892,9 @@ impl FeishuChannel { })) .send() .await - .map_err(|e| ChannelError::ConnectionError(format!("Send message HTTP error: {}", e)))?; + .map_err(|e| { + ChannelError::ConnectionError(format!("Send message HTTP error: {}", e)) + })?; #[derive(Deserialize)] struct SendResp { @@ -810,7 +908,10 @@ impl FeishuChannel { .map_err(|e| ChannelError::Other(format!("Parse send response error: {}", e)))?; if send_resp.code != 0 { - return Err(ChannelError::Other(format!("Send message failed: code={} msg={}", send_resp.code, send_resp.msg))); + return Err(ChannelError::Other(format!( + "Send message failed: code={} msg={}", + send_resp.code, send_resp.msg + ))); } Ok(()) @@ -841,7 +942,9 @@ impl FeishuChannel { return Ok(None); } - let payload = frame.payload.as_deref() + let payload = frame + .payload + .as_deref() .ok_or_else(|| ChannelError::Other("No payload in frame".to_string()))?; #[cfg(debug_assertions)] @@ -877,7 +980,10 @@ impl FeishuChannel { #[cfg(debug_assertions)] tracing::debug!(message_id = %message_id, "Received Feishu message"); - let open_id = payload_data.sender.sender_id.open_id + let open_id = payload_data + .sender + .sender_id + .open_id .ok_or_else(|| ChannelError::Other("No open_id".to_string()))?; let msg = payload_data.message; @@ -889,13 +995,16 @@ impl FeishuChannel { #[cfg(debug_assertions)] tracing::debug!(msg_type = %msg_type, chat_id = %chat_id, open_id = %open_id, "Parsing message content"); - let (mut content, media) = self.parse_and_download_message(msg_type, &raw_content, &message_id).await?; + let (mut content, media) = self + .parse_and_download_message(msg_type, &raw_content, &message_id) + .await?; // Fetch and prepend quoted message content if this is a reply if let Some(ref pid) = parent_id - && let Some(reply_ctx) = self.get_message_content(pid).await { - content = format!("{}\n{}", reply_ctx, content); - } + && let Some(reply_ctx) = self.get_message_content(pid).await + { + content = format!("{}\n{}", reply_ctx, content); + } #[cfg(debug_assertions)] if let Some(ref m) = media { @@ -922,18 +1031,23 @@ impl FeishuChannel { let (text, media) = match msg_type { "text" => { let text = if let Ok(parsed) = serde_json::from_str::(content) { - parsed.get("text").and_then(|v| v.as_str()).unwrap_or(content).to_string() + parsed + .get("text") + .and_then(|v| v.as_str()) + .unwrap_or(content) + .to_string() } else { content.to_string() }; (text, None) } - "post" => { - (parse_post_content(content), None) - } + "post" => (parse_post_content(content), None), "image" | "audio" | "file" | "media" => { if let Ok(content_json) = serde_json::from_str::(content) { - match self.download_media(msg_type, &content_json, message_id).await { + match self + .download_media(msg_type, &content_json, message_id) + .await + { Ok((text, media)) => (text, media), Err(_) => (format!("[{}: content unavailable]", msg_type), None), } @@ -944,7 +1058,10 @@ impl FeishuChannel { "share_chat" => { // Shared chat/cannel messages if let Ok(parsed) = serde_json::from_str::(content) { - let chat_id = parsed.get("chat_id").and_then(|v| v.as_str()).unwrap_or("unknown"); + let chat_id = parsed + .get("chat_id") + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); (format!("[shared chat: {}]", chat_id), None) } else { ("[shared chat]".to_string(), None) @@ -953,7 +1070,10 @@ impl FeishuChannel { "share_user" => { // Shared user messages if let Ok(parsed) = serde_json::from_str::(content) { - let user_id = parsed.get("user_id").and_then(|v| v.as_str()).unwrap_or("unknown"); + let user_id = parsed + .get("user_id") + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); (format!("[shared user: {}]", user_id), None) } else { ("[shared user]".to_string(), None) @@ -976,20 +1096,19 @@ impl FeishuChannel { Err(_) => (content.to_string(), None), } } - "merge_forward" => { - ("[merged forward messages]".to_string(), None) - } + "merge_forward" => ("[merged forward messages]".to_string(), None), "share_calendar_event" => { if let Ok(parsed) = serde_json::from_str::(content) { - let event_key = parsed.get("event_key").and_then(|v| v.as_str()).unwrap_or("unknown"); + let event_key = parsed + .get("event_key") + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); (format!("[shared calendar event: {}]", event_key), None) } else { ("[shared calendar event]".to_string(), None) } } - "system" => { - ("[system message]".to_string(), None) - } + "system" => ("[system message]".to_string(), None), _ => (content.to_string(), None), }; @@ -999,20 +1118,35 @@ impl FeishuChannel { } /// Send acknowledgment for a message - async fn send_ack(frame: &PbFrame, write: &mut futures_util::stream::SplitSink>, tokio_tungstenite::tungstenite::Message>) -> Result<(), ChannelError> { + async fn send_ack( + frame: &PbFrame, + write: &mut futures_util::stream::SplitSink< + tokio_tungstenite::WebSocketStream< + tokio_tungstenite::MaybeTlsStream, + >, + tokio_tungstenite::tungstenite::Message, + >, + ) -> Result<(), ChannelError> { let mut ack = frame.clone(); ack.payload = Some(br#"{"code":200,"headers":{},"data":[]}"#.to_vec()); ack.headers.push(PbHeader { key: "biz_rt".into(), value: "0".into(), }); - write.send(tokio_tungstenite::tungstenite::Message::Binary(ack.encode_to_vec().into())) + write + .send(tokio_tungstenite::tungstenite::Message::Binary( + ack.encode_to_vec().into(), + )) .await .map_err(|e| ChannelError::Other(format!("Failed to send ack: {}", e)))?; Ok(()) } - async fn run_ws_loop(&self, bus: Arc, mut shutdown_rx: broadcast::Receiver<()>) -> Result<(), ChannelError> { + async fn run_ws_loop( + &self, + bus: Arc, + mut shutdown_rx: broadcast::Receiver<()>, + ) -> Result<(), ChannelError> { let (wss_url, client_config) = self.get_ws_endpoint(&self.http_client).await?; let service_id = Self::extract_service_id(&wss_url); @@ -1020,7 +1154,9 @@ impl FeishuChannel { let (ws_stream, _) = tokio_tungstenite::connect_async(&wss_url) .await - .map_err(|e| ChannelError::ConnectionError(format!("WebSocket connection failed: {}", e)))?; + .map_err(|e| { + ChannelError::ConnectionError(format!("WebSocket connection failed: {}", e)) + })?; *self.connected.write().await = true; tracing::info!("Feishu WebSocket connected"); @@ -1039,12 +1175,18 @@ impl FeishuChannel { }], payload: None, }; - write.send(tokio_tungstenite::tungstenite::Message::Binary(ping_frame.encode_to_vec().into())) + write + .send(tokio_tungstenite::tungstenite::Message::Binary( + ping_frame.encode_to_vec().into(), + )) .await - .map_err(|e| ChannelError::ConnectionError(format!("Failed to send initial ping: {}", e)))?; + .map_err(|e| { + ChannelError::ConnectionError(format!("Failed to send initial ping: {}", e)) + })?; let ping_interval = client_config.ping_interval.unwrap_or(120).max(10); - let mut ping_interval_tok = tokio::time::interval(tokio::time::Duration::from_secs(ping_interval)); + let mut ping_interval_tok = + tokio::time::interval(tokio::time::Duration::from_secs(ping_interval)); let mut timeout_check = tokio::time::interval(tokio::time::Duration::from_secs(10)); let mut seq: u64 = 1; let mut last_recv = Instant::now(); @@ -1201,7 +1343,8 @@ fn parse_post_content(content: &str) -> String { } } "a" => { - let link_text = el.get("text") + let link_text = el + .get("text") .and_then(|t| t.as_str()) .filter(|s| !s.is_empty()) .or_else(|| el.get("href").and_then(|h| h.as_str())) @@ -1209,7 +1352,8 @@ fn parse_post_content(content: &str) -> String { out.push(link_text.to_string()); } "at" => { - let name = el.get("user_name") + let name = el + .get("user_name") .and_then(|n| n.as_str()) .or_else(|| el.get("user_id").and_then(|i| i.as_str())) .unwrap_or("user"); @@ -1230,7 +1374,10 @@ fn parse_post_content(content: &str) -> String { /// Parse a single block {title, content: [[...]]} and append text to out. fn parse_block(block: &serde_json::Value, out: &mut Vec) { - let title = block.get("title").and_then(|t| t.as_str()).filter(|s| !s.is_empty()); + let title = block + .get("title") + .and_then(|t| t.as_str()) + .filter(|s| !s.is_empty()); if let Some(t) = title { out.push(t.to_string()); out.push("\n\n".to_string()); @@ -1287,14 +1434,15 @@ fn parse_post_content(content: &str) -> String { if let Some(root_obj) = root.as_object() { for (_key, val) in root_obj { if let Some(obj) = val.as_object() - && obj.get("content").and_then(|c| c.as_array()).is_some() { - parse_block(val, &mut texts); - let result = texts.join(""); - if !result.trim().is_empty() { - return result.trim().to_string(); - } - texts.clear(); + && obj.get("content").and_then(|c| c.as_array()).is_some() + { + parse_block(val, &mut texts); + let result = texts.join(""); + if !result.trim().is_empty() { + return result.trim().to_string(); } + texts.clear(); + } } } @@ -1319,18 +1467,20 @@ fn extract_interactive_content(content: &str) -> Result<(String, Option) } "div" => { if let Some(text_obj) = element.get("text").and_then(|t| t.as_object()) { - let content = text_obj.get("content") + let content = text_obj + .get("content") .and_then(|c| c.as_str()) .unwrap_or(""); texts.push(content.to_string()); @@ -1363,12 +1514,8 @@ fn extract_element_content(element: &serde_json::Value, texts: &mut Vec) texts.push("\n".to_string()); } "a" => { - let href = element.get("href") - .and_then(|h| h.as_str()) - .unwrap_or(""); - let text = element.get("text") - .and_then(|t| t.as_str()) - .unwrap_or(""); + let href = element.get("href").and_then(|h| h.as_str()).unwrap_or(""); + let text = element.get("text").and_then(|t| t.as_str()).unwrap_or(""); if !text.is_empty() { texts.push(text.to_string()); } else if !href.is_empty() { @@ -1377,8 +1524,13 @@ fn extract_element_content(element: &serde_json::Value, texts: &mut Vec) } "img" => { let alt = element.get("alt"); - let alt_text = alt.and_then(|a| a.as_str()) - .or_else(|| alt.and_then(|a| a.as_object()).and_then(|o| o.get("content")).and_then(|c| c.as_str())) + let alt_text = alt + .and_then(|a| a.as_str()) + .or_else(|| { + alt.and_then(|a| a.as_object()) + .and_then(|o| o.get("content")) + .and_then(|c| c.as_str()) + }) .unwrap_or("[image]"); texts.push(format!("{}\n", alt_text)); } @@ -1422,7 +1574,8 @@ fn parse_list_content(content: &str) -> Result<(String, Option), Chan Err(_) => return Ok((content.to_string(), None)), }; - let items = parsed.get("items") + let items = parsed + .get("items") .and_then(|i| i.as_array()) .or_else(|| parsed.get("content").and_then(|c| c.as_array())); @@ -1481,10 +1634,13 @@ fn collect_list_items(items: &[serde_json::Value], lines: &mut Vec, dept None } }) - }) - && let Some(children) = children_arr.as_object().and_then(|o| o.get("children")).and_then(|c| c.as_array()) { - collect_list_items(children, lines, depth + 1); - } + }) && let Some(children) = children_arr + .as_object() + .and_then(|o| o.get("children")) + .and_then(|c| c.as_array()) + { + collect_list_items(children, lines, depth + 1); + } } } @@ -1497,7 +1653,8 @@ fn extract_inline_text(el: &serde_json::Value, out: &mut String) { } } "a" => { - let text = el.get("text") + let text = el + .get("text") .and_then(|t| t.as_str()) .filter(|s| !s.is_empty()) .or_else(|| el.get("href").and_then(|h| h.as_str())) @@ -1505,7 +1662,8 @@ fn extract_inline_text(el: &serde_json::Value, out: &mut String) { out.push_str(text); } "at" => { - let name = el.get("user_name") + let name = el + .get("user_name") .and_then(|n| n.as_str()) .or_else(|| el.get("user_id").and_then(|i| i.as_str())) .unwrap_or("user"); @@ -1525,7 +1683,9 @@ fn strip_at_placeholders(text: &str) -> String { let rest: String = chars.clone().collect(); if let Some(after) = rest.strip_prefix("_user_") { // Skip until we hit a non-alphanumeric character - let placeholder_len = after.find(|c: char| !c.is_alphanumeric()).unwrap_or(after.len()); + let placeholder_len = after + .find(|c: char| !c.is_alphanumeric()) + .unwrap_or(after.len()); // Skip the placeholder for _ in 0..placeholder_len { chars.next(); @@ -1554,11 +1714,11 @@ fn resolve_image_ext(content_type: &str) -> &str { } fn resolve_file_ext(content_json: &serde_json::Value) -> String { - if let Some(name) = content_json - .get("file_name") - .and_then(|v| v.as_str()) - { - if let Some(ext) = std::path::Path::new(name).extension().and_then(|e| e.to_str()) { + if let Some(name) = content_json.get("file_name").and_then(|v| v.as_str()) { + if let Some(ext) = std::path::Path::new(name) + .extension() + .and_then(|e| e.to_str()) + { return ext.to_string(); } } @@ -1595,9 +1755,8 @@ async fn resolve_unique_path(dir: &Path, filename: &str) -> std::path::PathBuf { impl FeishuChannel { fn strip_thinking_tags(content: &str) -> String { use std::sync::LazyLock; - static THINK_RE: LazyLock = LazyLock::new(|| { - Regex::new(r"(?s).*?").unwrap() - }); + static THINK_RE: LazyLock = + LazyLock::new(|| Regex::new(r"(?s).*?").unwrap()); let stripped = THINK_RE.replace_all(content, ""); stripped.trim().to_string() } @@ -1676,8 +1835,12 @@ impl FeishuChannel { ) -> Result<(), ChannelError> { let token = self.get_tenant_access_token().await?; - let resp = self.http_client - .post(format!("{}/im/v1/messages?receive_id_type={}", FEISHU_API_BASE, receive_id_type)) + let resp = self + .http_client + .post(format!( + "{}/im/v1/messages?receive_id_type={}", + FEISHU_API_BASE, receive_id_type + )) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", token)) .json(&serde_json::json!({ @@ -1687,7 +1850,9 @@ impl FeishuChannel { })) .send() .await - .map_err(|e| ChannelError::ConnectionError(format!("Send interactive card HTTP error: {}", e)))?; + .map_err(|e| { + ChannelError::ConnectionError(format!("Send interactive card HTTP error: {}", e)) + })?; #[derive(Deserialize)] struct SendResp { @@ -1695,10 +1860,9 @@ impl FeishuChannel { msg: String, } - let send_resp: SendResp = resp - .json() - .await - .map_err(|e| ChannelError::Other(format!("Parse send interactive card response error: {}", e)))?; + let send_resp: SendResp = resp.json().await.map_err(|e| { + ChannelError::Other(format!("Parse send interactive card response error: {}", e)) + })?; if send_resp.code != 0 { return Err(ChannelError::Other(format!( @@ -1732,7 +1896,7 @@ impl Channel for FeishuChannel { async fn start(&self, bus: Arc) -> Result<(), ChannelError> { if self.config.app_id.is_empty() || self.config.app_secret.is_empty() { return Err(ChannelError::ConfigError( - "Feishu app_id or app_secret is not configured".to_string() + "Feishu app_id or app_secret is not configured".to_string(), )); } @@ -1802,8 +1966,16 @@ impl Channel for FeishuChannel { content: Self::strip_thinking_tags(&msg.content), ..msg }; - let receive_id = if msg.chat_id.starts_with("oc_") { &msg.chat_id } else { msg.reply_to.as_ref().unwrap_or(&msg.chat_id) }; - let receive_id_type = if msg.chat_id.starts_with("oc_") { "chat_id" } else { "open_id" }; + let receive_id = if msg.chat_id.starts_with("oc_") { + &msg.chat_id + } else { + msg.reply_to.as_ref().unwrap_or(&msg.chat_id) + }; + let receive_id_type = if msg.chat_id.starts_with("oc_") { + "chat_id" + } else { + "open_id" + }; // If no media, send as interactive card with raw markdown if msg.media.is_empty() { @@ -1818,10 +1990,15 @@ impl Channel for FeishuChannel { let chunks = Self::split_markdown_chunks(content); for chunk in &chunks { let card = Self::build_card_content(chunk); - if let Err(e) = self.send_interactive_card(receive_id, receive_id_type, &card).await { + if let Err(e) = self + .send_interactive_card(receive_id, receive_id_type, &card) + .await + { tracing::warn!(error = %e, "Failed to send interactive card, falling back to text"); let text_content = serde_json::json!({ "text": chunk }).to_string(); - let result = self.send_message_to_feishu(receive_id, receive_id_type, "text", &text_content).await; + let result = self + .send_message_to_feishu(receive_id, receive_id_type, "text", &text_content) + .await; self.remove_reaction_from_metadata(&msg.metadata).await; return result; } @@ -1856,9 +2033,15 @@ impl Channel for FeishuChannel { _ => "file", }; let file_content = serde_json::json!({"file_key": file_key}).to_string(); - if let Err(e) = self.send_message_to_feishu( - receive_id, receive_id_type, file_msg_type, &file_content, - ).await { + if let Err(e) = self + .send_message_to_feishu( + receive_id, + receive_id_type, + file_msg_type, + &file_content, + ) + .await + { tracing::warn!(error = %e, msg_type = file_msg_type, "Failed to send file message"); } } @@ -1874,7 +2057,10 @@ impl Channel for FeishuChannel { if !msg.content.is_empty() { const MAX_TEXT_LENGTH: usize = 60_000; let truncated_text = if msg.content.len() > MAX_TEXT_LENGTH { - format!("{}...\n\n[Content truncated due to length limit]", &msg.content[..msg.content.ceil_char_boundary(MAX_TEXT_LENGTH)]) + format!( + "{}...\n\n[Content truncated due to length limit]", + &msg.content[..msg.content.ceil_char_boundary(MAX_TEXT_LENGTH)] + ) } else { msg.content.clone() }; @@ -1913,26 +2099,27 @@ impl Channel for FeishuChannel { let content = if msg_type == "image" { // Image-only: content is just {"image_key": "..."} - let image_key = content_parts[0]["image_key"] - .as_str() - .unwrap_or(""); + let image_key = content_parts[0]["image_key"].as_str().unwrap_or(""); serde_json::json!({"image_key": image_key}).to_string() } else { // Post with media: zh_cn wrapped post structure - let post_content: Vec> = content_parts - .into_iter() - .map(|part| vec![part]) - .collect(); + let post_content: Vec> = + content_parts.into_iter().map(|part| vec![part]).collect(); serde_json::json!({ "zh_cn": { "title": "", "content": post_content } - }).to_string() + }) + .to_string() }; - let resp = self.http_client - .post(format!("{}/im/v1/messages?receive_id_type={}", FEISHU_API_BASE, receive_id_type)) + let resp = self + .http_client + .post(format!( + "{}/im/v1/messages?receive_id_type={}", + FEISHU_API_BASE, receive_id_type + )) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", token)) .json(&serde_json::json!({ @@ -1942,10 +2129,14 @@ impl Channel for FeishuChannel { })) .send() .await - .map_err(|e| ChannelError::ConnectionError(format!("Send multimodal message HTTP error: {}", e)))?; + .map_err(|e| { + ChannelError::ConnectionError(format!("Send multimodal message HTTP error: {}", e)) + })?; let send_status = resp.status(); - let send_body = resp.text().await + let send_body = resp + .text() + .await .map_err(|e| ChannelError::Other(format!("Failed to read send response: {}", e)))?; tracing::debug!(status = %send_status, body = %send_body, msg_type = %msg_type, "Feishu send message"); @@ -1955,11 +2146,18 @@ impl Channel for FeishuChannel { msg: String, } - let send_resp: SendResp = serde_json::from_str(&send_body) - .map_err(|e| ChannelError::Other(format!("Parse send response error: {} | body: {}", e, &send_body)))?; + let send_resp: SendResp = serde_json::from_str(&send_body).map_err(|e| { + ChannelError::Other(format!( + "Parse send response error: {} | body: {}", + e, &send_body + )) + })?; if send_resp.code != 0 { - return Err(ChannelError::Other(format!("Send multimodal message failed: code={} msg={}", send_resp.code, send_resp.msg))); + return Err(ChannelError::Other(format!( + "Send multimodal message failed: code={} msg={}", + send_resp.code, send_resp.msg + ))); } // Remove pending reaction after successfully sending diff --git a/src/channels/manager.rs b/src/channels/manager.rs index 1e1c200..d089ea9 100644 --- a/src/channels/manager.rs +++ b/src/channels/manager.rs @@ -24,7 +24,10 @@ impl ChannelManager { } } - pub fn with_bus(cli_chat_channel: Arc, bus: Arc) -> Self { + pub fn with_bus( + cli_chat_channel: Arc, + bus: Arc, + ) -> Self { Self { channels: Arc::new(RwLock::new(HashMap::new())), cli_chat_channel, @@ -39,7 +42,10 @@ impl ChannelManager { /// Register a channel with the manager pub async fn register_channel(&self, name: &str, channel: Arc) { - self.channels.write().await.insert(name.to_string(), channel); + self.channels + .write() + .await + .insert(name.to_string(), channel); } /// Get CLI chat channel @@ -56,14 +62,19 @@ impl ChannelManager { // Initialize Feishu channel if enabled if let Some(feishu_config) = config.channels.get("feishu") { if feishu_config.enabled { - let channel = FeishuChannel::new(feishu_config.clone(), &workspace_dir) - .map_err(|e| ChannelError::Other(format!("Failed to create Feishu channel: {}", e)))?; + let channel = + FeishuChannel::new(feishu_config.clone(), &workspace_dir).map_err(|e| { + ChannelError::Other(format!("Failed to create Feishu channel: {}", e)) + })?; self.channels .write() .await .insert("feishu".to_string(), Arc::new(channel)); - tracing::info!("Feishu channel registered (media_dir: {}/media/feishu)", workspace_dir.display()); + tracing::info!( + "Feishu channel registered (media_dir: {}/media/feishu)", + workspace_dir.display() + ); } else { tracing::info!("Feishu channel disabled in config"); } @@ -118,7 +129,10 @@ impl ChannelManager { if let Some(channel) = self.get_channel(channel_name).await { channel.send(msg).await } else { - Err(ChannelError::Other(format!("Channel not found: {}", channel_name))) + Err(ChannelError::Other(format!( + "Channel not found: {}", + channel_name + ))) } } } diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 3503db1..9331bad 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -1,11 +1,11 @@ pub mod base; -pub mod feishu; pub mod cli_chat; +pub mod feishu; pub mod manager; pub mod slash_command; pub use base::{Channel, ChannelError}; -pub use manager::ChannelManager; -pub use feishu::FeishuChannel; pub use cli_chat::CliChatChannel; -pub use slash_command::{parse_slash_command, command_matches}; +pub use feishu::FeishuChannel; +pub use manager::ChannelManager; +pub use slash_command::{command_matches, parse_slash_command}; diff --git a/src/channels/slash_command.rs b/src/channels/slash_command.rs index 7d2699a..f508c8a 100644 --- a/src/channels/slash_command.rs +++ b/src/channels/slash_command.rs @@ -16,7 +16,9 @@ pub fn parse_slash_command(content: &str) -> Option<(&str, &str)> { /// 检查内容是否匹配指定命令 pub fn command_matches(content: &str, aliases: &[&str]) -> bool { let trimmed = content.trim(); - aliases.iter().any(|&alias| trimmed == alias || trimmed.starts_with(&format!("{} ", alias))) + aliases + .iter() + .any(|&alias| trimmed == alias || trimmed.starts_with(&format!("{} ", alias))) } #[cfg(test)] @@ -27,7 +29,10 @@ mod tests { fn test_parse_slash_command() { assert_eq!(parse_slash_command("/reset"), Some(("reset", ""))); assert_eq!(parse_slash_command("/reset arg"), Some(("reset", "arg"))); - assert_eq!(parse_slash_command("/new hello world"), Some(("new", "hello world"))); + assert_eq!( + parse_slash_command("/new hello world"), + Some(("new", "hello world")) + ); assert_eq!(parse_slash_command("/??"), Some(("??", ""))); assert_eq!(parse_slash_command("/? arg"), Some(("?", "arg"))); assert_eq!(parse_slash_command("/?"), Some(("?", ""))); diff --git a/src/client/mod.rs b/src/client/mod.rs index f3e4f69..30a9f87 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -8,10 +8,10 @@ use crate::client::tui::ui::render_ui; use crossterm::{ event::{self, Event}, execute, - terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen}, + terminal::{EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, enable_raw_mode}, }; use futures_util::{SinkExt, StreamExt}; -use ratatui::{prelude::CrosstermBackend, Terminal}; +use ratatui::{Terminal, prelude::CrosstermBackend}; use std::io; use tokio_tungstenite::{connect_async, tungstenite::Message}; @@ -104,7 +104,10 @@ async fn handle_ws_message(app: &mut App, outbound: WsOutbound) { WsOutbound::SessionCreated { session_id, .. } => { app.set_current_session(Some(session_id)); } - WsOutbound::SessionList { sessions, current_session_id } => { + WsOutbound::SessionList { + sessions, + current_session_id, + } => { app.set_sessions(sessions); if let Some(id) = current_session_id { app.set_current_session(Some(id)); diff --git a/src/client/tui/components/chat_history.rs b/src/client/tui/components/chat_history.rs index 70186e3..4efd5aa 100644 --- a/src/client/tui/components/chat_history.rs +++ b/src/client/tui/components/chat_history.rs @@ -1,10 +1,10 @@ use crate::client::tui::app::{App, MessageRole}; use ratatui::{ + Frame, layout::Rect, style::{Color, Modifier, Style}, text::Line, widgets::{Block, Borders, List, ListItem}, - Frame, }; pub fn render(f: &mut Frame, area: Rect, app: &App) { diff --git a/src/client/tui/components/command_menu.rs b/src/client/tui/components/command_menu.rs index 45e031f..555c14b 100644 --- a/src/client/tui/components/command_menu.rs +++ b/src/client/tui/components/command_menu.rs @@ -1,10 +1,10 @@ use crate::client::tui::app::App; use ratatui::{ + Frame, layout::Rect, style::{Color, Modifier, Style}, text::{Line, Span}, widgets::{Block, Borders, List, ListItem}, - Frame, }; pub fn render(f: &mut Frame, area: Rect, app: &App) { diff --git a/src/client/tui/components/help_popup.rs b/src/client/tui/components/help_popup.rs index 085c1af..111f56c 100644 --- a/src/client/tui/components/help_popup.rs +++ b/src/client/tui/components/help_popup.rs @@ -1,8 +1,8 @@ use ratatui::{ + Frame, layout::Rect, style::{Color, Modifier, Style}, widgets::{Block, Borders, Clear, List, ListItem}, - Frame, }; pub fn render(f: &mut Frame, area: Rect) { diff --git a/src/client/tui/components/input_area.rs b/src/client/tui/components/input_area.rs index 1b3f643..427e062 100644 --- a/src/client/tui/components/input_area.rs +++ b/src/client/tui/components/input_area.rs @@ -1,9 +1,9 @@ use crate::client::tui::app::App; use ratatui::{ + Frame, layout::Rect, style::{Color, Style}, widgets::{Block, Borders, Paragraph}, - Frame, }; pub fn render(f: &mut Frame, area: Rect, app: &App) { diff --git a/src/client/tui/components/session_list.rs b/src/client/tui/components/session_list.rs index 99d197d..c626f2f 100644 --- a/src/client/tui/components/session_list.rs +++ b/src/client/tui/components/session_list.rs @@ -1,9 +1,9 @@ use crate::client::tui::app::App; use ratatui::{ + Frame, layout::Rect, style::{Color, Modifier, Style}, widgets::{Block, Borders, List, ListItem}, - Frame, }; pub fn render(f: &mut Frame, area: Rect, app: &App) { @@ -11,9 +11,7 @@ pub fn render(f: &mut Frame, area: Rect, app: &App) { .sessions .iter() .map(|session| { - let is_current = app - .current_session_id - .as_ref() == Some(&session.session_id); + let is_current = app.current_session_id.as_ref() == Some(&session.session_id); let archived = session.archived_at.is_some(); let mut content = if is_current { diff --git a/src/client/tui/components/title_bar.rs b/src/client/tui/components/title_bar.rs index 9a92d11..be06a98 100644 --- a/src/client/tui/components/title_bar.rs +++ b/src/client/tui/components/title_bar.rs @@ -1,15 +1,18 @@ use crate::client::tui::app::App; use ratatui::{ + Frame, layout::Rect, style::{Color, Modifier, Style}, widgets::{Block, Borders, Paragraph}, - Frame, }; pub fn render(f: &mut Frame, area: Rect, app: &App) { let (title, style) = if app.pending_quit { let msg = if let Some(session_id) = &app.current_session_id { - format!("PicoBot | Session: {} | Press Ctrl+C again to quit", session_id) + format!( + "PicoBot | Session: {} | Press Ctrl+C again to quit", + session_id + ) } else { "PicoBot | Press Ctrl+C again to quit".to_string() }; diff --git a/src/client/tui/event.rs b/src/client/tui/event.rs index 9c81110..4ac4ff6 100644 --- a/src/client/tui/event.rs +++ b/src/client/tui/event.rs @@ -1,6 +1,6 @@ use crate::client::tui::app::{App, MessageRole}; -use crate::protocol::serialize_inbound; use crate::protocol::WsInbound; +use crate::protocol::serialize_inbound; use crossterm::event::{KeyCode, KeyEvent}; use futures_util::SinkExt; @@ -48,7 +48,10 @@ pub async fn handle_key_event(app: &mut App, key: KeyEvent) { async fn handle_normal_input(app: &mut App, key: KeyEvent) { // Handle Ctrl+C for quit (double press to exit) - let is_ctrl_c = key.code == KeyCode::Char('c') && key.modifiers.contains(crossterm::event::KeyModifiers::CONTROL); + let is_ctrl_c = key.code == KeyCode::Char('c') + && key + .modifiers + .contains(crossterm::event::KeyModifiers::CONTROL); if is_ctrl_c { if app.handle_ctrl_c_for_quit() { return; @@ -63,9 +66,11 @@ async fn handle_normal_input(app: &mut App, key: KeyEvent) { } KeyCode::Char(c) => { app.input_insert_char(c); - + // Show command menu when input starts with / - if !app.show_command_menu && (app.input == "/" || (app.input.len() > 1 && app.input.starts_with('/'))) { + if !app.show_command_menu + && (app.input == "/" || (app.input.len() > 1 && app.input.starts_with('/'))) + { app.show_command_menu = true; app.selected_command_idx = 0; } else if app.show_command_menu && !app.input.starts_with('/') { @@ -74,7 +79,7 @@ async fn handle_normal_input(app: &mut App, key: KeyEvent) { } KeyCode::Backspace => { app.input_delete_char(); - + // Hide menu if input no longer starts with / if app.show_command_menu && !app.input.starts_with('/') { app.show_command_menu = false; @@ -121,7 +126,9 @@ async fn process_input(app: &mut App, input: String) { sender_id: None, }; if let Ok(text) = serialize_inbound(&inbound) { - let _ = sender.send(tokio_tungstenite::tungstenite::Message::Text(text.into())).await; + let _ = sender + .send(tokio_tungstenite::tungstenite::Message::Text(text.into())) + .await; } } } diff --git a/src/client/tui/ui.rs b/src/client/tui/ui.rs index 525b345..02b7fd7 100644 --- a/src/client/tui/ui.rs +++ b/src/client/tui/ui.rs @@ -1,8 +1,8 @@ use crate::client::tui::app::App; use crate::client::tui::components::*; use ratatui::{ - layout::{Constraint, Direction, Layout, Rect}, Frame, + layout::{Constraint, Direction, Layout, Rect}, }; pub fn render_ui(f: &mut Frame, app: &App) { diff --git a/src/config/mod.rs b/src/config/mod.rs index 5c2f8e9..2d613b7 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -273,12 +273,16 @@ impl Default for MemoryConfig { impl MemoryConfig { /// Resolve consolidation provider name, falling back to the main agent's provider. pub fn resolve_consolidation_provider(&self, default: &str) -> String { - self.consolidation_provider.clone().unwrap_or_else(|| default.to_string()) + self.consolidation_provider + .clone() + .unwrap_or_else(|| default.to_string()) } /// Resolve consolidation model name, falling back to the main agent's model. pub fn resolve_consolidation_model(&self, default: &str) -> String { - self.consolidation_model.clone().unwrap_or_else(|| default.to_string()) + self.consolidation_model + .clone() + .unwrap_or_else(|| default.to_string()) } } @@ -366,10 +370,18 @@ impl Default for BrowserConfig { } } -fn default_recall_limit() -> usize { 5 } -fn default_idle_consolidation_minutes() -> u64 { 10 } -fn default_timeline_retention_days() -> u64 { 90 } -fn default_max_failures_before_degrade() -> usize { 3 } +fn default_recall_limit() -> usize { + 5 +} +fn default_idle_consolidation_minutes() -> u64 { + 10 +} +fn default_timeline_retention_days() -> u64 { + 90 +} +fn default_max_failures_before_degrade() -> usize { + 3 +} #[derive(Debug, Clone)] pub struct LLMProviderConfig { @@ -469,7 +481,11 @@ pub enum ConfigError { impl std::fmt::Display for ConfigError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - ConfigError::ConfigNotFound(path) => write!(f, "Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json", path), + ConfigError::ConfigNotFound(path) => write!( + f, + "Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json", + path + ), ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name), ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name), ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name), diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index 40cea34..eb85920 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -1,12 +1,12 @@ -use std::sync::Arc; -use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage}; +use super::GatewayState; +use crate::protocol::WsOutbound; +use crate::protocol::serialize_outbound; use axum::extract::State; +use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade}; use axum::response::Response; use futures_util::{SinkExt, StreamExt}; +use std::sync::Arc; use tokio::sync::mpsc; -use crate::protocol::serialize_outbound; -use crate::protocol::WsOutbound; -use super::GatewayState; pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State>) -> Response { ws.on_upgrade(|socket| async move { @@ -25,9 +25,11 @@ async fn handle_socket(ws: WebSocket, state: Arc) { let (session_id, client) = cli_chat_channel.register_client(sender.clone()).await; // Send session established message - let _ = sender.send(WsOutbound::SessionEstablished { - session_id: session_id.clone(), - }).await; + let _ = sender + .send(WsOutbound::SessionEstablished { + session_id: session_id.clone(), + }) + .await; tracing::info!(session_id = %session_id, "CLI session established"); @@ -37,9 +39,10 @@ async fn handle_socket(ws: WebSocket, state: Arc) { tokio::spawn(async move { while let Some(msg) = receiver.recv().await { if let Ok(text) = serialize_outbound(&msg) - && ws_sender.send(WsMessage::Text(text.into())).await.is_err() { - break; - } + && ws_sender.send(WsMessage::Text(text.into())).await.is_err() + { + break; + } } }); diff --git a/src/lib.rs b/src/lib.rs index 6e651ba..5b8edd4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,17 +1,17 @@ -pub mod config; -pub mod providers; -pub mod bus; pub mod agent; -pub mod gateway; -pub mod session; -pub mod client; -pub mod protocol; +pub mod bus; pub mod channels; +pub mod client; +pub mod config; +pub mod gateway; pub mod logging; pub mod mcp; pub mod memory; pub mod observability; +pub mod protocol; +pub mod providers; pub mod scheduler; +pub mod session; pub mod skills; pub mod storage; pub mod tools; diff --git a/src/logging.rs b/src/logging.rs index 5aff5af..2cdf0cd 100644 --- a/src/logging.rs +++ b/src/logging.rs @@ -1,11 +1,7 @@ use std::path::PathBuf; use tracing_appender::rolling::{RollingFileAppender, Rotation}; use tracing_subscriber::{ - fmt, - layer::SubscriberExt, - util::SubscriberInitExt, - fmt::time::LocalTime, - EnvFilter, + EnvFilter, fmt, fmt::time::LocalTime, layer::SubscriberExt, util::SubscriberInitExt, }; /// Get the default log directory path: ~/.picobot/logs @@ -27,20 +23,20 @@ pub fn init_logging() { // Create log directory if it doesn't exist if !log_dir.exists() - && let Err(e) = std::fs::create_dir_all(&log_dir) { - eprintln!("Warning: Failed to create log directory {}: {}", log_dir.display(), e); - } + && let Err(e) = std::fs::create_dir_all(&log_dir) + { + eprintln!( + "Warning: Failed to create log directory {}: {}", + log_dir.display(), + e + ); + } // Create file appender with daily rotation - let file_appender = RollingFileAppender::new( - Rotation::DAILY, - &log_dir, - "picobot.log", - ); + let file_appender = RollingFileAppender::new(Rotation::DAILY, &log_dir, "picobot.log"); // Build subscriber with both console and file output - let env_filter = EnvFilter::try_from_default_env() - .unwrap_or_else(|_| EnvFilter::new("info")); + let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); let file_layer = fmt::layer() .with_writer(file_appender) @@ -66,8 +62,7 @@ pub fn init_logging() { /// Initialize logging without file output (console only) pub fn init_logging_console_only() { - let env_filter = EnvFilter::try_from_default_env() - .unwrap_or_else(|_| EnvFilter::new("info")); + let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); let console_layer = fmt::layer() .with_timer(LocalTime::rfc_3339()) diff --git a/src/main.rs b/src/main.rs index 6a68b2f..23772b5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -use clap::{Parser, CommandFactory}; +use clap::{CommandFactory, Parser}; #[derive(Parser)] #[command(name = "picobot")] diff --git a/src/mcp/mod.rs b/src/mcp/mod.rs index 1d2df63..d23cdff 100644 --- a/src/mcp/mod.rs +++ b/src/mcp/mod.rs @@ -92,24 +92,19 @@ fn extract_text(result: &rmcp::model::CallToolResult) -> String { parts.push(text.text.clone()); } RawContent::Image(image) => { - parts.push(format!( - "[image: {}]", - image.mime_type, - )); + parts.push(format!("[image: {}]", image.mime_type,)); } - RawContent::Resource(resource) => { - match &resource.resource { - rmcp::model::ResourceContents::TextResourceContents { text, .. } => { - parts.push(format!( - "[resource text: {}]", - text.chars().take(200).collect::(), - )); - } - rmcp::model::ResourceContents::BlobResourceContents { uri, .. } => { - parts.push(format!("[resource blob: {}]", uri)); - } + RawContent::Resource(resource) => match &resource.resource { + rmcp::model::ResourceContents::TextResourceContents { text, .. } => { + parts.push(format!( + "[resource text: {}]", + text.chars().take(200).collect::(), + )); } - } + rmcp::model::ResourceContents::BlobResourceContents { uri, .. } => { + parts.push(format!("[resource blob: {}]", uri)); + } + }, _ => { parts.push("[unsupported content]".to_string()); } @@ -225,8 +220,8 @@ async fn connect_server(config: &McpServerConfig) -> anyhow::Result anyhow::Result, duration: Duration) -> Self { + pub fn failure_with_duration( + output: String, + error_reason: Option, + duration: Duration, + ) -> Self { Self { output, success: false, diff --git a/src/providers/anthropic.rs b/src/providers/anthropic.rs index 501587c..1a8a34c 100644 --- a/src/providers/anthropic.rs +++ b/src/providers/anthropic.rs @@ -4,23 +4,24 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::time::Duration; -use crate::bus::message::ContentBlock; -use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall}; use super::traits::Usage; -use std::sync::Arc; +use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall}; +use crate::bus::message::ContentBlock; use crate::storage::Storage; +use std::sync::Arc; const LLM_REQUEST_TIMEOUT_SECS: u64 = 300; fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec { - blocks.iter().map(|b| match b { - ContentBlock::Text { text } => { - serde_json::json!({ "type": "text", "text": text }) - } - ContentBlock::ImageUrl { image_url } => { - convert_image_url_to_anthropic(&image_url.url) - } - }).collect() + blocks + .iter() + .map(|b| match b { + ContentBlock::Text { text } => { + serde_json::json!({ "type": "text", "text": text }) + } + ContentBlock::ImageUrl { image_url } => convert_image_url_to_anthropic(&image_url.url), + }) + .collect() } fn convert_image_url_to_anthropic(url: &str) -> serde_json::Value { @@ -197,8 +198,13 @@ impl LLMProvider for AnthropicProvider { }; let content = if let Some(ref tc_id) = m.tool_call_id { // Tool result: wrap as tool_result content block - let output = m.content.iter() - .filter_map(|b| match b { ContentBlock::Text { text } => Some(text.as_str()), _ => None }) + let output = m + .content + .iter() + .filter_map(|b| match b { + ContentBlock::Text { text } => Some(text.as_str()), + _ => None, + }) .collect::>() .join(""); vec![serde_json::json!({ @@ -244,19 +250,18 @@ impl LLMProvider for AnthropicProvider { let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default(); tracing::debug!(req_body = %req_body_str, "LLM request"); - let resp = req_builder.json(&body).send().await - .inspect_err(|e| { - let is_timeout = e.is_timeout(); - tracing::error!( - provider = %self.name, - model = %self.model_id, - url = %url, - timeout = is_timeout, - error = %e, - elapsed_ms = %start.elapsed().as_millis(), - "LLM API request failed" - ); - })?; + let resp = req_builder.json(&body).send().await.inspect_err(|e| { + let is_timeout = e.is_timeout(); + tracing::error!( + provider = %self.name, + model = %self.model_id, + url = %url, + timeout = is_timeout, + error = %e, + elapsed_ms = %start.elapsed().as_millis(), + "LLM API request failed" + ); + })?; let status = resp.status(); let body_text = resp.text().await?; @@ -281,32 +286,38 @@ impl LLMProvider for AnthropicProvider { "LLM API returned error" ); if let Some(ref storage) = self.storage { - let _ = storage.append_llm_call( - &self.name, &self.model_id, &req_body_str, - Some(&body_text), Some(&error_msg), - start.elapsed().as_millis() as u64, - ).await; + let _ = storage + .append_llm_call( + &self.name, + &self.model_id, + &req_body_str, + Some(&body_text), + Some(&error_msg), + start.elapsed().as_millis() as u64, + ) + .await; } return Err(format!("API error ({}): {}", status.as_u16(), error_msg).into()); } - let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text) - .map_err(|e| { - let err_msg = format!("decode error: {} | body: {}", e, &body_text); - if let Some(ref storage) = self.storage { - let name = self.name.clone(); - let model = self.model_id.clone(); - let req = req_body_str.clone(); - let resp_body = body_text.clone(); - let dur = start.elapsed().as_millis() as u64; - let err = err_msg.clone(); - let s = storage.clone(); - tokio::spawn(async move { - let _ = s.append_llm_call(&name, &model, &req, Some(&resp_body), Some(&err), dur).await; - }); - } - err_msg - })?; + let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text).map_err(|e| { + let err_msg = format!("decode error: {} | body: {}", e, &body_text); + if let Some(ref storage) = self.storage { + let name = self.name.clone(); + let model = self.model_id.clone(); + let req = req_body_str.clone(); + let resp_body = body_text.clone(); + let dur = start.elapsed().as_millis() as u64; + let err = err_msg.clone(); + let s = storage.clone(); + tokio::spawn(async move { + let _ = s + .append_llm_call(&name, &model, &req, Some(&resp_body), Some(&err), dur) + .await; + }); + } + err_msg + })?; let mut content = String::new(); let mut reasoning = None; @@ -343,21 +354,35 @@ impl LLMProvider for AnthropicProvider { reasoning_content: reasoning, tool_calls, usage: Usage { - prompt_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens).unwrap_or(0), - completion_tokens: anthropic_resp.usage.as_ref().map(|u| u.output_tokens).unwrap_or(0), - total_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens + u.output_tokens).unwrap_or(0), + prompt_tokens: anthropic_resp + .usage + .as_ref() + .map(|u| u.input_tokens) + .unwrap_or(0), + completion_tokens: anthropic_resp + .usage + .as_ref() + .map(|u| u.output_tokens) + .unwrap_or(0), + total_tokens: anthropic_resp + .usage + .as_ref() + .map(|u| u.input_tokens + u.output_tokens) + .unwrap_or(0), }, }; if let Some(ref storage) = self.storage { - let _ = storage.append_llm_call( - &self.name, - &self.model_id, - &req_body_str, - Some(&body_text), - None, - start.elapsed().as_millis() as u64, - ).await; + let _ = storage + .append_llm_call( + &self.name, + &self.model_id, + &req_body_str, + Some(&body_text), + None, + start.elapsed().as_millis() as u64, + ) + .await; } Ok(response) diff --git a/src/providers/mod.rs b/src/providers/mod.rs index eedab44..57f70c7 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -1,12 +1,15 @@ -pub mod traits; -pub mod openai; pub mod anthropic; +pub mod openai; +pub mod traits; -pub use self::openai::OpenAIProvider; pub use self::anthropic::AnthropicProvider; +pub use self::openai::OpenAIProvider; use crate::config::LLMProviderConfig; -pub use traits::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall, ToolFunction, Usage}; +pub use traits::{ + ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall, + ToolFunction, Usage, +}; pub fn create_provider(config: LLMProviderConfig) -> Result, ProviderError> { match config.provider_type.as_str() { diff --git a/src/providers/openai.rs b/src/providers/openai.rs index d5e60bf..b89255a 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -1,29 +1,35 @@ use async_trait::async_trait; use reqwest::Client; use serde::Deserialize; -use serde_json::{json, Value}; +use serde_json::{Value, json}; use std::collections::HashMap; use std::time::Duration; -use crate::bus::message::ContentBlock; -use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall}; use super::traits::Usage; -use std::sync::Arc; +use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall}; +use crate::bus::message::ContentBlock; use crate::storage::Storage; +use std::sync::Arc; const LLM_REQUEST_TIMEOUT_SECS: u64 = 300; fn convert_content_blocks(blocks: &[ContentBlock]) -> Value { if blocks.len() == 1 - && let ContentBlock::Text { text } = &blocks[0] { - return Value::String(text.clone()); - } - Value::Array(blocks.iter().map(|b| match b { - ContentBlock::Text { text } => json!({ "type": "text", "text": text }), - ContentBlock::ImageUrl { image_url } => { - json!({ "type": "image_url", "image_url": { "url": image_url.url } }) - } - }).collect()) + && let ContentBlock::Text { text } = &blocks[0] + { + return Value::String(text.clone()); + } + Value::Array( + blocks + .iter() + .map(|b| match b { + ContentBlock::Text { text } => json!({ "type": "text", "text": text }), + ContentBlock::ImageUrl { image_url } => { + json!({ "type": "image_url", "image_url": { "url": image_url.url } }) + } + }) + .collect(), + ) } pub struct OpenAIProvider { @@ -201,10 +207,14 @@ impl LLMProvider for OpenAIProvider { if let Some(content) = msg.get("content").and_then(|c| c.as_array()) { for (j, item) in content.iter().enumerate() { if item.get("type").and_then(|t| t.as_str()) == Some("image_url") - && let Some(url_str) = item.get("image_url").and_then(|u| u.get("url")).and_then(|v| v.as_str()) { - let prefix: String = url_str.chars().take(20).collect(); - tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)"); - } + && let Some(url_str) = item + .get("image_url") + .and_then(|u| u.get("url")) + .and_then(|v| v.as_str()) + { + let prefix: String = url_str.chars().take(20).collect(); + tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)"); + } } } } @@ -224,19 +234,18 @@ impl LLMProvider for OpenAIProvider { let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default(); tracing::debug!(req_body = %req_body_str, "LLM request"); - let resp = req_builder.json(&body).send().await - .inspect_err(|e| { - let is_timeout = e.is_timeout(); - tracing::error!( - provider = %self.name, - model = %self.model_id, - url = %url, - timeout = is_timeout, - error = %e, - elapsed_ms = %start.elapsed().as_millis(), - "LLM API request failed" - ); - })?; + let resp = req_builder.json(&body).send().await.inspect_err(|e| { + let is_timeout = e.is_timeout(); + tracing::error!( + provider = %self.name, + model = %self.model_id, + url = %url, + timeout = is_timeout, + error = %e, + elapsed_ms = %start.elapsed().as_millis(), + "LLM API request failed" + ); + })?; let status = resp.status(); let text = resp.text().await?; @@ -253,37 +262,48 @@ impl LLMProvider for OpenAIProvider { "LLM API returned error" ); if let Some(ref storage) = self.storage - && let Err(e) = storage.append_llm_call( - &self.name, &self.model_id, &req_body_str, - Some(&text), Some(&error), - start.elapsed().as_millis() as u64, - ).await { - tracing::warn!("failed to persist LLM call: {}", e); - } + && let Err(e) = storage + .append_llm_call( + &self.name, + &self.model_id, + &req_body_str, + Some(&text), + Some(&error), + start.elapsed().as_millis() as u64, + ) + .await + { + tracing::warn!("failed to persist LLM call: {}", e); + } return Err(error.into()); } - let openai_resp: OpenAIResponse = serde_json::from_str(&text) - .map_err(|e| { - let err_msg = format!("decode error: {} | body: {}", e, &text); - if let Some(ref storage) = self.storage { - let name = self.name.clone(); - let model = self.model_id.clone(); - let req = req_body_str.clone(); - let resp = text.clone(); - let dur = start.elapsed().as_millis() as u64; - let err = err_msg.clone(); - let s = storage.clone(); - tokio::spawn(async move { - if let Err(e) = s.append_llm_call(&name, &model, &req, Some(&resp), Some(&err), dur).await { - tracing::warn!("failed to persist LLM call (decode error): {}", e); - } - }); - } - err_msg - })?; + let openai_resp: OpenAIResponse = serde_json::from_str(&text).map_err(|e| { + let err_msg = format!("decode error: {} | body: {}", e, &text); + if let Some(ref storage) = self.storage { + let name = self.name.clone(); + let model = self.model_id.clone(); + let req = req_body_str.clone(); + let resp = text.clone(); + let dur = start.elapsed().as_millis() as u64; + let err = err_msg.clone(); + let s = storage.clone(); + tokio::spawn(async move { + if let Err(e) = s + .append_llm_call(&name, &model, &req, Some(&resp), Some(&err), dur) + .await + { + tracing::warn!("failed to persist LLM call (decode error): {}", e); + } + }); + } + err_msg + })?; - let first_choice = openai_resp.choices.into_iter().next() + let first_choice = openai_resp + .choices + .into_iter() + .next() .ok_or("no choices in response")?; let content = first_choice @@ -300,7 +320,8 @@ impl LLMProvider for OpenAIProvider { .map(|tc| ToolCall { id: tc.id.clone(), name: tc.function.name.clone(), - arguments: serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null), + arguments: serde_json::from_str(&tc.function.arguments) + .unwrap_or(serde_json::Value::Null), }) .collect(); @@ -318,13 +339,19 @@ impl LLMProvider for OpenAIProvider { }; if let Some(ref storage) = self.storage - && let Err(e) = storage.append_llm_call( - &self.name, &self.model_id, &req_body_str, - Some(&text), None, - start.elapsed().as_millis() as u64, - ).await { - tracing::warn!("failed to persist LLM call: {}", e); - } + && let Err(e) = storage + .append_llm_call( + &self.name, + &self.model_id, + &req_body_str, + Some(&text), + None, + start.elapsed().as_millis() as u64, + ) + .await + { + tracing::warn!("failed to persist LLM call: {}", e); + } Ok(response) } @@ -386,6 +413,9 @@ mod tests { assert_eq!(tool_calls[0]["id"], "call_1"); assert_eq!(tool_calls[0]["type"], "function"); assert_eq!(tool_calls[0]["function"]["name"], "calculator"); - assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}"); + assert_eq!( + tool_calls[0]["function"]["arguments"], + "{\"expression\":\"1+1\"}" + ); } } diff --git a/src/providers/traits.rs b/src/providers/traits.rs index b2f8a1a..5e74e8a 100644 --- a/src/providers/traits.rs +++ b/src/providers/traits.rs @@ -1,6 +1,6 @@ +use crate::bus::message::ContentBlock; use async_trait::async_trait; use serde::{Deserialize, Serialize}; -use crate::bus::message::ContentBlock; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Message { @@ -61,7 +61,11 @@ impl Message { } } - pub fn tool(tool_call_id: impl Into, tool_name: impl Into, content: impl Into) -> Self { + pub fn tool( + tool_call_id: impl Into, + tool_name: impl Into, + content: impl Into, + ) -> Self { Self { role: "tool".to_string(), content: vec![ContentBlock::text(content)], diff --git a/src/scheduler/mod.rs b/src/scheduler/mod.rs index 8cbf82f..cf0a70c 100644 --- a/src/scheduler/mod.rs +++ b/src/scheduler/mod.rs @@ -5,11 +5,11 @@ use std::time::Instant; use tokio::time; use crate::config::SchedulerConfig; -use crate::session::session::HandleResult; use crate::session::SessionManager; +use crate::session::session::HandleResult; +use crate::storage::JobRun; use crate::storage::ScheduledJob; use crate::storage::Storage; -use crate::storage::JobRun; pub use types::Schedule; @@ -89,7 +89,11 @@ impl Scheduler { let now = now_ms(); - let due = match self.storage.due_scheduled_jobs(now, self.config.max_concurrent).await { + let due = match self + .storage + .due_scheduled_jobs(now, self.config.max_concurrent) + .await + { Ok(jobs) => jobs, Err(e) => { tracing::error!("scheduler: failed to query due jobs: {}", e); @@ -107,7 +111,11 @@ impl Scheduler { let start = Instant::now(); let started_at = now_ms(); - if let Err(e) = self.storage.touch_scheduled_job_last_run(&job.id, started_at).await { + if let Err(e) = self + .storage + .touch_scheduled_job_last_run(&job.id, started_at) + .await + { tracing::error!(job_id = %job.id, "scheduler: failed to touch last_run_at: {}", e); continue; } @@ -135,7 +143,10 @@ impl Scheduler { match result { Ok(HandleResult::AgentResponse(output)) => { let output_truncated = if output.len() > 8000 { - format!("{}...[truncated]", &output[..output.ceil_char_boundary(8000)]) + format!( + "{}...[truncated]", + &output[..output.ceil_char_boundary(8000)] + ) } else { output.clone() }; @@ -155,7 +166,11 @@ impl Scheduler { tracing::error!(job_id = %job.id, "scheduler: failed to record run: {}", e); } - if let Err(e) = self.storage.set_scheduled_job_last_status(&job.id, "ok", None).await { + if let Err(e) = self + .storage + .set_scheduled_job_last_status(&job.id, "ok", None) + .await + { tracing::error!(job_id = %job.id, "scheduler: failed to set last_status: {}", e); } @@ -199,9 +214,11 @@ impl Scheduler { tracing::error!(job_id = %job.id, "scheduler: failed to record error run: {}", e2); } - if let Err(e2) = self.storage.set_scheduled_job_last_status( - &job.id, "error", Some(&error_str), - ).await { + if let Err(e2) = self + .storage + .set_scheduled_job_last_status(&job.id, "error", Some(&error_str)) + .await + { tracing::error!(job_id = %job.id, "scheduler: failed to set error status: {}", e2); } @@ -231,17 +248,23 @@ impl Scheduler { self.storage.remove_scheduled_job(&job.id).await?; tracing::info!(job_id = %job.id, "scheduler: one-shot job deleted after run"); } else { - self.storage.set_scheduled_job_enabled(&job.id, false).await?; + self.storage + .set_scheduled_job_enabled(&job.id, false) + .await?; tracing::info!(job_id = %job.id, "scheduler: one-shot job disabled after run"); } } Schedule::Every { .. } | Schedule::Cron { .. } => { if let Some(next) = next_run_for_schedule(&job.schedule, now) { - self.storage.set_scheduled_job_next_run(&job.id, next).await?; + self.storage + .set_scheduled_job_next_run(&job.id, next) + .await?; tracing::info!(job_id = %job.id, next_run_at = %next, "scheduler: job rescheduled"); } else { tracing::error!(job_id = %job.id, "scheduler: could not compute next run -- disabling job"); - self.storage.set_scheduled_job_enabled(&job.id, false).await?; + self.storage + .set_scheduled_job_enabled(&job.id, false) + .await?; } } } diff --git a/src/session/commands.rs b/src/session/commands.rs index 324f845..bc16e4a 100644 --- a/src/session/commands.rs +++ b/src/session/commands.rs @@ -22,32 +22,20 @@ pub enum SessionCommand { dialog_id: String, }, /// Get the current dialog for a chat - GetCurrentDialog { - channel: String, - chat_id: String, - }, + GetCurrentDialog { channel: String, chat_id: String }, /// Rename a dialog RenameDialog { session_id: UnifiedSessionId, title: String, }, /// Archive a dialog - ArchiveDialog { - session_id: UnifiedSessionId, - }, + ArchiveDialog { session_id: UnifiedSessionId }, /// Delete a dialog - DeleteDialog { - session_id: UnifiedSessionId, - }, + DeleteDialog { session_id: UnifiedSessionId }, /// Clear dialog history - ClearHistory { - session_id: UnifiedSessionId, - }, + ClearHistory { session_id: UnifiedSessionId }, /// Get list of available slash commands - GetSlashCommands { - channel: String, - chat_id: String, - }, + GetSlashCommands { channel: String, chat_id: String }, /// Execute a slash command ExecuteSlashCommand { command: String, @@ -60,7 +48,11 @@ pub enum SessionCommand { impl SessionCommand { /// Create a CreateDialog command - pub fn create_dialog(channel: impl Into, chat_id: impl Into, title: Option) -> Self { + pub fn create_dialog( + channel: impl Into, + chat_id: impl Into, + title: Option, + ) -> Self { Self::CreateDialog { channel: channel.into(), chat_id: chat_id.into(), @@ -69,7 +61,11 @@ impl SessionCommand { } /// Create a ListDialogs command - pub fn list_dialogs(channel: impl Into, chat_id: impl Into, include_archived: bool) -> Self { + pub fn list_dialogs( + channel: impl Into, + chat_id: impl Into, + include_archived: bool, + ) -> Self { Self::ListDialogs { channel: channel.into(), chat_id: chat_id.into(), diff --git a/src/session/events.rs b/src/session/events.rs index afb3e2a..8fdf6ce 100644 --- a/src/session/events.rs +++ b/src/session/events.rs @@ -1,5 +1,5 @@ -use super::session_id::UnifiedSessionId; use super::session::SlashCommand; +use super::session_id::UnifiedSessionId; /// Dialog information returned by SessionManager #[derive(Debug, Clone)] @@ -30,30 +30,20 @@ pub enum SessionEvent { session_id: Option, }, /// Dialog switched successfully - DialogSwitched { - session_id: UnifiedSessionId, - }, + DialogSwitched { session_id: UnifiedSessionId }, /// Dialog renamed DialogRenamed { session_id: UnifiedSessionId, title: String, }, /// Dialog archived - DialogArchived { - session_id: UnifiedSessionId, - }, + DialogArchived { session_id: UnifiedSessionId }, /// Dialog deleted - DialogDeleted { - session_id: UnifiedSessionId, - }, + DialogDeleted { session_id: UnifiedSessionId }, /// Dialog history cleared - HistoryCleared { - session_id: UnifiedSessionId, - }, + HistoryCleared { session_id: UnifiedSessionId }, /// List of available slash commands - SlashCommandsList { - commands: Vec, - }, + SlashCommandsList { commands: Vec }, /// Slash command executed successfully SlashCommandExecuted { new_session_id: Option, @@ -70,8 +60,5 @@ pub enum SessionEvent { message_count: usize, }, /// Error occurred - Error { - code: String, - message: String, - }, + Error { code: String, message: String }, } diff --git a/src/session/mod.rs b/src/session/mod.rs index cdb268a..bcc099b 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -1,11 +1,11 @@ -pub mod error; pub mod commands; +pub mod error; pub mod events; pub mod session; pub mod session_id; -pub use error::SessionError; pub use commands::SessionCommand; -pub use events::{SessionEvent, DialogInfo}; -pub use session::{Session, SessionManager, SlashCommand, SLASH_COMMANDS}; +pub use error::SessionError; +pub use events::{DialogInfo, SessionEvent}; +pub use session::{SLASH_COMMANDS, Session, SessionManager, SlashCommand}; pub use session_id::UnifiedSessionId; diff --git a/src/session/session_id.rs b/src/session/session_id.rs index f88d22d..0db2c9c 100644 --- a/src/session/session_id.rs +++ b/src/session/session_id.rs @@ -8,7 +8,6 @@ /// /// For simple cases where only one dialog exists per chat: /// - `dialog_id` defaults to `"default"` - use serde::{Deserialize, Serialize}; pub const DEFAULT_DIALOG_ID: &str = "default"; @@ -22,7 +21,11 @@ pub struct UnifiedSessionId { impl UnifiedSessionId { /// Create a new UnifiedSessionId - pub fn new(channel: impl Into, chat_id: impl Into, dialog_id: impl Into) -> Self { + pub fn new( + channel: impl Into, + chat_id: impl Into, + dialog_id: impl Into, + ) -> Self { Self { channel: channel.into(), chat_id: chat_id.into(), diff --git a/src/skills/builtin.rs b/src/skills/builtin.rs index 05feedd..701a612 100644 --- a/src/skills/builtin.rs +++ b/src/skills/builtin.rs @@ -1,6 +1,6 @@ use std::path::Path; -use super::embedded::{EmbeddedSkill, EMBEDDED_SKILLS}; +use super::embedded::{EMBEDDED_SKILLS, EmbeddedSkill}; pub fn install_builtin_skills(target_dir: &Path) { for skill in EMBEDDED_SKILLS { @@ -22,8 +22,7 @@ pub fn install_builtin_skills(target_dir: &Path) { } fn install_one(skill: &EmbeddedSkill, target_dir: &Path) -> Result<(), String> { - let decompressed = zstd::decode_all(skill.data) - .map_err(|e| format!("zstd decode: {}", e))?; + let decompressed = zstd::decode_all(skill.data).map_err(|e| format!("zstd decode: {}", e))?; let mut archive = tar::Archive::new(decompressed.as_slice()); archive diff --git a/src/skills/mod.rs b/src/skills/mod.rs index 7affd8e..3d59cbf 100644 --- a/src/skills/mod.rs +++ b/src/skills/mod.rs @@ -120,7 +120,11 @@ impl SkillsLoader { let count = loaded.len(); let mut replaced = 0usize; for skill in loaded { - if let Some(existing) = state.loaded_skills.iter_mut().find(|s| s.name == skill.name) { + if let Some(existing) = state + .loaded_skills + .iter_mut() + .find(|s| s.name == skill.name) + { *existing = skill; replaced += 1; } else { @@ -138,33 +142,42 @@ impl SkillsLoader { // Load from workspace skills dir (highest priority) — replace same-name skills if let Some(ref ws_dir) = self.workspace_skills_dir - && ws_dir.exists() { - let loaded = self.load_skills_from_dir(ws_dir); - let count = loaded.len(); - let mut replaced = 0usize; - for skill in loaded { - if let Some(existing) = state.loaded_skills.iter_mut().find(|s| s.name == skill.name) { - *existing = skill; - replaced += 1; - } else { - state.loaded_skills.push(skill); - } + && ws_dir.exists() + { + let loaded = self.load_skills_from_dir(ws_dir); + let count = loaded.len(); + let mut replaced = 0usize; + for skill in loaded { + if let Some(existing) = state + .loaded_skills + .iter_mut() + .find(|s| s.name == skill.name) + { + *existing = skill; + replaced += 1; + } else { + state.loaded_skills.push(skill); } - tracing::debug!( - dir = %ws_dir.display(), - count = count, - replaced = replaced, - "Loaded skills from workspace directory" - ); - state.last_workspace_mtime = Self::get_dir_mtime(ws_dir); } + tracing::debug!( + dir = %ws_dir.display(), + count = count, + replaced = replaced, + "Loaded skills from workspace directory" + ); + state.last_workspace_mtime = Self::get_dir_mtime(ws_dir); + } state.last_load_time = SystemTime::now(); if state.loaded_skills.is_empty() { tracing::debug!("No skills found in any skills directory"); } else { - tracing::info!(count = state.loaded_skills.len(), "Loaded {} skills total", state.loaded_skills.len()); + tracing::info!( + count = state.loaded_skills.len(), + "Loaded {} skills total", + state.loaded_skills.len() + ); } } @@ -215,18 +228,20 @@ impl SkillsLoader { let mut max_mtime = None; if let Ok(metadata) = std::fs::metadata(dir) - && let Ok(mtime) = metadata.modified() { - max_mtime = Some(mtime); - } + && let Ok(mtime) = metadata.modified() + { + max_mtime = Some(mtime); + } if let Ok(entries) = std::fs::read_dir(dir) { for entry in entries.flatten() { let path = entry.path(); if let Ok(metadata) = std::fs::metadata(&path) && let Ok(mtime) = metadata.modified() - && max_mtime.is_none_or(|current| mtime > current) { - max_mtime = Some(mtime); - } + && max_mtime.is_none_or(|current| mtime > current) + { + max_mtime = Some(mtime); + } } } @@ -244,7 +259,12 @@ impl SkillsLoader { pub fn get_always_skills(&self) -> Vec { self.reload_if_changed(); let state = self.state.lock().unwrap(); - state.loaded_skills.iter().filter(|s| s.always).cloned().collect() + state + .loaded_skills + .iter() + .filter(|s| s.always) + .cloned() + .collect() } /// Get a specific skill by name (checks for changes first) @@ -258,7 +278,8 @@ impl SkillsLoader { pub fn list_skills(&self) -> Vec<(String, String)> { self.reload_if_changed(); let state = self.state.lock().unwrap(); - state.loaded_skills + state + .loaded_skills .iter() .map(|s| (s.name.clone(), s.description.clone())) .collect() @@ -279,15 +300,21 @@ impl SkillsLoader { prompt.push_str("### 目录说明\n\n"); prompt.push_str("- `~/.agents/skills/` — 外部共享 skill 目录(第三方、系统级 skill)\n"); prompt.push_str("- `~/.picobot/skills/` — 安装 skill 的默认目录\n"); - prompt.push_str("- `{workspace}/skills/` — 工作目录下的 skill,picobot 自行创建的 skill 存放于此\n\n"); - prompt.push_str("安装或创建 skill 时请按上述目录规范存放,创建skill时不要和已有skill同名。\n\n"); + prompt.push_str( + "- `{workspace}/skills/` — 工作目录下的 skill,picobot 自行创建的 skill 存放于此\n\n", + ); + prompt.push_str( + "安装或创建 skill 时请按上述目录规范存放,创建skill时不要和已有skill同名。\n\n", + ); // Always skills summary let always_skills: Vec<_> = state.loaded_skills.iter().filter(|s| s.always).collect(); if !always_skills.is_empty() { prompt.push_str("### 常用技能\n\n"); for skill in &always_skills { - let path_str = skill.path.as_ref() + let path_str = skill + .path + .as_ref() .map(|p| p.to_string_lossy().to_string()) .unwrap_or_else(|| "—".to_string()); prompt.push_str(&format!( @@ -300,8 +327,12 @@ impl SkillsLoader { // Usage instructions prompt.push_str("### 使用方法\n\n"); - prompt.push_str("- 使用 `get_skill` 工具 action=\"list\" 列出所有可用 skill 及其名称、简介、路径\n"); - prompt.push_str("- 使用 `get_skill` 工具 action=\"get\" 并提供 `skill_name` 获取指定 skill 完整内容\n"); + prompt.push_str( + "- 使用 `get_skill` 工具 action=\"list\" 列出所有可用 skill 及其名称、简介、路径\n", + ); + prompt.push_str( + "- 使用 `get_skill` 工具 action=\"get\" 并提供 `skill_name` 获取指定 skill 完整内容\n", + ); prompt.push_str("- 当用户询问已安装的 skill 相关信息时,需重新调用 get_skill 工具查询最新内容,避免 skill 已变更导致信息过时\n"); // Always skills full content @@ -338,25 +369,23 @@ impl SkillsLoader { } match std::fs::read_to_string(&skill_file) { - Ok(content) => { - match self.parse_skill(&path, &content) { - Some(skill) => { - tracing::debug!( - skill = %skill.name, - path = %skill_file.display(), - always = skill.always, - "Loaded skill" - ); - skills.push(skill); - } - None => { - tracing::warn!( - path = %skill_file.display(), - "Failed to parse skill" - ); - } + Ok(content) => match self.parse_skill(&path, &content) { + Some(skill) => { + tracing::debug!( + skill = %skill.name, + path = %skill_file.display(), + always = skill.always, + "Loaded skill" + ); + skills.push(skill); } - } + None => { + tracing::warn!( + path = %skill_file.display(), + "Failed to parse skill" + ); + } + }, Err(e) => { tracing::warn!( path = %skill_file.display(), @@ -447,7 +476,6 @@ impl Default for SkillsLoader { } } - /// Extract first non-empty, non-heading line as description fn extract_description(content: &str) -> String { content diff --git a/src/storage/memory.rs b/src/storage/memory.rs index 12d35d4..4010c00 100644 --- a/src/storage/memory.rs +++ b/src/storage/memory.rs @@ -241,12 +241,11 @@ impl super::Storage { let cutoff = chrono::Utc::now() - chrono::Duration::days(retention_days as i64); let cutoff_str = cutoff.to_rfc3339(); - let result = sqlx::query( - "DELETE FROM memories WHERE category = 'timeline' AND created_at < ?", - ) - .bind(&cutoff_str) - .execute(self.pool()) - .await?; + let result = + sqlx::query("DELETE FROM memories WHERE category = 'timeline' AND created_at < ?") + .bind(&cutoff_str) + .execute(self.pool()) + .await?; Ok(result.rows_affected()) } @@ -276,9 +275,7 @@ impl super::Storage { } } -fn parse_memory_rows( - rows: &[sqlx::sqlite::SqliteRow], -) -> Result, StorageError> { +fn parse_memory_rows(rows: &[sqlx::sqlite::SqliteRow]) -> Result, StorageError> { rows.iter() .map(|row| { Ok(MemoryEntry { diff --git a/src/storage/scheduler.rs b/src/storage/scheduler.rs index 999b294..3a2eff6 100644 --- a/src/storage/scheduler.rs +++ b/src/storage/scheduler.rs @@ -165,7 +165,11 @@ impl crate::storage::Storage { } /// Update next_run_at and last_run_at for a job. - pub async fn set_scheduled_job_next_run(&self, id: &str, next_run_at: i64) -> anyhow::Result<()> { + pub async fn set_scheduled_job_next_run( + &self, + id: &str, + next_run_at: i64, + ) -> anyhow::Result<()> { let now = now_ms(); sqlx::query( "UPDATE scheduled_jobs SET next_run_at = ?, last_run_at = ?, updated_at = ? WHERE id = ?", @@ -331,7 +335,9 @@ mod tests { async fn setup_storage() -> Storage { let pool = SqlitePool::connect("sqlite::memory:").await.unwrap(); let storage = Storage { pool }; - Storage::init_scheduler_schema(storage.pool()).await.unwrap(); + Storage::init_scheduler_schema(storage.pool()) + .await + .unwrap(); storage } @@ -450,7 +456,10 @@ mod tests { updated_at: t, }; storage.add_scheduled_job(&job).await.unwrap(); - storage.set_scheduled_job_enabled("job-toggle", false).await.unwrap(); + storage + .set_scheduled_job_enabled("job-toggle", false) + .await + .unwrap(); let got = storage.get_scheduled_job("job-toggle").await.unwrap(); assert!(!got.enabled); } @@ -461,31 +470,55 @@ mod tests { let t = now(); let jobs = vec![ ScheduledJob { - id: "due".into(), name: "due".into(), - schedule: Schedule::At { at: t }, prompt: "1".into(), - channel: "cli_chat".into(), chat_id: "c".into(), - model: None, enabled: true, delete_after_run: false, - next_run_at: t - 1000, last_run_at: None, - last_status: None, last_error: None, - created_at: t, updated_at: t, + id: "due".into(), + name: "due".into(), + schedule: Schedule::At { at: t }, + prompt: "1".into(), + channel: "cli_chat".into(), + chat_id: "c".into(), + model: None, + enabled: true, + delete_after_run: false, + next_run_at: t - 1000, + last_run_at: None, + last_status: None, + last_error: None, + created_at: t, + updated_at: t, }, ScheduledJob { - id: "future".into(), name: "future".into(), - schedule: Schedule::At { at: t + 99999999 }, prompt: "2".into(), - channel: "cli_chat".into(), chat_id: "c".into(), - model: None, enabled: true, delete_after_run: false, - next_run_at: t + 99999999, last_run_at: None, - last_status: None, last_error: None, - created_at: t, updated_at: t, + id: "future".into(), + name: "future".into(), + schedule: Schedule::At { at: t + 99999999 }, + prompt: "2".into(), + channel: "cli_chat".into(), + chat_id: "c".into(), + model: None, + enabled: true, + delete_after_run: false, + next_run_at: t + 99999999, + last_run_at: None, + last_status: None, + last_error: None, + created_at: t, + updated_at: t, }, ScheduledJob { - id: "disabled-due".into(), name: "disabled due".into(), - schedule: Schedule::At { at: t }, prompt: "3".into(), - channel: "cli_chat".into(), chat_id: "c".into(), - model: None, enabled: false, delete_after_run: false, - next_run_at: t - 1000, last_run_at: None, - last_status: None, last_error: None, - created_at: t, updated_at: t, + id: "disabled-due".into(), + name: "disabled due".into(), + schedule: Schedule::At { at: t }, + prompt: "3".into(), + channel: "cli_chat".into(), + chat_id: "c".into(), + model: None, + enabled: false, + delete_after_run: false, + next_run_at: t - 1000, + last_run_at: None, + last_status: None, + last_error: None, + created_at: t, + updated_at: t, }, ]; for j in &jobs { @@ -501,24 +534,39 @@ mod tests { let storage = setup_storage().await; let t = now(); let job = ScheduledJob { - id: "job-run".into(), name: "run test".into(), + id: "job-run".into(), + name: "run test".into(), schedule: Schedule::Every { every_ms: 1000 }, - prompt: "hi".into(), channel: "cli_chat".into(), chat_id: "c".into(), - model: None, enabled: true, delete_after_run: false, - next_run_at: t, last_run_at: None, - last_status: None, last_error: None, - created_at: t, updated_at: t, + prompt: "hi".into(), + channel: "cli_chat".into(), + chat_id: "c".into(), + model: None, + enabled: true, + delete_after_run: false, + next_run_at: t, + last_run_at: None, + last_status: None, + last_error: None, + created_at: t, + updated_at: t, }; storage.add_scheduled_job(&job).await.unwrap(); let run = super::JobRun { - id: 0, job_id: "job-run".into(), - started_at: t, finished_at: t + 500, - status: "ok".into(), output: Some("hello".into()), - error: None, duration_ms: 500, + id: 0, + job_id: "job-run".into(), + started_at: t, + finished_at: t + 500, + status: "ok".into(), + output: Some("hello".into()), + error: None, + duration_ms: 500, }; storage.record_scheduled_job_run(&run).await.unwrap(); - let runs = storage.list_scheduled_job_runs("job-run", 10).await.unwrap(); + let runs = storage + .list_scheduled_job_runs("job-run", 10) + .await + .unwrap(); assert_eq!(runs.len(), 1); assert_eq!(runs[0].status, "ok"); assert_eq!(runs[0].output.as_deref(), Some("hello")); @@ -529,22 +577,34 @@ mod tests { let storage = setup_storage().await; let t = now(); let job = ScheduledJob { - id: "job-update".into(), name: "old name".into(), + id: "job-update".into(), + name: "old name".into(), schedule: Schedule::Every { every_ms: 1000 }, - prompt: "old prompt".into(), channel: "feishu".into(), - chat_id: "oc_1".into(), model: None, - enabled: true, delete_after_run: false, - next_run_at: t, last_run_at: None, - last_status: None, last_error: None, - created_at: t, updated_at: t, + prompt: "old prompt".into(), + channel: "feishu".into(), + chat_id: "oc_1".into(), + model: None, + enabled: true, + delete_after_run: false, + next_run_at: t, + last_run_at: None, + last_status: None, + last_error: None, + created_at: t, + updated_at: t, }; storage.add_scheduled_job(&job).await.unwrap(); - storage.update_scheduled_job( - "job-update", - Some("new prompt".into()), - Some(Schedule::Every { every_ms: 60000 }), - None, None, None, - ).await.unwrap(); + storage + .update_scheduled_job( + "job-update", + Some("new prompt".into()), + Some(Schedule::Every { every_ms: 60000 }), + None, + None, + None, + ) + .await + .unwrap(); let got = storage.get_scheduled_job("job-update").await.unwrap(); assert_eq!(got.prompt, "new prompt"); } diff --git a/src/tools/bash.rs b/src/tools/bash.rs index 32edae7..714cbc4 100644 --- a/src/tools/bash.rs +++ b/src/tools/bash.rs @@ -167,10 +167,7 @@ impl Tool for BashTool { Err(_) => Ok(ToolResult { success: false, output: String::new(), - error: Some(format!( - "Command timed out after {} seconds", - timeout_secs - )), + error: Some(format!("Command timed out after {} seconds", timeout_secs)), }), } } @@ -249,10 +246,7 @@ mod tests { #[tokio::test] async fn test_pwd_command() { let tool = BashTool::new(); - let result = tool - .execute(json!({ "command": "pwd" })) - .await - .unwrap(); + let result = tool.execute(json!({ "command": "pwd" })).await.unwrap(); assert!(result.success); } @@ -260,7 +254,10 @@ mod tests { #[tokio::test] async fn test_ls_command() { let tool = BashTool::new(); - let result = tool.execute(json!({ "command": "ls -la /tmp" })).await.unwrap(); + let result = tool + .execute(json!({ "command": "ls -la /tmp" })) + .await + .unwrap(); assert!(result.success); } diff --git a/src/tools/browser.rs b/src/tools/browser.rs index d037370..a247a1b 100644 --- a/src/tools/browser.rs +++ b/src/tools/browser.rs @@ -5,7 +5,7 @@ use std::time::Duration; use anyhow::Context; use async_trait::async_trait; use base64::Engine; -use fantoccini::actions::{InputSource, MouseActions, PointerAction, MOUSE_BUTTON_LEFT}; +use fantoccini::actions::{InputSource, MOUSE_BUTTON_LEFT, MouseActions, PointerAction}; use fantoccini::key::Key; use fantoccini::{Client, ClientBuilder, Locator}; use serde::{Deserialize, Serialize}; @@ -63,7 +63,9 @@ impl BrowserTool { #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum BrowserAction { - Open { url: String }, + Open { + url: String, + }, Snapshot { #[serde(default)] interactive_only: bool, @@ -72,10 +74,20 @@ pub enum BrowserAction { #[serde(default)] depth: Option, }, - Click { selector: String }, - Fill { selector: String, value: String }, - Type { selector: Option, text: String }, - GetText { selector: String }, + Click { + selector: String, + }, + Fill { + selector: String, + value: String, + }, + Type { + selector: Option, + text: String, + }, + GetText { + selector: String, + }, GetTitle, GetUrl, Screenshot { @@ -84,7 +96,9 @@ pub enum BrowserAction { #[serde(default)] return_base64: bool, }, - Focus { selector: String }, + Focus { + selector: String, + }, Wait { #[serde(default)] selector: Option, @@ -93,9 +107,16 @@ pub enum BrowserAction { #[serde(default)] text: Option, }, - Press { key: String }, - Hover { selector: String }, - ClickAt { x: u32, y: u32 }, + Press { + key: String, + }, + Hover { + selector: String, + }, + ClickAt { + x: u32, + y: u32, + }, Scroll { direction: String, #[serde(default)] @@ -120,13 +141,8 @@ fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result { let selector = args @@ -198,10 +214,7 @@ fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result { let key = args @@ -239,11 +252,13 @@ fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result anyhow::bail!("Unsupported browser action: {}", other), @@ -488,7 +503,11 @@ impl BrowserState { } Err(e) => return Err(e.into()), } - tracing::debug!(action = "fill", output_len = value.len(), "Browser action completed"); + tracing::debug!( + action = "fill", + output_len = value.len(), + "Browser action completed" + ); Ok(ToolResult { success: true, output: format!("Filled {} with {}", selector, value), @@ -573,7 +592,10 @@ impl BrowserState { error: None, }) } - BrowserAction::Screenshot { path, return_base64 } => { + BrowserAction::Screenshot { + path, + return_base64, + } => { let client = self.active_client()?; let png = client.screenshot().await?; let save_path = path.unwrap_or_else(|| { @@ -588,14 +610,25 @@ impl BrowserState { tokio::fs::write(&save_path, &png).await?; if return_base64 { let b64 = base64::engine::general_purpose::STANDARD.encode(&png); - tracing::debug!(action = "screenshot", output_len = b64.len(), "Browser action completed"); + tracing::debug!( + action = "screenshot", + output_len = b64.len(), + "Browser action completed" + ); return Ok(ToolResult { success: true, - output: format!("Screenshot saved to {}. Base64: data:image/png;base64,{}", save_path, b64), + output: format!( + "Screenshot saved to {}. Base64: data:image/png;base64,{}", + save_path, b64 + ), error: None, }); } - tracing::debug!(action = "screenshot", output_len = save_path.len(), "Browser action completed"); + tracing::debug!( + action = "screenshot", + output_len = save_path.len(), + "Browser action completed" + ); Ok(ToolResult { success: true, output: format!("Screenshot saved to {}", save_path), @@ -611,18 +644,18 @@ impl BrowserState { vec![serde_json::to_value(el)?], ) .await?; - tracing::debug!(action = "focus", output_len = selector.len(), "Browser action completed"); + tracing::debug!( + action = "focus", + output_len = selector.len(), + "Browser action completed" + ); Ok(ToolResult { success: true, output: format!("Focused {}", selector), error: None, }) } - BrowserAction::Wait { - selector, - ms, - text, - } => { + BrowserAction::Wait { selector, ms, text } => { if let Some(sel) = selector { let client = self.active_client()?; wait_for_selector(client, &sel).await?; @@ -719,9 +752,21 @@ impl BrowserState { let id = info.get("id").and_then(|v| v.as_str()).unwrap_or(""); let el_type = info.get("type").and_then(|v| v.as_str()).unwrap_or(""); let text = info.get("text").and_then(|v| v.as_str()).unwrap_or(""); - let id_str = if id.is_empty() { String::new() } else { format!("#{id}") }; - let type_str = if el_type.is_empty() { String::new() } else { format!("[type={el_type}]") }; - let text_str = if text.is_empty() { String::new() } else { format!(" ({text})") }; + let id_str = if id.is_empty() { + String::new() + } else { + format!("#{id}") + }; + let type_str = if el_type.is_empty() { + String::new() + } else { + format!("[type={el_type}]") + }; + let text_str = if text.is_empty() { + String::new() + } else { + format!(" ({text})") + }; format!("Clicked at ({x},{y}) on <{tag}{id_str}{type_str}>{text_str}") } None => format!("Clicked at ({}, {})", x, y), @@ -1090,10 +1135,7 @@ fn css_attr_escape(input: &str) -> String { } fn xpath_contains_text(text: &str) -> String { - format!( - "//*[contains(normalize-space(.), {})]", - xpath_literal(text) - ) + format!("//*[contains(normalize-space(.), {})]", xpath_literal(text)) } fn xpath_literal(input: &str) -> String { @@ -1140,7 +1182,10 @@ fn webdriver_key(key: &str) -> String { "pagedown" => Key::PageDown.to_string(), "space" => " ".to_string(), other => { - tracing::warn!("Unrecognized key '{}', this will have no effect (press only supports single named keys)", other); + tracing::warn!( + "Unrecognized key '{}', this will have no effect (press only supports single named keys)", + other + ); other.to_string() } } diff --git a/src/tools/calculator.rs b/src/tools/calculator.rs index 8f033dc..2b42e8d 100644 --- a/src/tools/calculator.rs +++ b/src/tools/calculator.rs @@ -659,10 +659,7 @@ mod tests { #[tokio::test] async fn test_evaluate_missing_expression() { let tool = CalculatorTool::new(); - let result = tool - .execute(json!({"function": "evaluate"})) - .await - .unwrap(); + let result = tool.execute(json!({"function": "evaluate"})).await.unwrap(); assert!(!result.success); } diff --git a/src/tools/content_search.rs b/src/tools/content_search.rs index 96823aa..86152b6 100644 --- a/src/tools/content_search.rs +++ b/src/tools/content_search.rs @@ -31,10 +31,7 @@ impl ContentSearchTool { for (i, line) in lines.iter().enumerate() { if output.len() + line.len() + 1 > MAX_OUTPUT_CHARS { let omitted = lines.len() - i; - output.push_str(&format!( - "\n... ({} matches omitted) ...", - omitted - )); + output.push_str(&format!("\n... ({} matches omitted) ...", omitted)); break; } if !output.is_empty() { @@ -113,18 +110,40 @@ impl Tool for ContentSearchTool { let dir = self.resolve_dir(args.get("dir").and_then(|v| v.as_str())); let file_pattern = args.get("file_pattern").and_then(|v| v.as_str()); - let case_sensitive = args.get("case_sensitive").and_then(|v| v.as_bool()).unwrap_or(false); - let context_lines = args.get("context_lines").and_then(|v| v.as_u64()).unwrap_or(0) as usize; - let max_results = args.get("max_results").and_then(|v| v.as_u64()).unwrap_or(MAX_RESULTS as u64) as usize; + let case_sensitive = args + .get("case_sensitive") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + let context_lines = args + .get("context_lines") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize; + let max_results = args + .get("max_results") + .and_then(|v| v.as_u64()) + .unwrap_or(MAX_RESULTS as u64) as usize; - let result = self.run_search(pattern, &dir, file_pattern, case_sensitive, context_lines, max_results).await; + let result = self + .run_search( + pattern, + &dir, + file_pattern, + case_sensitive, + context_lines, + max_results, + ) + .await; match result { Ok(lines) => { let count = lines.len(); let mut output = self.truncate_output(&lines); output.push_str(&format!("\n\n---\n共 {} 条匹配", count)); - Ok(ToolResult { success: true, output, error: None }) + Ok(ToolResult { + success: true, + output, + error: None, + }) } Err(e) => Ok(ToolResult { success: false, @@ -146,22 +165,52 @@ impl ContentSearchTool { max_results: usize, ) -> anyhow::Result> { if which::which("rg").is_ok() { - match self.search_with_rg(pattern, dir, file_pattern, case_sensitive, context_lines, max_results).await { + match self + .search_with_rg( + pattern, + dir, + file_pattern, + case_sensitive, + context_lines, + max_results, + ) + .await + { Ok(lines) => return Ok(lines), Err(e) => tracing::warn!("rg failed: {}, falling back", e), } } if which::which("grep").is_ok() { - match self.search_with_grep(pattern, dir, file_pattern, case_sensitive, context_lines, max_results).await { + match self + .search_with_grep( + pattern, + dir, + file_pattern, + case_sensitive, + context_lines, + max_results, + ) + .await + { Ok(lines) if !lines.is_empty() => return Ok(lines), - Ok(_) => {}, + Ok(_) => {} Err(e) => tracing::warn!("grep failed: {}, falling back", e), } } - tracing::warn!("No rg/grep available, using built-in content search (much slower). Install ripgrep for better performance."); - self.search_with_rust(pattern, dir, file_pattern, case_sensitive, context_lines, max_results).await + tracing::warn!( + "No rg/grep available, using built-in content search (much slower). Install ripgrep for better performance." + ); + self.search_with_rust( + pattern, + dir, + file_pattern, + case_sensitive, + context_lines, + max_results, + ) + .await } async fn search_with_rg( @@ -176,8 +225,10 @@ impl ContentSearchTool { let mut cmd = Command::new("rg"); cmd.arg("-n") .arg("--no-heading") - .arg("--color").arg("never") - .arg("--max-count").arg(max_results.to_string()) + .arg("--color") + .arg("never") + .arg("--max-count") + .arg(max_results.to_string()) .arg(pattern) .arg(dir) .stdout(Stdio::piped()) @@ -193,12 +244,9 @@ impl ContentSearchTool { cmd.arg("--glob").arg(fp); } - let output = timeout( - std::time::Duration::from_secs(TIMEOUT_SECS), - cmd.output(), - ) - .await - .map_err(|_| anyhow::anyhow!("rg timed out after {}s", TIMEOUT_SECS))??; + let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output()) + .await + .map_err(|_| anyhow::anyhow!("rg timed out after {}s", TIMEOUT_SECS))??; if !output.status.success() && output.status.code() != Some(1) { let stderr = String::from_utf8_lossy(&output.stderr); @@ -206,7 +254,8 @@ impl ContentSearchTool { } let text = String::from_utf8_lossy(&output.stdout); - let lines: Vec = text.lines() + let lines: Vec = text + .lines() .take(max_results) .map(|l| l.to_string()) .collect(); @@ -242,15 +291,13 @@ impl ContentSearchTool { cmd.arg("--include").arg(fp); } - let output = timeout( - std::time::Duration::from_secs(TIMEOUT_SECS), - cmd.output(), - ) - .await - .map_err(|_| anyhow::anyhow!("grep timed out after {}s", TIMEOUT_SECS))??; + let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output()) + .await + .map_err(|_| anyhow::anyhow!("grep timed out after {}s", TIMEOUT_SECS))??; let text = String::from_utf8_lossy(&output.stdout); - let lines: Vec = text.lines() + let lines: Vec = text + .lines() .take(max_results) .map(|l| l.to_string()) .collect(); @@ -280,7 +327,9 @@ impl ContentSearchTool { if case_sensitive { regex::Regex::new(&re_str) } else { - regex::RegexBuilder::new(&re_str).case_insensitive(true).build() + regex::RegexBuilder::new(&re_str) + .case_insensitive(true) + .build() } }); @@ -291,7 +340,14 @@ impl ContentSearchTool { }; let mut results = Vec::new(); - grep_dir(Path::new(dir), Path::new(dir), &re, file_re.as_ref(), &mut results, max_results)?; + grep_dir( + Path::new(dir), + Path::new(dir), + &re, + file_re.as_ref(), + &mut results, + max_results, + )?; Ok(results) } @@ -350,16 +406,19 @@ fn grep_dir( if path.is_dir() { if let Some(name) = rel.file_name().and_then(|n| n.to_str()) - && name.starts_with('.') && name.len() > 1 { - continue; - } + && name.starts_with('.') + && name.len() > 1 + { + continue; + } grep_dir(base, &path, re, file_re, results, max)?; } else if path.is_file() { if let Some(file_re) = file_re && let Some(name) = rel.file_name().and_then(|n| n.to_str()) - && !file_re.is_match(name) { - continue; - } + && !file_re.is_match(name) + { + continue; + } if let Ok(content) = std::fs::read_to_string(&path) { for (line_num, line) in content.lines().enumerate() { @@ -391,8 +450,16 @@ mod tests { #[tokio::test] async fn test_content_search_rust_fallback() { let dir = TempDir::new().unwrap(); - fs::write(dir.path().join("main.rs"), "fn main() {\n let x = 42;\n println!(\"hello\");\n}").unwrap(); - fs::write(dir.path().join("lib.rs"), "pub fn foo() -> u32 {\n let y = 42;\n y\n}").unwrap(); + fs::write( + dir.path().join("main.rs"), + "fn main() {\n let x = 42;\n println!(\"hello\");\n}", + ) + .unwrap(); + fs::write( + dir.path().join("lib.rs"), + "pub fn foo() -> u32 {\n let y = 42;\n y\n}", + ) + .unwrap(); fs::write(dir.path().join("README.md"), "# Project\nHello world").unwrap(); let tool = ContentSearchTool::new(); diff --git a/src/tools/cron.rs b/src/tools/cron.rs index ac5f5d7..d96cb43 100644 --- a/src/tools/cron.rs +++ b/src/tools/cron.rs @@ -1,10 +1,10 @@ use std::sync::Arc; use async_trait::async_trait; -use serde_json::{json, Value}; +use serde_json::{Value, json}; use uuid::Uuid; -use crate::scheduler::{next_run_for_schedule, Schedule}; +use crate::scheduler::{Schedule, next_run_for_schedule}; use crate::storage::{ScheduledJob, Storage}; use crate::tools::traits::{Tool, ToolResult}; @@ -229,10 +229,7 @@ impl Tool for CronListTool { } async fn execute(&self, args: Value) -> anyhow::Result { - let filter = args - .get("status") - .and_then(|v| v.as_str()) - .unwrap_or("all"); + let filter = args.get("status").and_then(|v| v.as_str()).unwrap_or("all"); let jobs = self.storage.list_scheduled_jobs().await?; let filtered: Vec<&ScheduledJob> = match filter { @@ -397,7 +394,9 @@ impl Tool for CronEnableTool { .map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?; let next = next_run_for_schedule(&job.schedule, now_ms()); - self.storage.set_scheduled_job_enabled(&job_id, true).await?; + self.storage + .set_scheduled_job_enabled(&job_id, true) + .await?; if let Some(n) = next { self.storage.set_scheduled_job_next_run(&job_id, n).await?; } @@ -464,7 +463,9 @@ impl Tool for CronDisableTool { .get_scheduled_job(&job_id) .await .map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?; - self.storage.set_scheduled_job_enabled(&job_id, false).await?; + self.storage + .set_scheduled_job_enabled(&job_id, false) + .await?; Ok(ToolResult { success: true, @@ -580,7 +581,9 @@ impl Tool for CronUpdateTool { if args.get("schedule").is_some() { let job = self.storage.get_scheduled_job(&job_id).await?; if let Some(next) = next_run_for_schedule(&job.schedule, now_ms()) { - self.storage.set_scheduled_job_next_run(&job_id, next).await?; + self.storage + .set_scheduled_job_next_run(&job_id, next) + .await?; } } @@ -765,9 +768,7 @@ mod tests { let job = ScheduledJob { id: "job-update-tool".into(), name: "old".into(), - schedule: Schedule::Every { - every_ms: 3600000, - }, + schedule: Schedule::Every { every_ms: 3600000 }, prompt: "old prompt".into(), channel: "feishu".into(), chat_id: "oc_1".into(), diff --git a/src/tools/delegate.rs b/src/tools/delegate.rs index 2300cce..7e3f521 100644 --- a/src/tools/delegate.rs +++ b/src/tools/delegate.rs @@ -102,7 +102,10 @@ impl Tool for DelegateTool { _ => Ok(ToolResult { success: false, output: String::new(), - error: Some(format!("Unknown action: {}. Supported: run, check_task, cancel_task, list_tasks", action)), + error: Some(format!( + "Unknown action: {}. Supported: run, check_task, cancel_task, list_tasks", + action + )), }), } } @@ -115,9 +118,11 @@ impl DelegateTool { .ok_or_else(|| anyhow::anyhow!("missing required parameter: prompt"))? .to_string(); - let allowed_tools: Option> = args["allowed_tools"] - .as_array() - .map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect()); + let allowed_tools: Option> = args["allowed_tools"].as_array().map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(|s| s.to_string())) + .collect() + }); let max_iterations = args["max_iterations"].as_u64().map(|v| v as usize); let timeout_secs = args["timeout_secs"].as_u64(); @@ -141,15 +146,21 @@ impl DelegateTool { return Ok(ToolResult { success: false, output: String::new(), - error: Some(format!("unknown mode: {}. Supported: inline, background, parallel", mode_str)), - }) + error: Some(format!( + "unknown mode: {}. Supported: inline, background, parallel", + mode_str + )), + }); } }; match mode { ExecutionMode::Inline => { let config = self.parse_config_from_args(args)?; - let result = self.sub_agent_manager.run_inline(config).await + let result = self + .sub_agent_manager + .run_inline(config) + .await .map_err(|e| anyhow::anyhow!("{}", e))?; match result.status { @@ -177,10 +188,14 @@ impl DelegateTool { } ExecutionMode::Background => { let config = self.parse_config_from_args(args)?; - let ctx = crate::agent::sub_agent::get_delegate_context() - .map_err(|_| anyhow::anyhow!("delegate context not available: not in an agent worker"))?; + let ctx = crate::agent::sub_agent::get_delegate_context().map_err(|_| { + anyhow::anyhow!("delegate context not available: not in an agent worker") + })?; - let task_id = self.sub_agent_manager.run_background(config, ctx).await + let task_id = self + .sub_agent_manager + .run_background(config, ctx) + .await .map_err(|e| anyhow::anyhow!("{}", e))?; Ok(ToolResult { @@ -200,9 +215,12 @@ impl DelegateTool { .as_str() .ok_or_else(|| anyhow::anyhow!("each parallel task requires 'prompt'"))? .to_string(); - let allowed_tools: Option> = task["allowed_tools"] - .as_array() - .map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect()); + let allowed_tools: Option> = + task["allowed_tools"].as_array().map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(|s| s.to_string())) + .collect() + }); configs.push(SubAgentConfig { prompt, @@ -216,13 +234,18 @@ impl DelegateTool { let has_args_allowed = args["allowed_tools"].as_array().is_some(); for c in &mut configs { if c.allowed_tools.is_none() && has_args_allowed { - c.allowed_tools = args["allowed_tools"] - .as_array() - .map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect()); + c.allowed_tools = args["allowed_tools"].as_array().map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(|s| s.to_string())) + .collect() + }); } } - let results = self.sub_agent_manager.run_parallel(configs).await + let results = self + .sub_agent_manager + .run_parallel(configs) + .await .map_err(|e| anyhow::anyhow!("{}", e))?; let mut output = String::new(); @@ -243,7 +266,9 @@ impl DelegateTool { } } - let all_success = results.iter().all(|r| matches!(r.status, TaskStatus::Completed)); + let all_success = results + .iter() + .all(|r| matches!(r.status, TaskStatus::Completed)); Ok(ToolResult { success: all_success, output: output.trim().to_string(), diff --git a/src/tools/file_edit.rs b/src/tools/file_edit.rs index fa6b4fc..c4c87c1 100644 --- a/src/tools/file_edit.rs +++ b/src/tools/file_edit.rs @@ -243,8 +243,8 @@ impl Tool for FileEditTool { #[cfg(test)] mod tests { use super::*; - use tempfile::NamedTempFile; use std::io::Write; + use tempfile::NamedTempFile; #[tokio::test] async fn test_edit_simple() { diff --git a/src/tools/file_read.rs b/src/tools/file_read.rs index 4daecbd..3115de2 100644 --- a/src/tools/file_read.rs +++ b/src/tools/file_read.rs @@ -181,10 +181,7 @@ impl Tool for FileReadTool { } result = lines[..end_idx].join("\n"); let truncated = original_len - result.len(); - result.push_str(&format!( - "\n\n... ({} chars truncated) ...", - truncated - )); + result.push_str(&format!("\n\n... ({} chars truncated) ...", truncated)); } if end < total { @@ -196,10 +193,7 @@ impl Tool for FileReadTool { end + 1 )); } else { - result.push_str(&format!( - "\n\n(End of file — {} lines total)", - total - )); + result.push_str(&format!("\n\n(End of file — {} lines total)", total)); } if let Some(label) = encoding_label { @@ -214,7 +208,7 @@ impl Tool for FileReadTool { } None => { // Truly binary file — base64 encode - use base64::{engine::general_purpose::STANDARD, Engine}; + use base64::{Engine, engine::general_purpose::STANDARD}; let encoded = STANDARD.encode(&bytes); let mime = mime_guess::from_path(&resolved) .first_or_octet_stream() @@ -278,8 +272,8 @@ fn decode_text(bytes: &[u8]) -> (Option, Option<&'static str>) { #[cfg(test)] mod tests { use super::*; - use tempfile::NamedTempFile; use std::io::Write; + use tempfile::NamedTempFile; #[tokio::test] async fn test_read_simple_file() { @@ -338,10 +332,7 @@ mod tests { #[tokio::test] async fn test_is_directory() { let tool = FileReadTool::new(); - let result = tool - .execute(json!({ "path": "." })) - .await - .unwrap(); + let result = tool.execute(json!({ "path": "." })).await.unwrap(); assert!(!result.success); assert!(result.error.unwrap().contains("Not a file")); diff --git a/src/tools/file_search.rs b/src/tools/file_search.rs index 0399a6d..8ebff77 100644 --- a/src/tools/file_search.rs +++ b/src/tools/file_search.rs @@ -101,17 +101,29 @@ impl Tool for FileSearchTool { }; let dir = self.resolve_dir(args.get("dir").and_then(|v| v.as_str())); - let case_sensitive = args.get("case_sensitive").and_then(|v| v.as_bool()).unwrap_or(true); - let max_results = args.get("max_results").and_then(|v| v.as_u64()).unwrap_or(MAX_RESULTS as u64) as usize; + let case_sensitive = args + .get("case_sensitive") + .and_then(|v| v.as_bool()) + .unwrap_or(true); + let max_results = args + .get("max_results") + .and_then(|v| v.as_u64()) + .unwrap_or(MAX_RESULTS as u64) as usize; - let result = self.run_search(pattern, &dir, case_sensitive, max_results).await; + let result = self + .run_search(pattern, &dir, case_sensitive, max_results) + .await; match result { Ok(lines) => { let count = lines.len(); let mut output = self.truncate_output(&lines); output.push_str(&format!("\n\n---\n共 {} 个文件", count)); - Ok(ToolResult { success: true, output, error: None }) + Ok(ToolResult { + success: true, + output, + error: None, + }) } Err(e) => Ok(ToolResult { success: false, @@ -139,9 +151,12 @@ impl FileSearchTool { }; if !fd_cmd.is_empty() { - match self.search_with_fd(pattern, dir, case_sensitive, max_results, fd_cmd).await { + match self + .search_with_fd(pattern, dir, case_sensitive, max_results, fd_cmd) + .await + { Ok(lines) if !lines.is_empty() => return Ok(lines), - Ok(_) => {}, + Ok(_) => {} Err(e) => tracing::warn!("{} failed: {}, falling back", fd_cmd, e), } } @@ -149,13 +164,14 @@ impl FileSearchTool { if which::which("find").is_ok() { match self.search_with_find(pattern, dir, max_results).await { Ok(lines) if !lines.is_empty() => return Ok(lines), - Ok(_) => {}, + Ok(_) => {} Err(e) => tracing::warn!("find failed: {}, falling back", e), } } tracing::warn!("No fd/find available, using built-in file search (slower)"); - self.search_with_rust(pattern, dir, case_sensitive, max_results).await + self.search_with_rust(pattern, dir, case_sensitive, max_results) + .await } async fn search_with_fd( @@ -167,11 +183,15 @@ impl FileSearchTool { fd_cmd: &str, ) -> anyhow::Result> { let mut cmd = Command::new(fd_cmd); - cmd.arg("--search-path").arg(dir) - .arg("--glob").arg(pattern) - .arg("--color").arg("never") + cmd.arg("--search-path") + .arg(dir) + .arg("--glob") + .arg(pattern) + .arg("--color") + .arg("never") .arg("--strip-cwd-prefix") - .arg("--max-results").arg(max_results.to_string()) + .arg("--max-results") + .arg(max_results.to_string()) .stdout(Stdio::piped()) .stderr(Stdio::piped()); @@ -179,12 +199,9 @@ impl FileSearchTool { cmd.arg("--ignore-case"); } - let output = timeout( - std::time::Duration::from_secs(TIMEOUT_SECS), - cmd.output(), - ) - .await - .map_err(|_| anyhow::anyhow!("fd timed out after {}s", TIMEOUT_SECS))??; + let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output()) + .await + .map_err(|_| anyhow::anyhow!("fd timed out after {}s", TIMEOUT_SECS))??; if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); @@ -192,7 +209,8 @@ impl FileSearchTool { } let text = String::from_utf8_lossy(&output.stdout); - let lines: Vec = text.lines() + let lines: Vec = text + .lines() .filter(|l| !l.is_empty()) .map(|l| l.to_string()) .collect(); @@ -215,15 +233,13 @@ impl FileSearchTool { .stdout(Stdio::piped()) .stderr(Stdio::null()); - let output = timeout( - std::time::Duration::from_secs(TIMEOUT_SECS), - cmd.output(), - ) - .await - .map_err(|_| anyhow::anyhow!("find timed out after {}s", TIMEOUT_SECS))??; + let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output()) + .await + .map_err(|_| anyhow::anyhow!("find timed out after {}s", TIMEOUT_SECS))??; let text = String::from_utf8_lossy(&output.stdout); - let mut lines: Vec = text.lines() + let mut lines: Vec = text + .lines() .filter(|l| !l.is_empty()) .map(|l| { let p = Path::new(l); @@ -254,7 +270,13 @@ impl FileSearchTool { .map_err(|e| anyhow::anyhow!("Invalid glob pattern '{}': {}", pattern, e))?; let mut results = Vec::new(); - walk_dir(Path::new(dir), Path::new(dir), &re, &mut results, max_results)?; + walk_dir( + Path::new(dir), + Path::new(dir), + &re, + &mut results, + max_results, + )?; Ok(results) } } @@ -311,15 +333,18 @@ fn walk_dir( if path.is_dir() { if let Some(name) = rel.file_name().and_then(|n| n.to_str()) - && name.starts_with('.') && name.len() > 1 { - continue; - } + && name.starts_with('.') + && name.len() > 1 + { + continue; + } walk_dir(base, &path, re, results, max)?; } else if path.is_file() { if let Some(name) = rel.file_name().and_then(|n| n.to_str()) - && re.is_match(name) { - results.push(rel.to_string_lossy().to_string()); - } + && re.is_match(name) + { + results.push(rel.to_string_lossy().to_string()); + } if results.len() >= max { return Ok(()); } diff --git a/src/tools/file_write.rs b/src/tools/file_write.rs index 55235c9..7d11195 100644 --- a/src/tools/file_write.rs +++ b/src/tools/file_write.rs @@ -90,13 +90,14 @@ impl Tool for FileWriteTool { // Create parent directories if needed if let Some(parent) = resolved.parent() && !parent.exists() - && let Err(e) = std::fs::create_dir_all(parent) { - return Ok(ToolResult { - success: false, - output: String::new(), - error: Some(format!("Failed to create parent directory: {}", e)), - }); - } + && let Err(e) = std::fs::create_dir_all(parent) + { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to create parent directory: {}", e)), + }); + } match std::fs::write(&resolved, content) { Ok(_) => Ok(ToolResult { @@ -168,10 +169,7 @@ mod tests { #[tokio::test] async fn test_write_missing_path() { let tool = FileWriteTool::new(); - let result = tool - .execute(json!({ "content": "Hello" })) - .await - .unwrap(); + let result = tool.execute(json!({ "content": "Hello" })).await.unwrap(); assert!(!result.success); assert!(result.error.unwrap().contains("path")); diff --git a/src/tools/get_skill.rs b/src/tools/get_skill.rs index 3250ec0..6076dca 100644 --- a/src/tools/get_skill.rs +++ b/src/tools/get_skill.rs @@ -129,7 +129,9 @@ impl GetSkillTool { let mut output = format!("可用 skill (共 {} 个):\n", skills.len()); for s in &skills { let always_mark = if s.always { " [常驻]" } else { "" }; - let path_str = s.path.as_ref() + let path_str = s + .path + .as_ref() .map(|p| p.to_string_lossy().to_string()) .unwrap_or_else(|| "—".to_string()); output.push_str(&format!( @@ -148,10 +150,10 @@ impl GetSkillTool { #[cfg(test)] mod tests { use super::*; - use tempfile::tempdir; use std::fs::File; use std::io::Write; use std::path::PathBuf; + use tempfile::tempdir; #[tokio::test] async fn test_get_existing_skill() { diff --git a/src/tools/http_request.rs b/src/tools/http_request.rs index 37a02ae..96b039b 100644 --- a/src/tools/http_request.rs +++ b/src/tools/http_request.rs @@ -50,10 +50,7 @@ impl HttpRequestTool { } if !host_matches_allowlist(&host, &self.allowed_domains) { - return Err(format!( - "Host '{}' is not in allowed_domains", - host - )); + return Err(format!("Host '{}' is not in allowed_domains", host)); } Ok(url.to_string()) @@ -80,11 +77,10 @@ impl HttpRequestTool { for (key, value) in obj { if let Some(str_val) = value.as_str() && let Ok(name) = reqwest::header::HeaderName::from_bytes(key.as_bytes()) - && let Ok(val) = - reqwest::header::HeaderValue::from_str(str_val) - { - header_map.insert(name, val); - } + && let Ok(val) = reqwest::header::HeaderValue::from_str(str_val) + { + header_map.insert(name, val); + } } } @@ -191,7 +187,9 @@ fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool { allowed_domains.iter().any(|domain| { host == domain - || host.strip_suffix(domain).is_some_and(|prefix| prefix.ends_with('.')) + || host + .strip_suffix(domain) + .is_some_and(|prefix| prefix.ends_with('.')) }) } @@ -202,7 +200,11 @@ fn is_private_host(host: &str) -> bool { } // Check .local TLD - if host.rsplit('.').next().is_some_and(|label| label == "local") { + if host + .rsplit('.') + .next() + .is_some_and(|label| label == "local") + { return true; } @@ -224,9 +226,7 @@ fn is_private_ip(ip: &std::net::IpAddr) -> bool { || v4.is_broadcast() || v4.is_multicast() } - std::net::IpAddr::V6(v6) => { - v6.is_loopback() || v6.is_unspecified() || v6.is_multicast() - } + std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(), } } @@ -278,10 +278,7 @@ impl Tool for HttpRequestTool { } }; - let method_str = args - .get("method") - .and_then(|v| v.as_str()) - .unwrap_or("GET"); + let method_str = args.get("method").and_then(|v| v.as_str()).unwrap_or("GET"); let headers_val = args.get("headers").cloned().unwrap_or(json!({})); let body = args.get("body").and_then(|v| v.as_str()); diff --git a/src/tools/memory.rs b/src/tools/memory.rs index 0966c77..b4602da 100644 --- a/src/tools/memory.rs +++ b/src/tools/memory.rs @@ -151,10 +151,19 @@ impl Tool for MemoryRecallTool { .and_then(|v| v.as_i64()) .unwrap_or(chrono::Utc::now().timestamp_millis()); self.memory - .recall_by_time(since, until, Some(query), limit, Some(MemoryCategory::Knowledge), None) + .recall_by_time( + since, + until, + Some(query), + limit, + Some(MemoryCategory::Knowledge), + None, + ) .await? } else { - self.memory.recall(query, limit, Some(MemoryCategory::Knowledge), None).await? + self.memory + .recall(query, limit, Some(MemoryCategory::Knowledge), None) + .await? }; if entries.is_empty() { @@ -168,7 +177,11 @@ impl Tool for MemoryRecallTool { let formatted = entries .iter() .map(|e| { - let session = e.session_id.as_deref().map(|s| format!(" [session: {}]", s)).unwrap_or_default(); + let session = e + .session_id + .as_deref() + .map(|s| format!(" [session: {}]", s)) + .unwrap_or_default(); format!( "- {} [{}]{} [importance: {:.1}]: {}", e.key, @@ -264,10 +277,19 @@ impl Tool for TimelineRecallTool { .and_then(|v| v.as_i64()) .unwrap_or(chrono::Utc::now().timestamp_millis()); self.memory - .recall_by_time(since, until, Some(query), limit, Some(MemoryCategory::Timeline), session_id) + .recall_by_time( + since, + until, + Some(query), + limit, + Some(MemoryCategory::Timeline), + session_id, + ) .await? } else { - self.memory.recall(query, limit, Some(MemoryCategory::Timeline), session_id).await? + self.memory + .recall(query, limit, Some(MemoryCategory::Timeline), session_id) + .await? }; if entries.is_empty() { @@ -281,7 +303,11 @@ impl Tool for TimelineRecallTool { let formatted = entries .iter() .map(|e| { - let session = e.session_id.as_deref().map(|s| format!(" [session: {}]", s)).unwrap_or_default(); + let session = e + .session_id + .as_deref() + .map(|s| format!(" [session: {}]", s)) + .unwrap_or_default(); format!( "- {} [{}]{} [importance: {:.1}]: {}", e.key, diff --git a/src/tools/mod.rs b/src/tools/mod.rs index b94fec7..1d3dcf0 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -37,11 +37,11 @@ pub use send_message::SendMessageTool; pub use traits::{OutboundMessenger, Tool, ToolResult}; pub use web_fetch::WebFetchTool; -use std::sync::Arc; use crate::agent::SubAgentManager; use crate::config::BrowserConfig; use crate::memory::MemoryManager; use crate::skills::SkillsLoader; +use std::sync::Arc; /// Create the base tool registry (without send_message). /// `send_message` tool is registered later via `SessionManager::register_outbound_tool()` diff --git a/src/tools/registry.rs b/src/tools/registry.rs index 84df8b7..e93e47d 100644 --- a/src/tools/registry.rs +++ b/src/tools/registry.rs @@ -17,7 +17,10 @@ impl ToolRegistry { } pub fn register(&self, tool: T) { - self.tools.lock().unwrap().insert(tool.name().to_string(), Arc::new(tool)); + self.tools + .lock() + .unwrap() + .insert(tool.name().to_string(), Arc::new(tool)); } /// Register an existing Arc-wrapped tool by name diff --git a/src/tools/schema.rs b/src/tools/schema.rs index 137dc46..015d7eb 100644 --- a/src/tools/schema.rs +++ b/src/tools/schema.rs @@ -115,9 +115,11 @@ impl SchemaCleanr { } if let Some(Value::String(t)) = obj.get("type") - && t == "object" && !obj.contains_key("properties") { - tracing::warn!("Object schema without 'properties' field may cause issues"); - } + && t == "object" + && !obj.contains_key("properties") + { + tracing::warn!("Object schema without 'properties' field may cause issues"); + } Ok(()) } @@ -173,9 +175,10 @@ impl SchemaCleanr { // Handle anyOf/oneOf simplification if (obj.contains_key("anyOf") || obj.contains_key("oneOf")) - && let Some(simplified) = Self::try_simplify_union(&obj, defs, strategy, ref_stack) { - return simplified; - } + && let Some(simplified) = Self::try_simplify_union(&obj, defs, strategy, ref_stack) + { + return simplified; + } // Build cleaned object let mut cleaned = Map::new(); @@ -243,12 +246,13 @@ impl SchemaCleanr { } if let Some(def_name) = Self::parse_local_ref(ref_value) - && let Some(definition) = defs.get(def_name.as_str()) { - ref_stack.insert(ref_value.to_string()); - let cleaned = Self::clean_with_defs(definition.clone(), defs, strategy, ref_stack); - ref_stack.remove(ref_value); - return Self::preserve_meta(obj, cleaned); - } + && let Some(definition) = defs.get(def_name.as_str()) + { + ref_stack.insert(ref_value.to_string()); + let cleaned = Self::clean_with_defs(definition.clone(), defs, strategy, ref_stack); + ref_stack.remove(ref_value); + return Self::preserve_meta(obj, cleaned); + } tracing::warn!("Cannot resolve $ref: {}", ref_value); Self::preserve_meta(obj, Value::Object(Map::new())) @@ -340,13 +344,16 @@ impl SchemaCleanr { return true; } if let Some(Value::Array(arr)) = obj.get("enum") - && arr.len() == 1 && matches!(arr[0], Value::Null) { - return true; - } + && arr.len() == 1 + && matches!(arr[0], Value::Null) + { + return true; + } if let Some(Value::String(t)) = obj.get("type") - && t == "null" { - return true; - } + && t == "null" + { + return true; + } } false } @@ -403,7 +410,10 @@ impl SchemaCleanr { match non_null.len() { 0 => Value::String("null".to_string()), - 1 => non_null.into_iter().next().unwrap_or(Value::String("null".to_string())), + 1 => non_null + .into_iter() + .next() + .unwrap_or(Value::String("null".to_string())), _ => Value::Array(non_null), } } else { diff --git a/src/tools/send_message.rs b/src/tools/send_message.rs index f16e82b..6cb0de9 100644 --- a/src/tools/send_message.rs +++ b/src/tools/send_message.rs @@ -1,5 +1,5 @@ -use std::sync::Arc; use std::collections::HashSet; +use std::sync::Arc; use async_trait::async_trait; use mime_guess::mime; @@ -31,14 +31,20 @@ fn parse_target_chat_id(raw: &str) -> Result<(&str, &str, Option<&str>), String> match parts.len() { 2 => { if parts[0].is_empty() || parts[1].is_empty() { - Err(format!("Invalid target_chat_id format '{}': channel and chat_id must not be empty", raw)) + Err(format!( + "Invalid target_chat_id format '{}': channel and chat_id must not be empty", + raw + )) } else { Ok((parts[0], parts[1], None)) } } 3 => { if parts[0].is_empty() || parts[1].is_empty() || parts[2].is_empty() { - Err(format!("Invalid target_chat_id format '{}': all three parts must not be empty", raw)) + Err(format!( + "Invalid target_chat_id format '{}': all three parts must not be empty", + raw + )) } else { Ok((parts[0], parts[1], Some(parts[2]))) } @@ -98,8 +104,8 @@ target_chat_id 支持两种格式::(发送到该聊天下 .ok_or_else(|| anyhow::anyhow!("missing content"))?; // 1. Parse target_chat_id - let (channel, chat_id, dialog_id) = parse_target_chat_id(raw_id) - .map_err(|e| anyhow::anyhow!(e))?; + let (channel, chat_id, dialog_id) = + parse_target_chat_id(raw_id).map_err(|e| anyhow::anyhow!(e))?; // 2. Validate channel if !self.available_channels.contains(channel) { @@ -109,7 +115,11 @@ target_chat_id 支持两种格式::(发送到该聊天下 error: Some(format!( "Channel '{}' is not available. Available channels: {}", channel, - self.available_channels.iter().cloned().collect::>().join(", ") + self.available_channels + .iter() + .cloned() + .collect::>() + .join(", ") )), }); } @@ -129,7 +139,8 @@ target_chat_id 支持两种格式::(发送到该聊天下 let media = parse_files_arg(&args); // 4. Send via messenger - match self.messenger + match self + .messenger .send_message(channel, chat_id, dialog_id, content, source, media) .await { diff --git a/src/tools/traits.rs b/src/tools/traits.rs index 584054c..b2213a5 100644 --- a/src/tools/traits.rs +++ b/src/tools/traits.rs @@ -1,5 +1,5 @@ -use async_trait::async_trait; use crate::bus::{MediaItem, MessageSource}; +use async_trait::async_trait; #[derive(Debug, Clone)] pub struct ToolResult { diff --git a/src/tools/web_fetch.rs b/src/tools/web_fetch.rs index e4717b5..5f2f650 100644 --- a/src/tools/web_fetch.rs +++ b/src/tools/web_fetch.rs @@ -239,7 +239,11 @@ fn is_private_host(host: &str) -> bool { return true; } - if host.rsplit('.').next().is_some_and(|label| label == "local") { + if host + .rsplit('.') + .next() + .is_some_and(|label| label == "local") + { return true; } @@ -248,7 +252,9 @@ fn is_private_host(host: &str) -> bool { std::net::IpAddr::V4(v4) => { v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified() } - std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(), + std::net::IpAddr::V6(v6) => { + v6.is_loopback() || v6.is_unspecified() || v6.is_multicast() + } }; } diff --git a/tests/test_integration.rs b/tests/test_integration.rs index dc0ecd9..7f81dca 100644 --- a/tests/test_integration.rs +++ b/tests/test_integration.rs @@ -1,6 +1,6 @@ -use std::collections::HashMap; -use picobot::providers::{create_provider, ChatCompletionRequest, Message}; use picobot::config::{Config, LLMProviderConfig}; +use picobot::providers::{ChatCompletionRequest, Message, create_provider}; +use std::collections::HashMap; fn load_config() -> Option { dotenv::from_filename("tests/test.env").ok()?; @@ -42,8 +42,7 @@ fn create_request(content: &str) -> ChatCompletionRequest { #[tokio::test] #[ignore] async fn test_openai_simple_completion() { - let config = load_config() - .expect("Please configure tests/test.env with valid API keys"); + let config = load_config().expect("Please configure tests/test.env with valid API keys"); let provider = create_provider(config).expect("Failed to create provider"); let response = provider.chat(create_request("Say 'ok'")).await.unwrap(); @@ -57,8 +56,7 @@ async fn test_openai_simple_completion() { #[tokio::test] #[ignore] async fn test_openai_conversation() { - let config = load_config() - .expect("Please configure tests/test.env with valid API keys"); + let config = load_config().expect("Please configure tests/test.env with valid API keys"); let provider = create_provider(config).expect("Failed to create provider"); @@ -82,7 +80,9 @@ async fn test_openai_conversation() { async fn test_config_load() { // Test that config.json can be loaded and provider config created let config = Config::load("config.json").expect("Failed to load config.json"); - let provider_config = config.get_provider_config("default").expect("Failed to get provider config"); + let provider_config = config + .get_provider_config("default") + .expect("Failed to get provider config"); assert_eq!(provider_config.provider_type, "openai"); assert_eq!(provider_config.name, "aliyun"); diff --git a/tests/test_scheduler.rs b/tests/test_scheduler.rs index 44fefe9..21c4c6e 100644 --- a/tests/test_scheduler.rs +++ b/tests/test_scheduler.rs @@ -41,7 +41,7 @@ async fn test_scheduler_types_roundtrip() { /// Verify that next_run_for_schedule produces valid future timestamps. #[test] fn test_next_run_always_future() { - use picobot::scheduler::{next_run_for_schedule, Schedule}; + use picobot::scheduler::{Schedule, next_run_for_schedule}; let now = 1700000000000_i64; @@ -56,6 +56,10 @@ fn test_next_run_always_future() { for s in &schedules { let next = next_run_for_schedule(s, now); assert!(next.is_some(), "expected next run for {:?}", s); - assert!(next.unwrap() > now, "next run should be after now for {:?}", s); + assert!( + next.unwrap() > now, + "next run should be after now for {:?}", + s + ); } } diff --git a/tests/test_tool_calling.rs b/tests/test_tool_calling.rs index 86e565d..bdf1086 100644 --- a/tests/test_tool_calling.rs +++ b/tests/test_tool_calling.rs @@ -1,6 +1,6 @@ -use std::collections::HashMap; -use picobot::providers::{create_provider, ChatCompletionRequest, Message, Tool, ToolFunction}; use picobot::config::LLMProviderConfig; +use picobot::providers::{ChatCompletionRequest, Message, Tool, ToolFunction, create_provider}; +use std::collections::HashMap; fn load_openai_config() -> Option { dotenv::from_filename("tests/test.env").ok()?; @@ -53,8 +53,7 @@ fn make_weather_tool() -> Tool { #[tokio::test] #[ignore] async fn test_openai_tool_call() { - let config = load_openai_config() - .expect("Please configure tests/test.env with valid API keys"); + let config = load_openai_config().expect("Please configure tests/test.env with valid API keys"); let provider = create_provider(config).expect("Failed to create provider"); @@ -68,7 +67,11 @@ async fn test_openai_tool_call() { let response = provider.chat(request).await.unwrap(); // Should have tool calls - assert!(!response.tool_calls.is_empty(), "Expected tool call, got: {}", response.content); + assert!( + !response.tool_calls.is_empty(), + "Expected tool call, got: {}", + response.content + ); let tool_call = &response.tool_calls[0]; assert_eq!(tool_call.name, "get_weather"); @@ -78,8 +81,7 @@ async fn test_openai_tool_call() { #[tokio::test] #[ignore] async fn test_openai_tool_call_with_manual_execution() { - let config = load_openai_config() - .expect("Please configure tests/test.env with valid API keys"); + let config = load_openai_config().expect("Please configure tests/test.env with valid API keys"); let provider = create_provider(config).expect("Failed to create provider"); @@ -92,8 +94,7 @@ async fn test_openai_tool_call_with_manual_execution() { }; let response1 = provider.chat(request1).await.unwrap(); - let tool_call = response1.tool_calls.first() - .expect("Expected tool call"); + let tool_call = response1.tool_calls.first().expect("Expected tool call"); assert_eq!(tool_call.name, "get_weather"); // Second request with tool result @@ -116,8 +117,7 @@ async fn test_openai_tool_call_with_manual_execution() { #[tokio::test] #[ignore] async fn test_openai_no_tool_when_not_provided() { - let config = load_openai_config() - .expect("Please configure tests/test.env with valid API keys"); + let config = load_openai_config().expect("Please configure tests/test.env with valid API keys"); let provider = create_provider(config).expect("Failed to create provider");