From 73dab09bfeeb88a27383bd525cc008e7ce7ba408 Mon Sep 17 00:00:00 2001 From: ooodc <549496103@qq.com> Date: Tue, 28 Apr 2026 10:33:31 +0800 Subject: [PATCH] Refactor code for improved readability and consistency - Adjusted formatting and indentation in various files for better clarity. - Consolidated multi-line statements into single lines where appropriate. - Enhanced error handling messages for better debugging. - Added a new InboundProcessor struct to handle inbound messages more effectively. - Updated test cases to ensure they align with the new code structure. --- src/agent/agent_loop.rs | 178 ++++++----- src/agent/context_compressor.rs | 106 +++++-- src/agent/mod.rs | 2 +- src/bus/dispatcher.rs | 7 +- src/bus/message.rs | 145 +++++---- src/bus/mod.rs | 14 +- src/channels/feishu.rs | 539 ++++++++++++++++++++++---------- src/channels/manager.rs | 29 +- src/channels/mod.rs | 2 +- src/cli/channel.rs | 2 +- src/cli/input.rs | 45 ++- src/client/mod.rs | 11 +- src/config/mod.rs | 168 ++++++---- src/gateway/mod.rs | 83 +---- src/gateway/processor.rs | 77 +++++ src/gateway/session.rs | 460 +++++++++++++++++---------- src/gateway/ws.rs | 124 +++++--- src/lib.rs | 20 +- src/logging.rs | 39 +-- src/main.rs | 2 +- src/observability/mod.rs | 11 +- src/protocol.rs | 10 +- src/providers/anthropic.rs | 34 +- src/providers/mod.rs | 11 +- src/providers/openai.rs | 68 ++-- src/providers/traits.rs | 8 +- src/scheduler/mod.rs | 229 ++++++++++---- src/skills/mod.rs | 106 +++++-- src/storage/mod.rs | 338 ++++++++++++-------- src/text.rs | 2 +- src/tools/bash.rs | 18 +- src/tools/calculator.rs | 5 +- src/tools/file_edit.rs | 2 +- src/tools/file_read.rs | 9 +- src/tools/file_write.rs | 5 +- src/tools/http_request.rs | 28 +- src/tools/memory_manage.rs | 13 +- src/tools/memory_search.rs | 13 +- src/tools/scheduler_manage.rs | 70 +++-- src/tools/schema.rs | 5 +- src/tools/skill_manage.rs | 20 +- src/tools/time.rs | 37 ++- src/tools/web_fetch.rs | 10 +- tests/test_integration.rs | 14 +- tests/test_request_format.rs | 25 +- tests/test_tool_calling.rs | 22 +- 46 files changed, 2068 insertions(+), 1098 deletions(-) create mode 100644 src/gateway/processor.rs diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index 4a64e90..3b31ac1 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -1,16 +1,16 @@ -use async_trait::async_trait; -use crate::bus::message::ContentBlock; use crate::bus::ChatMessage; +use crate::bus::message::ContentBlock; use crate::bus::message::ToolMessageState; use crate::config::LLMProviderConfig; use crate::observability::{ - truncate_args, Observer, ObserverEvent, ToolExecutionOutcome, ToolExecutionState, + Observer, ObserverEvent, ToolExecutionOutcome, ToolExecutionState, truncate_args, }; -use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall}; +use crate::providers::{ChatCompletionRequest, LLMProvider, Message, ToolCall, create_provider}; use crate::skills::SkillRuntime; use crate::storage::SessionStore; -use crate::tools::{ToolContext, ToolRegistry}; use crate::text::{char_count, take_prefix_chars, take_suffix_chars}; +use crate::tools::{ToolContext, ToolRegistry}; +use async_trait::async_trait; use std::collections::VecDeque; use std::hash::{Hash, Hasher}; use std::io::Read; @@ -19,18 +19,13 @@ use std::time::Instant; /// Minimum characters to keep when truncating const TRUNCATION_SUFFIX_LEN: usize = 200; -const MEMORY_TOOL_USAGE_SYSTEM_PROMPT: &str = - include_str!("memory_tool_usage_system_prompt.md"); +const MEMORY_TOOL_USAGE_SYSTEM_PROMPT: &str = include_str!("memory_tool_usage_system_prompt.md"); const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__"; -const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str = "工具已经启动并进入等待用户操作的状态。请先完成外部操作,完成后直接告诉我继续。"; +const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str = + "工具已经启动并进入等待用户操作的状态。请先完成外部操作,完成后直接告诉我继续。"; const RECOVERABLE_LLM_ERROR_MESSAGE: &str = "模型服务暂时不可用或响应超时。请稍后重试。"; -const SUPPORTED_IMAGE_MIME_TYPES: &[&str] = &[ - "image/jpeg", - "image/png", - "image/gif", - "image/webp", -]; +const SUPPORTED_IMAGE_MIME_TYPES: &[&str] = &["image/jpeg", "image/png", "image/gif", "image/webp"]; /// Build content blocks from text and media paths fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec { @@ -115,14 +110,15 @@ fn truncate_tool_result(output: &str, max_tool_result_chars: usize) -> String { let tail = take_suffix_chars(output, total_chars.saturating_sub(truncated_start_len)); format!( "...\n\n[Output truncated - {} characters removed]\n\n{}", - truncated_start_len, - tail + truncated_start_len, tail ) } } fn parse_pending_tool_output(output: &str) -> Option { - output.strip_prefix(PENDING_USER_ACTION_MARKER).map(|rest| rest.trim().to_string()) + output + .strip_prefix(PENDING_USER_ACTION_MARKER) + .map(|rest| rest.trim().to_string()) } fn normalize_tool_arguments(arguments: &serde_json::Value) -> serde_json::Value { @@ -341,7 +337,10 @@ impl AgentLoop { }) } - pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc) -> Result { + pub fn with_tools( + provider_config: LLMProviderConfig, + tools: Arc, + ) -> Result { let max_iterations = provider_config.max_tool_iterations; let provider = create_provider(provider_config.clone()) .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; @@ -416,9 +415,16 @@ impl AgentLoop { /// it loops back to the LLM with the tool results until either: /// - The LLM returns no more tool calls (final response) /// - Maximum iterations are reached - pub async fn process(&self, mut messages: Vec) -> Result { + pub async fn process( + &self, + mut messages: Vec, + ) -> Result { #[cfg(debug_assertions)] - tracing::debug!(history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process"); + tracing::debug!( + history_len = messages.len(), + max_iterations = self.max_iterations, + "Starting agent process" + ); // Track tool calls for loop detection let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default()); @@ -441,7 +447,11 @@ impl AgentLoop { if let Some(skill_tool) = self.skills.skill_tool_definition() { tool_defs.push(skill_tool); } - let tools = if tool_defs.is_empty() { None } else { Some(tool_defs) }; + let tools = if tool_defs.is_empty() { + None + } else { + Some(tool_defs) + }; let request = ChatCompletionRequest { messages: messages_for_llm, @@ -461,7 +471,8 @@ impl AgentLoop { error_details = %format_error_chain(e.as_ref()), "LLM request failed" ); - let assistant_message = ChatMessage::assistant(recoverable_llm_message(&e.to_string())); + let assistant_message = + ChatMessage::assistant(recoverable_llm_message(&e.to_string())); emitted_messages.push(assistant_message.clone()); return Ok(AgentProcessResult { final_response: assistant_message, @@ -480,7 +491,8 @@ impl AgentLoop { // If no tool calls, this is the final response if response.tool_calls.is_empty() { - let assistant_message = if let Some(reasoning_content) = response.reasoning_content { + let assistant_message = if let Some(reasoning_content) = response.reasoning_content + { ChatMessage::assistant_with_reasoning(response.content, reasoning_content) } else { ChatMessage::assistant(response.content) @@ -493,24 +505,35 @@ impl AgentLoop { } // Execute tool calls - tracing::info!(iteration, count = response.tool_calls.len(), "Tool calls detected, executing tools"); + tracing::info!( + iteration, + count = response.tool_calls.len(), + "Tool calls detected, executing tools" + ); // Add assistant message with tool calls - let assistant_message = if let Some(reasoning_content) = response.reasoning_content.clone() { - ChatMessage::assistant_with_tool_calls_and_reasoning( - response.content.clone(), - response.tool_calls.clone(), - reasoning_content, - ) - } else { - ChatMessage::assistant_with_tool_calls( - response.content.clone(), - response.tool_calls.clone(), - ) - }; + let assistant_message = + if let Some(reasoning_content) = response.reasoning_content.clone() { + ChatMessage::assistant_with_tool_calls_and_reasoning( + response.content.clone(), + response.tool_calls.clone(), + reasoning_content, + ) + } else { + ChatMessage::assistant_with_tool_calls( + response.content.clone(), + response.tool_calls.clone(), + ) + }; messages.push(assistant_message.clone()); emitted_messages.push(assistant_message); - self.emit_live_tool_call_message(emitted_messages.last().expect("assistant message just pushed").clone()).await; + self.emit_live_tool_call_message( + emitted_messages + .last() + .expect("assistant message just pushed") + .clone(), + ) + .await; // Execute tools and add results to messages let tool_results = self.execute_tools(&response.tool_calls).await; @@ -519,7 +542,9 @@ impl AgentLoop { // Log function call with name and arguments let args_str = match &tool_call.arguments { serde_json::Value::Object(obj) if obj.is_empty() => "{}".to_string(), - other => serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()), + other => { + serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()) + } }; tracing::info!(tool = %tool_call.name, args = %args_str, "Calling tool"); @@ -595,7 +620,11 @@ impl AgentLoop { // Loop continues to next iteration with updated messages #[cfg(debug_assertions)] - tracing::debug!(iteration, message_count = messages.len(), "Tool execution complete, continuing to next iteration"); + tracing::debug!( + iteration, + message_count = messages.len(), + "Tool execution complete, continuing to next iteration" + ); } // Max iterations reached - ask LLM for a summary based on completed work @@ -604,7 +633,7 @@ impl AgentLoop { // Add a message asking for summary let summary_request = ChatMessage::user( "You have reached the maximum number of tool call iterations. \ - Please provide your best answer based on the work completed so far." + Please provide your best answer based on the work completed so far.", ); messages.push(summary_request); @@ -624,7 +653,8 @@ impl AgentLoop { match (*self.provider).chat(request).await { Ok(response) => { - let assistant_message = if let Some(reasoning_content) = response.reasoning_content { + let assistant_message = if let Some(reasoning_content) = response.reasoning_content + { ChatMessage::assistant_with_reasoning(response.content, reasoning_content) } else { ChatMessage::assistant(response.content) @@ -745,10 +775,7 @@ impl AgentLoop { } // Apply duration - ToolExecutionOutcome { - duration, - ..result - } + ToolExecutionOutcome { duration, ..result } } /// Internal tool execution without event tracking. @@ -790,10 +817,7 @@ impl AgentLoop { "arguments": normalized_arguments, }), ); - ToolExecutionOutcome::failure( - format!("Error: {}", err), - Some(err), - ) + ToolExecutionOutcome::failure(format!("Error: {}", err), Some(err)) } }; } @@ -809,7 +833,10 @@ impl AgentLoop { } }; - match tool.execute_with_context(&self.tool_context, normalized_arguments.clone()).await { + match tool + .execute_with_context(&self.tool_context, normalized_arguments.clone()) + .await + { Ok(result) => { if result.success { if let Some(pending_output) = parse_pending_tool_output(&result.output) { @@ -827,10 +854,7 @@ impl AgentLoop { output = %result.output, "Tool returned an error result" ); - ToolExecutionOutcome::failure( - format!("Error: {}", error), - Some(error), - ) + ToolExecutionOutcome::failure(format!("Error: {}", error), Some(error)) } } Err(e) => { @@ -842,10 +866,7 @@ impl AgentLoop { error_details = %format!("{:#}", e), "Tool execution failed" ); - ToolExecutionOutcome::failure( - format!("Error: {}", e), - Some(e.to_string()), - ) + ToolExecutionOutcome::failure(format!("Error: {}", e), Some(e.to_string())) } } } @@ -863,7 +884,9 @@ impl AgentLoop { return; }; - if let Err(err) = store.append_skill_event(Some(session_id), event_type, skill_name, &payload) { + if let Err(err) = + store.append_skill_event(Some(session_id), event_type, skill_name, &payload) + { tracing::warn!(error = %err, event_type = %event_type, "Failed to record skill event"); } } @@ -942,28 +965,37 @@ mod tests { assert_eq!(provider_message.role, "assistant"); assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1); - assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].id, "call_1"); - assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].name, "calculator"); + assert_eq!( + provider_message.tool_calls.as_ref().unwrap()[0].id, + "call_1" + ); + assert_eq!( + provider_message.tool_calls.as_ref().unwrap()[0].name, + "calculator" + ); } #[test] fn test_chat_message_to_llm_message_preserves_reasoning_content() { - let chat_message = ChatMessage::assistant_with_reasoning( - "final answer", - "hidden chain of thought", - ); + let chat_message = + ChatMessage::assistant_with_reasoning("final answer", "hidden chain of thought"); let provider_message = chat_message_to_llm_message(&chat_message); assert_eq!(provider_message.role, "assistant"); - assert_eq!(provider_message.reasoning_content.as_deref(), Some("hidden chain of thought")); + assert_eq!( + provider_message.reasoning_content.as_deref(), + Some("hidden chain of thought") + ); } #[test] fn test_memory_prompt_requires_proactive_memory_search() { assert!(MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("在绝大多数请求开始时")); assert!(MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("先使用长期记忆检索工具 memory_search")); - assert!(MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("不要因为你自认为已经能直接回答就省略检索")); + assert!( + MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("不要因为你自认为已经能直接回答就省略检索") + ); } #[test] @@ -1001,9 +1033,13 @@ mod tests { #[test] fn test_normalize_tool_arguments_keeps_plain_string() { - let normalized = normalize_tool_arguments(&serde_json::Value::String("plain text".to_string())); + let normalized = + normalize_tool_arguments(&serde_json::Value::String("plain text".to_string())); - assert_eq!(normalized, serde_json::Value::String("plain text".to_string())); + assert_eq!( + normalized, + serde_json::Value::String("plain text".to_string()) + ); } #[test] @@ -1028,7 +1064,9 @@ mod tests { assert_eq!(blocks.len(), 2); assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello")); - assert!(matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,"))); + assert!( + matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,")) + ); } } diff --git a/src/agent/context_compressor.rs b/src/agent/context_compressor.rs index a4e79cb..a495ed8 100644 --- a/src/agent/context_compressor.rs +++ b/src/agent/context_compressor.rs @@ -1,11 +1,9 @@ use crate::bus::{ - ChatMessage, - SYSTEM_CONTEXT_AGENT_PROMPT, - SYSTEM_CONTEXT_HISTORY_COMPACTION, + ChatMessage, SYSTEM_CONTEXT_AGENT_PROMPT, SYSTEM_CONTEXT_HISTORY_COMPACTION, SYSTEM_CONTEXT_SCHEDULED_PROMPT, }; use crate::config::LLMProviderConfig; -use crate::providers::{create_provider, ChatCompletionRequest, LLMProvider, Message}; +use crate::providers::{ChatCompletionRequest, LLMProvider, Message, create_provider}; use crate::text::{char_count, take_prefix_chars}; use crate::agent::AgentError; @@ -221,7 +219,9 @@ Be concise, aim for {} characters or less. .await; } - let per_chunk_target = (target_chars / layer.len().max(1)).max(500).min(target_chars); + let per_chunk_target = (target_chars / layer.len().max(1)) + .max(500) + .min(target_chars); let mut summaries = Vec::with_capacity(layer.len()); for chunk in &layer { summaries.push( @@ -241,7 +241,9 @@ Be concise, aim for {} characters or less. let merged = summaries.join("\n\n"); if char_count(&merged) <= target_chars { - return self.summarize_transcript(provider, &merged, target_chars).await; + return self + .summarize_transcript(provider, &merged, target_chars) + .await; } layer = Self::split_text_chunks(&merged, target_chars); @@ -314,7 +316,10 @@ Be concise, aim for {} characters or less. || message.has_system_context(SYSTEM_CONTEXT_SCHEDULED_PROMPT)) } - fn split_prefix_messages(&self, history: &[ChatMessage]) -> (Vec, Vec) { + fn split_prefix_messages( + &self, + history: &[ChatMessage], + ) -> (Vec, Vec) { let preserved_system_messages = history .iter() .filter(|message| self.should_preserve_system_message(message)) @@ -343,7 +348,8 @@ Be concise, aim for {} characters or less. return Ok(None); } - let preserved_turn_start = turn_ranges[turn_ranges.len() - self.config.retain_last_user_turns].start; + let preserved_turn_start = + turn_ranges[turn_ranges.len() - self.config.retain_last_user_turns].start; if preserved_turn_start == 0 { return Ok(None); } @@ -357,10 +363,10 @@ Be concise, aim for {} characters or less. Ok(Some(HistoryCompactionPlan { preserved_system_messages, - summary_message: ChatMessage::system_with_context(format!( - "[Compressed History]\n\n{}", - summary - ), Some(SYSTEM_CONTEXT_HISTORY_COMPACTION.to_string())), + summary_message: ChatMessage::system_with_context( + format!("[Compressed History]\n\n{}", summary), + Some(SYSTEM_CONTEXT_HISTORY_COMPACTION.to_string()), + ), preserved_messages: history[preserved_turn_start..].to_vec(), compressed_turns: turn_ranges.len() - self.config.retain_last_user_turns, preserved_turns: self.config.retain_last_user_turns, @@ -392,7 +398,10 @@ Be concise, aim for {} characters or less. "Starting context compression" ); - let current_history = match self.build_compaction_plan(&history, provider_config).await? { + let current_history = match self + .build_compaction_plan(&history, provider_config) + .await? + { Some(plan) => { let mut compressed = Vec::with_capacity( plan.preserved_system_messages.len() + plan.preserved_messages.len() + 1, @@ -429,8 +438,12 @@ Be concise, aim for {} characters or less. let transcript = Self::build_transcript(messages); let result = if char_count(&transcript) <= self.config.summary_max_chars { - self.summarize_transcript(provider.as_ref(), &transcript, self.config.summary_max_chars) - .await + self.summarize_transcript( + provider.as_ref(), + &transcript, + self.config.summary_max_chars, + ) + .await } else { self.summarize_chunked_transcript(provider.as_ref(), messages, &transcript) .await @@ -440,7 +453,10 @@ Be concise, aim for {} characters or less. Ok(summary) => Ok(summary), Err(e) => { tracing::warn!(error = %e, "LLM summarization failed, using truncated transcript"); - Ok(take_prefix_chars(&transcript, self.config.summary_max_chars)) + Ok(take_prefix_chars( + &transcript, + self.config.summary_max_chars, + )) } } } @@ -463,7 +479,11 @@ mod tests { // "Hi there!" (8) -> ceil(8/4)+4 = 2+4 = 6 // "How are you?" (11) -> ceil(11/4)+4 = 3+4 = 7 // raw = 19, with 1.2x = ~23 - assert!(tokens > 18 && tokens < 30, "Expected ~23 tokens, got {}", tokens); + assert!( + tokens > 18 && tokens < 30, + "Expected ~23 tokens, got {}", + tokens + ); } #[test] @@ -487,21 +507,39 @@ mod tests { ]; let turns = compressor.user_turn_ranges(&history); - assert_eq!(turns, vec![ - UserTurnRange { start: 1, end_exclusive: 4 }, - UserTurnRange { start: 4, end_exclusive: 6 }, - UserTurnRange { start: 6, end_exclusive: 7 }, - ]); + assert_eq!( + turns, + vec![ + UserTurnRange { + start: 1, + end_exclusive: 4 + }, + UserTurnRange { + start: 4, + end_exclusive: 6 + }, + UserTurnRange { + start: 6, + end_exclusive: 7 + }, + ] + ); } #[test] fn test_split_prefix_messages_preserves_key_system_messages() { let compressor = ContextCompressor::new(50); let prefix = vec![ - ChatMessage::system_with_context("agent prompt", Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string())), + ChatMessage::system_with_context( + "agent prompt", + Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string()), + ), ChatMessage::user("u1"), ChatMessage::assistant("a1"), - ChatMessage::system_with_context("scheduled prompt", Some(SYSTEM_CONTEXT_SCHEDULED_PROMPT.to_string())), + ChatMessage::system_with_context( + "scheduled prompt", + Some(SYSTEM_CONTEXT_SCHEDULED_PROMPT.to_string()), + ), ]; let (preserved_system_messages, summary_source) = compressor.split_prefix_messages(&prefix); @@ -519,10 +557,22 @@ mod tests { #[test] fn test_summary_char_budget_for_context_window_scales_and_clamps() { - assert_eq!(ContextCompressor::summary_char_budget_for_context_window(4_096), 1_500); - assert_eq!(ContextCompressor::summary_char_budget_for_context_window(65_536), 16_384); - assert_eq!(ContextCompressor::summary_char_budget_for_context_window(128_000), 32_000); - assert_eq!(ContextCompressor::summary_char_budget_for_context_window(400_000), 50_000); + assert_eq!( + ContextCompressor::summary_char_budget_for_context_window(4_096), + 1_500 + ); + assert_eq!( + ContextCompressor::summary_char_budget_for_context_window(65_536), + 16_384 + ); + assert_eq!( + ContextCompressor::summary_char_budget_for_context_window(128_000), + 32_000 + ); + assert_eq!( + ContextCompressor::summary_char_budget_for_context_window(400_000), + 50_000 + ); } #[test] diff --git a/src/agent/mod.rs b/src/agent/mod.rs index cbd5daf..1e91055 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -1,5 +1,5 @@ pub mod agent_loop; pub mod context_compressor; -pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult, EmittedMessageHandler}; +pub use agent_loop::{AgentError, AgentLoop, AgentProcessResult, EmittedMessageHandler}; pub use context_compressor::ContextCompressor; diff --git a/src/bus/dispatcher.rs b/src/bus/dispatcher.rs index 2555c00..9678b9f 100644 --- a/src/bus/dispatcher.rs +++ b/src/bus/dispatcher.rs @@ -1,6 +1,6 @@ +use std::collections::HashMap; use std::sync::Arc; use tokio::sync::RwLock; -use std::collections::HashMap; use crate::bus::{MessageBus, OutboundMessage}; use crate::channels::base::{Channel, ChannelError}; @@ -22,7 +22,10 @@ impl OutboundDispatcher { /// Register a channel with the dispatcher pub async fn register_channel(&self, name: &str, channel: Arc) { - self.channels.write().await.insert(name.to_string(), channel); + self.channels + .write() + .await + .insert(name.to_string(), channel); } /// Run the dispatcher loop - consumes from bus and dispatches to channels diff --git a/src/bus/message.rs b/src/bus/message.rs index 10db019..75c03d8 100644 --- a/src/bus/message.rs +++ b/src/bus/message.rs @@ -1,5 +1,5 @@ -use std::collections::HashMap; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use crate::providers::ToolCall; @@ -34,7 +34,9 @@ pub struct ImageUrlBlock { impl ContentBlock { pub fn text(content: impl Into) -> Self { - Self::Text { text: content.into() } + Self::Text { + text: content.into(), + } } pub fn image_url(url: impl Into) -> Self { @@ -50,10 +52,10 @@ impl ContentBlock { #[derive(Debug, Clone)] pub struct MediaItem { - pub path: String, // Local file path - pub media_type: String, // "image", "audio", "file", "video" + pub path: String, // Local file path + pub media_type: String, // "image", "audio", "file", "video" pub mime_type: Option, - pub original_key: Option, // Feishu file_key for download + pub original_key: Option, // Feishu file_key for download } impl MediaItem { @@ -76,7 +78,7 @@ pub struct ChatMessage { pub id: String, pub role: String, pub content: String, - pub media_refs: Vec, // Paths to media files for context + pub media_refs: Vec, // Paths to media files for context pub timestamp: i64, #[serde(skip_serializing_if = "Option::is_none")] pub system_context: Option, @@ -150,7 +152,10 @@ impl ChatMessage { message } - pub fn assistant_with_tool_calls(content: impl Into, tool_calls: Vec) -> Self { + pub fn assistant_with_tool_calls( + content: impl Into, + tool_calls: Vec, + ) -> Self { Self { id: uuid::Uuid::new_v4().to_string(), role: "assistant".to_string(), @@ -199,8 +204,17 @@ impl ChatMessage { } } - pub fn tool(tool_call_id: impl Into, tool_name: impl Into, content: impl Into) -> Self { - Self::tool_with_state(tool_call_id, tool_name, content, ToolMessageState::Completed) + pub fn tool( + tool_call_id: impl Into, + tool_name: impl Into, + content: impl Into, + ) -> Self { + Self::tool_with_state( + tool_call_id, + tool_name, + content, + ToolMessageState::Completed, + ) } pub fn tool_with_state( @@ -287,6 +301,8 @@ pub enum OutboundEventKind { ToolCall, ToolResult, ToolPending, + SchedulerNotification, + ErrorNotification, } impl OutboundMessage { @@ -316,6 +332,30 @@ impl OutboundMessage { } } + pub fn scheduler_notification( + channel: impl Into, + chat_id: impl Into, + content: impl Into, + reply_to: Option, + metadata: HashMap, + ) -> Self { + let mut message = Self::assistant(channel, chat_id, content, reply_to, metadata); + message.event_kind = OutboundEventKind::SchedulerNotification; + message + } + + pub fn error_notification( + channel: impl Into, + chat_id: impl Into, + content: impl Into, + reply_to: Option, + metadata: HashMap, + ) -> Self { + let mut message = Self::assistant(channel, chat_id, content, reply_to, metadata); + message.event_kind = OutboundEventKind::ErrorNotification; + message + } + pub fn tool_call( channel: impl Into, chat_id: impl Into, @@ -417,20 +457,17 @@ impl OutboundMessage { )); } - outbound.extend(tool_calls - .iter() - .map(|tool_call| { - Self::tool_call( - channel.to_string(), - chat_id.to_string(), - tool_call.id.clone(), - tool_call.name.clone(), - tool_call.arguments.clone(), - reply_to.clone(), - metadata.clone(), - ) - }) - ); + outbound.extend(tool_calls.iter().map(|tool_call| { + Self::tool_call( + channel.to_string(), + chat_id.to_string(), + tool_call.id.clone(), + tool_call.name.clone(), + tool_call.arguments.clone(), + reply_to.clone(), + metadata.clone(), + ) + })); outbound } else { vec![Self::assistant( @@ -442,7 +479,11 @@ impl OutboundMessage { )] } } - "tool" => match message.tool_state.as_ref().unwrap_or(&ToolMessageState::Completed) { + "tool" => match message + .tool_state + .as_ref() + .unwrap_or(&ToolMessageState::Completed) + { ToolMessageState::Completed => vec![Self::tool_result( channel.to_string(), chat_id.to_string(), @@ -467,7 +508,10 @@ impl OutboundMessage { } } -pub(crate) fn format_tool_call_content(tool_name: &str, tool_arguments: &serde_json::Value) -> String { +pub(crate) fn format_tool_call_content( + tool_name: &str, + tool_arguments: &serde_json::Value, +) -> String { match tool_arguments { serde_json::Value::Object(map) if map.is_empty() => tool_name.to_string(), other => format!("{}\nargs: {}", tool_name, format_tool_arguments_json(other)), @@ -544,21 +588,25 @@ mod tests { ], ); - let outbound = OutboundMessage::from_chat_message( - "feishu", - "chat-1", - None, - &HashMap::new(), - &message, - ); + let outbound = + OutboundMessage::from_chat_message("feishu", "chat-1", None, &HashMap::new(), &message); assert_eq!(outbound.len(), 2); assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolCall); assert_eq!(outbound[0].tool_name.as_deref(), Some("calculator")); - assert_eq!(outbound[0].tool_arguments.as_ref().unwrap()["expression"], "1 + 1"); - assert_eq!(outbound[0].content, "calculator\nargs: {\"expression\":\"1 + 1\"}"); + assert_eq!( + outbound[0].tool_arguments.as_ref().unwrap()["expression"], + "1 + 1" + ); + assert_eq!( + outbound[0].content, + "calculator\nargs: {\"expression\":\"1 + 1\"}" + ); assert_eq!(outbound[1].tool_name.as_deref(), Some("file_read")); - assert_eq!(outbound[1].content, "file_read\nargs: {\"path\":\"README.md\"}"); + assert_eq!( + outbound[1].content, + "file_read\nargs: {\"path\":\"README.md\"}" + ); } #[test] @@ -572,13 +620,8 @@ mod tests { }], ); - let outbound = OutboundMessage::from_chat_message( - "feishu", - "chat-1", - None, - &HashMap::new(), - &message, - ); + let outbound = + OutboundMessage::from_chat_message("feishu", "chat-1", None, &HashMap::new(), &message); assert_eq!(outbound.len(), 2); assert_eq!(outbound[0].event_kind, OutboundEventKind::AssistantResponse); @@ -591,13 +634,8 @@ mod tests { fn test_from_chat_message_includes_tool_result() { let message = ChatMessage::tool("call-9", "calculator", "2"); - let outbound = OutboundMessage::from_chat_message( - "feishu", - "chat-1", - None, - &HashMap::new(), - &message, - ); + let outbound = + OutboundMessage::from_chat_message("feishu", "chat-1", None, &HashMap::new(), &message); assert_eq!(outbound.len(), 1); assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolResult); @@ -612,13 +650,8 @@ mod tests { ToolMessageState::PendingUserAction, ); - let outbound = OutboundMessage::from_chat_message( - "feishu", - "chat-1", - None, - &HashMap::new(), - &message, - ); + let outbound = + OutboundMessage::from_chat_message("feishu", "chat-1", None, &HashMap::new(), &message); assert_eq!(outbound.len(), 1); assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolPending); diff --git a/src/bus/mod.rs b/src/bus/mod.rs index b10c5af..77e642c 100644 --- a/src/bus/mod.rs +++ b/src/bus/mod.rs @@ -3,18 +3,13 @@ pub mod message; pub use dispatcher::OutboundDispatcher; pub use message::{ - ChatMessage, - ContentBlock, - InboundMessage, - MediaItem, - OutboundMessage, - SYSTEM_CONTEXT_AGENT_PROMPT, - SYSTEM_CONTEXT_HISTORY_COMPACTION, + ChatMessage, ContentBlock, InboundMessage, MediaItem, OutboundMessage, + SYSTEM_CONTEXT_AGENT_PROMPT, SYSTEM_CONTEXT_HISTORY_COMPACTION, SYSTEM_CONTEXT_SCHEDULED_PROMPT, }; use std::sync::Arc; -use tokio::sync::{mpsc, Mutex}; +use tokio::sync::{Mutex, mpsc}; // ============================================================================ // MessageBus - Async message queue for Channel <-> Agent communication @@ -52,7 +47,8 @@ impl MessageBus { /// Consume an inbound message (Agent -> Bus) pub async fn consume_inbound(&self) -> InboundMessage { - let msg = self.inbound_rx + let msg = self + .inbound_rx .lock() .await .recv() diff --git a/src/channels/feishu.rs b/src/channels/feishu.rs index ff2e46b..9357714 100644 --- a/src/channels/feishu.rs +++ b/src/channels/feishu.rs @@ -8,9 +8,9 @@ use futures_util::{SinkExt, StreamExt}; use prost::{Message as ProstMessage, bytes::Bytes}; use regex::Regex; use serde::Deserialize; -use tokio::sync::{broadcast, RwLock}; +use tokio::sync::{RwLock, broadcast}; -use crate::bus::{MessageBus, MediaItem, OutboundMessage}; +use crate::bus::{MediaItem, MessageBus, OutboundMessage}; use crate::channels::base::{Channel, ChannelError}; use crate::config::{FeishuChannelConfig, LLMProviderConfig}; use crate::text::{char_count, truncate_with_ellipsis}; @@ -189,7 +189,10 @@ impl FeishuChannel { } /// Get WebSocket endpoint URL from Feishu API - async fn get_ws_endpoint(&self, client: &reqwest::Client) -> Result<(String, WsClientConfig), ChannelError> { + async fn get_ws_endpoint( + &self, + client: &reqwest::Client, + ) -> Result<(String, WsClientConfig), ChannelError> { let resp = client .post(format!("{}/callback/ws/endpoint", FEISHU_WS_BASE)) .header("locale", "zh") @@ -201,10 +204,9 @@ impl FeishuChannel { .await .map_err(|e| ChannelError::ConnectionError(format!("HTTP error: {}", e)))?; - let endpoint_resp: WsEndpointResp = resp - .json() - .await - .map_err(|e| ChannelError::ConnectionError(format!("Failed to parse endpoint response: {}", e)))?; + let endpoint_resp: WsEndpointResp = resp.json().await.map_err(|e| { + ChannelError::ConnectionError(format!("Failed to parse endpoint response: {}", e)) + })?; if endpoint_resp.code != 0 { return Err(ChannelError::ConnectionError(format!( @@ -214,7 +216,8 @@ impl FeishuChannel { ))); } - let ep = endpoint_resp.data + let ep = endpoint_resp + .data .ok_or_else(|| ChannelError::ConnectionError("Empty endpoint data".to_string()))?; let client_config = ep.client_config.unwrap_or_default(); @@ -251,8 +254,12 @@ impl FeishuChannel { /// Fetch a new tenant access token from Feishu. async fn fetch_new_token(&self) -> Result<(String, Duration), ChannelError> { - let resp = self.http_client - .post(format!("{}/auth/v3/tenant_access_token/internal", FEISHU_API_BASE)) + let resp = self + .http_client + .post(format!( + "{}/auth/v3/tenant_access_token/internal", + FEISHU_API_BASE + )) .header("Content-Type", "application/json") .json(&serde_json::json!({ "app_id": self.config.app_id, @@ -278,10 +285,12 @@ impl FeishuChannel { return Err(ChannelError::Other("Auth failed".to_string())); } - let token = token_resp.tenant_access_token + let token = token_resp + .tenant_access_token .ok_or_else(|| ChannelError::Other("No token in response".to_string()))?; - let ttl = token_resp.expire + let ttl = token_resp + .expire .and_then(|v| u64::try_from(v).ok()) .map(Duration::from_secs) .unwrap_or(DEFAULT_TOKEN_TTL); @@ -312,12 +321,19 @@ impl FeishuChannel { message_id: &str, ) -> Result<(String, Option), ChannelError> { let media_dir = Path::new(&self.config.media_dir); - tokio::fs::create_dir_all(media_dir).await + tokio::fs::create_dir_all(media_dir) + .await .map_err(|e| ChannelError::Other(format!("Failed to create media dir: {}", e)))?; match msg_type { - "image" => self.download_image(content_json, message_id, media_dir).await, - "audio" | "file" | "media" => self.download_file(content_json, message_id, media_dir, msg_type).await, + "image" => { + self.download_image(content_json, message_id, media_dir) + .await + } + "audio" | "file" | "media" => { + self.download_file(content_json, message_id, media_dir, msg_type) + .await + } _ => Ok((format!("[unsupported media type: {}]", msg_type), None)), } } @@ -329,24 +345,31 @@ impl FeishuChannel { message_id: &str, media_dir: &Path, ) -> Result<(String, Option), ChannelError> { - let image_key = content_json.get("image_key") + let image_key = content_json + .get("image_key") .and_then(|v| v.as_str()) .ok_or_else(|| ChannelError::Other("No image_key in message".to_string()))?; let token = self.get_tenant_access_token().await?; // Use message resource API for downloading message images - let url = format!("{}/im/v1/messages/{}/resources/{}?type=image", FEISHU_API_BASE, message_id, image_key); + let url = format!( + "{}/im/v1/messages/{}/resources/{}?type=image", + FEISHU_API_BASE, message_id, image_key + ); #[cfg(debug_assertions)] tracing::debug!(url = %url, image_key = %image_key, message_id = %message_id, "Downloading image from Feishu via message resource API"); - let resp = self.http_client + let resp = self + .http_client .get(&url) .header("Authorization", format!("Bearer {}", token)) .send() .await - .map_err(|e| ChannelError::ConnectionError(format!("Download image HTTP error: {}", e)))?; + .map_err(|e| { + ChannelError::ConnectionError(format!("Download image HTTP error: {}", e)) + })?; let status = resp.status(); #[cfg(debug_assertions)] @@ -354,26 +377,33 @@ impl FeishuChannel { if !status.is_success() { let error_text = resp.text().await.unwrap_or_default(); - return Err(ChannelError::Other(format!("Image download failed {}: {}", status, error_text))); + return Err(ChannelError::Other(format!( + "Image download failed {}: {}", + status, error_text + ))); } - let data = resp.bytes().await + let data = resp + .bytes() + .await .map_err(|e| ChannelError::Other(format!("Failed to read image data: {}", e)))? .to_vec(); #[cfg(debug_assertions)] tracing::debug!(data_len = %data.len(), "Downloaded image data"); - let filename = format!("{}_{}.jpg", message_id, &image_key[..8.min(image_key.len())]); + let filename = format!( + "{}_{}.jpg", + message_id, + &image_key[..8.min(image_key.len())] + ); let file_path = media_dir.join(&filename); - tokio::fs::write(&file_path, &data).await + tokio::fs::write(&file_path, &data) + .await .map_err(|e| ChannelError::Other(format!("Failed to write image: {}", e)))?; - let media_item = MediaItem::new( - file_path.to_string_lossy().to_string(), - "image", - ); + let media_item = MediaItem::new(file_path.to_string_lossy().to_string(), "image"); tracing::info!(message_id = %message_id, filename = %filename, "Downloaded image"); @@ -388,34 +418,46 @@ impl FeishuChannel { media_dir: &Path, file_type: &str, ) -> Result<(String, Option), ChannelError> { - let file_key = content_json.get("file_key") + let file_key = content_json + .get("file_key") .and_then(|v| v.as_str()) .ok_or_else(|| ChannelError::Other("No file_key in message".to_string()))?; let token = self.get_tenant_access_token().await?; // Use message resource API for downloading message files - let url = format!("{}/im/v1/messages/{}/resources/{}?type=file", FEISHU_API_BASE, message_id, file_key); + let url = format!( + "{}/im/v1/messages/{}/resources/{}?type=file", + FEISHU_API_BASE, message_id, file_key + ); #[cfg(debug_assertions)] tracing::debug!(url = %url, file_key = %file_key, message_id = %message_id, "Downloading file from Feishu via message resource API"); - let resp = self.http_client + let resp = self + .http_client .get(&url) .header("Authorization", format!("Bearer {}", token)) .send() .await - .map_err(|e| ChannelError::ConnectionError(format!("Download file HTTP error: {}", e)))?; + .map_err(|e| { + ChannelError::ConnectionError(format!("Download file HTTP error: {}", e)) + })?; let status = resp.status(); if !status.is_success() { let error_text = resp.text().await.unwrap_or_default(); - return Err(ChannelError::Other(format!("File download failed {}: {}", status, error_text))); + return Err(ChannelError::Other(format!( + "File download failed {}: {}", + status, error_text + ))); } let response_headers = resp.headers().clone(); - let data = resp.bytes().await + let data = resp + .bytes() + .await .map_err(|e| ChannelError::Other(format!("Failed to read file data: {}", e)))? .to_vec(); @@ -428,13 +470,11 @@ impl FeishuChannel { ); let file_path = media_dir.join(&filename); - tokio::fs::write(&file_path, &data).await + tokio::fs::write(&file_path, &data) + .await .map_err(|e| ChannelError::Other(format!("Failed to write file: {}", e)))?; - let media_item = MediaItem::new( - file_path.to_string_lossy().to_string(), - file_type, - ); + let media_item = MediaItem::new(file_path.to_string_lossy().to_string(), file_type); tracing::info!(message_id = %message_id, filename = %filename, file_type = %file_type, "Downloaded file"); @@ -447,7 +487,12 @@ impl FeishuChannel { "video" => "mp4", _ => "bin", }; - format!("{}_{}.{}", message_id, &file_key[..8.min(file_key.len())], extension) + format!( + "{}_{}.{}", + message_id, + &file_key[..8.min(file_key.len())], + extension + ) } /// Upload image to Feishu and return the image_key @@ -463,7 +508,8 @@ impl FeishuChannel { .and_then(|n| n.to_str()) .unwrap_or("image.jpg"); - let file_data = tokio::fs::read(file_path).await + let file_data = tokio::fs::read(file_path) + .await .map_err(|e| ChannelError::Other(format!("Failed to read file: {}", e)))?; let part = reqwest::multipart::Part::bytes(file_data) @@ -475,13 +521,16 @@ impl FeishuChannel { .text("image_type", "message".to_string()) .part("image", part); - let resp = self.http_client + let resp = self + .http_client .post(format!("{}/im/v1/images/upload", FEISHU_API_BASE)) .header("Authorization", format!("Bearer {}", token)) .multipart(form) .send() .await - .map_err(|e| ChannelError::ConnectionError(format!("Upload image HTTP error: {}", e)))?; + .map_err(|e| { + ChannelError::ConnectionError(format!("Upload image HTTP error: {}", e)) + })?; #[derive(Deserialize)] struct UploadResp { @@ -494,7 +543,9 @@ impl FeishuChannel { image_key: String, } - let result: UploadResp = resp.json().await + let result: UploadResp = resp + .json() + .await .map_err(|e| ChannelError::Other(format!("Parse upload response error: {}", e)))?; if result.code != 0 { @@ -505,7 +556,8 @@ impl FeishuChannel { ))); } - result.data + result + .data .map(|d| d.image_key) .ok_or_else(|| ChannelError::Other("No image_key in response".to_string())) } @@ -532,7 +584,8 @@ impl FeishuChannel { _ => "file", }; - let file_data = tokio::fs::read(file_path).await + let file_data = tokio::fs::read(file_path) + .await .map_err(|e| ChannelError::Other(format!("Failed to read file: {}", e)))?; let part = reqwest::multipart::Part::bytes(file_data) @@ -545,7 +598,8 @@ impl FeishuChannel { .text("file_name", file_name.to_string()) .part("file", part); - let resp = self.http_client + let resp = self + .http_client .post(format!("{}/im/v1/files", FEISHU_API_BASE)) .header("Authorization", format!("Bearer {}", token)) .multipart(form) @@ -564,7 +618,9 @@ impl FeishuChannel { file_key: String, } - let result: UploadResp = resp.json().await + let result: UploadResp = resp + .json() + .await .map_err(|e| ChannelError::Other(format!("Parse upload response error: {}", e)))?; if result.code != 0 { @@ -575,7 +631,8 @@ impl FeishuChannel { ))); } - result.data + result + .data .map(|d| d.file_key) .ok_or_else(|| ChannelError::Other("No file_key in response".to_string())) } @@ -586,15 +643,21 @@ impl FeishuChannel { let emoji = self.config.reaction_emoji.as_str(); let token = self.get_tenant_access_token().await?; - let resp = self.http_client - .post(format!("{}/im/v1/messages/{}/reactions", FEISHU_API_BASE, message_id)) + let resp = self + .http_client + .post(format!( + "{}/im/v1/messages/{}/reactions", + FEISHU_API_BASE, message_id + )) .header("Authorization", format!("Bearer {}", token)) .json(&serde_json::json!({ "reaction_type": { "emoji_type": emoji } })) .send() .await - .map_err(|e| ChannelError::ConnectionError(format!("Add reaction HTTP error: {}", e)))?; + .map_err(|e| { + ChannelError::ConnectionError(format!("Add reaction HTTP error: {}", e)) + })?; #[derive(Deserialize)] struct ReactionResp { @@ -607,7 +670,9 @@ impl FeishuChannel { reaction_id: Option, } - let result: ReactionResp = resp.json().await + let result: ReactionResp = resp + .json() + .await .map_err(|e| ChannelError::Other(format!("Parse reaction response error: {}", e)))?; if result.code != 0 { @@ -626,7 +691,10 @@ impl FeishuChannel { /// Remove reaction using feishu metadata propagated through OutboundMessage. /// Reads feishu.message_id and feishu.reaction_id from metadata. - async fn remove_reaction_from_metadata(&self, metadata: &std::collections::HashMap) { + async fn remove_reaction_from_metadata( + &self, + metadata: &std::collections::HashMap, + ) { let (message_id, reaction_id) = match ( metadata.get("feishu.message_id"), metadata.get("feishu.reaction_id"), @@ -640,15 +708,25 @@ impl FeishuChannel { } /// Remove a reaction emoji from a message. - async fn remove_reaction(&self, message_id: &str, reaction_id: &str) -> Result<(), ChannelError> { + async fn remove_reaction( + &self, + message_id: &str, + reaction_id: &str, + ) -> Result<(), ChannelError> { let token = self.get_tenant_access_token().await?; - let resp = self.http_client - .delete(format!("{}/im/v1/messages/{}/reactions/{}", FEISHU_API_BASE, message_id, reaction_id)) + let resp = self + .http_client + .delete(format!( + "{}/im/v1/messages/{}/reactions/{}", + FEISHU_API_BASE, message_id, reaction_id + )) .header("Authorization", format!("Bearer {}", token)) .send() .await - .map_err(|e| ChannelError::ConnectionError(format!("Remove reaction HTTP error: {}", e)))?; + .map_err(|e| { + ChannelError::ConnectionError(format!("Remove reaction HTTP error: {}", e)) + })?; #[derive(Deserialize)] struct ReactionResp { @@ -656,8 +734,9 @@ impl FeishuChannel { msg: Option, } - let result: ReactionResp = resp.json().await - .map_err(|e| ChannelError::Other(format!("Parse remove reaction response error: {}", e)))?; + let result: ReactionResp = resp.json().await.map_err(|e| { + ChannelError::Other(format!("Parse remove reaction response error: {}", e)) + })?; if result.code != 0 { tracing::debug!( @@ -682,7 +761,8 @@ impl FeishuChannel { } }; - let resp = self.http_client + let resp = self + .http_client .get(format!("{}/im/v1/messages/{}", FEISHU_API_BASE, message_id)) .header("Authorization", format!("Bearer {}", token)) .send() @@ -734,16 +814,12 @@ impl FeishuChannel { let msg_type = msg_obj.msg_type.as_str(); let text = match msg_type { - "text" => { - serde_json::from_str::(raw_content) - .ok()? - .get("text")? - .as_str()? - .to_string() - } - "post" => { - parse_post_content(raw_content) - } + "text" => serde_json::from_str::(raw_content) + .ok()? + .get("text")? + .as_str()? + .to_string(), + "post" => parse_post_content(raw_content), _ => String::new(), }; @@ -761,7 +837,13 @@ impl FeishuChannel { } /// Send a message to Feishu chat with specified message type - async fn send_message_to_feishu(&self, receive_id: &str, receive_id_type: &str, msg_type: &str, content: &str) -> Result<(), ChannelError> { + async fn send_message_to_feishu( + &self, + receive_id: &str, + receive_id_type: &str, + msg_type: &str, + content: &str, + ) -> Result<(), ChannelError> { let token = self.get_tenant_access_token().await?; let payload_content = if msg_type == "text" { @@ -791,8 +873,12 @@ impl FeishuChannel { } }; - let resp = self.http_client - .post(format!("{}/im/v1/messages?receive_id_type={}", FEISHU_API_BASE, receive_id_type)) + let resp = self + .http_client + .post(format!( + "{}/im/v1/messages?receive_id_type={}", + FEISHU_API_BASE, receive_id_type + )) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", token)) .json(&serde_json::json!({ @@ -802,7 +888,9 @@ impl FeishuChannel { })) .send() .await - .map_err(|e| ChannelError::ConnectionError(format!("Send message HTTP error: {}", e)))?; + .map_err(|e| { + ChannelError::ConnectionError(format!("Send message HTTP error: {}", e)) + })?; #[derive(Deserialize)] struct SendResp { @@ -816,7 +904,10 @@ impl FeishuChannel { .map_err(|e| ChannelError::Other(format!("Parse send response error: {}", e)))?; if send_resp.code != 0 { - return Err(ChannelError::Other(format!("Send message failed: code={} msg={}", send_resp.code, send_resp.msg))); + return Err(ChannelError::Other(format!( + "Send message failed: code={} msg={}", + send_resp.code, send_resp.msg + ))); } Ok(()) @@ -847,7 +938,9 @@ impl FeishuChannel { return Ok(None); } - let payload = frame.payload.as_deref() + let payload = frame + .payload + .as_deref() .ok_or_else(|| ChannelError::Other("No payload in frame".to_string()))?; #[cfg(debug_assertions)] @@ -883,7 +976,10 @@ impl FeishuChannel { #[cfg(debug_assertions)] tracing::debug!(message_id = %message_id, "Received Feishu message"); - let open_id = payload_data.sender.sender_id.open_id + let open_id = payload_data + .sender + .sender_id + .open_id .ok_or_else(|| ChannelError::Other("No open_id".to_string()))?; let msg = payload_data.message; @@ -895,7 +991,9 @@ impl FeishuChannel { #[cfg(debug_assertions)] tracing::debug!(msg_type = %msg_type, chat_id = %chat_id, open_id = %open_id, "Parsing message content"); - let (mut content, media) = self.parse_and_download_message(msg_type, &raw_content, &message_id).await?; + let (mut content, media) = self + .parse_and_download_message(msg_type, &raw_content, &message_id) + .await?; // Fetch and prepend quoted message content if this is a reply if let Some(ref pid) = parent_id { @@ -929,18 +1027,23 @@ impl FeishuChannel { let (text, media) = match msg_type { "text" => { let text = if let Ok(parsed) = serde_json::from_str::(content) { - parsed.get("text").and_then(|v| v.as_str()).unwrap_or(content).to_string() + parsed + .get("text") + .and_then(|v| v.as_str()) + .unwrap_or(content) + .to_string() } else { content.to_string() }; (text, None) } - "post" => { - (parse_post_content(content), None) - } + "post" => (parse_post_content(content), None), "image" | "audio" | "file" | "media" => { if let Ok(content_json) = serde_json::from_str::(content) { - match self.download_media(msg_type, &content_json, message_id).await { + match self + .download_media(msg_type, &content_json, message_id) + .await + { Ok((text, media)) => (text, media), Err(_) => (format!("[{}: content unavailable]", msg_type), None), } @@ -951,7 +1054,10 @@ impl FeishuChannel { "share_chat" => { // Shared chat/cannel messages if let Ok(parsed) = serde_json::from_str::(content) { - let chat_id = parsed.get("chat_id").and_then(|v| v.as_str()).unwrap_or("unknown"); + let chat_id = parsed + .get("chat_id") + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); (format!("[shared chat: {}]", chat_id), None) } else { ("[shared chat]".to_string(), None) @@ -960,7 +1066,10 @@ impl FeishuChannel { "share_user" => { // Shared user messages if let Ok(parsed) = serde_json::from_str::(content) { - let user_id = parsed.get("user_id").and_then(|v| v.as_str()).unwrap_or("unknown"); + let user_id = parsed + .get("user_id") + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); (format!("[shared user: {}]", user_id), None) } else { ("[shared user]".to_string(), None) @@ -983,20 +1092,19 @@ impl FeishuChannel { Err(_) => (content.to_string(), None), } } - "merge_forward" => { - ("[merged forward messages]".to_string(), None) - } + "merge_forward" => ("[merged forward messages]".to_string(), None), "share_calendar_event" => { if let Ok(parsed) = serde_json::from_str::(content) { - let event_key = parsed.get("event_key").and_then(|v| v.as_str()).unwrap_or("unknown"); + let event_key = parsed + .get("event_key") + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); (format!("[shared calendar event: {}]", event_key), None) } else { ("[shared calendar event]".to_string(), None) } } - "system" => { - ("[system message]".to_string(), None) - } + "system" => ("[system message]".to_string(), None), _ => (content.to_string(), None), }; @@ -1006,20 +1114,35 @@ impl FeishuChannel { } /// Send acknowledgment for a message - async fn send_ack(frame: &PbFrame, write: &mut futures_util::stream::SplitSink>, tokio_tungstenite::tungstenite::Message>) -> Result<(), ChannelError> { + async fn send_ack( + frame: &PbFrame, + write: &mut futures_util::stream::SplitSink< + tokio_tungstenite::WebSocketStream< + tokio_tungstenite::MaybeTlsStream, + >, + tokio_tungstenite::tungstenite::Message, + >, + ) -> Result<(), ChannelError> { let mut ack = frame.clone(); ack.payload = Some(br#"{"code":200,"headers":{},"data":[]}"#.to_vec()); ack.headers.push(PbHeader { key: "biz_rt".into(), value: "0".into(), }); - write.send(tokio_tungstenite::tungstenite::Message::Binary(ack.encode_to_vec().into())) + write + .send(tokio_tungstenite::tungstenite::Message::Binary( + ack.encode_to_vec().into(), + )) .await .map_err(|e| ChannelError::Other(format!("Failed to send ack: {}", e)))?; Ok(()) } - async fn run_ws_loop(&self, bus: Arc, mut shutdown_rx: broadcast::Receiver<()>) -> Result<(), ChannelError> { + async fn run_ws_loop( + &self, + bus: Arc, + mut shutdown_rx: broadcast::Receiver<()>, + ) -> Result<(), ChannelError> { let (wss_url, client_config) = self.get_ws_endpoint(&self.http_client).await?; let service_id = Self::extract_service_id(&wss_url); @@ -1027,7 +1150,9 @@ impl FeishuChannel { let (ws_stream, _) = tokio_tungstenite::connect_async(&wss_url) .await - .map_err(|e| ChannelError::ConnectionError(format!("WebSocket connection failed: {}", e)))?; + .map_err(|e| { + ChannelError::ConnectionError(format!("WebSocket connection failed: {}", e)) + })?; *self.connected.write().await = true; tracing::info!("Feishu WebSocket connected"); @@ -1046,12 +1171,18 @@ impl FeishuChannel { }], payload: None, }; - write.send(tokio_tungstenite::tungstenite::Message::Binary(ping_frame.encode_to_vec().into())) + write + .send(tokio_tungstenite::tungstenite::Message::Binary( + ping_frame.encode_to_vec().into(), + )) .await - .map_err(|e| ChannelError::ConnectionError(format!("Failed to send initial ping: {}", e)))?; + .map_err(|e| { + ChannelError::ConnectionError(format!("Failed to send initial ping: {}", e)) + })?; let ping_interval = client_config.ping_interval.unwrap_or(120).max(10); - let mut ping_interval_tok = tokio::time::interval(tokio::time::Duration::from_secs(ping_interval)); + let mut ping_interval_tok = + tokio::time::interval(tokio::time::Duration::from_secs(ping_interval)); let mut timeout_check = tokio::time::interval(tokio::time::Duration::from_secs(10)); let mut seq: u64 = 1; let mut last_recv = Instant::now(); @@ -1216,7 +1347,8 @@ fn parse_post_content(content: &str) -> String { } } "a" => { - let link_text = el.get("text") + let link_text = el + .get("text") .and_then(|t| t.as_str()) .filter(|s| !s.is_empty()) .or_else(|| el.get("href").and_then(|h| h.as_str())) @@ -1224,7 +1356,8 @@ fn parse_post_content(content: &str) -> String { out.push(link_text.to_string()); } "at" => { - let name = el.get("user_name") + let name = el + .get("user_name") .and_then(|n| n.as_str()) .or_else(|| el.get("user_id").and_then(|i| i.as_str())) .unwrap_or("user"); @@ -1245,7 +1378,10 @@ fn parse_post_content(content: &str) -> String { /// Parse a single block {title, content: [[...]]} and append text to out. fn parse_block(block: &serde_json::Value, out: &mut Vec) { - let title = block.get("title").and_then(|t| t.as_str()).filter(|s| !s.is_empty()); + let title = block + .get("title") + .and_then(|t| t.as_str()) + .filter(|s| !s.is_empty()); if let Some(t) = title { out.push(t.to_string()); out.push("\n\n".to_string()); @@ -1372,7 +1508,8 @@ fn extract_element_content(element: &serde_json::Value, texts: &mut Vec) } "div" => { if let Some(text_obj) = element.get("text").and_then(|t| t.as_object()) { - let content = text_obj.get("content") + let content = text_obj + .get("content") .and_then(|c| c.as_str()) .unwrap_or(""); texts.push(content.to_string()); @@ -1382,12 +1519,8 @@ fn extract_element_content(element: &serde_json::Value, texts: &mut Vec) texts.push("\n".to_string()); } "a" => { - let href = element.get("href") - .and_then(|h| h.as_str()) - .unwrap_or(""); - let text = element.get("text") - .and_then(|t| t.as_str()) - .unwrap_or(""); + let href = element.get("href").and_then(|h| h.as_str()).unwrap_or(""); + let text = element.get("text").and_then(|t| t.as_str()).unwrap_or(""); if !text.is_empty() { texts.push(text.to_string()); } else if !href.is_empty() { @@ -1396,8 +1529,13 @@ fn extract_element_content(element: &serde_json::Value, texts: &mut Vec) } "img" => { let alt = element.get("alt"); - let alt_text = alt.and_then(|a| a.as_str()) - .or_else(|| alt.and_then(|a| a.as_object()).and_then(|o| o.get("content")).and_then(|c| c.as_str())) + let alt_text = alt + .and_then(|a| a.as_str()) + .or_else(|| { + alt.and_then(|a| a.as_object()) + .and_then(|o| o.get("content")) + .and_then(|c| c.as_str()) + }) .unwrap_or("[image]"); texts.push(format!("{}\n", alt_text)); } @@ -1441,7 +1579,8 @@ fn parse_list_content(content: &str) -> Result<(String, Option), Chan Err(_) => return Ok((content.to_string(), None)), }; - let items = parsed.get("items") + let items = parsed + .get("items") .and_then(|i| i.as_array()) .or_else(|| parsed.get("content").and_then(|c| c.as_array())); @@ -1501,7 +1640,11 @@ fn collect_list_items(items: &[serde_json::Value], lines: &mut Vec, dept } }) }) { - if let Some(children) = children_arr.as_object().and_then(|o| o.get("children")).and_then(|c| c.as_array()) { + if let Some(children) = children_arr + .as_object() + .and_then(|o| o.get("children")) + .and_then(|c| c.as_array()) + { collect_list_items(children, lines, depth + 1); } } @@ -1517,7 +1660,8 @@ fn extract_inline_text(el: &serde_json::Value, out: &mut String) { } } "a" => { - let text = el.get("text") + let text = el + .get("text") .and_then(|t| t.as_str()) .filter(|s| !s.is_empty()) .or_else(|| el.get("href").and_then(|h| h.as_str())) @@ -1525,7 +1669,8 @@ fn extract_inline_text(el: &serde_json::Value, out: &mut String) { out.push_str(text); } "at" => { - let name = el.get("user_name") + let name = el + .get("user_name") .and_then(|n| n.as_str()) .or_else(|| el.get("user_id").and_then(|i| i.as_str())) .unwrap_or("user"); @@ -1545,7 +1690,9 @@ fn strip_at_placeholders(text: &str) -> String { let rest: String = chars.clone().collect(); if let Some(after) = rest.strip_prefix("_user_") { // Skip until we hit a non-alphanumeric character - let placeholder_len = after.find(|c: char| !c.is_alphanumeric()).unwrap_or(after.len()); + let placeholder_len = after + .find(|c: char| !c.is_alphanumeric()) + .unwrap_or(after.len()); // Skip the placeholder for _ in 0..placeholder_len { chars.next(); @@ -1749,7 +1896,10 @@ impl FeishuChannel { let mut code_blocks: Vec = Vec::new(); for m in patterns.code_block_re.find_iter(content) { code_blocks.push(m.as_str().to_string()); - protected = protected.replace(m.as_str(), &format!("\x00CODE{}\x00", code_blocks.len() - 1)); + protected = protected.replace( + m.as_str(), + &format!("\x00CODE{}\x00", code_blocks.len() - 1), + ); } let mut elements: Vec = Vec::new(); @@ -1807,7 +1957,9 @@ impl FeishuChannel { /// Split card elements into groups with at most one table element each. /// Feishu cards have a hard limit of one table per card (API error 11310). - fn split_elements_by_table_limit(elements: &[serde_json::Value]) -> Vec> { + fn split_elements_by_table_limit( + elements: &[serde_json::Value], + ) -> Vec> { if elements.is_empty() { return vec![vec![]]; } @@ -1901,8 +2053,12 @@ impl FeishuChannel { ) -> Result<(), ChannelError> { let token = self.get_tenant_access_token().await?; - let resp = self.http_client - .post(format!("{}/im/v1/messages?receive_id_type={}", FEISHU_API_BASE, receive_id_type)) + let resp = self + .http_client + .post(format!( + "{}/im/v1/messages?receive_id_type={}", + FEISHU_API_BASE, receive_id_type + )) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", token)) .json(&serde_json::json!({ @@ -1912,7 +2068,9 @@ impl FeishuChannel { })) .send() .await - .map_err(|e| ChannelError::ConnectionError(format!("Send interactive card HTTP error: {}", e)))?; + .map_err(|e| { + ChannelError::ConnectionError(format!("Send interactive card HTTP error: {}", e)) + })?; #[derive(Deserialize)] struct SendResp { @@ -1920,10 +2078,9 @@ impl FeishuChannel { msg: String, } - let send_resp: SendResp = resp - .json() - .await - .map_err(|e| ChannelError::Other(format!("Parse send interactive card response error: {}", e)))?; + let send_resp: SendResp = resp.json().await.map_err(|e| { + ChannelError::Other(format!("Parse send interactive card response error: {}", e)) + })?; if send_resp.code != 0 { return Err(ChannelError::Other(format!( @@ -2012,7 +2169,10 @@ fn sanitize_download_file_name(file_name: &str) -> String { #[cfg(test)] mod tests { - use super::{extract_file_name_from_content_disposition, infer_download_filename, sanitize_download_file_name, FeishuChannel, MsgFormat}; + use super::{ + FeishuChannel, MsgFormat, extract_file_name_from_content_disposition, + infer_download_filename, sanitize_download_file_name, + }; #[test] fn markdown_post_uses_md_tag() { @@ -2033,7 +2193,10 @@ mod tests { #[test] fn headings_still_use_interactive() { let content = "intro\n## heading"; - assert_eq!(FeishuChannel::detect_msg_format(content), MsgFormat::Interactive); + assert_eq!( + FeishuChannel::detect_msg_format(content), + MsgFormat::Interactive + ); } #[test] @@ -2044,7 +2207,8 @@ mod tests { }); let headers = reqwest::header::HeaderMap::new(); - let filename = infer_download_filename(&content, &headers, "om_123", "file_key_123", "file"); + let filename = + infer_download_filename(&content, &headers, "om_123", "file_key_123", "file"); assert_eq!(filename, "om_123_demo-archive.zip"); } @@ -2060,7 +2224,8 @@ mod tests { reqwest::header::HeaderValue::from_static("attachment; filename=meeting-notes.zip"), ); - let filename = infer_download_filename(&content, &headers, "om_123", "file_key_123", "file"); + let filename = + infer_download_filename(&content, &headers, "om_123", "file_key_123", "file"); assert_eq!(filename, "om_123_meeting-notes.zip"); } @@ -2072,7 +2237,8 @@ mod tests { }); let headers = reqwest::header::HeaderMap::new(); - let filename = infer_download_filename(&content, &headers, "om_123", "file_key_123", "file"); + let filename = + infer_download_filename(&content, &headers, "om_123", "file_key_123", "file"); assert_eq!(filename, "om_123_file_key.bin"); } @@ -2105,7 +2271,7 @@ impl Channel for FeishuChannel { async fn start(&self, bus: Arc) -> Result<(), ChannelError> { if self.config.app_id.is_empty() || self.config.app_secret.is_empty() { return Err(ChannelError::ConfigError( - "Feishu app_id or app_secret is not configured".to_string() + "Feishu app_id or app_secret is not configured".to_string(), )); } @@ -2171,8 +2337,16 @@ impl Channel for FeishuChannel { } async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> { - let receive_id = if msg.chat_id.starts_with("oc_") { &msg.chat_id } else { &msg.reply_to.as_ref().unwrap_or(&msg.chat_id) }; - let receive_id_type = if msg.chat_id.starts_with("oc_") { "chat_id" } else { "open_id" }; + let receive_id = if msg.chat_id.starts_with("oc_") { + &msg.chat_id + } else { + &msg.reply_to.as_ref().unwrap_or(&msg.chat_id) + }; + let receive_id_type = if msg.chat_id.starts_with("oc_") { + "chat_id" + } else { + "open_id" + }; // If no media, use smart format detection if msg.media.is_empty() { @@ -2189,14 +2363,18 @@ impl Channel for FeishuChannel { match fmt { MsgFormat::Text => { // Short plain text – send as simple text message - let result = self.send_message_to_feishu(receive_id, receive_id_type, "text", content).await; + let result = self + .send_message_to_feishu(receive_id, receive_id_type, "text", content) + .await; self.remove_reaction_from_metadata(&msg.metadata).await; return result; } MsgFormat::Post => { // Medium content with links – send as rich-text post let post_body = Self::markdown_to_post(content); - let result = self.send_message_to_feishu(receive_id, receive_id_type, "post", &post_body).await; + let result = self + .send_message_to_feishu(receive_id, receive_id_type, "post", &post_body) + .await; self.remove_reaction_from_metadata(&msg.metadata).await; return result; } @@ -2208,10 +2386,20 @@ impl Channel for FeishuChannel { "config": { "wide_screen_mode": true }, "elements": chunk }); - if let Err(e) = self.send_interactive_card(receive_id, receive_id_type, &card.to_string()).await { + if let Err(e) = self + .send_interactive_card(receive_id, receive_id_type, &card.to_string()) + .await + { tracing::warn!(error = %e, "Failed to send interactive card, falling back to text"); // Fallback to plain text - let result = self.send_message_to_feishu(receive_id, receive_id_type, "text", content).await; + let result = self + .send_message_to_feishu( + receive_id, + receive_id_type, + "text", + content, + ) + .await; self.remove_reaction_from_metadata(&msg.metadata).await; return result; } @@ -2232,7 +2420,10 @@ impl Channel for FeishuChannel { if !msg.content.is_empty() { const MAX_TEXT_LENGTH: usize = 60_000; let truncated_text = if msg.content.len() > MAX_TEXT_LENGTH { - format!("{}...\n\n[Content truncated due to length limit]", &msg.content[..MAX_TEXT_LENGTH]) + format!( + "{}...\n\n[Content truncated due to length limit]", + &msg.content[..MAX_TEXT_LENGTH] + ) } else { msg.content.clone() }; @@ -2246,32 +2437,28 @@ impl Channel for FeishuChannel { for media_item in &msg.media { let path = &media_item.path; match media_item.media_type.as_str() { - "image" => { - match self.upload_image(path).await { - Ok(image_key) => { - content_parts.push(serde_json::json!({ - "tag": "image", - "image_key": image_key - })); - } - Err(e) => { - tracing::warn!(error = %e, path = %path, "Failed to upload image"); - } + "image" => match self.upload_image(path).await { + Ok(image_key) => { + content_parts.push(serde_json::json!({ + "tag": "image", + "image_key": image_key + })); } - } - "audio" | "file" | "video" => { - match self.upload_file(path).await { - Ok(file_key) => { - content_parts.push(serde_json::json!({ - "tag": "file", - "file_key": file_key - })); - } - Err(e) => { - tracing::warn!(error = %e, path = %path, "Failed to upload file"); - } + Err(e) => { + tracing::warn!(error = %e, path = %path, "Failed to upload image"); } - } + }, + "audio" | "file" | "video" => match self.upload_file(path).await { + Ok(file_key) => { + content_parts.push(serde_json::json!({ + "tag": "file", + "file_key": file_key + })); + } + Err(e) => { + tracing::warn!(error = %e, path = %path, "Failed to upload file"); + } + }, _ => { tracing::warn!(media_type = %media_item.media_type, "Unsupported media type for sending"); } @@ -2280,7 +2467,9 @@ impl Channel for FeishuChannel { // If no content parts after processing, just send empty text if content_parts.is_empty() { - let result = self.send_message_to_feishu(receive_id, receive_id_type, "text", "").await; + let result = self + .send_message_to_feishu(receive_id, receive_id_type, "text", "") + .await; // Remove pending reaction after sending (using metadata propagated from inbound) self.remove_reaction_from_metadata(&msg.metadata).await; return result; @@ -2296,10 +2485,15 @@ impl Channel for FeishuChannel { let content = serde_json::json!({ "content": content_parts - }).to_string(); + }) + .to_string(); - let resp = self.http_client - .post(format!("{}/im/v1/messages?receive_id_type={}", FEISHU_API_BASE, receive_id_type)) + let resp = self + .http_client + .post(format!( + "{}/im/v1/messages?receive_id_type={}", + FEISHU_API_BASE, receive_id_type + )) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", token)) .json(&serde_json::json!({ @@ -2309,7 +2503,9 @@ impl Channel for FeishuChannel { })) .send() .await - .map_err(|e| ChannelError::ConnectionError(format!("Send multimodal message HTTP error: {}", e)))?; + .map_err(|e| { + ChannelError::ConnectionError(format!("Send multimodal message HTTP error: {}", e)) + })?; #[derive(Deserialize)] struct SendResp { @@ -2317,11 +2513,16 @@ impl Channel for FeishuChannel { msg: String, } - let send_resp: SendResp = resp.json().await + let send_resp: SendResp = resp + .json() + .await .map_err(|e| ChannelError::Other(format!("Parse send response error: {}", e)))?; if send_resp.code != 0 { - return Err(ChannelError::Other(format!("Send multimodal message failed: code={} msg={}", send_resp.code, send_resp.msg))); + return Err(ChannelError::Other(format!( + "Send multimodal message failed: code={} msg={}", + send_resp.code, send_resp.msg + ))); } // Remove pending reaction after successfully sending diff --git a/src/channels/manager.rs b/src/channels/manager.rs index 568639e..ae5a932 100644 --- a/src/channels/manager.rs +++ b/src/channels/manager.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use std::sync::Arc; use tokio::sync::RwLock; -use crate::bus::{MessageBus, OutboundMessage}; +use crate::bus::MessageBus; use crate::channels::base::{Channel, ChannelError}; use crate::channels::feishu::FeishuChannel; use crate::config::Config; @@ -28,12 +28,18 @@ impl ChannelManager { } /// Initialize all Channel instances from config - pub async fn init(&self, config: &Config, _provider_config: crate::config::LLMProviderConfig) -> Result<(), ChannelError> { + pub async fn init( + &self, + config: &Config, + _provider_config: crate::config::LLMProviderConfig, + ) -> Result<(), ChannelError> { // Initialize Feishu channel if enabled if let Some(feishu_config) = config.channels.get("feishu") { if feishu_config.enabled { - let channel = FeishuChannel::new(feishu_config.clone(), _provider_config) - .map_err(|e| ChannelError::Other(format!("Failed to create Feishu channel: {}", e)))?; + let channel = + FeishuChannel::new(feishu_config.clone(), _provider_config).map_err(|e| { + ChannelError::Other(format!("Failed to create Feishu channel: {}", e)) + })?; self.channels .write() @@ -75,13 +81,12 @@ impl ChannelManager { self.channels.read().await.get(name).cloned() } - /// Dispatch an outbound message to the appropriate channel - pub async fn dispatch(&self, msg: OutboundMessage) -> Result<(), ChannelError> { - let channel_name = &msg.channel; - if let Some(channel) = self.get_channel(channel_name).await { - channel.send(msg).await - } else { - Err(ChannelError::Other(format!("Channel not found: {}", channel_name))) - } + pub async fn channels(&self) -> Vec<(String, Arc)> { + self.channels + .read() + .await + .iter() + .map(|(name, channel)| (name.clone(), channel.clone())) + .collect() } } diff --git a/src/channels/mod.rs b/src/channels/mod.rs index d8a5b1d..e6a94de 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -3,5 +3,5 @@ pub mod feishu; pub mod manager; pub use base::{Channel, ChannelError}; -pub use manager::ChannelManager; pub use feishu::FeishuChannel; +pub use manager::ChannelManager; diff --git a/src/cli/channel.rs b/src/cli/channel.rs index 56030e1..8cd4d57 100644 --- a/src/cli/channel.rs +++ b/src/cli/channel.rs @@ -1,4 +1,4 @@ -use tokio::io::{AsyncBufReadExt, BufReader, AsyncWriteExt}; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; pub struct CliChannel { read: BufReader, diff --git a/src/cli/input.rs b/src/cli/input.rs index d024374..10bf07b 100644 --- a/src/cli/input.rs +++ b/src/cli/input.rs @@ -49,18 +49,27 @@ impl InputHandler { } pub async fn write_output(&mut self, content: &str) -> Result<(), InputError> { - self.channel.write_line(content).await.map_err(InputError::IoError) + self.channel + .write_line(content) + .await + .map_err(InputError::IoError) } pub async fn write_response(&mut self, content: &str) -> Result<(), InputError> { - self.channel.write_response(content).await.map_err(InputError::IoError) + self.channel + .write_response(content) + .await + .map_err(InputError::IoError) } fn handle_special_commands(&self, line: &str) -> Option { let trimmed = line.trim(); let mut parts = trimmed.splitn(2, char::is_whitespace); let command = parts.next()?; - let arg = parts.next().map(str::trim).filter(|value| !value.is_empty()); + let arg = parts + .next() + .map(str::trim) + .filter(|value| !value.is_empty()); match command { "/quit" | "/exit" | "/q" => Some(InputCommand::Exit), @@ -105,14 +114,26 @@ mod tests { fn test_special_command_parsing() { let handler = InputHandler::new(); - assert_eq!(handler.handle_special_commands("/quit"), Some(InputCommand::Exit)); - assert_eq!(handler.handle_special_commands("/clear"), Some(InputCommand::Clear)); - assert_eq!(handler.handle_special_commands("/new"), Some(InputCommand::New(None))); + assert_eq!( + handler.handle_special_commands("/quit"), + Some(InputCommand::Exit) + ); + assert_eq!( + handler.handle_special_commands("/clear"), + Some(InputCommand::Clear) + ); + assert_eq!( + handler.handle_special_commands("/new"), + Some(InputCommand::New(None)) + ); assert_eq!( handler.handle_special_commands("/new planning"), Some(InputCommand::New(Some("planning".to_string()))) ); - assert_eq!(handler.handle_special_commands("/sessions"), Some(InputCommand::Sessions)); + assert_eq!( + handler.handle_special_commands("/sessions"), + Some(InputCommand::Sessions) + ); assert_eq!( handler.handle_special_commands("/use abc123"), Some(InputCommand::Use("abc123".to_string())) @@ -121,8 +142,14 @@ mod tests { handler.handle_special_commands("/rename project alpha"), Some(InputCommand::Rename("project alpha".to_string())) ); - assert_eq!(handler.handle_special_commands("/archive"), Some(InputCommand::Archive)); - assert_eq!(handler.handle_special_commands("/delete"), Some(InputCommand::Delete)); + assert_eq!( + handler.handle_special_commands("/archive"), + Some(InputCommand::Archive) + ); + assert_eq!( + handler.handle_special_commands("/delete"), + Some(InputCommand::Delete) + ); assert_eq!(handler.handle_special_commands("/unknown"), None); assert_eq!(handler.handle_special_commands("/use"), None); } diff --git a/src/client/mod.rs b/src/client/mod.rs index 4bd50be..a580c8a 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -5,7 +5,10 @@ use tokio_tungstenite::{connect_async, tungstenite::Message}; use crate::cli::{InputCommand, InputEvent, InputHandler}; -fn format_session_list(sessions: &[crate::protocol::SessionSummary], current_session_id: Option<&str>) -> String { +fn format_session_list( + sessions: &[crate::protocol::SessionSummary], + current_session_id: Option<&str>, +) -> String { if sessions.is_empty() { return "No sessions found.".to_string(); } @@ -25,11 +28,7 @@ fn format_session_list(sessions: &[crate::protocol::SessionSummary], current_ses }; lines.push(format!( "{} {} | {} | {} messages{}", - marker, - session.session_id, - session.title, - session.message_count, - archived, + marker, session.session_id, session.title, session.message_count, archived, )); } diff --git a/src/config/mod.rs b/src/config/mod.rs index 9969173..d44f3a0 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -123,7 +123,9 @@ fn default_allow_from() -> Vec { fn default_media_dir() -> String { let home = dirs::home_dir().unwrap_or_else(|| std::path::PathBuf::from(".")); - home.join(".picobot/media/feishu").to_string_lossy().to_string() + home.join(".picobot/media/feishu") + .to_string_lossy() + .to_string() } fn default_reaction_emoji() -> String { @@ -199,7 +201,10 @@ pub struct GatewayConfig { pub show_tool_results: bool, #[serde(default, rename = "session_ttl_hours")] pub session_ttl_hours: Option, - #[serde(default = "default_agent_prompt_reinject_every", rename = "agent_prompt_reinject_every")] + #[serde( + default = "default_agent_prompt_reinject_every", + rename = "agent_prompt_reinject_every" + )] pub agent_prompt_reinject_every: u64, } @@ -388,7 +393,10 @@ impl SchedulerSchedule { } pub fn is_one_shot(&self) -> bool { - matches!(self, SchedulerSchedule::Delay { .. } | SchedulerSchedule::At { .. }) + matches!( + self, + SchedulerSchedule::Delay { .. } | SchedulerSchedule::At { .. } + ) } pub fn normalized_for_storage(&self) -> Self { @@ -581,13 +589,19 @@ impl Config { } pub fn get_provider_config(&self, agent_name: &str) -> Result { - let agent = self.agents.get(agent_name) + let agent = self + .agents + .get(agent_name) .ok_or(ConfigError::AgentNotFound(agent_name.to_string()))?; - let provider = self.providers.get(&agent.provider) + let provider = self + .providers + .get(&agent.provider) .ok_or(ConfigError::ProviderNotFound(agent.provider.clone()))?; - let model = self.models.get(&agent.model) + let model = self + .models + .get(&agent.model) .ok_or(ConfigError::ModelNotFound(agent.model.clone()))?; Ok(LLMProviderConfig { @@ -621,11 +635,17 @@ pub enum ConfigError { impl std::fmt::Display for ConfigError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - ConfigError::ConfigNotFound(path) => write!(f, "Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json", path), + ConfigError::ConfigNotFound(path) => write!( + f, + "Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json", + path + ), ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name), ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name), ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name), - ConfigError::InvalidSchedulerJob(message) => write!(f, "Invalid scheduler job: {}", message), + ConfigError::InvalidSchedulerJob(message) => { + write!(f, "Invalid scheduler job: {}", message) + } ConfigError::InvalidTimezone(message) => write!(f, "Invalid timezone: {}", message), } } @@ -661,18 +681,19 @@ fn resolve_env_placeholders(content: &str) -> String { re.replace_all(content, |caps: ®ex::Captures| { let var_name = &caps[1]; env::var(var_name).unwrap_or_else(|_| caps[0].to_string()) - }).to_string() + }) + .to_string() } #[cfg(test)] mod tests { use super::*; - fn write_test_config() -> tempfile::NamedTempFile { - let file = tempfile::NamedTempFile::new().unwrap(); - std::fs::write( - file.path(), - r#"{ + fn write_test_config() -> tempfile::NamedTempFile { + let file = tempfile::NamedTempFile::new().unwrap(); + std::fs::write( + file.path(), + r#"{ "providers": { "aliyun": { "type": "openai", @@ -708,15 +729,15 @@ mod tests { "agent_prompt_reinject_every": 120 } }"#, - ) - .unwrap(); - file - } + ) + .unwrap(); + file + } #[test] fn test_config_load() { - let file = write_test_config(); - let config = Config::load(file.path().to_str().unwrap()).unwrap(); + let file = write_test_config(); + let config = Config::load(file.path().to_str().unwrap()).unwrap(); // Check providers assert!(config.providers.contains_key("volcengine")); @@ -876,7 +897,10 @@ mod tests { let config = Config::load(file.path().to_str().unwrap()).unwrap(); assert_eq!(config.time.timezone, "Asia/Shanghai"); - assert_eq!(config.time.parse_timezone().unwrap(), chrono_tz::Asia::Shanghai); + assert_eq!( + config.time.parse_timezone().unwrap(), + chrono_tz::Asia::Shanghai + ); } #[test] @@ -983,7 +1007,10 @@ mod tests { let config = Config::load(file.path().to_str().unwrap()).unwrap(); assert_eq!(config.agents["default"].max_tool_iterations, 100); assert_eq!(config.agents["default"].tool_result_max_chars, 20_000); - assert_eq!(config.agents["default"].context_tool_result_trim_chars, 2_000); + assert_eq!( + config.agents["default"].context_tool_result_trim_chars, + 2_000 + ); } #[test] @@ -1159,7 +1186,10 @@ mod tests { assert!(config.scheduler.enabled); assert_eq!(config.scheduler.tick_resolution_ms, 1_000); assert_eq!(config.scheduler.worker_queue_capacity, 64); - assert_eq!(config.scheduler.misfire_policy, SchedulerMisfirePolicy::Skip); + assert_eq!( + config.scheduler.misfire_policy, + SchedulerMisfirePolicy::Skip + ); assert!(config.scheduler.jobs.is_empty()); let effective_jobs = config.scheduler.effective_jobs(&config.time); @@ -1273,7 +1303,10 @@ mod tests { assert!(config.scheduler.enabled); assert_eq!(config.scheduler.tick_resolution_ms, 500); assert_eq!(config.scheduler.worker_queue_capacity, 8); - assert_eq!(config.scheduler.misfire_policy, SchedulerMisfirePolicy::CatchUp); + assert_eq!( + config.scheduler.misfire_policy, + SchedulerMisfirePolicy::CatchUp + ); assert_eq!(config.scheduler.jobs.len(), 1); let job = &config.scheduler.jobs[0]; @@ -1284,11 +1317,17 @@ mod tests { assert_eq!(job.startup_delay_secs, 5); assert_eq!(job.target.channel.as_deref(), Some("feishu")); assert_eq!(job.target.chat_id.as_deref(), Some("oc_demo")); - assert_eq!(job.payload.get("content").and_then(|value| value.as_str()), Some("heartbeat")); - assert_eq!(job.resolved_schedule().unwrap(), SchedulerSchedule::Interval { - seconds: 60, - startup_delay_secs: 5, - }); + assert_eq!( + job.payload.get("content").and_then(|value| value.as_str()), + Some("heartbeat") + ); + assert_eq!( + job.resolved_schedule().unwrap(), + SchedulerSchedule::Interval { + seconds: 60, + startup_delay_secs: 5, + } + ); } #[test] @@ -1362,21 +1401,30 @@ mod tests { config.scheduler.jobs[0].resolved_schedule().unwrap(), SchedulerSchedule::Delay { seconds: 30 } ); - assert_eq!(config.scheduler.jobs[0].kind, SchedulerJobKind::InternalEvent); + assert_eq!( + config.scheduler.jobs[0].kind, + SchedulerJobKind::InternalEvent + ); assert_eq!( config.scheduler.jobs[1].resolved_schedule().unwrap(), SchedulerSchedule::At { timestamp: "2026-04-23T09:00:00+00:00".to_string(), } ); - assert_eq!(config.scheduler.jobs[1].kind, SchedulerJobKind::OutboundMessage); + assert_eq!( + config.scheduler.jobs[1].kind, + SchedulerJobKind::OutboundMessage + ); assert_eq!( config.scheduler.jobs[2].resolved_schedule().unwrap(), SchedulerSchedule::Cron { expression: "0 9 * * *".to_string(), } ); - assert_eq!(config.scheduler.jobs[2].kind, SchedulerJobKind::InternalEvent); + assert_eq!( + config.scheduler.jobs[2].kind, + SchedulerJobKind::InternalEvent + ); } #[test] @@ -1433,7 +1481,10 @@ mod tests { assert_eq!(job.kind, SchedulerJobKind::AgentTask); assert_eq!(job.target.channel.as_deref(), Some("feishu")); assert_eq!(job.target.chat_id.as_deref(), Some("oc_demo")); - assert_eq!(job.payload.get("prompt").and_then(|value| value.as_str()), Some("请总结今天待办")); + assert_eq!( + job.payload.get("prompt").and_then(|value| value.as_str()), + Some("请总结今天待办") + ); } #[test] @@ -1495,29 +1546,40 @@ mod tests { job.target.session_chat_id.as_deref(), Some("scheduler/agent.daily_summary.background") ); - assert_eq!(job.payload.get("prompt").and_then(|value| value.as_str()), Some("请后台总结今天待办")); + assert_eq!( + job.payload.get("prompt").and_then(|value| value.as_str()), + Some("请后台总结今天待办") + ); } #[test] fn test_scheduler_schedule_validation_rejects_invalid_values() { - assert!(SchedulerSchedule::Delay { seconds: 0 } - .validate("delay.job") - .is_err()); - assert!(SchedulerSchedule::Interval { - seconds: 0, - startup_delay_secs: 0, - } - .validate("interval.job") - .is_err()); - assert!(SchedulerSchedule::At { - timestamp: "bad timestamp".to_string(), - } - .validate("at.job") - .is_err()); - assert!(SchedulerSchedule::Cron { - expression: "bad cron".to_string(), - } - .validate("cron.job") - .is_err()); + assert!( + SchedulerSchedule::Delay { seconds: 0 } + .validate("delay.job") + .is_err() + ); + assert!( + SchedulerSchedule::Interval { + seconds: 0, + startup_delay_secs: 0, + } + .validate("interval.job") + .is_err() + ); + assert!( + SchedulerSchedule::At { + timestamp: "bad timestamp".to_string(), + } + .validate("at.job") + .is_err() + ); + assert!( + SchedulerSchedule::Cron { + expression: "bad cron".to_string(), + } + .validate("cron.job") + .is_err() + ); } } diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 6f942e2..d69d4af 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -1,10 +1,11 @@ pub mod http; +pub mod processor; pub mod session; pub mod ws; +use axum::{Router, routing}; use std::collections::HashMap; use std::sync::Arc; -use axum::{routing, Router}; use tokio::net::TcpListener; use crate::bus::{MessageBus, OutboundDispatcher}; @@ -14,7 +15,8 @@ use crate::config::LLMProviderConfig; use crate::logging; use crate::scheduler::Scheduler; use crate::skills::SkillRuntime; -use session::{BusToolCallEmitter, SessionManager}; +use processor::InboundProcessor; +use session::SessionManager; pub struct GatewayState { pub config: Config, @@ -61,74 +63,17 @@ impl GatewayState { /// Start the message processing loops pub async fn start_message_processing(&self) { - let bus_for_inbound = self.bus.clone(); let bus_for_outbound = self.bus.clone(); - let session_manager = self.session_manager.clone(); - - // Spawn inbound message processor - // This consumes from bus.inbound, processes via SessionManager, publishes to bus.outbound - tokio::spawn(async move { - tracing::info!("Inbound processor started"); - loop { - let inbound = bus_for_inbound.consume_inbound().await; - #[cfg(debug_assertions)] - { - tracing::debug!( - channel = %inbound.channel, - chat_id = %inbound.chat_id, - sender = %inbound.sender_id, - content = %inbound.content, - media_count = %inbound.media.len(), - "Processing inbound message" - ); - if !inbound.media.is_empty() { - for (i, m) in inbound.media.iter().enumerate() { - tracing::debug!(media_index = i, media_type = %m.media_type, path = %m.path, "Media item"); - } - } - } - - // Process via session manager - let live_emitter = Arc::new(BusToolCallEmitter::new( - bus_for_inbound.clone(), - inbound.channel.clone(), - inbound.chat_id.clone(), - inbound.forwarded_metadata.clone(), - session_manager.show_tool_results(), - )); - match session_manager.handle_message( - &inbound.channel, - &inbound.sender_id, - &inbound.chat_id, - &inbound.content, - inbound.media, - Some(live_emitter), - ).await { - Ok(outbound_messages) => { - // Forward channel-specific metadata from inbound to outbound. - // This allows channels to propagate context (e.g. feishu message_id for reaction cleanup) - // without gateway needing channel-specific code. - for mut outbound in outbound_messages { - outbound.metadata.extend(inbound.forwarded_metadata.clone()); - if let Err(e) = bus_for_inbound.publish_outbound(outbound).await { - tracing::error!(error = %e, "Failed to publish outbound"); - } - } - } - Err(e) => { - tracing::error!(error = %e, "Failed to handle message"); - } - } - } - }); + let inbound_processor = + InboundProcessor::new(self.bus.clone(), self.session_manager.clone()); + tokio::spawn(inbound_processor.run()); // Spawn outbound dispatcher let dispatcher = OutboundDispatcher::new(bus_for_outbound); let channel_manager = self.channel_manager.clone(); - // Register channels with dispatcher - if let Some(channel) = channel_manager.get_channel("feishu").await { - dispatcher.register_channel("feishu", channel).await; + for (name, channel) in channel_manager.channels().await { + dispatcher.register_channel(&name, channel).await; } tokio::spawn(async move { @@ -138,7 +83,10 @@ impl GatewayState { } } -pub async fn run(host: Option, port: Option) -> Result<(), Box> { +pub async fn run( + host: Option, + port: Option, +) -> Result<(), Box> { let config = Config::load_default()?; let timezone = config.time.parse_timezone()?; @@ -152,7 +100,10 @@ pub async fn run(host: Option, port: Option) -> Result<(), Box, + session_manager: SessionManager, +} + +impl InboundProcessor { + pub fn new(bus: Arc, session_manager: SessionManager) -> Self { + Self { + bus, + session_manager, + } + } + + pub async fn run(self) { + tracing::info!("Inbound processor started"); + + loop { + let inbound = self.bus.consume_inbound().await; + #[cfg(debug_assertions)] + { + tracing::debug!( + channel = %inbound.channel, + chat_id = %inbound.chat_id, + sender = %inbound.sender_id, + content = %inbound.content, + media_count = %inbound.media.len(), + "Processing inbound message" + ); + if !inbound.media.is_empty() { + for (i, media) in inbound.media.iter().enumerate() { + tracing::debug!(media_index = i, media_type = %media.media_type, path = %media.path, "Media item"); + } + } + } + + let live_emitter = Arc::new(BusToolCallEmitter::new( + self.bus.clone(), + inbound.channel.clone(), + inbound.chat_id.clone(), + inbound.forwarded_metadata.clone(), + self.session_manager.show_tool_results(), + )); + + match self + .session_manager + .handle_message( + &inbound.channel, + &inbound.sender_id, + &inbound.chat_id, + &inbound.content, + inbound.media, + Some(live_emitter), + ) + .await + { + Ok(outbound_messages) => { + for mut outbound in outbound_messages { + outbound.metadata.extend(inbound.forwarded_metadata.clone()); + if let Err(error) = self.bus.publish_outbound(outbound).await { + tracing::error!(error = %error, "Failed to publish outbound"); + } + } + } + Err(error) => { + tracing::error!(error = %error, "Failed to handle message"); + } + } + } + } +} diff --git a/src/gateway/session.rs b/src/gateway/session.rs index 0763fd9..d47a70a 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -1,37 +1,33 @@ +use crate::agent::{AgentError, AgentLoop, ContextCompressor, EmittedMessageHandler}; +use crate::bus::{ + ChatMessage, MessageBus, OutboundMessage, SYSTEM_CONTEXT_AGENT_PROMPT, + SYSTEM_CONTEXT_SCHEDULED_PROMPT, +}; +use crate::config::LLMProviderConfig; +use crate::protocol::WsOutbound; +use crate::providers::{ChatCompletionRequest, Message, create_provider}; +use crate::skills::SkillRuntime; +use crate::storage::{SessionRecord, SessionStore, persistent_session_id}; +use crate::tools::{ + BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, HttpRequestTool, + MemoryManageTool, MemorySearchTool, SchedulerManageTool, SkillListTool, SkillManageTool, + TimeTool, ToolContext, ToolRegistry, WebFetchTool, +}; +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; use std::fs; use std::path::Path; use std::sync::Arc; use std::time::{Duration, Instant}; -use async_trait::async_trait; -use serde::{Deserialize, Serialize}; use tokio::sync::{Mutex, mpsc}; use uuid::Uuid; -use crate::bus::{ - ChatMessage, - MessageBus, - OutboundMessage, - SYSTEM_CONTEXT_AGENT_PROMPT, - SYSTEM_CONTEXT_SCHEDULED_PROMPT, -}; -use crate::config::LLMProviderConfig; -use crate::agent::{AgentLoop, AgentError, ContextCompressor, EmittedMessageHandler}; -use crate::providers::{create_provider, ChatCompletionRequest, Message}; -use crate::protocol::WsOutbound; -use crate::skills::SkillRuntime; -use crate::storage::{SessionRecord, SessionStore, persistent_session_id}; -use crate::tools::{ - BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, - HttpRequestTool, MemoryManageTool, MemorySearchTool, SchedulerManageTool, SkillListTool, SkillManageTool, ToolContext, ToolRegistry, - TimeTool, WebFetchTool, -}; const DEFAULT_AGENT_PROMPT: &str = include_str!("default_agent_prompt.md"); const MANAGED_AGENT_MEMORY_BLOCK_START: &str = ""; const MANAGED_AGENT_MEMORY_BLOCK_END: &str = ""; const MANAGED_AGENT_MEMORY_TITLE: &str = "## 用户记忆摘要"; -const MEMORY_MAINTENANCE_SYSTEM_PROMPT: &str = - include_str!("memory_maintenance_system_prompt.md"); +const MEMORY_MAINTENANCE_SYSTEM_PROMPT: &str = include_str!("memory_maintenance_system_prompt.md"); const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000]; const SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT: &str = "系统说明:当前输入来自一次已经触发的定时任务执行。你现在需要执行任务内容本身,而不是创建、修改、恢复、暂停或查询新的定时任务。除非当前任务内容明确要求管理调度器,否则不要调用任何定时任务管理工具;像“每小时”、“每天”、“cron”、“定时”等词,只应视为任务背景,不应再解释为新的建任务请求。"; @@ -90,7 +86,9 @@ pub(crate) struct MemoryMaintenanceScopeResult { pub(crate) output: MemoryMaintenanceModelOutput, } -fn build_memory_maintenance_plan(memories: &[crate::storage::MemoryRecord]) -> MemoryMaintenancePlan { +fn build_memory_maintenance_plan( + memories: &[crate::storage::MemoryRecord], +) -> MemoryMaintenancePlan { let mut plan = MemoryMaintenancePlan::default(); let mut seen = HashSet::new(); @@ -132,7 +130,9 @@ fn memory_maintenance_category(namespace: &str) -> MemoryMaintenanceCategory { match namespace.trim().to_ascii_lowercase().as_str() { "profile" | "facts" | "identity" => MemoryMaintenanceCategory::UserFacts, "preferences" | "style" | "likes" => MemoryMaintenanceCategory::Preferences, - "patterns" | "behavior" | "habits" | "workflow" => MemoryMaintenanceCategory::BehaviorPatterns, + "patterns" | "behavior" | "habits" | "workflow" => { + MemoryMaintenanceCategory::BehaviorPatterns + } _ => MemoryMaintenanceCategory::Other, } } @@ -269,7 +269,10 @@ fn preview_text(content: &str, max_chars: usize) -> String { preview.replace('\n', "\\n") } -fn enrich_user_content_with_media_refs(content: &str, media_refs: &[String]) -> Result { +fn enrich_user_content_with_media_refs( + content: &str, + media_refs: &[String], +) -> Result { if media_refs.is_empty() { return Ok(content.to_string()); } @@ -295,19 +298,23 @@ fn combine_managed_memory_markdown(chunks: &[String]) -> String { .filter(|line| !line.is_empty()) .collect::>(); - let is_subset_of_other = normalized_chunks.iter().enumerate().any(|(other_index, other)| { - if index == other_index { - return false; - } + let is_subset_of_other = + normalized_chunks + .iter() + .enumerate() + .any(|(other_index, other)| { + if index == other_index { + return false; + } - let other_lines = other - .lines() - .map(str::trim) - .filter(|line| !line.is_empty()) - .collect::>(); + let other_lines = other + .lines() + .map(str::trim) + .filter(|line| !line.is_empty()) + .collect::>(); - chunk_lines.len() < other_lines.len() && chunk_lines.is_subset(&other_lines) - }); + chunk_lines.len() < other_lines.len() && chunk_lines.is_subset(&other_lines) + }); if !is_subset_of_other && !combined.iter().any(|existing: &String| existing == chunk) { combined.push((*chunk).to_string()); @@ -355,7 +362,9 @@ fn apply_memory_maintenance_output( let existing_target_id = source_candidates .iter() - .find(|candidate| candidate.namespace == merge.namespace && candidate.key == merge.memory_key) + .find(|candidate| { + candidate.namespace == merge.namespace && candidate.key == merge.memory_key + }) .map(|candidate| candidate.id.clone()); store @@ -375,13 +384,18 @@ fn apply_memory_maintenance_output( .map_err(|err| AgentError::Other(format!("upsert merged memory error: {}", err)))?; for candidate in source_candidates { - if existing_target_id.as_ref().is_some_and(|target_id| target_id == &candidate.id) { + if existing_target_id + .as_ref() + .is_some_and(|target_id| target_id == &candidate.id) + { continue; } if deleted_ids.insert(candidate.id.clone()) { store .delete_memory("user", scope_key, &candidate.namespace, &candidate.key) - .map_err(|err| AgentError::Other(format!("delete merged source memory error: {}", err)))?; + .map_err(|err| { + AgentError::Other(format!("delete merged source memory error: {}", err)) + })?; } } } @@ -391,7 +405,9 @@ fn apply_memory_maintenance_output( if deleted_ids.insert(candidate.id.clone()) { store .delete_memory("user", scope_key, &candidate.namespace, &candidate.key) - .map_err(|err| AgentError::Other(format!("delete low value memory error: {}", err)))?; + .map_err(|err| { + AgentError::Other(format!("delete low value memory error: {}", err)) + })?; } } } @@ -521,7 +537,10 @@ impl Session { Ok(()) } - pub fn ensure_agent_prompt_before_user_message(&mut self, chat_id: &str) -> Result<(), AgentError> { + pub fn ensure_agent_prompt_before_user_message( + &mut self, + chat_id: &str, + ) -> Result<(), AgentError> { self.ensure_chat_loaded(chat_id)?; let session_id = self.persistent_session_id(chat_id); @@ -530,10 +549,12 @@ impl Session { .get_session(&session_id) .map_err(|err| AgentError::Other(format!("get session error: {}", err)))? .ok_or_else(|| AgentError::Other("Session not found".to_string()))?; - let active_user_turns = self - .store - .count_active_user_messages(&session_id) - .map_err(|err| AgentError::Other(format!("count active user messages error: {}", err)))?; + let active_user_turns = + self.store + .count_active_user_messages(&session_id) + .map_err(|err| { + AgentError::Other(format!("count active user messages error: {}", err)) + })?; if self.agent_prompt_reinject_every > 0 && active_user_turns > 0 @@ -550,7 +571,9 @@ impl Session { )?; self.store .mark_agent_prompt_reinjected(&session_id) - .map_err(|err| AgentError::Other(format!("mark agent prompt reinjection error: {}", err)))?; + .map_err(|err| { + AgentError::Other(format!("mark agent prompt reinjection error: {}", err)) + })?; } } @@ -607,16 +630,26 @@ impl Session { } /// 将消息写入内存与持久化层 - pub fn append_persisted_message(&mut self, chat_id: &str, message: ChatMessage) -> Result<(), AgentError> { + pub fn append_persisted_message( + &mut self, + chat_id: &str, + message: ChatMessage, + ) -> Result<(), AgentError> { let session_id = self.persistent_session_id(chat_id); self.store .append_message(&session_id, &message) - .map_err(|err| AgentError::Other(format!("append message persistence error: {}", err)))?; + .map_err(|err| { + AgentError::Other(format!("append message persistence error: {}", err)) + })?; self.add_message(chat_id, message); Ok(()) } - pub fn append_persisted_messages(&mut self, chat_id: &str, messages: I) -> Result<(), AgentError> + pub fn append_persisted_messages( + &mut self, + chat_id: &str, + messages: I, + ) -> Result<(), AgentError> where I: IntoIterator, { @@ -661,12 +694,18 @@ impl Session { .unwrap_or(false) } - fn stale_result_diagnostics(&self, chat_id: &str) -> (Option<&str>, Option, bool, usize) { + fn stale_result_diagnostics( + &self, + chat_id: &str, + ) -> (Option<&str>, Option, bool, usize) { let latest_user = self.latest_user_message(chat_id); let latest_user_id = latest_user.map(|message| message.id.as_str()); let latest_user_preview = latest_user.map(|message| preview_text(&message.content, 80)); let compression_in_flight = self.compression_in_flight.contains(chat_id); - let history_len = self.get_history(chat_id).map(|history| history.len()).unwrap_or(0); + let history_len = self + .get_history(chat_id) + .map(|history| history.len()) + .unwrap_or(0); ( latest_user_id, @@ -688,7 +727,9 @@ impl Session { for chat_id in chat_ids { self.store .clear_messages(&self.persistent_session_id(&chat_id)) - .map_err(|err| AgentError::Other(format!("clear history persistence error: {}", err)))?; + .map_err(|err| { + AgentError::Other(format!("clear history persistence error: {}", err)) + })?; } Ok(()) @@ -747,7 +788,12 @@ impl Session { sender_id: Option<&str>, message_id: Option<&str>, ) -> Result { - self.create_agent_with_provider_config(chat_id, sender_id, message_id, self.provider_config.clone()) + self.create_agent_with_provider_config( + chat_id, + sender_id, + message_id, + self.provider_config.clone(), + ) } pub fn create_agent_with_provider_config( @@ -758,23 +804,19 @@ impl Session { provider_config: LLMProviderConfig, ) -> Result { let session_id = self.persistent_session_id(chat_id); - AgentLoop::with_tools_and_skills( - provider_config, - self.tools.clone(), - self.skills.clone(), - ) - .map(|agent| { - agent - .with_skill_event_store(self.store.clone(), session_id.clone()) - .with_tool_context(ToolContext { - channel_name: Some(self.channel_name.clone()), - sender_id: sender_id.map(str::to_string), - chat_id: Some(chat_id.to_string()), - session_id: Some(session_id), - message_id: message_id.map(str::to_string), - message_seq: None, - }) - }) + AgentLoop::with_tools_and_skills(provider_config, self.tools.clone(), self.skills.clone()) + .map(|agent| { + agent + .with_skill_event_store(self.store.clone(), session_id.clone()) + .with_tool_context(ToolContext { + channel_name: Some(self.channel_name.clone()), + sender_id: sender_id.map(str::to_string), + chat_id: Some(chat_id.to_string()), + session_id: Some(session_id), + message_id: message_id.map(str::to_string), + message_seq: None, + }) + }) } fn ensure_initial_agent_prompt(&mut self, chat_id: &str) -> Result<(), AgentError> { @@ -868,8 +910,8 @@ fn default_tools( registry.register(HttpRequestTool::new( vec!["*".to_string()], // 允许所有域名,实际使用时建议限制 1_000_000, // max_response_size - 30, // timeout_secs - false, // allow_private_hosts + 30, // timeout_secs + false, // allow_private_hosts )); registry.register(WebFetchTool::new(50_000, 30)); // max_chars, timeout_secs registry @@ -901,7 +943,6 @@ pub(crate) fn handle_in_chat_command( } } - pub(crate) async fn schedule_background_history_compaction( session: Arc>, chat_id: impl Into, @@ -934,14 +975,24 @@ pub(crate) async fn schedule_background_history_compaction( ) }; - let (store, session_id, expected_reset_cutoff_seq, snapshot_end_seq, history, compressor, provider_config) = snapshot; + let ( + store, + session_id, + expected_reset_cutoff_seq, + snapshot_end_seq, + history, + compressor, + provider_config, + ) = snapshot; let session_for_task = session.clone(); let chat_id_for_task = chat_id.clone(); tokio::spawn(async move { tracing::info!(chat_id = %chat_id_for_task, snapshot_end_seq, "Starting background history compaction"); - let compaction_result = compressor.build_compaction_plan(&history, &provider_config).await; + let compaction_result = compressor + .build_compaction_plan(&history, &provider_config) + .await; let mut committed = false; match compaction_result { @@ -1005,7 +1056,9 @@ impl SessionManager { ); let known_agents = provider_configs.keys().cloned().collect::>(); - if let Err(err) = store.append_skill_event(None, "discovered", None, &skills.discovery_event_payload()) { + if let Err(err) = + store.append_skill_event(None, "discovered", None, &skills.discovery_event_payload()) + { tracing::warn!(error = %err, "Failed to record skill discovery event"); } @@ -1062,11 +1115,15 @@ impl SessionManager { Ok(Some(build_memory_maintenance_plan(&memories))) } - pub(crate) fn upsert_managed_agent_memory_summary(&self, markdown_body: &str) -> Result<(), AgentError> { + pub(crate) fn upsert_managed_agent_memory_summary( + &self, + markdown_body: &str, + ) -> Result<(), AgentError> { let path = agent_prompt_path()?; let existing = if path.exists() { - fs::read_to_string(&path) - .map_err(|err| AgentError::Other(format!("read agent prompt file error: {}", err)))? + fs::read_to_string(&path).map_err(|err| { + AgentError::Other(format!("read agent prompt file error: {}", err)) + })? } else { DEFAULT_AGENT_PROMPT.to_string() }; @@ -1083,7 +1140,9 @@ impl SessionManager { return Ok(None); }; - self.summarize_memory_maintenance_plan(scope_key, &plan).await.map(Some) + self.summarize_memory_maintenance_plan(scope_key, &plan) + .await + .map(Some) } async fn summarize_memory_maintenance_plan( @@ -1091,10 +1150,10 @@ impl SessionManager { scope_key: &str, plan: &MemoryMaintenancePlan, ) -> Result { - let provider_config = self.provider_config_for_agent(None)?; - let provider = create_provider(provider_config) - .map_err(|err| AgentError::Other(format!("create maintenance provider error: {}", err)))?; + let provider = create_provider(provider_config).map_err(|err| { + AgentError::Other(format!("create maintenance provider error: {}", err)) + })?; let request = ChatCompletionRequest { messages: vec![ @@ -1129,8 +1188,8 @@ impl SessionManager { } Err(err) => { let error_text = err.to_string(); - let should_retry = delay_ms.is_some() - && is_recoverable_maintenance_llm_error(&error_text); + let should_retry = + delay_ms.is_some() && is_recoverable_maintenance_llm_error(&error_text); last_error = Some(error_text.clone()); if should_retry { @@ -1141,11 +1200,15 @@ impl SessionManager { error = %error_text, "Memory maintenance model request failed, retrying" ); - tokio::time::sleep(Duration::from_millis(delay_ms.unwrap_or_default())).await; + tokio::time::sleep(Duration::from_millis(delay_ms.unwrap_or_default())) + .await; continue; } - return Err(AgentError::Other(format!("memory maintenance model error: {}", error_text))); + return Err(AgentError::Other(format!( + "memory maintenance model error: {}", + error_text + ))); } } } @@ -1160,8 +1223,8 @@ impl SessionManager { let raw_content = strip_json_code_fence(&response.content); let json_candidate = extract_json_object(raw_content).unwrap_or(raw_content); - let output: MemoryMaintenanceModelOutput = serde_json::from_str(json_candidate) - .map_err(|err| { + let output: MemoryMaintenanceModelOutput = + serde_json::from_str(json_candidate).map_err(|err| { tracing::error!( scope_key = %scope_key, error = %err, @@ -1185,7 +1248,9 @@ impl SessionManager { return Ok(None); }; - let output = self.summarize_memory_maintenance_plan(scope_key, &plan).await?; + let output = self + .summarize_memory_maintenance_plan(scope_key, &plan) + .await?; apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &output)?; Ok(Some(output)) @@ -1198,11 +1263,16 @@ impl SessionManager { let scope_keys = if let Some(cutoff) = updated_since { self.store .list_memory_scope_keys_updated_since("user", cutoff) - .map_err(|err| AgentError::Other(format!("list memory scope keys updated since error: {}", err)))? + .map_err(|err| { + AgentError::Other(format!( + "list memory scope keys updated since error: {}", + err + )) + })? } else { - self.store - .list_memory_scope_keys("user") - .map_err(|err| AgentError::Other(format!("list memory scope keys error: {}", err)))? + self.store.list_memory_scope_keys("user").map_err(|err| { + AgentError::Other(format!("list memory scope keys error: {}", err)) + })? }; let mut results = Vec::new(); @@ -1228,7 +1298,10 @@ impl SessionManager { Ok(results) } - pub fn provider_config_for_agent(&self, agent_name: Option<&str>) -> Result { + pub fn provider_config_for_agent( + &self, + agent_name: Option<&str>, + ) -> Result { select_provider_config(&self.provider_config, &self.provider_configs, agent_name) } @@ -1238,13 +1311,19 @@ impl SessionManager { .map_err(|err| AgentError::Other(format!("create session error: {}", err))) } - pub fn get_session_record(&self, session_id: &str) -> Result, AgentError> { + pub fn get_session_record( + &self, + session_id: &str, + ) -> Result, AgentError> { self.store .get_session(session_id) .map_err(|err| AgentError::Other(format!("get session error: {}", err))) } - pub fn list_cli_sessions(&self, include_archived: bool) -> Result, AgentError> { + pub fn list_cli_sessions( + &self, + include_archived: bool, + ) -> Result, AgentError> { self.store .list_sessions("cli", include_archived) .map_err(|err| AgentError::Other(format!("list sessions error: {}", err))) @@ -1284,7 +1363,8 @@ impl SessionManager { pub async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> { let mut inner = self.inner.lock().await; - let should_recreate = if let Some(last_active) = inner.session_timestamps.get(channel_name) { + let should_recreate = if let Some(last_active) = inner.session_timestamps.get(channel_name) + { let elapsed = last_active.elapsed(); if elapsed > inner.session_ttl { tracing::info!(channel = %channel_name, elapsed_hours = elapsed.as_secs() / 3600, "Session expired, recreating"); @@ -1317,7 +1397,9 @@ impl SessionManager { let arc = Arc::new(Mutex::new(session)); inner.sessions.insert(channel_name.to_string(), arc.clone()); - inner.session_timestamps.insert(channel_name.to_string(), Instant::now()); + inner + .session_timestamps + .insert(channel_name.to_string(), Instant::now()); } Ok(()) @@ -1332,7 +1414,9 @@ impl SessionManager { /// 更新最后活跃时间 pub async fn touch(&self, channel_name: &str) { let mut inner = self.inner.lock().await; - inner.session_timestamps.insert(channel_name.to_string(), Instant::now()); + inner + .session_timestamps + .insert(channel_name.to_string(), Instant::now()); } pub async fn cleanup_expired_sessions(&self) -> usize { @@ -1401,7 +1485,9 @@ impl SessionManager { session_guard.ensure_persistent_session(chat_id)?; session_guard.ensure_chat_loaded(chat_id)?; - if let Some(command_response) = handle_in_chat_command(&mut session_guard, chat_id, content)? { + if let Some(command_response) = + handle_in_chat_command(&mut session_guard, chat_id, content)? + { return Ok(vec![OutboundMessage::assistant( channel_name.to_string(), chat_id.to_string(), @@ -1427,7 +1513,8 @@ impl SessionManager { session_guard.record_skill_offer(chat_id)?; // 创建 agent 并处理 - let mut agent = session_guard.create_agent(chat_id, Some(sender_id), Some(&user_message.id))?; + let mut agent = + session_guard.create_agent(chat_id, Some(sender_id), Some(&user_message.id))?; if let Some(handler) = live_emitter.clone() { agent = agent.with_emitted_message_handler(handler); } @@ -1457,7 +1544,8 @@ impl SessionManager { Vec::new() } else { // 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复 - session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?; + session_guard + .append_persisted_messages(chat_id, result.emitted_messages.clone())?; should_schedule_compaction = true; result @@ -1481,7 +1569,9 @@ impl SessionManager { }; if should_schedule_compaction { - if let Err(error) = schedule_background_history_compaction(session.clone(), chat_id.to_string()).await { + if let Err(error) = + schedule_background_history_compaction(session.clone(), chat_id.to_string()).await + { tracing::warn!(channel = %channel_name, chat_id = %chat_id, error = %error, "Failed to schedule background history compaction"); } } @@ -1578,13 +1668,16 @@ impl SessionManager { ); Vec::new() } else { - session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?; + session_guard + .append_persisted_messages(chat_id, result.emitted_messages.clone())?; should_schedule_compaction = true; result .emitted_messages .iter() - .filter(|message| should_display_message_to_user(self.show_tool_results, message)) + .filter(|message| { + should_display_message_to_user(self.show_tool_results, message) + }) .flat_map(|message| { OutboundMessage::from_chat_message( channel_name, @@ -1599,7 +1692,9 @@ impl SessionManager { }; if should_schedule_compaction { - if let Err(error) = schedule_background_history_compaction(session.clone(), chat_id.to_string()).await { + if let Err(error) = + schedule_background_history_compaction(session.clone(), chat_id.to_string()).await + { tracing::warn!(channel = %channel_name, chat_id = %chat_id, error = %error, "Failed to schedule background history compaction for scheduled task"); } } @@ -1624,17 +1719,22 @@ fn should_display_message_to_user(show_tool_results: bool, message: &ChatMessage show_tool_results || matches!( - message.tool_state.as_ref().unwrap_or(&crate::bus::message::ToolMessageState::Completed), + message + .tool_state + .as_ref() + .unwrap_or(&crate::bus::message::ToolMessageState::Completed), crate::bus::message::ToolMessageState::PendingUserAction ) } fn compose_scheduled_task_system_prompt(system_prompt: Option<&str>) -> String { - match system_prompt.map(str::trim).filter(|value| !value.is_empty()) { + match system_prompt + .map(str::trim) + .filter(|value| !value.is_empty()) + { Some(system_prompt) => format!( "{}\n\n任务专属要求:{}", - SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT, - system_prompt + SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT, system_prompt ), None => SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT.to_string(), } @@ -1647,20 +1747,19 @@ fn select_provider_config( ) -> Result { match agent_name.map(str::trim).filter(|value| !value.is_empty()) { None | Some("default") => Ok(default_provider_config.clone()), - Some(agent_name) => provider_configs - .get(agent_name) - .cloned() - .ok_or_else(|| AgentError::Other(format!("Scheduled agent '{}' not found", agent_name))), + Some(agent_name) => provider_configs.get(agent_name).cloned().ok_or_else(|| { + AgentError::Other(format!("Scheduled agent '{}' not found", agent_name)) + }), } } #[cfg(test)] mod tests { - use axum::http::StatusCode; use super::*; - use axum::{Json, Router, routing::post}; use crate::bus::MessageBus; use crate::storage::MemoryRecord; + use axum::http::StatusCode; + use axum::{Json, Router, routing::post}; use serde_json::{Value, json}; use std::collections::HashMap; use std::sync::{ @@ -1714,7 +1813,8 @@ mod tests { test_provider_config_named("planner-provider", "planner-model"), )]); - let selected = select_provider_config(&default_provider, &provider_configs, Some("planner")).unwrap(); + let selected = + select_provider_config(&default_provider, &provider_configs, Some("planner")).unwrap(); assert_eq!(selected.name, "planner-provider"); assert_eq!(selected.model_id, "planner-model"); } @@ -1811,13 +1911,19 @@ mod tests { .unwrap(); let second = session.create_user_message("second", Vec::new()); - session.append_persisted_message("chat-1", second.clone()).unwrap(); + session + .append_persisted_message("chat-1", second.clone()) + .unwrap(); session .append_persisted_message("chat-1", ChatMessage::assistant("answer-2")) .unwrap(); let session_id = session.persistent_session_id("chat-1"); - let snapshot_end_seq = store.get_session(&session_id).unwrap().unwrap().message_count; + let snapshot_end_seq = store + .get_session(&session_id) + .unwrap() + .unwrap() + .message_count; let preserved_messages = session.get_history("chat-1").unwrap().clone(); store @@ -1843,7 +1949,8 @@ mod tests { let default_provider = test_provider_config_named("default-provider", "default-model"); let provider_configs = HashMap::new(); - let selected = select_provider_config(&default_provider, &provider_configs, Some("default")).unwrap(); + let selected = + select_provider_config(&default_provider, &provider_configs, Some("default")).unwrap(); assert_eq!(selected.name, "default-provider"); assert_eq!(selected.model_id, "default-model"); } @@ -2126,7 +2233,11 @@ mod tests { .expect("missing scheduled system prompt"); assert!(scheduled_prompt.content.contains("已经触发的定时任务执行")); - assert!(scheduled_prompt.content.contains("不要调用任何定时任务管理工具")); + assert!( + scheduled_prompt + .content + .contains("不要调用任何定时任务管理工具") + ); assert!(scheduled_prompt.content.contains("你是邮箱待办同步助手。")); } @@ -2199,7 +2310,10 @@ mod tests { assert_eq!(output.user_facts, vec!["用户在做AI产品".to_string()]); assert_eq!(output.preferences, vec!["偏好简洁表达".to_string()]); - assert_eq!(output.behavior_patterns, vec!["习惯先问方案再要代码".to_string()]); + assert_eq!( + output.behavior_patterns, + vec!["习惯先问方案再要代码".to_string()] + ); assert!(output.managed_markdown.contains("### 用户事实")); } @@ -2208,8 +2322,12 @@ mod tests { assert!(is_recoverable_maintenance_llm_error( "error sending request for url (https://example.invalid/v1/chat/completions)" )); - assert!(is_recoverable_maintenance_llm_error("API error 504 Gateway Timeout: stream timeout")); - assert!(!is_recoverable_maintenance_llm_error("API error 401 Unauthorized")); + assert!(is_recoverable_maintenance_llm_error( + "API error 504 Gateway Timeout: stream timeout" + )); + assert!(!is_recoverable_maintenance_llm_error( + "API error 401 Unauthorized" + )); } #[test] @@ -2217,7 +2335,10 @@ mod tests { let wrapped = "下面是结果:\n```json\n{\n \"user_facts\": [],\n \"preferences\": []\n}\n```\n请查收"; let stripped = strip_json_code_fence(wrapped); let extracted = extract_json_object(stripped).unwrap(); - assert_eq!(extracted, "{\n \"user_facts\": [],\n \"preferences\": []\n}"); + assert_eq!( + extracted, + "{\n \"user_facts\": [],\n \"preferences\": []\n}" + ); } #[tokio::test] @@ -2457,7 +2578,9 @@ mod tests { }) .unwrap(); - let plan = build_memory_maintenance_plan(&store.list_memories_for_scope("user", scope_key).unwrap()); + let plan = build_memory_maintenance_plan( + &store.list_memories_for_scope("user", scope_key).unwrap(), + ); let output = MemoryMaintenanceModelOutput { user_facts: vec!["用户在做AI产品".to_string()], preferences: Vec::new(), @@ -2490,7 +2613,9 @@ mod tests { "### 用户事实\n- 用户名为区德成,昵称DC。".to_string(), ]); - assert!(combined.contains("### 用户事实\n- 用户在做AI产品\n\n### 用户偏好\n- 偏好简洁表达")); + assert!( + combined.contains("### 用户事实\n- 用户在做AI产品\n\n### 用户偏好\n- 偏好简洁表达") + ); assert!(combined.contains("### 用户事实\n- 用户名为区德成,昵称DC。")); assert_eq!(combined.matches("- 用户在做AI产品").count(), 1); } @@ -2513,27 +2638,30 @@ mod tests { #[tokio::test] async fn test_bus_tool_call_emitter_hides_completed_tool_results_when_disabled() { let bus = MessageBus::new(4); - let emitter = BusToolCallEmitter::new( - bus.clone(), - "feishu", - "chat-1", - HashMap::new(), - false, - ); + let emitter = + BusToolCallEmitter::new(bus.clone(), "feishu", "chat-1", HashMap::new(), false); emitter .handle(ChatMessage::tool("call-1", "calculator", "2")) .await; - assert!(tokio::time::timeout(std::time::Duration::from_millis(50), bus.consume_outbound()) - .await - .is_err()); + assert!( + tokio::time::timeout(std::time::Duration::from_millis(50), bus.consume_outbound()) + .await + .is_err() + ); } #[test] fn test_parse_in_chat_command_aliases() { - assert_eq!(parse_in_chat_command("/new"), Some(InChatCommand::FreshConversation)); - assert_eq!(parse_in_chat_command(" /reset \n"), Some(InChatCommand::FreshConversation)); + assert_eq!( + parse_in_chat_command("/new"), + Some(InChatCommand::FreshConversation) + ); + assert_eq!( + parse_in_chat_command(" /reset \n"), + Some(InChatCommand::FreshConversation) + ); assert_eq!(parse_in_chat_command("/new planning"), None); assert_eq!(parse_in_chat_command("please /reset"), None); } @@ -2573,10 +2701,12 @@ mod tests { assert_eq!(response, "Started a fresh conversation."); assert!(session.get_history("chat-1").unwrap().is_empty()); - assert!(store - .load_messages(&session.persistent_session_id("chat-1")) - .unwrap() - .is_empty()); + assert!( + store + .load_messages(&session.persistent_session_id("chat-1")) + .unwrap() + .is_empty() + ); assert_eq!( store .load_all_messages(&session.persistent_session_id("chat-1")) @@ -2655,10 +2785,15 @@ mod tests { .unwrap(); } - session.ensure_agent_prompt_before_user_message("chat-1").unwrap(); + session + .ensure_agent_prompt_before_user_message("chat-1") + .unwrap(); let history = session.get_history("chat-1").unwrap(); - let system_messages = history.iter().filter(|message| message.role == "system").count(); + let system_messages = history + .iter() + .filter(|message| message.role == "system") + .count(); assert_eq!(system_messages, 2); let stored = store @@ -2667,9 +2802,14 @@ mod tests { .unwrap(); assert_eq!(stored.agent_prompt_reinjection_count, 1); - session.ensure_agent_prompt_before_user_message("chat-1").unwrap(); + session + .ensure_agent_prompt_before_user_message("chat-1") + .unwrap(); let history = session.get_history("chat-1").unwrap(); - let system_messages = history.iter().filter(|message| message.role == "system").count(); + let system_messages = history + .iter() + .filter(|message| message.role == "system") + .count(); assert_eq!(system_messages, 2); } @@ -2705,10 +2845,15 @@ mod tests { .unwrap(); } - session.ensure_agent_prompt_before_user_message("chat-1").unwrap(); + session + .ensure_agent_prompt_before_user_message("chat-1") + .unwrap(); let history = session.get_history("chat-1").unwrap(); - let system_messages = history.iter().filter(|message| message.role == "system").count(); + let system_messages = history + .iter() + .filter(|message| message.role == "system") + .count(); assert_eq!(system_messages, 1); } @@ -2742,7 +2887,9 @@ mod tests { .unwrap(); handle_in_chat_command(&mut session, "chat-1", "/reset").unwrap(); - session.ensure_agent_prompt_before_user_message("chat-1").unwrap(); + session + .ensure_agent_prompt_before_user_message("chat-1") + .unwrap(); let history = session.get_history("chat-1").unwrap(); assert_eq!(history.len(), 1); @@ -2753,12 +2900,7 @@ mod tests { fn test_default_tools_registers_get_time() { let skills = Arc::new(SkillRuntime::default()); let store = Arc::new(SessionStore::in_memory().unwrap()); - let tools = default_tools( - skills, - store, - HashSet::new(), - "Asia/Shanghai".to_string(), - ); + let tools = default_tools(skills, store, HashSet::new(), "Asia/Shanghai".to_string()); assert!(tools.tool_names().iter().any(|name| name == "get_time")); } @@ -2844,7 +2986,8 @@ mod tests { #[test] fn test_upsert_managed_agent_memory_block_inserts_before_reply_rules() { - let original = "# PicoBot 代理配置\n\n## 身份\n- 你是 PicoBot。\n\n## 回复规则\n- 使用中文回复。\n"; + let original = + "# PicoBot 代理配置\n\n## 身份\n- 你是 PicoBot。\n\n## 回复规则\n- 使用中文回复。\n"; let updated = upsert_managed_agent_memory_block( original, "### 用户事实\n- 用户在做AI产品\n\n### 用户偏好\n- 偏好简洁表达", @@ -2857,5 +3000,4 @@ mod tests { assert!(updated.contains("用户在做AI产品")); assert!(updated.contains("偏好简洁表达")); } - } diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index 17ce8e1..a82f03f 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -1,15 +1,18 @@ -use std::sync::Arc; +use super::{ + GatewayState, + session::{Session, handle_in_chat_command, schedule_background_history_compaction}, +}; +use crate::agent::EmittedMessageHandler; +use crate::bus::ChatMessage; +use crate::bus::message::{ToolMessageState, format_tool_call_content}; +use crate::protocol::{SessionSummary, WsInbound, WsOutbound, parse_inbound, serialize_outbound}; use async_trait::async_trait; -use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage}; use axum::extract::State; +use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade}; use axum::response::Response; use futures_util::{SinkExt, StreamExt}; -use tokio::sync::{mpsc, Mutex}; -use crate::agent::EmittedMessageHandler; -use crate::bus::message::{format_tool_call_content, ToolMessageState}; -use crate::bus::ChatMessage; -use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound}; -use super::{GatewayState, session::{Session, handle_in_chat_command, schedule_background_history_compaction}}; +use std::sync::Arc; +use tokio::sync::{Mutex, mpsc}; struct WsToolCallEmitter { sender: mpsc::Sender, @@ -120,7 +123,9 @@ async fn handle_socket(ws: WebSocket, state: Arc) { &runtime_session_id, &mut current_session_id, inbound, - ).await { + ) + .await + { tracing::warn!(error = %e, session_id = %current_session_id, "Failed to handle inbound message"); let _ = session .lock() @@ -182,17 +187,14 @@ fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec { }); } - outbound.extend(tool_calls - .iter() - .map(|tool_call| WsOutbound::ToolCall { - id: message.id.clone(), - tool_call_id: tool_call.id.clone(), - tool_name: tool_call.name.clone(), - arguments: tool_call.arguments.clone(), - content: format_tool_call_content(&tool_call.name, &tool_call.arguments), - role: message.role.clone(), - }) - ); + outbound.extend(tool_calls.iter().map(|tool_call| WsOutbound::ToolCall { + id: message.id.clone(), + tool_call_id: tool_call.id.clone(), + tool_name: tool_call.name.clone(), + arguments: tool_call.arguments.clone(), + content: format_tool_call_content(&tool_call.name, &tool_call.arguments), + role: message.role.clone(), + })); outbound } else { vec![WsOutbound::AssistantResponse { @@ -202,7 +204,11 @@ fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec { }] } } - "tool" => match message.tool_state.as_ref().unwrap_or(&ToolMessageState::Completed) { + "tool" => match message + .tool_state + .as_ref() + .unwrap_or(&ToolMessageState::Completed) + { ToolMessageState::Completed => vec![WsOutbound::ToolResult { id: message.id.clone(), tool_call_id: message.tool_call_id.clone().unwrap_or_default(), @@ -230,7 +236,10 @@ fn should_display_message_to_user(show_tool_results: bool, message: &ChatMessage show_tool_results || matches!( - message.tool_state.as_ref().unwrap_or(&ToolMessageState::Completed), + message + .tool_state + .as_ref() + .unwrap_or(&ToolMessageState::Completed), ToolMessageState::PendingUserAction ) } @@ -243,7 +252,12 @@ async fn handle_inbound( inbound: WsInbound, ) -> Result<(), crate::agent::AgentError> { match inbound { - WsInbound::UserInput { content, chat_id, sender_id, .. } => { + WsInbound::UserInput { + content, + chat_id, + sender_id, + .. + } => { let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone()); let sender_id = resolve_ws_sender_id(sender_id.as_deref(), runtime_session_id); let (history, agent, user_tx) = { @@ -252,7 +266,9 @@ async fn handle_inbound( session_guard.ensure_persistent_session(&chat_id)?; session_guard.ensure_chat_loaded(&chat_id)?; - if let Some(command_response) = handle_in_chat_command(&mut session_guard, &chat_id, &content)? { + if let Some(command_response) = + handle_in_chat_command(&mut session_guard, &chat_id, &content)? + { let _ = session_guard .send(WsOutbound::AssistantResponse { id: uuid::Uuid::new_v4().to_string(), @@ -286,13 +302,17 @@ async fn handle_inbound( match agent.process(history).await { Ok(result) => { let mut session_guard = session.lock().await; - session_guard.append_persisted_messages(&chat_id, result.emitted_messages.clone())?; + session_guard + .append_persisted_messages(&chat_id, result.emitted_messages.clone())?; for outbound in result .emitted_messages .iter() .filter(|message| { !message.is_assistant_tool_call_message() - && should_display_message_to_user(state.config.gateway.show_tool_results, message) + && should_display_message_to_user( + state.config.gateway.show_tool_results, + message, + ) }) .flat_map(ws_outbound_from_chat_message) { @@ -301,7 +321,10 @@ async fn handle_inbound( drop(session_guard); - if let Err(error) = schedule_background_history_compaction(session.clone(), chat_id.clone()).await { + if let Err(error) = + schedule_background_history_compaction(session.clone(), chat_id.clone()) + .await + { tracing::warn!(chat_id = %chat_id, error = %error, "Failed to schedule background history compaction for CLI session"); } } @@ -318,16 +341,19 @@ async fn handle_inbound( Ok(()) } - WsInbound::ClearHistory { session_id, chat_id } => { - let target = session_id.or(chat_id).unwrap_or_else(|| current_session_id.clone()); + WsInbound::ClearHistory { + session_id, + chat_id, + } => { + let target = session_id + .or(chat_id) + .unwrap_or_else(|| current_session_id.clone()); state.session_manager.clear_session_messages(&target)?; let mut session_guard = session.lock().await; session_guard.remove_history(&target); let _ = session_guard - .send(WsOutbound::HistoryCleared { - session_id: target, - }) + .send(WsOutbound::HistoryCleared { session_id: target }) .await; Ok(()) } @@ -452,17 +478,15 @@ fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> St #[cfg(test)] mod tests { - use crate::agent::EmittedMessageHandler; use super::{ - WsToolCallEmitter, - resolve_ws_sender_id, - should_display_message_to_user, + WsToolCallEmitter, resolve_ws_sender_id, should_display_message_to_user, ws_outbound_from_chat_message, }; + use crate::agent::EmittedMessageHandler; use crate::bus::ChatMessage; use crate::bus::message::ToolMessageState; - use crate::providers::ToolCall; use crate::protocol::WsOutbound; + use crate::providers::ToolCall; use serde_json::json; use tokio::sync::mpsc; @@ -481,7 +505,13 @@ mod tests { assert_eq!(outbound.len(), 1); match &outbound[0] { - WsOutbound::ToolCall { tool_call_id, tool_name, arguments, content, .. } => { + WsOutbound::ToolCall { + tool_call_id, + tool_name, + arguments, + content, + .. + } => { assert_eq!(tool_call_id, "call-1"); assert_eq!(tool_name, "calculator"); assert_eq!(arguments["expression"], "1 + 1"); @@ -551,8 +581,14 @@ mod tests { #[test] fn test_resolve_ws_sender_id_prefers_inbound_sender() { - assert_eq!(resolve_ws_sender_id(Some("user-42"), "runtime-1"), "user-42"); - assert_eq!(resolve_ws_sender_id(Some(" user-42 "), "runtime-1"), "user-42"); + assert_eq!( + resolve_ws_sender_id(Some("user-42"), "runtime-1"), + "user-42" + ); + assert_eq!( + resolve_ws_sender_id(Some(" user-42 "), "runtime-1"), + "user-42" + ); } #[test] @@ -573,8 +609,10 @@ mod tests { .handle(ChatMessage::tool("call-1", "calculator", "2")) .await; - assert!(tokio::time::timeout(std::time::Duration::from_millis(50), receiver.recv()) - .await - .is_err()); + assert!( + tokio::time::timeout(std::time::Duration::from_millis(50), receiver.recv()) + .await + .is_err() + ); } } diff --git a/src/lib.rs b/src/lib.rs index 8d4b095..c99e424 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,16 +1,16 @@ -pub mod config; -pub mod text; -pub mod providers; -pub mod bus; -pub mod cli; pub mod agent; -pub mod gateway; -pub mod client; -pub mod protocol; +pub mod bus; pub mod channels; +pub mod cli; +pub mod client; +pub mod config; +pub mod gateway; pub mod logging; pub mod observability; +pub mod protocol; +pub mod providers; pub mod scheduler; -pub mod storage; -pub mod tools; pub mod skills; +pub mod storage; +pub mod text; +pub mod tools; diff --git a/src/logging.rs b/src/logging.rs index 02d6e5c..57bfc8d 100644 --- a/src/logging.rs +++ b/src/logging.rs @@ -1,13 +1,9 @@ -use std::path::PathBuf; use chrono::Utc; use chrono_tz::Tz; +use std::path::PathBuf; use tracing_appender::rolling::{RollingFileAppender, Rotation}; use tracing_subscriber::{ - fmt, - fmt::time::FormatTime, - layer::SubscriberExt, - util::SubscriberInitExt, - EnvFilter, + EnvFilter, fmt, fmt::time::FormatTime, layer::SubscriberExt, util::SubscriberInitExt, }; #[derive(Clone, Copy, Debug)] @@ -16,8 +12,17 @@ struct ConfiguredTimestamp { } impl FormatTime for ConfiguredTimestamp { - fn format_time(&self, writer: &mut tracing_subscriber::fmt::format::Writer<'_>) -> std::fmt::Result { - write!(writer, "{}", Utc::now().with_timezone(&self.timezone).to_rfc3339_opts(chrono::SecondsFormat::Millis, true)) + fn format_time( + &self, + writer: &mut tracing_subscriber::fmt::format::Writer<'_>, + ) -> std::fmt::Result { + write!( + writer, + "{}", + Utc::now() + .with_timezone(&self.timezone) + .to_rfc3339_opts(chrono::SecondsFormat::Millis, true) + ) } } @@ -41,20 +46,19 @@ pub fn init_logging(timezone: Tz) { // Create log directory if it doesn't exist if !log_dir.exists() { if let Err(e) = std::fs::create_dir_all(&log_dir) { - eprintln!("Warning: Failed to create log directory {}: {}", log_dir.display(), e); + eprintln!( + "Warning: Failed to create log directory {}: {}", + log_dir.display(), + e + ); } } // Create file appender with daily rotation - let file_appender = RollingFileAppender::new( - Rotation::DAILY, - &log_dir, - "picobot.log", - ); + let file_appender = RollingFileAppender::new(Rotation::DAILY, &log_dir, "picobot.log"); // Build subscriber with both console and file output - let env_filter = EnvFilter::try_from_default_env() - .unwrap_or_else(|_| EnvFilter::new("info")); + let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); let file_layer = fmt::layer() .with_writer(file_appender) @@ -80,8 +84,7 @@ pub fn init_logging(timezone: Tz) { /// Initialize logging without file output (console only) pub fn init_logging_console_only(timezone: Tz) { - let env_filter = EnvFilter::try_from_default_env() - .unwrap_or_else(|_| EnvFilter::new("info")); + let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); let console_layer = fmt::layer() .with_timer(ConfiguredTimestamp { timezone }) diff --git a/src/main.rs b/src/main.rs index 81d73ce..20df5e9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -use clap::{Parser, CommandFactory}; +use clap::{CommandFactory, Parser}; #[derive(Parser)] #[command(name = "picobot")] diff --git a/src/observability/mod.rs b/src/observability/mod.rs index e66e86c..11521e9 100644 --- a/src/observability/mod.rs +++ b/src/observability/mod.rs @@ -26,10 +26,7 @@ pub enum ObserverEvent { success: bool, }, /// Emitted when the agent starts processing. - AgentStart { - provider: String, - model: String, - }, + AgentStart { provider: String, model: String }, /// Emitted when the agent finishes processing. AgentEnd { provider: String, @@ -116,7 +113,11 @@ impl ToolExecutionOutcome { } /// Create a failed outcome with duration. - pub fn failure_with_duration(output: String, error_reason: Option, duration: Duration) -> Self { + pub fn failure_with_duration( + output: String, + error_reason: Option, + duration: Duration, + ) -> Self { Self { output, success: false, diff --git a/src/protocol.rs b/src/protocol.rs index c895d00..a74a8c4 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -43,9 +43,7 @@ pub enum WsInbound { include_archived: bool, }, #[serde(rename = "load_session")] - LoadSession { - session_id: String, - }, + LoadSession { session_id: String }, #[serde(rename = "rename_session")] RenameSession { #[serde(default, skip_serializing_if = "Option::is_none")] @@ -70,7 +68,11 @@ pub enum WsInbound { #[serde(tag = "type")] pub enum WsOutbound { #[serde(rename = "assistant_response")] - AssistantResponse { id: String, content: String, role: String }, + AssistantResponse { + id: String, + content: String, + role: String, + }, #[serde(rename = "tool_call")] ToolCall { id: String, diff --git a/src/providers/anthropic.rs b/src/providers/anthropic.rs index 0c47ebd..f5563f0 100644 --- a/src/providers/anthropic.rs +++ b/src/providers/anthropic.rs @@ -4,9 +4,9 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::time::Duration; -use crate::bus::message::ContentBlock; -use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall}; use super::traits::Usage; +use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall}; +use crate::bus::message::ContentBlock; fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String { let mut details = vec![error.to_string()]; @@ -20,7 +20,10 @@ fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String { details.join("\ncaused by: ") } -fn serialize_content_blocks(blocks: &[serde_json::Value], serializer: S) -> Result +fn serialize_content_blocks( + blocks: &[serde_json::Value], + serializer: S, +) -> Result where S: serde::Serializer, { @@ -28,14 +31,15 @@ where } fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec { - blocks.iter().map(|b| match b { - ContentBlock::Text { text } => { - serde_json::json!({ "type": "text", "text": text }) - } - ContentBlock::ImageUrl { image_url } => { - convert_image_url_to_anthropic(&image_url.url) - } - }).collect() + blocks + .iter() + .map(|b| match b { + ContentBlock::Text { text } => { + serde_json::json!({ "type": "text", "text": text }) + } + ContentBlock::ImageUrl { image_url } => convert_image_url_to_anthropic(&image_url.url), + }) + .collect() } fn convert_image_url_to_anthropic(url: &str) -> serde_json::Value { @@ -147,9 +151,13 @@ struct AnthropicResponse { #[derive(Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] enum AnthropicContent { - Text { text: String }, + Text { + text: String, + }, #[allow(dead_code)] - Thinking { thinking: String }, + Thinking { + thinking: String, + }, #[serde(rename = "tool_use")] ToolUse { id: String, diff --git a/src/providers/mod.rs b/src/providers/mod.rs index e4eae27..f4f0675 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -1,12 +1,15 @@ -pub mod traits; -pub mod openai; pub mod anthropic; +pub mod openai; +pub mod traits; -pub use self::openai::OpenAIProvider; pub use self::anthropic::AnthropicProvider; +pub use self::openai::OpenAIProvider; use crate::config::LLMProviderConfig; -pub use traits::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall, ToolFunction, Usage}; +pub use traits::{ + ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall, + ToolFunction, Usage, +}; pub fn create_provider(config: LLMProviderConfig) -> Result, ProviderError> { match config.provider_type.as_str() { diff --git a/src/providers/openai.rs b/src/providers/openai.rs index ca90f52..05a3f96 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -1,18 +1,15 @@ use async_trait::async_trait; use reqwest::Client; use serde::Deserialize; -use serde_json::{json, Value}; +use serde_json::{Value, json}; use std::collections::HashMap; use std::time::Duration; -use crate::bus::message::ContentBlock; -use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall}; use super::traits::Usage; +use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall}; +use crate::bus::message::ContentBlock; -const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &[ - "tool_call_arguments_json", - "mock_response_content", -]; +const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content"]; fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String { let mut details = vec![error.to_string()]; @@ -32,12 +29,17 @@ fn convert_content_blocks(blocks: &[ContentBlock]) -> Value { return Value::String(text.clone()); } } - Value::Array(blocks.iter().map(|b| match b { - ContentBlock::Text { text } => json!({ "type": "text", "text": text }), - ContentBlock::ImageUrl { image_url } => { - json!({ "type": "image_url", "image_url": { "url": image_url.url } }) - } - }).collect()) + Value::Array( + blocks + .iter() + .map(|b| match b { + ContentBlock::Text { text } => json!({ "type": "text", "text": text }), + ContentBlock::ImageUrl { image_url } => { + json!({ "type": "image_url", "image_url": { "url": image_url.url } }) + } + }) + .collect(), + ) } pub struct OpenAIProvider { @@ -122,7 +124,9 @@ impl OpenAIProvider { fn request_model_extra(&self) -> impl Iterator { self.model_extra.iter().filter(|(key, _)| { - !INTERNAL_MODEL_EXTRA_KEYS.iter().any(|internal| internal == &key.as_str()) + !INTERNAL_MODEL_EXTRA_KEYS + .iter() + .any(|internal| internal == &key.as_str()) }) } @@ -265,7 +269,11 @@ impl LLMProvider for OpenAIProvider { if let Some(content) = msg.get("content").and_then(|c| c.as_array()) { for (j, item) in content.iter().enumerate() { if item.get("type").and_then(|t| t.as_str()) == Some("image_url") { - if let Some(url_str) = item.get("image_url").and_then(|u| u.get("url")).and_then(|v| v.as_str()) { + if let Some(url_str) = item + .get("image_url") + .and_then(|u| u.get("url")) + .and_then(|v| v.as_str()) + { let prefix: String = url_str.chars().take(20).collect(); tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)"); } @@ -419,7 +427,10 @@ mod tests { assert_eq!(tool_calls[0]["id"], "call_1"); assert_eq!(tool_calls[0]["type"], "function"); assert_eq!(tool_calls[0]["function"]["name"], "calculator"); - assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}"); + assert_eq!( + tool_calls[0]["function"]["arguments"], + "{\"expression\":\"1+1\"}" + ); } #[test] @@ -433,10 +444,7 @@ mod tests { "gpt-test".to_string(), None, None, - HashMap::from([( - "tool_call_arguments_json".to_string(), - Value::Bool(true), - )]), + HashMap::from([("tool_call_arguments_json".to_string(), Value::Bool(true))]), ); let request = ChatCompletionRequest { @@ -461,7 +469,10 @@ mod tests { let messages = body["messages"].as_array().unwrap(); let tool_calls = messages[0]["tool_calls"].as_array().unwrap(); - assert_eq!(tool_calls[0]["function"]["arguments"], json!({"expression": "1+1"})); + assert_eq!( + tool_calls[0]["function"]["arguments"], + json!({"expression": "1+1"}) + ); assert!(body.get("tool_call_arguments_json").is_none()); } @@ -501,7 +512,10 @@ mod tests { let messages = body["messages"].as_array().unwrap(); let tool_calls = messages[0]["tool_calls"].as_array().unwrap(); - assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}"); + assert_eq!( + tool_calls[0]["function"]["arguments"], + "{\"expression\":\"1+1\"}" + ); } #[test] @@ -517,7 +531,10 @@ mod tests { None, HashMap::from([ ("tool_call_arguments_json".to_string(), Value::Bool(true)), - ("mock_response_content".to_string(), Value::String("stub".to_string())), + ( + "mock_response_content".to_string(), + Value::String("stub".to_string()), + ), ("parallel_tool_calls".to_string(), Value::Bool(true)), ]), ); @@ -590,7 +607,10 @@ mod tests { })) .unwrap(); - assert_eq!(response.choices[0].message.reasoning_content.as_deref(), Some("hidden reasoning")); + assert_eq!( + response.choices[0].message.reasoning_content.as_deref(), + Some("hidden reasoning") + ); } #[test] diff --git a/src/providers/traits.rs b/src/providers/traits.rs index 280911c..a1ace3c 100644 --- a/src/providers/traits.rs +++ b/src/providers/traits.rs @@ -1,6 +1,6 @@ +use crate::bus::message::ContentBlock; use async_trait::async_trait; use serde::{Deserialize, Serialize}; -use crate::bus::message::ContentBlock; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Message { @@ -61,7 +61,11 @@ impl Message { } } - pub fn tool(tool_call_id: impl Into, tool_name: impl Into, content: impl Into) -> Self { + pub fn tool( + tool_call_id: impl Into, + tool_name: impl Into, + content: impl Into, + ) -> Self { Self { role: "tool".to_string(), content: vec![ContentBlock::text(content)], diff --git a/src/scheduler/mod.rs b/src/scheduler/mod.rs index f2e6edf..ff6a1b0 100644 --- a/src/scheduler/mod.rs +++ b/src/scheduler/mod.rs @@ -8,11 +8,11 @@ use tokio::sync::watch; use crate::bus::{MessageBus, OutboundMessage}; use crate::config::{ - SchedulerConfig, SchedulerJobConfig, SchedulerJobKind, SchedulerJobTarget, SchedulerMisfirePolicy, - SchedulerSchedule, + SchedulerConfig, SchedulerJobConfig, SchedulerJobKind, SchedulerJobTarget, + SchedulerMisfirePolicy, SchedulerSchedule, }; -use crate::gateway::session::SessionManager; use crate::gateway::session::ScheduledAgentTaskOptions; +use crate::gateway::session::SessionManager; use crate::storage::{ SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus, SchedulerJobUpsert, SessionStore, }; @@ -76,8 +76,11 @@ impl Scheduler { fn sync_config_jobs(&self) -> anyhow::Result<()> { let now = Utc::now(); - for job in self.config.effective_jobs(&crate::config::TimeConfig { timezone: self.timezone.name().to_string() }) { - let runtime = RuntimeJob::from_config(&job, now, self.config.misfire_policy, self.timezone)?; + for job in self.config.effective_jobs(&crate::config::TimeConfig { + timezone: self.timezone.name().to_string(), + }) { + let runtime = + RuntimeJob::from_config(&job, now, self.config.misfire_policy, self.timezone)?; self.store.upsert_scheduler_job(&runtime.to_upsert())?; } Ok(()) @@ -88,7 +91,9 @@ impl Scheduler { let jobs = self.store.list_scheduler_jobs(true)?; for record in jobs { - let Some(mut job) = RuntimeJob::from_record(&record, self.config.misfire_policy, self.timezone)? else { + let Some(mut job) = + RuntimeJob::from_record(&record, self.config.misfire_policy, self.timezone)? + else { continue; }; @@ -178,8 +183,12 @@ impl Scheduler { } SchedulerJobKind::SilentAgentTask => { let execution_chat_id = resolve_execution_chat_id(job)?; - if let Err(error) = execute_agent_task(&self.session_manager, job, &execution_chat_id).await { - if let Err(notify_error) = self.notify_silent_agent_task_failure(job, &error).await { + if let Err(error) = + execute_agent_task(&self.session_manager, job, &execution_chat_id).await + { + if let Err(notify_error) = + self.notify_silent_agent_task_failure(job, &error).await + { tracing::error!( job_id = %job.id, error = %notify_error, @@ -208,10 +217,13 @@ impl Scheduler { let mut metadata = HashMap::new(); metadata.insert("scheduler_job_id".to_string(), job.id.clone()); - metadata.insert("scheduler_job_kind".to_string(), "silent_agent_task".to_string()); + metadata.insert( + "scheduler_job_kind".to_string(), + "silent_agent_task".to_string(), + ); self.bus - .publish_outbound(OutboundMessage::assistant( + .publish_outbound(OutboundMessage::error_notification( channel, chat_id, format!( @@ -304,14 +316,24 @@ impl RuntimeJob { } }; - let schedule = deserialize_schedule(&record.schedule, record.interval_secs, record.startup_delay_secs)?; + let schedule = deserialize_schedule( + &record.schedule, + record.interval_secs, + record.startup_delay_secs, + )?; let now = Utc::now(); let next_fire_at = match (record.enabled, record.state.clone(), record.next_fire_at) { (false, _, _) => None, (_, SchedulerJobState::Paused, _) => None, (_, SchedulerJobState::Completed, _) => None, (_, _, some_next) if some_next.is_some() => some_next, - _ => compute_initial_next_fire_at(&schedule, now, record.last_fired_at, misfire_policy, timezone)?, + _ => compute_initial_next_fire_at( + &schedule, + now, + record.last_fired_at, + misfire_policy, + timezone, + )?, }; Ok(Some(Self { @@ -338,7 +360,10 @@ impl RuntimeJob { fn is_due(&self, now: DateTime) -> bool { self.enabled && self.state == SchedulerJobState::Scheduled - && self.next_fire_at.map(|value| value <= now.timestamp_millis()).unwrap_or(false) + && self + .next_fire_at + .map(|value| value <= now.timestamp_millis()) + .unwrap_or(false) } fn after_execution( @@ -371,7 +396,8 @@ impl RuntimeJob { let reference_ms = self.next_fire_at.or(self.last_fired_at); self.state = SchedulerJobState::Scheduled; self.completed_at = None; - self.next_fire_at = compute_next_fire_at(&self.schedule, now, reference_ms, misfire_policy, timezone)?; + self.next_fire_at = + compute_next_fire_at(&self.schedule, now, reference_ms, misfire_policy, timezone)?; Ok(()) } @@ -384,7 +410,8 @@ impl RuntimeJob { SchedulerJobKind::AgentTask => "agent_task".to_string(), SchedulerJobKind::SilentAgentTask => "silent_agent_task".to_string(), }, - schedule: serde_json::to_value(&self.schedule).unwrap_or_else(|_| serde_json::json!({})), + schedule: serde_json::to_value(&self.schedule) + .unwrap_or_else(|_| serde_json::json!({})), interval_secs: self.interval_secs, startup_delay_secs: self.startup_delay_secs, target: serde_json::to_value(&self.target).unwrap_or_else(|_| serde_json::json!({})), @@ -430,21 +457,36 @@ fn compute_initial_next_fire_at( timezone: Tz, ) -> anyhow::Result> { match last_fired_at { - Some(last_fired_at) => compute_next_fire_at(schedule, now, Some(last_fired_at), misfire_policy, timezone), + Some(last_fired_at) => { + compute_next_fire_at(schedule, now, Some(last_fired_at), misfire_policy, timezone) + } None => match schedule { - SchedulerSchedule::Delay { seconds } => Ok(Some((now + ChronoDuration::seconds(*seconds as i64)).timestamp_millis())), + SchedulerSchedule::Delay { seconds } => Ok(Some( + (now + ChronoDuration::seconds(*seconds as i64)).timestamp_millis(), + )), SchedulerSchedule::Interval { seconds, startup_delay_secs, } => { - let delay = if *startup_delay_secs > 0 { *startup_delay_secs } else { *seconds }; - Ok(Some((now + ChronoDuration::seconds(delay as i64)).timestamp_millis())) + let delay = if *startup_delay_secs > 0 { + *startup_delay_secs + } else { + *seconds + }; + Ok(Some( + (now + ChronoDuration::seconds(delay as i64)).timestamp_millis(), + )) + } + SchedulerSchedule::At { timestamp } => { + Ok(Some(parse_rfc3339_to_utc(timestamp)?.timestamp_millis())) } - SchedulerSchedule::At { timestamp } => Ok(Some(parse_rfc3339_to_utc(timestamp)?.timestamp_millis())), SchedulerSchedule::Cron { expression } => { let schedule = parse_scheduler_cron(expression)?; let local_now = now.with_timezone(&timezone); - Ok(schedule.after(&local_now).next().map(|next| next.with_timezone(&Utc).timestamp_millis())) + Ok(schedule + .after(&local_now) + .next() + .map(|next| next.with_timezone(&Utc).timestamp_millis())) } }, } @@ -483,7 +525,10 @@ fn compute_next_fire_at( .map(|value| value.with_timezone(&timezone)) .unwrap_or_else(|| now.with_timezone(&timezone)), }; - Ok(schedule.after(&anchor).next().map(|next| next.with_timezone(&Utc).timestamp_millis())) + Ok(schedule + .after(&anchor) + .next() + .map(|next| next.with_timezone(&Utc).timestamp_millis())) } } } @@ -525,12 +570,14 @@ fn build_outbound_message(job: &RuntimeJob) -> anyhow::Result { .payload .get("content") .and_then(|value| value.as_str()) - .ok_or_else(|| anyhow::anyhow!("outbound scheduler job payload.content must be a string"))?; + .ok_or_else(|| { + anyhow::anyhow!("outbound scheduler job payload.content must be a string") + })?; let mut metadata = HashMap::new(); metadata.insert("scheduler_job_id".to_string(), job.id.clone()); - Ok(OutboundMessage::assistant( + Ok(OutboundMessage::scheduler_notification( channel, chat_id, content.to_string(), @@ -539,7 +586,10 @@ fn build_outbound_message(job: &RuntimeJob) -> anyhow::Result { )) } -async fn execute_internal_event(session_manager: &SessionManager, job: &RuntimeJob) -> anyhow::Result<()> { +async fn execute_internal_event( + session_manager: &SessionManager, + job: &RuntimeJob, +) -> anyhow::Result<()> { let event = job .payload .get("event") @@ -599,7 +649,10 @@ async fn execute_agent_task( .map_err(|error| anyhow::anyhow!(error.to_string())) } -fn required_notification_chat_id<'a>(job: &'a RuntimeJob, kind_name: &str) -> anyhow::Result<&'a str> { +fn required_notification_chat_id<'a>( + job: &'a RuntimeJob, + kind_name: &str, +) -> anyhow::Result<&'a str> { job.target .chat_id .as_deref() @@ -608,7 +661,9 @@ fn required_notification_chat_id<'a>(job: &'a RuntimeJob, kind_name: &str) -> an fn resolve_execution_chat_id(job: &RuntimeJob) -> anyhow::Result { match job.kind { - SchedulerJobKind::AgentTask => Ok(required_notification_chat_id(job, "agent_task")?.to_string()), + SchedulerJobKind::AgentTask => { + Ok(required_notification_chat_id(job, "agent_task")?.to_string()) + } SchedulerJobKind::SilentAgentTask => Ok(job .target .session_chat_id @@ -633,7 +688,9 @@ fn summarize_scheduler_error(error: &anyhow::Error) -> String { } } -fn parse_scheduled_agent_task_options(job: &RuntimeJob) -> anyhow::Result { +fn parse_scheduled_agent_task_options( + job: &RuntimeJob, +) -> anyhow::Result { let sender_id = job .payload .get("sender_id") @@ -665,7 +722,9 @@ fn parse_scheduled_agent_task_options(job: &RuntimeJob) -> anyhow::Result) -> anyhow::Result> { +fn parse_metadata_map( + value: Option<&serde_json::Value>, +) -> anyhow::Result> { let Some(value) = value else { return Ok(HashMap::new()); }; @@ -685,7 +744,7 @@ fn parse_metadata_map(value: Option<&serde_json::Value>) -> anyhow::Result for SchedulerJobTarget { #[cfg(test)] mod tests { use super::*; - use std::collections::HashMap; use crate::bus::MessageBus; use crate::config::{BUILTIN_MEMORY_MAINTENANCE_JOB_ID, LLMProviderConfig}; use crate::gateway::session::SessionManager; use crate::skills::SkillRuntime; use crate::storage::{SchedulerJobUpsert, SessionStore}; + use std::collections::HashMap; fn test_provider_config() -> LLMProviderConfig { LLMProviderConfig { @@ -921,7 +1003,10 @@ mod tests { #[test] fn runtime_job_skip_policy_advances_from_now() { - let now = Utc.timestamp_millis_opt(1_700_000_000_000).single().unwrap(); + let now = Utc + .timestamp_millis_opt(1_700_000_000_000) + .single() + .unwrap(); let next = compute_next_fire_at( &SchedulerSchedule::Interval { seconds: 60, @@ -940,7 +1025,10 @@ mod tests { #[test] fn runtime_job_catch_up_policy_moves_past_now() { - let now = Utc.timestamp_millis_opt(1_700_000_000_000).single().unwrap(); + let now = Utc + .timestamp_millis_opt(1_700_000_000_000) + .single() + .unwrap(); let next = compute_next_fire_at( &SchedulerSchedule::Interval { seconds: 60, @@ -989,14 +1077,21 @@ mod tests { updated_at: 1_700_000_000_000, }; - let job = RuntimeJob::from_record(&record, SchedulerMisfirePolicy::Skip, chrono_tz::Asia::Shanghai) - .unwrap() - .unwrap(); + let job = RuntimeJob::from_record( + &record, + SchedulerMisfirePolicy::Skip, + chrono_tz::Asia::Shanghai, + ) + .unwrap() + .unwrap(); - assert_eq!(job.schedule, SchedulerSchedule::Interval { - seconds: 120, - startup_delay_secs: 10, - }); + assert_eq!( + job.schedule, + SchedulerSchedule::Interval { + seconds: 120, + startup_delay_secs: 10, + } + ); assert_eq!(job.next_fire_at, Some(1_700_000_010_000)); } @@ -1050,7 +1145,10 @@ mod tests { scheduler.process_tick().await.unwrap(); - let saved = store.get_scheduler_job("massage_reminder").unwrap().unwrap(); + let saved = store + .get_scheduler_job("massage_reminder") + .unwrap() + .unwrap(); assert!(saved.next_fire_at.is_some()); assert_eq!(saved.run_count, 0); assert_eq!(saved.state, SchedulerJobState::Scheduled); @@ -1080,7 +1178,10 @@ mod tests { assert_eq!(saved.kind, "internal_event"); assert!(saved.enabled); assert_eq!(saved.state, SchedulerJobState::Scheduled); - assert_eq!(saved.payload.get("event").and_then(|value| value.as_str()), Some("memory_maintenance")); + assert_eq!( + saved.payload.get("event").and_then(|value| value.as_str()), + Some("memory_maintenance") + ); assert_eq!( saved.schedule, serde_json::json!({ @@ -1088,7 +1189,13 @@ mod tests { "expression": "0 */4 * * *" }) ); - assert_eq!(saved.payload.get("local_time").and_then(|value| value.as_str()), Some("every_4_hours")); + assert_eq!( + saved + .payload + .get("local_time") + .and_then(|value| value.as_str()), + Some("every_4_hours") + ); assert!(saved.next_fire_at.is_some()); } @@ -1155,7 +1262,10 @@ mod tests { #[test] fn cron_schedule_uses_configured_timezone() { - let now = Utc.with_ymd_and_hms(2026, 4, 23, 18, 0, 0).single().unwrap(); + let now = Utc + .with_ymd_and_hms(2026, 4, 23, 18, 0, 0) + .single() + .unwrap(); let next = compute_next_fire_at( &SchedulerSchedule::Cron { expression: "0 3 * * *".to_string(), @@ -1169,6 +1279,11 @@ mod tests { .unwrap(); let next_utc = ts_millis_to_utc(next).unwrap(); - assert_eq!(next_utc, Utc.with_ymd_and_hms(2026, 4, 23, 19, 0, 0).single().unwrap()); + assert_eq!( + next_utc, + Utc.with_ymd_and_hms(2026, 4, 23, 19, 0, 0) + .single() + .unwrap() + ); } -} \ No newline at end of file +} diff --git a/src/skills/mod.rs b/src/skills/mod.rs index a2d79b3..109189a 100644 --- a/src/skills/mod.rs +++ b/src/skills/mod.rs @@ -89,7 +89,10 @@ impl SkillRuntime { } pub fn is_empty(&self) -> bool { - self.catalog.read().expect("skills rwlock poisoned").is_empty() + self.catalog + .read() + .expect("skills rwlock poisoned") + .is_empty() } pub fn len(&self) -> usize { @@ -97,31 +100,53 @@ impl SkillRuntime { } pub fn system_index_prompt(&self) -> Option { - self.catalog.read().expect("skills rwlock poisoned").system_index_prompt() + self.catalog + .read() + .expect("skills rwlock poisoned") + .system_index_prompt() } pub fn discovery_event_payload(&self) -> serde_json::Value { - self.catalog.read().expect("skills rwlock poisoned").discovery_event_payload() + self.catalog + .read() + .expect("skills rwlock poisoned") + .discovery_event_payload() } pub fn offered_event_payload(&self) -> serde_json::Value { - self.catalog.read().expect("skills rwlock poisoned").offered_event_payload() + self.catalog + .read() + .expect("skills rwlock poisoned") + .offered_event_payload() } pub fn skill_tool_definition(&self) -> Option { - self.catalog.read().expect("skills rwlock poisoned").skill_tool_definition() + self.catalog + .read() + .expect("skills rwlock poisoned") + .skill_tool_definition() } pub fn activation_payload(&self, name: &str) -> Result { - self.catalog.read().expect("skills rwlock poisoned").activation_payload(name) + self.catalog + .read() + .expect("skills rwlock poisoned") + .activation_payload(name) } pub fn activation_event_payload(&self, name: &str) -> Result { - self.catalog.read().expect("skills rwlock poisoned").activation_event_payload(name) + self.catalog + .read() + .expect("skills rwlock poisoned") + .activation_event_payload(name) } pub fn list_skills(&self) -> Vec { - self.catalog.read().expect("skills rwlock poisoned").skills.clone() + self.catalog + .read() + .expect("skills rwlock poisoned") + .skills + .clone() } pub fn get_skill(&self, name: &str) -> Option { @@ -143,7 +168,11 @@ impl SkillRuntime { validate_skill_name(name)?; let path = skill_file_path(scope, name)?; if path.exists() { - return Err(format!("skill '{}' already exists at {}", name, path.display())); + return Err(format!( + "skill '{}' already exists at {}", + name, + path.display() + )); } write_skill_file(&path, name, description, body)?; @@ -180,14 +209,20 @@ impl SkillRuntime { Ok(skill) } - pub fn delete_skill(&self, scope: SkillScope, name: &str, reload: bool) -> Result { + pub fn delete_skill( + &self, + scope: SkillScope, + name: &str, + reload: bool, + ) -> Result { validate_skill_name(name)?; let dir = skill_dir_path(scope, name)?; if !dir.exists() { return Err(format!("skill '{}' not found at {}", name, dir.display())); } - fs::remove_dir_all(&dir).map_err(|err| format!("failed to delete skill directory: {}", err))?; + fs::remove_dir_all(&dir) + .map_err(|err| format!("failed to delete skill directory: {}", err))?; if reload { let _ = self.reload()?; } @@ -439,7 +474,8 @@ fn validate_skill_name(name: &str) -> Result<(), String> { } pub fn project_skills_root() -> Result { - let cwd = std::env::current_dir().map_err(|err| format!("failed to get current dir: {}", err))?; + let cwd = + std::env::current_dir().map_err(|err| format!("failed to get current dir: {}", err))?; Ok(cwd.join(".picobot").join("skills")) } @@ -466,7 +502,9 @@ fn source_root(source: SkillSource, cwd: &Path) -> Option { fn root_for_scope(scope: SkillScope) -> Result { match scope { - SkillScope::User => user_skills_root().ok_or_else(|| "failed to resolve home directory".to_string()), + SkillScope::User => { + user_skills_root().ok_or_else(|| "failed to resolve home directory".to_string()) + } SkillScope::Project => project_skills_root(), } } @@ -508,7 +546,8 @@ fn render_skill_file(name: &str, description: &str, body: &str) -> Result Result<(), String> { let content = render_skill_file(name, description, body)?; if let Some(parent) = path.parent() { - fs::create_dir_all(parent).map_err(|err| format!("failed to create skill directory: {}", err))?; + fs::create_dir_all(parent) + .map_err(|err| format!("failed to create skill directory: {}", err))?; } fs::write(path, content).map_err(|err| format!("failed to write skill file: {}", err)) } @@ -556,11 +595,10 @@ struct SkillFrontmatter { } fn parse_skill_file(path: &Path, source: SkillSource) -> Result { - let content = fs::read_to_string(path) - .map_err(|e| format!("failed to read file: {}", e))?; + let content = fs::read_to_string(path).map_err(|e| format!("failed to read file: {}", e))?; - let (frontmatter_raw, body) = split_frontmatter(&content) - .ok_or_else(|| "missing YAML frontmatter block".to_string())?; + let (frontmatter_raw, body) = + split_frontmatter(&content).ok_or_else(|| "missing YAML frontmatter block".to_string())?; let frontmatter: SkillFrontmatter = serde_yaml::from_str(frontmatter_raw) .map_err(|e| format!("invalid YAML frontmatter: {}", e))?; @@ -576,11 +614,7 @@ fn parse_skill_file(path: &Path, source: SkillSource) -> Result { .map(|s| s.to_string_lossy().to_string()) .unwrap_or_else(|| "unknown-skill".to_string()); - let name = frontmatter - .name - .unwrap_or(dir_name) - .trim() - .to_string(); + let name = frontmatter.name.unwrap_or(dir_name).trim().to_string(); Ok(Skill { name, @@ -656,7 +690,10 @@ mod tests { ) .unwrap(); - let skills = load_skills_from_root(&dir.path().join(".picobot").join("skills"), SkillSource::Project); + let skills = load_skills_from_root( + &dir.path().join(".picobot").join("skills"), + SkillSource::Project, + ); let catalog = SkillCatalog { skills, max_index_chars: 4000, @@ -707,7 +744,13 @@ mod tests { assert_eq!(runtime.len(), 0); let created = runtime - .create_skill(SkillScope::Project, "demo-skill", "demo desc", "line 1", true) + .create_skill( + SkillScope::Project, + "demo-skill", + "demo desc", + "line 1", + true, + ) .unwrap(); assert_eq!(created.name, "demo-skill"); assert_eq!(runtime.len(), 1); @@ -722,7 +765,12 @@ mod tests { ) .unwrap(); assert_eq!(updated.description, "updated desc"); - assert!(runtime.activation_payload("demo-skill").unwrap().contains("line 2")); + assert!( + runtime + .activation_payload("demo-skill") + .unwrap() + .contains("line 2") + ); let deleted_path = runtime .delete_skill(SkillScope::Project, "demo-skill", true) @@ -759,7 +807,11 @@ mod tests { let temp_dir = tempfile::tempdir().unwrap(); let _guard = CurrentDirGuard::enter(temp_dir.path()); - let agent_skill_dir = temp_dir.path().join(".agents").join("skills").join("demo-agent"); + let agent_skill_dir = temp_dir + .path() + .join(".agents") + .join("skills") + .join("demo-agent"); fs::create_dir_all(&agent_skill_dir).unwrap(); fs::write( agent_skill_dir.join("SKILL.md"), diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 0a0c288..2bdde00 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -391,7 +391,8 @@ impl SessionStore { )?; drop(conn); - self.get_session(&id)?.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into()) + self.get_session(&id)? + .ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into()) } pub fn ensure_channel_session( @@ -419,7 +420,8 @@ impl SessionStore { )?; drop(conn); - self.get_session(&session_id)?.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into()) + self.get_session(&session_id)? + .ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into()) } pub fn get_session(&self, session_id: &str) -> Result, StorageError> { @@ -495,7 +497,10 @@ impl SessionStore { pub fn delete_session(&self, session_id: &str) -> Result<(), StorageError> { let conn = self.conn.lock().expect("session db mutex poisoned"); - conn.execute("DELETE FROM messages WHERE session_id = ?1", params![session_id])?; + conn.execute( + "DELETE FROM messages WHERE session_id = ?1", + params![session_id], + )?; conn.execute("DELETE FROM sessions WHERE id = ?1", params![session_id])?; Ok(()) } @@ -503,7 +508,10 @@ impl SessionStore { pub fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> { let now = current_timestamp(); let conn = self.conn.lock().expect("session db mutex poisoned"); - conn.execute("DELETE FROM messages WHERE session_id = ?1", params![session_id])?; + conn.execute( + "DELETE FROM messages WHERE session_id = ?1", + params![session_id], + )?; conn.execute( " UPDATE sessions @@ -549,7 +557,11 @@ impl SessionStore { Ok(()) } - pub fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError> { + pub fn append_message( + &self, + session_id: &str, + message: &ChatMessage, + ) -> Result<(), StorageError> { let conn = self.conn.lock().expect("session db mutex poisoned"); let tx = conn.unchecked_transaction()?; @@ -560,7 +572,11 @@ impl SessionStore { )?; let media_refs_json = serde_json::to_string(&message.media_refs)?; - let tool_calls_json = message.tool_calls.as_ref().map(serde_json::to_string).transpose()?; + let tool_calls_json = message + .tool_calls + .as_ref() + .map(serde_json::to_string) + .transpose()?; tx.execute( " INSERT INTO messages ( @@ -630,7 +646,8 @@ impl SessionStore { return Ok(false); } - let delta_messages = load_messages_between(&tx, session_id, snapshot_end_seq, current_max_seq)?; + let delta_messages = + load_messages_between(&tx, session_id, snapshot_end_seq, current_max_seq)?; let mut next_seq = current_max_seq + 1; let now = current_timestamp(); let mut inserted_count = 0_i64; @@ -782,8 +799,7 @@ impl SessionStore { ) .optional()?; - let (id, created_at) = existing - .unwrap_or_else(|| (uuid::Uuid::new_v4().to_string(), now)); + let (id, created_at) = existing.unwrap_or_else(|| (uuid::Uuid::new_v4().to_string(), now)); tx.execute( " @@ -881,7 +897,10 @@ impl SessionStore { LIMIT ?4 ", )?; - let rows = stmt.query_map(params![scope_kind, scope_key, namespace, limit], map_memory_record)?; + let rows = stmt.query_map( + params![scope_kind, scope_key, namespace, limit], + map_memory_record, + )?; for row in rows { memories.push(row?); } @@ -940,7 +959,9 @@ impl SessionStore { ", )?; - let rows = stmt.query_map(params![scope_kind, since_timestamp], |row| row.get::<_, String>(0))?; + let rows = stmt.query_map(params![scope_kind, since_timestamp], |row| { + row.get::<_, String>(0) + })?; let mut scope_keys = Vec::new(); for row in rows { scope_keys.push(row?); @@ -1010,7 +1031,10 @@ impl SessionStore { Ok(changed > 0) } - pub fn upsert_scheduler_job(&self, input: &SchedulerJobUpsert) -> Result { + pub fn upsert_scheduler_job( + &self, + input: &SchedulerJobUpsert, + ) -> Result { let now = current_timestamp(); let conn = self.conn.lock().expect("session db mutex poisoned"); conn.execute( @@ -1067,7 +1091,10 @@ impl SessionStore { .ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into()) } - pub fn get_scheduler_job(&self, job_id: &str) -> Result, StorageError> { + pub fn get_scheduler_job( + &self, + job_id: &str, + ) -> Result, StorageError> { let conn = self.conn.lock().expect("session db mutex poisoned"); let mut stmt = conn.prepare( " @@ -1085,7 +1112,10 @@ impl SessionStore { .map_err(StorageError::from) } - pub fn list_scheduler_jobs(&self, enabled_only: bool) -> Result, StorageError> { + pub fn list_scheduler_jobs( + &self, + enabled_only: bool, + ) -> Result, StorageError> { let conn = self.conn.lock().expect("session db mutex poisoned"); let sql = if enabled_only { " @@ -1195,7 +1225,10 @@ impl SessionStore { LIMIT ?5 ", )?; - let rows = stmt.query_map(params![query, scope_kind, scope_key, namespace, limit], map_memory_record)?; + let rows = stmt.query_map( + params![query, scope_kind, scope_key, namespace, limit], + map_memory_record, + )?; for row in rows { memories.push(row?); } @@ -1214,7 +1247,10 @@ impl SessionStore { LIMIT ?4 ", )?; - let rows = stmt.query_map(params![query, scope_kind, scope_key, limit], map_memory_record)?; + let rows = stmt.query_map( + params![query, scope_kind, scope_key, limit], + map_memory_record, + )?; for row in rows { memories.push(row?); } @@ -1256,7 +1292,10 @@ impl SessionStore { LIMIT ?5 ", )?; - let rows = stmt.query_map(params![query, scope_kind, scope_key, namespace, limit], map_memory_record)?; + let rows = stmt.query_map( + params![query, scope_kind, scope_key, namespace, limit], + map_memory_record, + )?; for row in rows { memories.push(row?); } @@ -1275,7 +1314,10 @@ impl SessionStore { LIMIT ?4 ", )?; - let rows = stmt.query_map(params![query, scope_kind, scope_key, limit], map_memory_record)?; + let rows = stmt.query_map( + params![query, scope_kind, scope_key, limit], + map_memory_record, + )?; for row in rows { memories.push(row?); } @@ -1347,11 +1389,7 @@ fn map_session_record(row: &rusqlite::Row<'_>) -> rusqlite::Result) -> rusqlite::Result { let payload_json: String = row.get(4)?; let payload = serde_json::from_str(&payload_json).map_err(|err| { - rusqlite::Error::FromSqlConversionFailure( - 4, - rusqlite::types::Type::Text, - Box::new(err), - ) + rusqlite::Error::FromSqlConversionFailure(4, rusqlite::types::Type::Text, Box::new(err)) })?; Ok(SkillEventRecord { @@ -1391,25 +1429,13 @@ fn map_scheduler_job_record(row: &rusqlite::Row<'_>) -> rusqlite::Result = row.get(9)?; let schedule = serde_json::from_str(&schedule_json).map_err(|err| { - rusqlite::Error::FromSqlConversionFailure( - 2, - rusqlite::types::Type::Text, - Box::new(err), - ) + rusqlite::Error::FromSqlConversionFailure(2, rusqlite::types::Type::Text, Box::new(err)) })?; let target = serde_json::from_str(&target_json).map_err(|err| { - rusqlite::Error::FromSqlConversionFailure( - 5, - rusqlite::types::Type::Text, - Box::new(err), - ) + rusqlite::Error::FromSqlConversionFailure(5, rusqlite::types::Type::Text, Box::new(err)) })?; let payload = serde_json::from_str(&payload_json).map_err(|err| { - rusqlite::Error::FromSqlConversionFailure( - 6, - rusqlite::types::Type::Text, - Box::new(err), - ) + rusqlite::Error::FromSqlConversionFailure(6, rusqlite::types::Type::Text, Box::new(err)) })?; Ok(SchedulerJobRecord { @@ -1472,7 +1498,10 @@ fn ensure_messages_schema(conn: &Connection) -> Result<(), StorageError> { } if !has_column(conn, "messages", "reasoning_content")? { - add_column_if_missing(conn, "ALTER TABLE messages ADD COLUMN reasoning_content TEXT")?; + add_column_if_missing( + conn, + "ALTER TABLE messages ADD COLUMN reasoning_content TEXT", + )?; } Ok(()) @@ -1494,17 +1523,11 @@ fn ensure_scheduler_schema(conn: &Connection) -> Result<(), StorageError> { } if !has_column(conn, "scheduler_jobs", "last_status")? { - conn.execute( - "ALTER TABLE scheduler_jobs ADD COLUMN last_status TEXT", - [], - )?; + conn.execute("ALTER TABLE scheduler_jobs ADD COLUMN last_status TEXT", [])?; } if !has_column(conn, "scheduler_jobs", "last_error")? { - conn.execute( - "ALTER TABLE scheduler_jobs ADD COLUMN last_error TEXT", - [], - )?; + conn.execute("ALTER TABLE scheduler_jobs ADD COLUMN last_error TEXT", [])?; } if !has_column(conn, "scheduler_jobs", "run_count")? { @@ -1515,10 +1538,7 @@ fn ensure_scheduler_schema(conn: &Connection) -> Result<(), StorageError> { } if !has_column(conn, "scheduler_jobs", "max_runs")? { - conn.execute( - "ALTER TABLE scheduler_jobs ADD COLUMN max_runs INTEGER", - [], - )?; + conn.execute("ALTER TABLE scheduler_jobs ADD COLUMN max_runs INTEGER", [])?; } if !has_column(conn, "scheduler_jobs", "paused_at")? { @@ -1538,7 +1558,11 @@ fn ensure_scheduler_schema(conn: &Connection) -> Result<(), StorageError> { Ok(()) } -fn has_column(conn: &Connection, table_name: &str, column_name: &str) -> Result { +fn has_column( + conn: &Connection, + table_name: &str, + column_name: &str, +) -> Result { let pragma = format!("PRAGMA table_info({})", table_name); let mut stmt = conn.prepare(&pragma)?; let mut rows = stmt.query([])?; @@ -1557,7 +1581,10 @@ fn add_column_if_missing(conn: &Connection, sql: &str) -> Result<(), StorageErro match conn.execute(sql, []) { Ok(_) => Ok(()), Err(rusqlite::Error::SqliteFailure(_, Some(message))) - if message.contains("duplicate column name") => Ok(()), + if message.contains("duplicate column name") => + { + Ok(()) + } Err(error) => Err(StorageError::Database(error)), } } @@ -1581,7 +1608,11 @@ fn insert_message_with_seq( message: &ChatMessage, ) -> Result<(), StorageError> { let media_refs_json = serde_json::to_string(&message.media_refs)?; - let tool_calls_json = message.tool_calls.as_ref().map(serde_json::to_string).transpose()?; + let tool_calls_json = message + .tool_calls + .as_ref() + .map(serde_json::to_string) + .transpose()?; conn.execute( " INSERT INTO messages ( @@ -1638,43 +1669,47 @@ fn load_messages_between( ", )?; - let rows = stmt.query_map(params![session_id, start_seq_exclusive, end_seq_inclusive], |row| { - let media_refs_json: String = row.get(5)?; - let media_refs: Vec = serde_json::from_str(&media_refs_json).map_err(|err| { - rusqlite::Error::FromSqlConversionFailure( - media_refs_json.len(), - rusqlite::types::Type::Text, - Box::new(err), - ) - })?; + let rows = stmt.query_map( + params![session_id, start_seq_exclusive, end_seq_inclusive], + |row| { + let media_refs_json: String = row.get(5)?; + let media_refs: Vec = + serde_json::from_str(&media_refs_json).map_err(|err| { + rusqlite::Error::FromSqlConversionFailure( + media_refs_json.len(), + rusqlite::types::Type::Text, + Box::new(err), + ) + })?; - let tool_calls_json: Option = row.get(9)?; - let tool_calls = tool_calls_json - .as_deref() - .map(serde_json::from_str) - .transpose() - .map_err(|err| { - rusqlite::Error::FromSqlConversionFailure( - 9, - rusqlite::types::Type::Text, - Box::new(err), - ) - })?; + let tool_calls_json: Option = row.get(9)?; + let tool_calls = tool_calls_json + .as_deref() + .map(serde_json::from_str) + .transpose() + .map_err(|err| { + rusqlite::Error::FromSqlConversionFailure( + 9, + rusqlite::types::Type::Text, + Box::new(err), + ) + })?; - Ok(ChatMessage { - id: row.get(0)?, - role: row.get(1)?, - content: row.get(2)?, - system_context: row.get(3)?, - reasoning_content: row.get(4)?, - media_refs, - timestamp: row.get(6)?, - tool_call_id: row.get(7)?, - tool_name: row.get(8)?, - tool_state: None, - tool_calls, - }) - })?; + Ok(ChatMessage { + id: row.get(0)?, + role: row.get(1)?, + content: row.get(2)?, + system_context: row.get(3)?, + reasoning_content: row.get(4)?, + media_refs, + timestamp: row.get(6)?, + tool_call_id: row.get(7)?, + tool_name: row.get(8)?, + tool_state: None, + tool_calls, + }) + }, + )?; let mut messages = Vec::new(); for row in rows { @@ -1866,7 +1901,10 @@ mod tests { assert_eq!(messages[0].role, "assistant"); assert_eq!(messages[0].tool_calls.as_ref().unwrap().len(), 1); assert_eq!(messages[0].tool_calls.as_ref().unwrap()[0].id, "call_1"); - assert_eq!(messages[0].tool_calls.as_ref().unwrap()[0].name, "calculator"); + assert_eq!( + messages[0].tool_calls.as_ref().unwrap()[0].name, + "calculator" + ); } #[test] @@ -1874,17 +1912,17 @@ mod tests { let store = SessionStore::in_memory().unwrap(); let session = store.create_cli_session(Some("reasoning")).unwrap(); - let assistant = ChatMessage::assistant_with_reasoning( - "final answer", - "hidden reasoning", - ); + let assistant = ChatMessage::assistant_with_reasoning("final answer", "hidden reasoning"); store.append_message(&session.id, &assistant).unwrap(); let messages = store.load_messages(&session.id).unwrap(); assert_eq!(messages.len(), 1); assert_eq!(messages[0].content, "final answer"); - assert_eq!(messages[0].reasoning_content.as_deref(), Some("hidden reasoning")); + assert_eq!( + messages[0].reasoning_content.as_deref(), + Some("hidden reasoning") + ); } #[test] @@ -1892,8 +1930,12 @@ mod tests { let store = SessionStore::in_memory().unwrap(); let session = store.create_cli_session(Some("reset")).unwrap(); - store.append_message(&session.id, &ChatMessage::user("before")).unwrap(); - store.append_message(&session.id, &ChatMessage::assistant("context")).unwrap(); + store + .append_message(&session.id, &ChatMessage::user("before")) + .unwrap(); + store + .append_message(&session.id, &ChatMessage::assistant("context")) + .unwrap(); store.reset_session(&session.id).unwrap(); let stored = store.get_session(&session.id).unwrap().unwrap(); @@ -1909,7 +1951,9 @@ mod tests { assert_eq!(all_messages[0].content, "before"); assert_eq!(all_messages[1].content, "context"); - store.append_message(&session.id, &ChatMessage::user("after")).unwrap(); + store + .append_message(&session.id, &ChatMessage::user("after")) + .unwrap(); let active_messages = store.load_messages(&session.id).unwrap(); assert_eq!(active_messages.len(), 1); assert_eq!(active_messages[0].content, "after"); @@ -2010,19 +2054,33 @@ mod tests { let store = SessionStore::in_memory().unwrap(); let session = store.create_cli_session(Some("count-users")).unwrap(); - store.append_message(&session.id, &ChatMessage::system("agent")).unwrap(); - store.append_message(&session.id, &ChatMessage::user("u1")).unwrap(); - store.append_message(&session.id, &ChatMessage::assistant("a1")).unwrap(); - store.append_message(&session.id, &ChatMessage::user("u2")).unwrap(); + store + .append_message(&session.id, &ChatMessage::system("agent")) + .unwrap(); + store + .append_message(&session.id, &ChatMessage::user("u1")) + .unwrap(); + store + .append_message(&session.id, &ChatMessage::assistant("a1")) + .unwrap(); + store + .append_message(&session.id, &ChatMessage::user("u2")) + .unwrap(); assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 2); store.reset_session(&session.id).unwrap(); assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 0); - store.append_message(&session.id, &ChatMessage::system("agent-again")).unwrap(); - store.append_message(&session.id, &ChatMessage::user("u3")).unwrap(); - store.append_message(&session.id, &ChatMessage::user("u4")).unwrap(); + store + .append_message(&session.id, &ChatMessage::system("agent-again")) + .unwrap(); + store + .append_message(&session.id, &ChatMessage::user("u3")) + .unwrap(); + store + .append_message(&session.id, &ChatMessage::user("u4")) + .unwrap(); assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 2); } @@ -2052,12 +2110,20 @@ mod tests { store.append_message(&session.id, message).unwrap(); } - let snapshot_end_seq = store.get_session(&session.id).unwrap().unwrap().message_count; + let snapshot_end_seq = store + .get_session(&session.id) + .unwrap() + .unwrap() + .message_count; let preserved_messages = store.load_messages(&session.id).unwrap()[3..].to_vec(); let preserved_system_messages = vec![agent_prompt]; - store.append_message(&session.id, &ChatMessage::user("u5")).unwrap(); - store.append_message(&session.id, &ChatMessage::assistant("a5")).unwrap(); + store + .append_message(&session.id, &ChatMessage::user("u5")) + .unwrap(); + store + .append_message(&session.id, &ChatMessage::assistant("a5")) + .unwrap(); let summary_message = ChatMessage::system("[Compressed History]\n\nsummary"); let compacted = store @@ -2074,16 +2140,22 @@ mod tests { assert!(compacted); let active_messages = store.load_messages(&session.id).unwrap(); - assert_eq!(active_messages.len(), 10); - assert_eq!(active_messages[0].role, "system"); - assert_eq!(active_messages[0].content, "agent"); - assert_eq!(active_messages[0].system_context.as_deref(), Some(SYSTEM_CONTEXT_AGENT_PROMPT)); - assert_eq!(active_messages[1].role, "system"); - assert_eq!(active_messages[1].content, "[Compressed History]\n\nsummary"); - assert_eq!(active_messages[2].content, "u2"); - assert_eq!(active_messages[3].content, "a2"); - assert_eq!(active_messages[8].content, "u5"); - assert_eq!(active_messages[9].content, "a5"); + assert_eq!(active_messages.len(), 10); + assert_eq!(active_messages[0].role, "system"); + assert_eq!(active_messages[0].content, "agent"); + assert_eq!( + active_messages[0].system_context.as_deref(), + Some(SYSTEM_CONTEXT_AGENT_PROMPT) + ); + assert_eq!(active_messages[1].role, "system"); + assert_eq!( + active_messages[1].content, + "[Compressed History]\n\nsummary" + ); + assert_eq!(active_messages[2].content, "u2"); + assert_eq!(active_messages[3].content, "a2"); + assert_eq!(active_messages[8].content, "u5"); + assert_eq!(active_messages[9].content, "a5"); let stored = store.get_session(&session.id).unwrap().unwrap(); assert_eq!(stored.reset_cutoff_seq, 11); @@ -2128,12 +2200,7 @@ mod tests { let session = store.create_cli_session(Some("skill-events")).unwrap(); store - .append_skill_event( - None, - "discovered", - None, - &serde_json::json!({"count": 2}), - ) + .append_skill_event(None, "discovered", None, &serde_json::json!({"count": 2})) .unwrap(); store .append_skill_event( @@ -2383,13 +2450,26 @@ mod tests { .unwrap(); let scope_keys = store.list_memory_scope_keys("user").unwrap(); - assert_eq!(scope_keys, vec!["feishu:user-1".to_string(), "feishu:user-2".to_string()]); + assert_eq!( + scope_keys, + vec!["feishu:user-1".to_string(), "feishu:user-2".to_string()] + ); - let full_scope = store.list_memories_for_scope("user", "feishu:user-1").unwrap(); + let full_scope = store + .list_memories_for_scope("user", "feishu:user-1") + .unwrap(); assert_eq!(full_scope.len(), 2); - assert!(full_scope.iter().all(|memory| memory.scope_key == "feishu:user-1")); + assert!( + full_scope + .iter() + .all(|memory| memory.scope_key == "feishu:user-1") + ); assert!(full_scope.iter().any(|memory| memory.memory_key == "work")); - assert!(full_scope.iter().any(|memory| memory.memory_key == "workflow")); + assert!( + full_scope + .iter() + .any(|memory| memory.memory_key == "workflow") + ); } #[test] @@ -2499,4 +2579,4 @@ mod tests { assert_eq!(recent_scope_keys, vec!["feishu:user-2".to_string()]); } -} \ No newline at end of file +} diff --git a/src/text.rs b/src/text.rs index 519e29e..a252ad4 100644 --- a/src/text.rs +++ b/src/text.rs @@ -17,4 +17,4 @@ pub fn truncate_with_ellipsis(text: &str, max_chars: usize) -> String { } format!("{}...", take_prefix_chars(text, max_chars)) -} \ No newline at end of file +} diff --git a/src/tools/bash.rs b/src/tools/bash.rs index 3354e5f..6ce0424 100644 --- a/src/tools/bash.rs +++ b/src/tools/bash.rs @@ -15,7 +15,8 @@ use crate::tools::traits::{Tool, ToolResult}; const MAX_TIMEOUT_SECS: u64 = 600; const MAX_OUTPUT_CHARS: usize = 50_000; const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__"; -const USER_ACTION_HINT: &str = "该命令正在等待你完成外部操作。完成后请告诉我继续,或重新运行后续检查命令。"; +const USER_ACTION_HINT: &str = + "该命令正在等待你完成外部操作。完成后请告诉我继续,或重新运行后续检查命令。"; pub struct BashTool { timeout_secs: u64, @@ -208,7 +209,10 @@ impl Tool for BashTool { .map(|d| Path::new(d)) .unwrap_or_else(|| Path::new(".")); - match self.run_command(command, cwd, timeout_secs, interactive).await { + match self + .run_command(command, cwd, timeout_secs, interactive) + .await + { Ok(output) => Ok(ToolResult { success: true, output, @@ -366,10 +370,7 @@ mod tests { #[tokio::test] async fn test_pwd_command() { let tool = BashTool::new(); - let result = tool - .execute(json!({ "command": "pwd" })) - .await - .unwrap(); + let result = tool.execute(json!({ "command": "pwd" })).await.unwrap(); assert!(result.success); } @@ -377,7 +378,10 @@ mod tests { #[tokio::test] async fn test_ls_command() { let tool = BashTool::new(); - let result = tool.execute(json!({ "command": "ls -la /tmp" })).await.unwrap(); + let result = tool + .execute(json!({ "command": "ls -la /tmp" })) + .await + .unwrap(); assert!(result.success); } diff --git a/src/tools/calculator.rs b/src/tools/calculator.rs index 8f033dc..2b42e8d 100644 --- a/src/tools/calculator.rs +++ b/src/tools/calculator.rs @@ -659,10 +659,7 @@ mod tests { #[tokio::test] async fn test_evaluate_missing_expression() { let tool = CalculatorTool::new(); - let result = tool - .execute(json!({"function": "evaluate"})) - .await - .unwrap(); + let result = tool.execute(json!({"function": "evaluate"})).await.unwrap(); assert!(!result.success); } diff --git a/src/tools/file_edit.rs b/src/tools/file_edit.rs index 78ab2f4..a097988 100644 --- a/src/tools/file_edit.rs +++ b/src/tools/file_edit.rs @@ -268,8 +268,8 @@ impl Tool for FileEditTool { #[cfg(test)] mod tests { use super::*; - use tempfile::NamedTempFile; use std::io::Write; + use tempfile::NamedTempFile; #[tokio::test] async fn test_edit_simple() { diff --git a/src/tools/file_read.rs b/src/tools/file_read.rs index 10725f1..c3e0baf 100644 --- a/src/tools/file_read.rs +++ b/src/tools/file_read.rs @@ -218,7 +218,7 @@ impl Tool for FileReadTool { // Try to read as binary and encode as base64 match std::fs::read(&resolved) { Ok(bytes) => { - use base64::{engine::general_purpose::STANDARD, Engine}; + use base64::{Engine, engine::general_purpose::STANDARD}; let encoded = STANDARD.encode(&bytes); let mime = mime_guess::from_path(&resolved) .first_or_octet_stream() @@ -248,8 +248,8 @@ impl Tool for FileReadTool { #[cfg(test)] mod tests { use super::*; - use tempfile::NamedTempFile; use std::io::Write; + use tempfile::NamedTempFile; #[tokio::test] async fn test_read_simple_file() { @@ -308,10 +308,7 @@ mod tests { #[tokio::test] async fn test_is_directory() { let tool = FileReadTool::new(); - let result = tool - .execute(json!({ "path": "." })) - .await - .unwrap(); + let result = tool.execute(json!({ "path": "." })).await.unwrap(); assert!(!result.success); assert!(result.error.unwrap().contains("Not a file")); diff --git a/src/tools/file_write.rs b/src/tools/file_write.rs index 3472c70..a3da100 100644 --- a/src/tools/file_write.rs +++ b/src/tools/file_write.rs @@ -195,10 +195,7 @@ mod tests { #[tokio::test] async fn test_write_missing_path() { let tool = FileWriteTool::new(); - let result = tool - .execute(json!({ "content": "Hello" })) - .await - .unwrap(); + let result = tool.execute(json!({ "content": "Hello" })).await.unwrap(); assert!(!result.success); assert!(result.error.unwrap().contains("path")); diff --git a/src/tools/http_request.rs b/src/tools/http_request.rs index cc70cce..0c05dd0 100644 --- a/src/tools/http_request.rs +++ b/src/tools/http_request.rs @@ -50,10 +50,7 @@ impl HttpRequestTool { } if !host_matches_allowlist(&host, &self.allowed_domains) { - return Err(format!( - "Host '{}' is not in allowed_domains", - host - )); + return Err(format!("Host '{}' is not in allowed_domains", host)); } Ok(url.to_string()) @@ -80,9 +77,7 @@ impl HttpRequestTool { for (key, value) in obj { if let Some(str_val) = value.as_str() { if let Ok(name) = reqwest::header::HeaderName::from_bytes(key.as_bytes()) { - if let Ok(val) = - reqwest::header::HeaderValue::from_str(str_val) - { + if let Ok(val) = reqwest::header::HeaderValue::from_str(str_val) { header_map.insert(name, val); } } @@ -193,7 +188,9 @@ fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool { allowed_domains.iter().any(|domain| { host == domain - || host.strip_suffix(domain).is_some_and(|prefix| prefix.ends_with('.')) + || host + .strip_suffix(domain) + .is_some_and(|prefix| prefix.ends_with('.')) }) } @@ -204,7 +201,11 @@ fn is_private_host(host: &str) -> bool { } // Check .local TLD - if host.rsplit('.').next().is_some_and(|label| label == "local") { + if host + .rsplit('.') + .next() + .is_some_and(|label| label == "local") + { return true; } @@ -226,9 +227,7 @@ fn is_private_ip(ip: &std::net::IpAddr) -> bool { || v4.is_broadcast() || v4.is_multicast() } - std::net::IpAddr::V6(v6) => { - v6.is_loopback() || v6.is_unspecified() || v6.is_multicast() - } + std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(), } } @@ -280,10 +279,7 @@ impl Tool for HttpRequestTool { } }; - let method_str = args - .get("method") - .and_then(|v| v.as_str()) - .unwrap_or("GET"); + let method_str = args.get("method").and_then(|v| v.as_str()).unwrap_or("GET"); let headers_val = args.get("headers").cloned().unwrap_or(json!({})); let body = args.get("body").and_then(|v| v.as_str()); diff --git a/src/tools/memory_manage.rs b/src/tools/memory_manage.rs index 4df41cd..bb29b8c 100644 --- a/src/tools/memory_manage.rs +++ b/src/tools/memory_manage.rs @@ -94,7 +94,7 @@ impl Tool for MemoryManageTool { return Ok(error_result(&format!( "memory '{}.{}' not found", input.namespace, input.memory_key - ))) + ))); } } } @@ -108,9 +108,14 @@ impl Tool for MemoryManageTool { None => return Ok(error_result("Missing required parameter: key")), }; - let deleted = self.store.delete_memory("user", &scope_key, namespace, key)?; + let deleted = self + .store + .delete_memory("user", &scope_key, namespace, key)?; if !deleted { - return Ok(error_result(&format!("memory '{}.{}' not found", namespace, key))); + return Ok(error_result(&format!( + "memory '{}.{}' not found", + namespace, key + ))); } json!({ @@ -289,4 +294,4 @@ mod tests { assert!(!result.success); assert!(result.error.unwrap().contains("Unsupported action")); } -} \ No newline at end of file +} diff --git a/src/tools/memory_search.rs b/src/tools/memory_search.rs index 68eac82..b7f268a 100644 --- a/src/tools/memory_search.rs +++ b/src/tools/memory_search.rs @@ -90,7 +90,9 @@ impl Tool for MemorySearchTool { .get("limit") .and_then(|value| value.as_u64()) .unwrap_or(10) as usize; - let memories = self.store.list_memories("user", &scope_key, namespace, limit)?; + let memories = self + .store + .list_memories("user", &scope_key, namespace, limit)?; json!({ "count": memories.len(), "memories": memories.into_iter().map(memory_to_json).collect::>() @@ -135,7 +137,12 @@ impl Tool for MemorySearchTool { match self.store.get_memory("user", &scope_key, namespace, key)? { Some(memory) => memory_to_json(memory), - None => return Ok(error_result(&format!("memory '{}.{}' not found", namespace, key))), + None => { + return Ok(error_result(&format!( + "memory '{}.{}' not found", + namespace, key + ))); + } } } _ => return Ok(error_result("Unsupported action")), @@ -286,4 +293,4 @@ mod tests { assert!(!result.success); assert!(result.error.unwrap().contains("queries")); } -} \ No newline at end of file +} diff --git a/src/tools/scheduler_manage.rs b/src/tools/scheduler_manage.rs index e5b3c10..aaa63d7 100644 --- a/src/tools/scheduler_manage.rs +++ b/src/tools/scheduler_manage.rs @@ -5,9 +5,7 @@ use async_trait::async_trait; use serde_json::json; use crate::config::SchedulerSchedule; -use crate::storage::{ - SchedulerJobRecord, SchedulerJobState, SchedulerJobUpsert, SessionStore, -}; +use crate::storage::{SchedulerJobRecord, SchedulerJobState, SchedulerJobUpsert, SessionStore}; use crate::tools::traits::{Tool, ToolResult}; pub struct SchedulerManageTool { @@ -35,11 +33,7 @@ impl Tool for SchedulerManageTool { } fn parameters_schema(&self) -> serde_json::Value { - let mut allowed_agents = self - .known_agents - .iter() - .cloned() - .collect::>(); + let mut allowed_agents = self.known_agents.iter().cloned().collect::>(); allowed_agents.sort(); let agent_hint = if allowed_agents.is_empty() { "agent_task payload.agent may be omitted or set to 'default'.".to_string() @@ -225,8 +219,15 @@ fn build_upsert( startup_delay_secs, target, payload, - enabled: args.get("enabled").and_then(|value| value.as_bool()).unwrap_or(true), - state: if args.get("enabled").and_then(|value| value.as_bool()).unwrap_or(true) { + enabled: args + .get("enabled") + .and_then(|value| value.as_bool()) + .unwrap_or(true), + state: if args + .get("enabled") + .and_then(|value| value.as_bool()) + .unwrap_or(true) + { SchedulerJobState::Scheduled } else { SchedulerJobState::Paused @@ -252,14 +253,28 @@ fn enrich_target_from_context( }; if !has_non_empty_string(&object, "channel") { - if let Some(channel_name) = context.channel_name.as_ref().filter(|value| !value.trim().is_empty()) { - object.insert("channel".to_string(), serde_json::Value::String(channel_name.clone())); + if let Some(channel_name) = context + .channel_name + .as_ref() + .filter(|value| !value.trim().is_empty()) + { + object.insert( + "channel".to_string(), + serde_json::Value::String(channel_name.clone()), + ); } } if !has_non_empty_string(&object, "chat_id") { - if let Some(chat_id) = context.chat_id.as_ref().filter(|value| !value.trim().is_empty()) { - object.insert("chat_id".to_string(), serde_json::Value::String(chat_id.clone())); + if let Some(chat_id) = context + .chat_id + .as_ref() + .filter(|value| !value.trim().is_empty()) + { + object.insert( + "chat_id".to_string(), + serde_json::Value::String(chat_id.clone()), + ); } } @@ -274,7 +289,10 @@ fn has_non_empty_string(object: &serde_json::Map, fie .unwrap_or(false) } -fn validate_agent_task_payload(payload: &serde_json::Value, known_agents: &HashSet) -> anyhow::Result<()> { +fn validate_agent_task_payload( + payload: &serde_json::Value, + known_agents: &HashSet, +) -> anyhow::Result<()> { let Some(prompt) = payload.get("prompt").and_then(|value| value.as_str()) else { anyhow::bail!("agent_task payload.prompt is required and must be a string") }; @@ -299,7 +317,8 @@ fn unknown_agent_message(agent_name: &str, known_agents: &HashSet) -> St configured_agents.sort(); let configured_hint = if configured_agents.is_empty() { - "No named agents are configured; use payload.agent='default' or omit payload.agent.".to_string() + "No named agents are configured; use payload.agent='default' or omit payload.agent." + .to_string() } else { format!( "payload.agent must be omitted, set to 'default', or use one of configured agents: default, {}.", @@ -309,9 +328,7 @@ fn unknown_agent_message(agent_name: &str, known_agents: &HashSet) -> St format!( "Unknown agent '{}' for agent_task payload.agent. {} '{}' is not an agent. If you mean a skill, do not put it in payload.agent.", - agent_name, - configured_hint, - agent_name, + agent_name, configured_hint, agent_name, ) } @@ -517,7 +534,10 @@ mod tests { .unwrap() .unwrap(); assert_eq!(saved.kind, "silent_agent_task"); - assert_eq!(saved.target["session_chat_id"], "scheduler/agent.daily_summary.background"); + assert_eq!( + saved.target["session_chat_id"], + "scheduler/agent.daily_summary.background" + ); } #[tokio::test] @@ -654,7 +674,9 @@ mod tests { assert!(!result.success); assert_eq!( result.error.as_deref(), - Some("Missing required parameters: scheduler_manage expects a JSON object like {\"action\":\"list\"}") + Some( + "Missing required parameters: scheduler_manage expects a JSON object like {\"action\":\"list\"}" + ) ); } @@ -668,8 +690,10 @@ mod tests { .as_str() .unwrap(); - assert!(payload_description.contains("avoid repeating schedule phrases or execution times")); + assert!( + payload_description.contains("avoid repeating schedule phrases or execution times") + ); assert!(payload_description.contains("每天9点")); assert!(payload_description.contains("每小时")); } -} \ No newline at end of file +} diff --git a/src/tools/schema.rs b/src/tools/schema.rs index 91bca47..89649f1 100644 --- a/src/tools/schema.rs +++ b/src/tools/schema.rs @@ -408,7 +408,10 @@ impl SchemaCleanr { match non_null.len() { 0 => Value::String("null".to_string()), - 1 => non_null.into_iter().next().unwrap_or(Value::String("null".to_string())), + 1 => non_null + .into_iter() + .next() + .unwrap_or(Value::String("null".to_string())), _ => Value::Array(non_null), } } else { diff --git a/src/tools/skill_manage.rs b/src/tools/skill_manage.rs index 0b162f5..e5a93b3 100644 --- a/src/tools/skill_manage.rs +++ b/src/tools/skill_manage.rs @@ -83,7 +83,11 @@ impl Tool for SkillManageTool { let scope = match args.get("scope").and_then(|v| v.as_str()) { Some(value) => match SkillScope::parse(value) { Some(scope) => scope, - None => return Ok(error_result("scope must be 'project' or 'user'; .agents sources are discovery-only")), + None => { + return Ok(error_result( + "scope must be 'project' or 'user'; .agents sources are discovery-only", + )); + } }, None => SkillScope::Project, }; @@ -91,9 +95,7 @@ impl Tool for SkillManageTool { let name = args.get("name").and_then(|v| v.as_str()); let result = match action { - "list" => { - list_skills_payload(&self.skills) - } + "list" => list_skills_payload(&self.skills), "get" => { let name = match name { Some(name) => name, @@ -127,7 +129,10 @@ impl Tool for SkillManageTool { }; let body = args.get("body").and_then(|v| v.as_str()).unwrap_or(""); - match self.skills.create_skill(scope, name, description, body, reload) { + match self + .skills + .create_skill(scope, name, description, body, reload) + { Ok(skill) => json!({ "status": "created", "name": skill.name, @@ -149,7 +154,10 @@ impl Tool for SkillManageTool { return Ok(error_result("update requires description or body")); } - match self.skills.update_skill(scope, name, description, body, reload) { + match self + .skills + .update_skill(scope, name, description, body, reload) + { Ok(skill) => json!({ "status": "updated", "name": skill.name, diff --git a/src/tools/time.rs b/src/tools/time.rs index a24f9ae..da1c7a4 100644 --- a/src/tools/time.rs +++ b/src/tools/time.rs @@ -99,9 +99,7 @@ fn execute_time_request( .and_then(Value::as_str) .unwrap_or(default_timezone); let timezone = timezone_name.parse::().map_err(|_| { - format!( - "Invalid timezone: {timezone_name}. Expected an IANA timezone like Asia/Shanghai" - ) + format!("Invalid timezone: {timezone_name}. Expected an IANA timezone like Asia/Shanghai") })?; let now_local = now_utc.with_timezone(&timezone); @@ -168,13 +166,14 @@ fn parse_offset_request(args: &Value) -> Result, String> { let direction = direction.ok_or_else(|| { "Missing required parameter: direction when requesting a relative time".to_string() })?; - let amount = amount - .and_then(Value::as_u64) - .ok_or_else(|| "Missing required parameter: amount when requesting a relative time".to_string())?; + let amount = amount.and_then(Value::as_u64).ok_or_else(|| { + "Missing required parameter: amount when requesting a relative time".to_string() + })?; let amount = u32::try_from(amount) .map_err(|_| "amount is too large; expected a 32-bit unsigned integer".to_string())?; - let unit = unit - .ok_or_else(|| "Missing required parameter: unit when requesting a relative time".to_string())?; + let unit = unit.ok_or_else(|| { + "Missing required parameter: unit when requesting a relative time".to_string() + })?; Ok(Some(OffsetRequest { direction: OffsetDirection::parse(direction)?, @@ -188,10 +187,18 @@ fn apply_offset( offset: &OffsetRequest, ) -> Result, String> { match (offset.direction, offset.unit) { - (OffsetDirection::Future, TimeUnit::Minute) => Ok(now_local + Duration::minutes(i64::from(offset.amount))), - (OffsetDirection::Past, TimeUnit::Minute) => Ok(now_local - Duration::minutes(i64::from(offset.amount))), - (OffsetDirection::Future, TimeUnit::Hour) => Ok(now_local + Duration::hours(i64::from(offset.amount))), - (OffsetDirection::Past, TimeUnit::Hour) => Ok(now_local - Duration::hours(i64::from(offset.amount))), + (OffsetDirection::Future, TimeUnit::Minute) => { + Ok(now_local + Duration::minutes(i64::from(offset.amount))) + } + (OffsetDirection::Past, TimeUnit::Minute) => { + Ok(now_local - Duration::minutes(i64::from(offset.amount))) + } + (OffsetDirection::Future, TimeUnit::Hour) => { + Ok(now_local + Duration::hours(i64::from(offset.amount))) + } + (OffsetDirection::Past, TimeUnit::Hour) => { + Ok(now_local - Duration::hours(i64::from(offset.amount))) + } (OffsetDirection::Future, TimeUnit::Day) => now_local .checked_add_days(Days::new(u64::from(offset.amount))) .ok_or_else(|| "Failed to add days to the current time".to_string()), @@ -439,9 +446,9 @@ mod tests { assert!(result.success); let payload: Value = serde_json::from_str(&result.output).unwrap(); - let result_time = chrono::DateTime::parse_from_rfc3339(payload["result_time"].as_str().unwrap()) - .unwrap(); + let result_time = + chrono::DateTime::parse_from_rfc3339(payload["result_time"].as_str().unwrap()).unwrap(); assert_eq!(result_time.hour(), 12); assert_eq!(result_time.minute(), 30); } -} \ No newline at end of file +} diff --git a/src/tools/web_fetch.rs b/src/tools/web_fetch.rs index f53fef8..16adb87 100644 --- a/src/tools/web_fetch.rs +++ b/src/tools/web_fetch.rs @@ -239,7 +239,11 @@ fn is_private_host(host: &str) -> bool { return true; } - if host.rsplit('.').next().is_some_and(|label| label == "local") { + if host + .rsplit('.') + .next() + .is_some_and(|label| label == "local") + { return true; } @@ -248,7 +252,9 @@ fn is_private_host(host: &str) -> bool { std::net::IpAddr::V4(v4) => { v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified() } - std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(), + std::net::IpAddr::V6(v6) => { + v6.is_loopback() || v6.is_unspecified() || v6.is_multicast() + } }; } diff --git a/tests/test_integration.rs b/tests/test_integration.rs index ee9858a..f428a1c 100644 --- a/tests/test_integration.rs +++ b/tests/test_integration.rs @@ -1,6 +1,6 @@ -use std::collections::HashMap; -use picobot::providers::{create_provider, ChatCompletionRequest, Message}; use picobot::config::{Config, LLMProviderConfig}; +use picobot::providers::{ChatCompletionRequest, Message, create_provider}; +use std::collections::HashMap; fn load_config() -> Option { dotenv::from_filename("tests/test.env").ok()?; @@ -44,8 +44,7 @@ fn create_request(content: &str) -> ChatCompletionRequest { #[tokio::test] #[ignore] async fn test_openai_simple_completion() { - let config = load_config() - .expect("Please configure tests/test.env with valid API keys"); + let config = load_config().expect("Please configure tests/test.env with valid API keys"); let provider = create_provider(config).expect("Failed to create provider"); let response = provider.chat(create_request("Say 'ok'")).await.unwrap(); @@ -59,8 +58,7 @@ async fn test_openai_simple_completion() { #[tokio::test] #[ignore] async fn test_openai_conversation() { - let config = load_config() - .expect("Please configure tests/test.env with valid API keys"); + let config = load_config().expect("Please configure tests/test.env with valid API keys"); let provider = create_provider(config).expect("Failed to create provider"); @@ -84,7 +82,9 @@ async fn test_openai_conversation() { async fn test_config_load() { // Test that config.json can be loaded and provider config created let config = Config::load("config.json").expect("Failed to load config.json"); - let provider_config = config.get_provider_config("default").expect("Failed to get provider config"); + let provider_config = config + .get_provider_config("default") + .expect("Failed to get provider config"); assert_eq!(provider_config.provider_type, "openai"); assert_eq!(provider_config.name, "aliyun"); diff --git a/tests/test_request_format.rs b/tests/test_request_format.rs index eb6ea16..de665d5 100644 --- a/tests/test_request_format.rs +++ b/tests/test_request_format.rs @@ -1,5 +1,5 @@ -use picobot::providers::{ChatCompletionRequest, Message}; use picobot::protocol::{SessionSummary, WsInbound, WsOutbound}; +use picobot::providers::{ChatCompletionRequest, Message}; /// Test that message with special characters is properly escaped #[test] @@ -19,7 +19,9 @@ fn test_message_special_characters() { #[test] fn test_multiline_system_prompt() { let messages = vec![ - Message::system("You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate"), + Message::system( + "You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate", + ), Message::user("Hi"), ]; @@ -33,10 +35,7 @@ fn test_multiline_system_prompt() { #[test] fn test_chat_request_serialization() { let request = ChatCompletionRequest { - messages: vec![ - Message::system("You are helpful"), - Message::user("Hello"), - ], + messages: vec![Message::system("You are helpful"), Message::user("Hello")], temperature: Some(0.7), max_tokens: Some(100), tools: None, @@ -136,7 +135,12 @@ fn test_tool_call_outbound_serialization() { let decoded: WsOutbound = serde_json::from_str(&json).unwrap(); match decoded { - WsOutbound::ToolCall { tool_call_id, tool_name, arguments, .. } => { + WsOutbound::ToolCall { + tool_call_id, + tool_name, + arguments, + .. + } => { assert_eq!(tool_call_id, "call-1"); assert_eq!(tool_name, "calculator"); assert_eq!(arguments["expression"], "1 + 1"); @@ -161,7 +165,12 @@ fn test_tool_result_outbound_serialization() { let decoded: WsOutbound = serde_json::from_str(&json).unwrap(); match decoded { - WsOutbound::ToolResult { tool_call_id, tool_name, content, .. } => { + WsOutbound::ToolResult { + tool_call_id, + tool_name, + content, + .. + } => { assert_eq!(tool_call_id, "call-1"); assert_eq!(tool_name, "calculator"); assert!(content.contains('2')); diff --git a/tests/test_tool_calling.rs b/tests/test_tool_calling.rs index 5fec8ce..961832e 100644 --- a/tests/test_tool_calling.rs +++ b/tests/test_tool_calling.rs @@ -1,6 +1,6 @@ -use std::collections::HashMap; -use picobot::providers::{create_provider, ChatCompletionRequest, Message, Tool, ToolFunction}; use picobot::config::LLMProviderConfig; +use picobot::providers::{ChatCompletionRequest, Message, Tool, ToolFunction, create_provider}; +use std::collections::HashMap; fn load_openai_config() -> Option { dotenv::from_filename("tests/test.env").ok()?; @@ -55,8 +55,7 @@ fn make_weather_tool() -> Tool { #[tokio::test] #[ignore] async fn test_openai_tool_call() { - let config = load_openai_config() - .expect("Please configure tests/test.env with valid API keys"); + let config = load_openai_config().expect("Please configure tests/test.env with valid API keys"); let provider = create_provider(config).expect("Failed to create provider"); @@ -70,7 +69,11 @@ async fn test_openai_tool_call() { let response = provider.chat(request).await.unwrap(); // Should have tool calls - assert!(!response.tool_calls.is_empty(), "Expected tool call, got: {}", response.content); + assert!( + !response.tool_calls.is_empty(), + "Expected tool call, got: {}", + response.content + ); let tool_call = &response.tool_calls[0]; assert_eq!(tool_call.name, "get_weather"); @@ -80,8 +83,7 @@ async fn test_openai_tool_call() { #[tokio::test] #[ignore] async fn test_openai_tool_call_with_manual_execution() { - let config = load_openai_config() - .expect("Please configure tests/test.env with valid API keys"); + let config = load_openai_config().expect("Please configure tests/test.env with valid API keys"); let provider = create_provider(config).expect("Failed to create provider"); @@ -94,8 +96,7 @@ async fn test_openai_tool_call_with_manual_execution() { }; let response1 = provider.chat(request1).await.unwrap(); - let tool_call = response1.tool_calls.first() - .expect("Expected tool call"); + let tool_call = response1.tool_calls.first().expect("Expected tool call"); assert_eq!(tool_call.name, "get_weather"); // Second request with tool result @@ -118,8 +119,7 @@ async fn test_openai_tool_call_with_manual_execution() { #[tokio::test] #[ignore] async fn test_openai_no_tool_when_not_provided() { - let config = load_openai_config() - .expect("Please configure tests/test.env with valid API keys"); + let config = load_openai_config().expect("Please configure tests/test.env with valid API keys"); let provider = create_provider(config).expect("Failed to create provider");