Format codebase with rustfmt

This commit is contained in:
xiaoski 2026-06-15 23:47:24 +08:00
parent c6f4392e63
commit 8f4ee79d8d
65 changed files with 1807 additions and 1061 deletions

View File

@ -4,10 +4,8 @@ use crate::agent::system_prompt::build_system_prompt;
use crate::bus::message::ContentBlock;
use crate::bus::{ChatMessage, MediaRef};
use crate::config::LLMProviderConfig;
use crate::observability::{
truncate_args, Observer, ObserverEvent, ToolExecutionOutcome,
};
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall};
use crate::observability::{Observer, ObserverEvent, ToolExecutionOutcome, truncate_args};
use crate::providers::{ChatCompletionRequest, LLMProvider, Message, ToolCall, create_provider};
use crate::tools::ToolRegistry;
use std::collections::VecDeque;
use std::hash::{Hash, Hasher};
@ -256,7 +254,10 @@ impl AgentLoop {
}
/// Create a new AgentLoop with provider created from config and given tools.
pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>) -> Result<Self, AgentError> {
pub fn with_tools(
provider_config: LLMProviderConfig,
tools: Arc<ToolRegistry>,
) -> Result<Self, AgentError> {
let max_iterations = provider_config.max_tool_iterations;
let model_name = provider_config.model_id.clone();
let workspace_dir = provider_config.workspace_dir.clone();
@ -279,7 +280,13 @@ impl AgentLoop {
}
/// Create a new AgentLoop with an existing shared provider.
pub fn with_provider(provider: Arc<dyn LLMProvider>, max_iterations: usize, model_name: String, workspace_dir: PathBuf, input_types: Vec<String>) -> Self {
pub fn with_provider(
provider: Arc<dyn LLMProvider>,
max_iterations: usize,
model_name: String,
workspace_dir: PathBuf,
input_types: Vec<String>,
) -> Self {
Self {
provider,
tools: Arc::new(ToolRegistry::new()),
@ -379,7 +386,12 @@ impl AgentLoop {
let content = if m.media_refs.is_empty() {
vec![ContentBlock::text(&m.content)]
} else {
build_content_blocks(&m.content, &m.media_refs, &self.input_types, &self.media_registry)
build_content_blocks(
&m.content,
&m.media_refs,
&self.input_types,
&self.media_registry,
)
};
Message {
@ -399,14 +411,28 @@ impl AgentLoop {
/// it loops back to the LLM with the tool results until either:
/// - The LLM returns no more tool calls (final response)
/// - Maximum iterations are reached
pub async fn process(&self, mut messages: Vec<ChatMessage>) -> Result<AgentProcessResult, AgentError> {
pub async fn process(
&self,
mut messages: Vec<ChatMessage>,
) -> Result<AgentProcessResult, AgentError> {
#[cfg(debug_assertions)]
tracing::debug!(history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process");
tracing::debug!(
history_len = messages.len(),
max_iterations = self.max_iterations,
"Starting agent process"
);
// Build and inject system prompt if not present
let has_system = messages.first().is_some_and(|m| m.role == "system");
if !has_system {
let system_prompt = build_system_prompt(&self.workspace_dir, &self.model_name, &self.tools, None, None, false);
let system_prompt = build_system_prompt(
&self.workspace_dir,
&self.model_name,
&self.tools,
None,
None,
false,
);
#[cfg(debug_assertions)]
tracing::debug!("System prompt injected:\n{}", system_prompt);
messages.insert(0, ChatMessage::system(system_prompt));
@ -427,9 +453,7 @@ impl AgentLoop {
let estimated = estimate_tokens(&messages);
let danger = (self.context_window as f64 * 0.8) as usize;
if estimated > danger {
let trimmed = self.preemptive_trim_old_tool_results(
&mut messages, 2000, 4,
);
let trimmed = self.preemptive_trim_old_tool_results(&mut messages, 2000, 4);
if trimmed > 0 {
#[cfg(debug_assertions)]
tracing::debug!(
@ -463,11 +487,10 @@ impl AgentLoop {
};
// Call LLM
let response = (*self.provider).chat(request).await
.map_err(|e| {
tracing::error!(error = %e, "LLM request failed");
AgentError::LlmError(e.to_string())
})?;
let response = (*self.provider).chat(request).await.map_err(|e| {
tracing::error!(error = %e, "LLM request failed");
AgentError::LlmError(e.to_string())
})?;
accumulated_tokens += response.usage.total_tokens;
@ -493,7 +516,9 @@ impl AgentLoop {
// Execute tool calls — log and notify immediately
{
let tools_info: Vec<String> = response.tool_calls.iter()
let tools_info: Vec<String> = response
.tool_calls
.iter()
.map(|tc| {
let args = serde_json::to_string(&tc.arguments).unwrap_or_default();
let s = format!("{}:{}", tc.name, args);
@ -522,7 +547,9 @@ impl AgentLoop {
// Log function call with name and arguments
let args_str = match &tool_call.arguments {
serde_json::Value::Object(obj) if obj.is_empty() => "{}".to_string(),
other => serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()),
other => {
serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string())
}
};
tracing::info!(tool = %tool_call.name, args = %args_str, "Calling tool");
@ -562,7 +589,11 @@ impl AgentLoop {
// Loop continues to next iteration with updated messages
#[cfg(debug_assertions)]
tracing::debug!(iteration, message_count = messages.len(), "Tool execution complete, continuing to next iteration");
tracing::debug!(
iteration,
message_count = messages.len(),
"Tool execution complete, continuing to next iteration"
);
}
// Max iterations reached - ask LLM for a summary based on completed work
@ -571,7 +602,7 @@ impl AgentLoop {
// Add a message asking for summary
let summary_request = ChatMessage::user(
"You have reached the maximum number of tool call iterations. \
Please provide your best answer based on the work completed so far."
Please provide your best answer based on the work completed so far.",
);
messages.push(summary_request);
@ -603,14 +634,19 @@ impl AgentLoop {
Err(e) => {
// Fallback if summary call fails
tracing::error!(error = %e, "Failed to get summary from LLM");
let final_message = ChatMessage::assistant(
format!("I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.", self.max_iterations)
);
let final_message = ChatMessage::assistant(format!(
"I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.",
self.max_iterations
));
emitted_messages.push(final_message.clone());
Ok(AgentProcessResult {
final_response: final_message,
emitted_messages,
total_tokens: if accumulated_tokens > 0 { Some(accumulated_tokens) } else { None },
total_tokens: if accumulated_tokens > 0 {
Some(accumulated_tokens)
} else {
None
},
})
}
}
@ -698,10 +734,7 @@ impl AgentLoop {
}
// Apply duration
ToolExecutionOutcome {
duration,
..result
}
ToolExecutionOutcome { duration, ..result }
}
/// Internal tool execution without event tracking.
@ -723,18 +756,12 @@ impl AgentLoop {
ToolExecutionOutcome::success(result.output)
} else {
let error = result.error.unwrap_or_default();
ToolExecutionOutcome::failure(
format!("Error: {}", error),
Some(error),
)
ToolExecutionOutcome::failure(format!("Error: {}", error), Some(error))
}
}
Err(e) => {
tracing::error!(tool = %tool_call.name, error = %e, "Tool execution failed");
ToolExecutionOutcome::failure(
format!("Error: {}", e),
Some(e.to_string()),
)
ToolExecutionOutcome::failure(format!("Error: {}", e), Some(e.to_string()))
}
}
}
@ -822,8 +849,14 @@ mod tests {
assert_eq!(provider_message.role, "assistant");
assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1);
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].id, "call_1");
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].name, "calculator");
assert_eq!(
provider_message.tool_calls.as_ref().unwrap()[0].id,
"call_1"
);
assert_eq!(
provider_message.tool_calls.as_ref().unwrap()[0].name,
"calculator"
);
}
}

View File

@ -234,27 +234,33 @@ impl ContextCompressor {
}
} else if messages[i].role == "tool"
&& let Some(ref tid) = messages[i].tool_call_id
&& !declared.contains(tid.as_str()) {
messages.remove(i);
continue;
}
&& !declared.contains(tid.as_str())
{
messages.remove(i);
continue;
}
i += 1;
}
let broken: Vec<usize> = messages.iter().enumerate()
let broken: Vec<usize> = messages
.iter()
.enumerate()
.filter_map(|(idx, msg)| {
if msg.role == "assistant"
&& let Some(ref tcs) = msg.tool_calls
&& !tcs.is_empty() {
let all_present = tcs.iter().all(|tc| {
messages.iter().any(|m| {
m.role == "tool"
&& m.tool_call_id.as_deref() == Some(tc.id.as_str())
})
});
if !all_present { Some(idx) } else { None }
} else { None }
}).collect();
&& !tcs.is_empty()
{
let all_present = tcs.iter().all(|tc| {
messages.iter().any(|m| {
m.role == "tool" && m.tool_call_id.as_deref() == Some(tc.id.as_str())
})
});
if !all_present { Some(idx) } else { None }
} else {
None
}
})
.collect();
for idx in broken {
let msg = &mut messages[idx];
@ -262,7 +268,8 @@ impl ContextCompressor {
let names: Vec<&str> = tcs.iter().map(|tc| tc.name.as_str()).collect();
msg.content = format!(
"{}\n\n[Tool calls ({}) — results are no longer available]",
msg.content, names.join(", ")
msg.content,
names.join(", ")
);
}
}
@ -275,7 +282,10 @@ impl ContextCompressor {
// Check if compression is needed
let tokens = self.token_estimate_with_history(&history);
if tokens <= self.threshold() {
return Ok(CompressionResult { history, created_timelines: false });
return Ok(CompressionResult {
history,
created_timelines: false,
});
}
#[cfg(debug_assertions)]
@ -299,7 +309,10 @@ impl ContextCompressor {
}
if tokens_after <= self.threshold() {
self.invalidate_token_cache();
return Ok(CompressionResult { history, created_timelines: false });
return Ok(CompressionResult {
history,
created_timelines: false,
});
}
// LLM summarization pass
@ -312,11 +325,7 @@ impl ContextCompressor {
}
#[cfg(debug_assertions)]
tracing::debug!(
pass = pass + 1,
tokens = tokens,
"Compression pass"
);
tracing::debug!(pass = pass + 1, tokens = tokens, "Compression pass");
match self.compress_once(&current_history).await {
Ok(Some(compressed)) => {
@ -352,18 +361,24 @@ impl ContextCompressor {
let m = &current_history[scan];
if m.role == "assistant" {
if let Some(tcs) = &m.tool_calls
&& !tcs.is_empty() {
let has_post = current_history[scan + 1..]
.iter()
.filter(|r| r.role == "tool")
.any(|r| tcs.iter().any(|tc| r.tool_call_id.as_deref() == Some(tc.id.as_str())));
if has_post {
tail_start = scan;
}
&& !tcs.is_empty()
{
let has_post = current_history[scan + 1..]
.iter()
.filter(|r| r.role == "tool")
.any(|r| {
tcs.iter()
.any(|tc| r.tool_call_id.as_deref() == Some(tc.id.as_str()))
});
if has_post {
tail_start = scan;
}
}
break;
}
if scan == 0 {
break;
}
if scan == 0 { break; }
scan -= 1;
}
}
@ -390,14 +405,16 @@ impl ContextCompressor {
for msg in &mut truncated[..self.config.protect_first_n] {
if msg.role == "assistant" {
if let Some(ref tcs) = msg.tool_calls
&& !tcs.is_empty() {
let names: Vec<&str> = tcs.iter().map(|tc| tc.name.as_str()).collect();
msg.content = format!(
"{}\n\n[Tool calls ({}) — results dropped during truncation]",
msg.content, names.join(", ")
);
msg.tool_calls = None;
}
&& !tcs.is_empty()
{
let names: Vec<&str> = tcs.iter().map(|tc| tc.name.as_str()).collect();
msg.content = format!(
"{}\n\n[Tool calls ({}) — results dropped during truncation]",
msg.content,
names.join(", ")
);
msg.tool_calls = None;
}
}
}
@ -424,7 +441,10 @@ impl ContextCompressor {
"Context compression completed"
);
Ok(CompressionResult { history: current_history, created_timelines })
Ok(CompressionResult {
history: current_history,
created_timelines,
})
}
/// Try to extract the actual context token limit from an LLM error message.
@ -447,20 +467,21 @@ impl ContextCompressor {
// Look for a number in the vicinity (up to 10 chars after marker)
if let Some(num_str) = find_number_nearby(after, 50)
&& let Ok(n) = num_str.parse::<usize>()
&& (1024..=10_000_000).contains(&n) {
return Some(n);
}
&& (1024..=10_000_000).contains(&n)
{
return Some(n);
}
}
}
// Also try: "XXXX token context" or "XXXX limit"
if let Some(num_str) = find_number_nearby(&lower, lower.len())
&& let Ok(n) = num_str.parse::<usize>()
&& (1024..=10_000_000).contains(&n)
&& (lower.contains("token") || lower.contains("context") || lower.contains("limit"))
{
return Some(n);
}
&& (1024..=10_000_000).contains(&n)
&& (lower.contains("token") || lower.contains("context") || lower.contains("limit"))
{
return Some(n);
}
None
}
@ -509,19 +530,26 @@ impl ContextCompressor {
// Persist compressed summary as timeline memory entry
let ts = chrono::Utc::now().format("%Y-%m-%d %H:%M").to_string();
let timeline_content = format!("[{}] Compressed {} conversation segments:\n{}",
ts, between.len(), summary);
let timeline_content = format!(
"[{}] Compressed {} conversation segments:\n{}",
ts,
between.len(),
summary
);
let key = format!("ctx_compressed_{}", uuid::Uuid::new_v4());
let mm = self.memory.clone();
let sid = self.session_id.clone();
tokio::spawn(async move {
if let Err(e) = mm.store(
&key,
&timeline_content,
crate::memory::MemoryCategory::Timeline,
sid.as_deref(),
Some(0.3),
).await {
if let Err(e) = mm
.store(
&key,
&timeline_content,
crate::memory::MemoryCategory::Timeline,
sid.as_deref(),
Some(0.3),
)
.await
{
tracing::warn!(error = %e, "Failed to store compressed context as timeline");
}
});
@ -552,10 +580,7 @@ impl ContextCompressor {
}
/// Summarize a segment of messages using LLM.
async fn summarize_segment(
&self,
messages: &[ChatMessage],
) -> Result<String, AgentError> {
async fn summarize_segment(&self, messages: &[ChatMessage]) -> Result<String, AgentError> {
if messages.is_empty() {
return Ok(String::new());
}
@ -569,7 +594,8 @@ impl ContextCompressor {
"tool" => "Tool",
_ => m.role.as_str(),
};
let name = m.tool_name
let name = m
.tool_name
.as_ref()
.map(|n| format!(" ({})", n))
.unwrap_or_default();
@ -614,7 +640,10 @@ Be concise, aim for {} characters or less.
);
let request = ChatCompletionRequest {
messages: vec![Message::system("You are a helpful assistant."), Message::user(&prompt)],
messages: vec![
Message::system("You are a helpful assistant."),
Message::user(&prompt),
],
temperature: Some(0.3),
max_tokens: Some(1000),
tools: None,
@ -686,13 +715,23 @@ mod tests {
content: "[summarized]".into(),
reasoning_content: None,
tool_calls: vec![],
usage: Usage { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 },
usage: Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
},
})
}
fn ptype(&self) -> &str { "mock" }
fn name(&self) -> &str { "mock" }
fn model_id(&self) -> &str { "mock" }
fn ptype(&self) -> &str {
"mock"
}
fn name(&self) -> &str {
"mock"
}
fn model_id(&self) -> &str {
"mock"
}
}
fn mock_summarizer() -> Arc<dyn LLMProvider> {
@ -704,11 +743,13 @@ mod tests {
MM.get_or_init(|| {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let tmp = std::env::temp_dir().join(format!("picobot_ctx_test_{}.db", std::process::id()));
let tmp = std::env::temp_dir()
.join(format!("picobot_ctx_test_{}.db", std::process::id()));
let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap());
Arc::new(MemoryManager::new(storage, "test".into(), "test".into()))
})
}).clone()
})
.clone()
}
#[test]
@ -724,7 +765,11 @@ mod tests {
// "Hi there!" (8) -> ceil(8/4)+4 = 2+4 = 6
// "How are you?" (11) -> ceil(11/4)+4 = 3+4 = 7
// raw = 19, with 1.2x = ~23
assert!(tokens > 18 && tokens < 30, "Expected ~23 tokens, got {}", tokens);
assert!(
tokens > 18 && tokens < 30,
"Expected ~23 tokens, got {}",
tokens
);
}
#[test]
@ -733,7 +778,8 @@ mod tests {
tool_result_trim_chars: 50,
..Default::default()
};
let compressor = ContextCompressor::with_config(mock_provider(), 100_000, config, test_memory_manager());
let compressor =
ContextCompressor::with_config(mock_provider(), 100_000, config, test_memory_manager());
let mut messages = vec![
ChatMessage::user("Hello"),
@ -774,7 +820,11 @@ mod tests {
ChatMessage::tool("call1", "bash", &"x".repeat(3000)),
];
let result = compressor.compress_if_needed(messages).await.unwrap().history;
let result = compressor
.compress_if_needed(messages)
.await
.unwrap()
.history;
let tool_msg = result.iter().find(|m| m.role == "tool").unwrap();
assert!(
@ -798,13 +848,14 @@ mod tests {
// - B2B (L275): last user message lost when it is the final history message
//
// context_window=200 → threshold=100. Large tool outputs force LLM summarization.
let tmp = std::env::temp_dir().join(format!("picobot_ctx_boundary_{}.db", std::process::id()));
let tmp =
std::env::temp_dir().join(format!("picobot_ctx_boundary_{}.db", std::process::id()));
let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap());
let mm = Arc::new(MemoryManager::new(storage, "test".into(), "test".into()));
let config = ContextCompressionConfig {
tool_result_trim_chars: 2000,
protect_first_n: 1, // system/protected → B2A: first user (after skip) duplicated
protect_first_n: 1, // system/protected → B2A: first user (after skip) duplicated
protect_last_n: 2,
max_passes: 1,
..Default::default()
@ -818,25 +869,43 @@ mod tests {
let big = "x".repeat(3000);
let messages = vec![
ChatMessage::system("You are a helper."), // 0: protected
ChatMessage::user("Q1"), // 1: first user
ChatMessage::tool("t1", "bash", &big), // 2
ChatMessage::user("Q2"), // 3
ChatMessage::assistant("thinking"), // 4
ChatMessage::tool("t2", "bash", &big), // 5
ChatMessage::user("Q3"), // 6
ChatMessage::assistant("thinking"), // 7
ChatMessage::user("Q4"), // 8: LAST, is user → B2B triggers
ChatMessage::user("Q1"), // 1: first user
ChatMessage::tool("t1", "bash", &big), // 2
ChatMessage::user("Q2"), // 3
ChatMessage::assistant("thinking"), // 4
ChatMessage::tool("t2", "bash", &big), // 5
ChatMessage::user("Q3"), // 6
ChatMessage::assistant("thinking"), // 7
ChatMessage::user("Q4"), // 8: LAST, is user → B2B triggers
];
let result = compressor.compress_if_needed(messages).await.unwrap().history;
let result = compressor
.compress_if_needed(messages)
.await
.unwrap()
.history;
// B2A: "Q1" must appear exactly once
let q1_count = result.iter().filter(|m| m.role == "user" && m.content == "Q1").count();
assert_eq!(q1_count, 1, "Q1 should appear exactly once, got {}", q1_count);
let q1_count = result
.iter()
.filter(|m| m.role == "user" && m.content == "Q1")
.count();
assert_eq!(
q1_count, 1,
"Q1 should appear exactly once, got {}",
q1_count
);
// B2B: "Q4" must NOT be lost
let q4_count = result.iter().filter(|m| m.role == "user" && m.content == "Q4").count();
assert_eq!(q4_count, 1, "Q4 should appear exactly once (not lost), got {}", q4_count);
let q4_count = result
.iter()
.filter(|m| m.role == "user" && m.content == "Q4")
.count();
assert_eq!(
q4_count, 1,
"Q4 should appear exactly once (not lost), got {}",
q4_count
);
let _ = std::fs::remove_file(&tmp);
}
@ -850,10 +919,10 @@ mod tests {
let mm = Arc::new(MemoryManager::new(storage, "test".into(), "test".into()));
let config = ContextCompressionConfig {
tool_result_trim_chars: 500, // trim reduces but not enough
tool_result_trim_chars: 500, // trim reduces but not enough
protect_first_n: 1,
protect_last_n: 2,
max_passes: 0, // no LLM summarization → will exceed danger
max_passes: 0, // no LLM summarization → will exceed danger
..Default::default()
};
// context_window=100, danger_threshold=90.
@ -872,13 +941,23 @@ mod tests {
ChatMessage::tool("t3", "bash", &big),
];
let result = compressor.compress_if_needed(messages).await.unwrap().history;
let result = compressor
.compress_if_needed(messages)
.await
.unwrap()
.history;
// After hard truncation: head (1) + trunc_note (1) + tail (2) = 4 messages
assert!(result.len() < 7, "expected truncation reduction, got {} messages", result.len());
assert!(
result.len() < 7,
"expected truncation reduction, got {} messages",
result.len()
);
// Truncation notice should be present
let has_notice = result.iter().any(|m| m.content.contains("Context truncation"));
let has_notice = result
.iter()
.any(|m| m.content.contains("Context truncation"));
assert!(has_notice, "hard truncation notice missing");
let _ = std::fs::remove_file(&tmp);
@ -893,9 +972,9 @@ mod tests {
let mut messages = vec![
ChatMessage::user("Q1"),
ChatMessage::user("[Context Summary]\n\nsummary of previous turn"),
ChatMessage::tool("tc1", "bash", "orphan result"), // orphan — tc1 never declared
ChatMessage::assistant("done"), // declares tc2
ChatMessage::tool("tc2", "bash", "legitimate result"), // legit
ChatMessage::tool("tc1", "bash", "orphan result"), // orphan — tc1 never declared
ChatMessage::assistant("done"), // declares tc2
ChatMessage::tool("tc2", "bash", "legitimate result"), // legit
];
// Set tool_call_id on tool messages and tool_calls on assistant
messages[2].tool_call_id = Some("tc1".into());
@ -910,8 +989,16 @@ mod tests {
// orphan should be removed; legitimate should stay
assert_eq!(messages.len(), 4);
assert!(messages.iter().all(|m| m.tool_call_id != Some("tc1".into())));
assert!(messages.iter().any(|m| m.tool_call_id == Some("tc2".into())));
assert!(
messages
.iter()
.all(|m| m.tool_call_id != Some("tc1".into()))
);
assert!(
messages
.iter()
.any(|m| m.tool_call_id == Some("tc2".into()))
);
}
#[test]

View File

@ -49,7 +49,7 @@ impl MediaHandler for ImageHandler {
}
fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> {
use base64::{engine::general_purpose::STANDARD, Engine as _};
use base64::{Engine as _, engine::general_purpose::STANDARD};
let mut file = std::fs::File::open(path)?;
let mut buffer = Vec::new();

View File

@ -4,10 +4,13 @@ pub mod media_handler;
pub mod sub_agent;
pub mod system_prompt;
pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult};
pub use agent_loop::{AgentError, AgentLoop, AgentProcessResult};
pub use context_compressor::{ContextCompressor, estimate_tokens};
pub use sub_agent::{DelegateContext, ExecutionMode, SubAgentConfig, SubAgentError, SubAgentManager, SubAgentResult, TaskNotification, TaskStatus};
pub use system_prompt::{
build_system_prompt, build_sub_agent_system_prompt, PromptContext, PromptSection,
SystemPromptBuilder,
pub use sub_agent::{
DelegateContext, ExecutionMode, SubAgentConfig, SubAgentError, SubAgentManager, SubAgentResult,
TaskNotification, TaskStatus,
};
pub use system_prompt::{
PromptContext, PromptSection, SystemPromptBuilder, build_sub_agent_system_prompt,
build_system_prompt,
};

View File

@ -6,12 +6,12 @@ use dashmap::DashMap;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use crate::agent::system_prompt::build_sub_agent_system_prompt;
use crate::agent::AgentLoop;
use crate::agent::AgentError;
use crate::agent::AgentLoop;
use crate::agent::system_prompt::build_sub_agent_system_prompt;
use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig;
use crate::providers::{create_provider, LLMProvider};
use crate::providers::{LLMProvider, create_provider};
use crate::skills::SkillsLoader;
use crate::tools::ToolRegistry;
@ -21,7 +21,8 @@ tokio::task_local! {
/// Read the delegate context from the current task. Returns an error if not set.
pub fn get_delegate_context() -> Result<DelegateContext, String> {
DELEGATE_CONTEXT.try_with(|ctx| ctx.clone())
DELEGATE_CONTEXT
.try_with(|ctx| ctx.clone())
.map_err(|_| "DELEGATE_CONTEXT not set".to_string())
}
@ -207,7 +208,10 @@ impl SubAgentManager {
let timeout_secs = config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
let timeout_human = format_duration(timeout_secs);
let http_get_only = config.allowed_tools.is_none()
|| config.allowed_tools.as_ref().is_some_and(|v| v.iter().any(|t| t == "http_request"));
|| config
.allowed_tools
.as_ref()
.is_some_and(|v| v.iter().any(|t| t == "http_request"));
let skills_prompt = self.get_skills_prompt(&tools);
let system_prompt = build_sub_agent_system_prompt(
&config.prompt,
@ -219,7 +223,8 @@ impl SubAgentManager {
http_get_only,
);
let agent = self.build_sub_agent(&config, tools)
let agent = self
.build_sub_agent(&config, tools)
.map_err(|e| SubAgentError::ProviderCreation(e.to_string()))?;
let history = vec![
@ -241,10 +246,14 @@ impl SubAgentManager {
Ok(Ok(agent_result)) => {
let (content, truncated) =
truncate_sub_agent_result(&agent_result.final_response.content);
let tool_calls_count = agent_result.emitted_messages.iter()
let tool_calls_count = agent_result
.emitted_messages
.iter()
.filter(|m| m.tool_calls.is_some())
.count();
let iterations = agent_result.emitted_messages.iter()
let iterations = agent_result
.emitted_messages
.iter()
.filter(|m| m.role == "assistant" && m.tool_calls.is_some())
.count();
Ok(SubAgentResult {
@ -343,7 +352,10 @@ impl SubAgentManager {
let timeout_secs = config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
let timeout_human = format_duration(timeout_secs);
let http_get_only = config.allowed_tools.is_none()
|| config.allowed_tools.as_ref().is_some_and(|v| v.iter().any(|t| t == "http_request"));
|| config
.allowed_tools
.as_ref()
.is_some_and(|v| v.iter().any(|t| t == "http_request"));
let skills_prompt = self.get_skills_prompt(&tools);
let system_prompt = build_sub_agent_system_prompt(
&config.prompt,
@ -372,8 +384,12 @@ impl SubAgentManager {
if let Some(ref s) = storage {
let _ = s
.update_background_task_status(
&tid, "running", None, None,
Some(started_at), None,
&tid,
"running",
None,
None,
Some(started_at),
None,
)
.await;
}
@ -384,8 +400,7 @@ impl SubAgentManager {
p.set_storage(s.clone());
}
}
let provider_result: Option<Arc<dyn LLMProvider>> =
provider.map(|p| Arc::from(p));
let provider_result: Option<Arc<dyn LLMProvider>> = provider.map(|p| Arc::from(p));
let result = match provider_result {
Some(provider) => {
@ -474,9 +489,12 @@ impl SubAgentManager {
if let Some(ref s) = storage {
let _ = s
.update_background_task_status(
&tid, &status_str,
Some(&result.content), error_val.as_deref(),
Some(started_at), Some(finished_at),
&tid,
&status_str,
Some(&result.content),
error_val.as_deref(),
Some(started_at),
Some(finished_at),
)
.await;
}
@ -514,15 +532,13 @@ impl SubAgentManager {
Ok(true)
} else if let Some(ref s) = self.storage {
match s.get_background_task(task_id).await {
Ok(task) => {
match task.status.as_str() {
"pending" | "running" => {
tracing::warn!(task_id, "task in DB but not in active_tasks");
Ok(false)
}
_ => Ok(false),
Ok(task) => match task.status.as_str() {
"pending" | "running" => {
tracing::warn!(task_id, "task in DB but not in active_tasks");
Ok(false)
}
}
_ => Ok(false),
},
Err(_) => Ok(false),
}
} else {
@ -530,10 +546,7 @@ impl SubAgentManager {
}
}
pub async fn check_task(
&self,
task_id: &str,
) -> Option<crate::storage::BackgroundTask> {
pub async fn check_task(&self, task_id: &str) -> Option<crate::storage::BackgroundTask> {
if let Some(ref s) = self.storage {
s.get_background_task(task_id).await.ok()
} else {
@ -541,12 +554,11 @@ impl SubAgentManager {
}
}
pub async fn list_tasks(
&self,
session_id: &str,
) -> Vec<crate::storage::BackgroundTask> {
pub async fn list_tasks(&self, session_id: &str) -> Vec<crate::storage::BackgroundTask> {
if let Some(ref s) = self.storage {
s.list_background_tasks(session_id).await.unwrap_or_default()
s.list_background_tasks(session_id)
.await
.unwrap_or_default()
} else {
vec![]
}

View File

@ -196,10 +196,10 @@ impl PromptSection for UserProfileSection {
if let Some(user_config_dir) = get_user_config_dir()
&& let Some(content) =
load_file_from_dir(&user_config_dir, "USER.md", BOOTSTRAP_MAX_CHARS)
{
output.push_str(&content);
return output;
}
{
output.push_str(&content);
return output;
}
// No USER.md found, return empty
String::new()
@ -220,10 +220,10 @@ impl PromptSection for AgentProfileSection {
if let Some(user_config_dir) = get_user_config_dir()
&& let Some(content) =
load_file_from_dir(&user_config_dir, "AGENTS.md", BOOTSTRAP_MAX_CHARS)
{
output.push_str(&content);
return output;
}
{
output.push_str(&content);
return output;
}
String::new()
}
@ -465,7 +465,9 @@ impl PromptSection for SubAgentToolsSection {
let mut s = String::from("## 可用工具\n\n");
s.push_str(&ctx.tools.describe_for_prompt());
if self.http_get_only {
s.push_str("\n\n**注意**:使用 http_request 时只允许 GET 方法,禁止 POST、PUT、DELETE 等。");
s.push_str(
"\n\n**注意**:使用 http_request 时只允许 GET 方法,禁止 POST、PUT、DELETE 等。",
);
}
s
}
@ -560,13 +562,8 @@ pub fn build_sub_agent_system_prompt(
memory_context: None,
has_compressed_history: false,
};
SystemPromptBuilder::with_sub_agent_defaults(
task,
timeout_human,
skills_prompt,
http_get_only,
)
.build(&ctx)
SystemPromptBuilder::with_sub_agent_defaults(task, timeout_human, skills_prompt, http_get_only)
.build(&ctx)
}
#[cfg(test)]

View File

@ -1,8 +1,8 @@
use std::sync::Arc;
use crate::bus::{MessageBus, OutboundMessage};
use crate::channels::base::{Channel, ChannelError};
use crate::channels::ChannelManager;
use crate::channels::base::{Channel, ChannelError};
/// OutboundDispatcher consumes outbound messages from the MessageBus
/// and dispatches them to the appropriate Channel

View File

@ -1,5 +1,5 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::providers::ToolCall;
@ -23,7 +23,9 @@ pub struct ImageUrlBlock {
impl ContentBlock {
pub fn text(content: impl Into<String>) -> Self {
Self::Text { text: content.into() }
Self::Text {
text: content.into(),
}
}
pub fn image_url(url: impl Into<String>) -> Self {
@ -49,10 +51,10 @@ pub struct MediaRef {
#[derive(Debug, Clone)]
pub struct MediaItem {
pub path: String, // Local file path
pub media_type: String, // "image", "audio", "file", "video"
pub path: String, // Local file path
pub media_type: String, // "image", "audio", "file", "video"
pub mime_type: Option<String>,
pub original_key: Option<String>, // Feishu file_key for download
pub original_key: Option<String>, // Feishu file_key for download
}
impl MediaItem {
@ -161,7 +163,10 @@ impl ChatMessage {
}
}
pub fn assistant_with_tool_calls(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
pub fn assistant_with_tool_calls(
content: impl Into<String>,
tool_calls: Vec<ToolCall>,
) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
role: "assistant".to_string(),
@ -206,7 +211,11 @@ impl ChatMessage {
}
}
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self {
pub fn tool(
tool_call_id: impl Into<String>,
tool_name: impl Into<String>,
content: impl Into<String>,
) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
role: "tool".to_string(),

View File

@ -2,10 +2,13 @@ pub mod dispatcher;
pub mod message;
pub use dispatcher::OutboundDispatcher;
pub use message::{ChatMessage, ContentBlock, ControlMessage, InboundMessage, MediaItem, MediaRef, MessageSource, OutboundMessage, SourceKind};
pub use message::{
ChatMessage, ContentBlock, ControlMessage, InboundMessage, MediaItem, MediaRef, MessageSource,
OutboundMessage, SourceKind,
};
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
use tokio::sync::{Mutex, mpsc};
// ============================================================================
// MessageBus - Async message queue for Channel <-> Agent communication
@ -49,7 +52,8 @@ impl MessageBus {
/// Consume an inbound message (Agent -> Bus)
pub async fn consume_inbound(&self) -> InboundMessage {
let msg = self.inbound_rx
let msg = self
.inbound_rx
.lock()
.await
.recv()

File diff suppressed because it is too large Load Diff

View File

@ -24,7 +24,10 @@ impl ChannelManager {
}
}
pub fn with_bus(cli_chat_channel: Arc<crate::channels::CliChatChannel>, bus: Arc<MessageBus>) -> Self {
pub fn with_bus(
cli_chat_channel: Arc<crate::channels::CliChatChannel>,
bus: Arc<MessageBus>,
) -> Self {
Self {
channels: Arc::new(RwLock::new(HashMap::new())),
cli_chat_channel,
@ -39,7 +42,10 @@ impl ChannelManager {
/// Register a channel with the manager
pub async fn register_channel(&self, name: &str, channel: Arc<dyn Channel + Send + Sync>) {
self.channels.write().await.insert(name.to_string(), channel);
self.channels
.write()
.await
.insert(name.to_string(), channel);
}
/// Get CLI chat channel
@ -56,14 +62,19 @@ impl ChannelManager {
// Initialize Feishu channel if enabled
if let Some(feishu_config) = config.channels.get("feishu") {
if feishu_config.enabled {
let channel = FeishuChannel::new(feishu_config.clone(), &workspace_dir)
.map_err(|e| ChannelError::Other(format!("Failed to create Feishu channel: {}", e)))?;
let channel =
FeishuChannel::new(feishu_config.clone(), &workspace_dir).map_err(|e| {
ChannelError::Other(format!("Failed to create Feishu channel: {}", e))
})?;
self.channels
.write()
.await
.insert("feishu".to_string(), Arc::new(channel));
tracing::info!("Feishu channel registered (media_dir: {}/media/feishu)", workspace_dir.display());
tracing::info!(
"Feishu channel registered (media_dir: {}/media/feishu)",
workspace_dir.display()
);
} else {
tracing::info!("Feishu channel disabled in config");
}
@ -118,7 +129,10 @@ impl ChannelManager {
if let Some(channel) = self.get_channel(channel_name).await {
channel.send(msg).await
} else {
Err(ChannelError::Other(format!("Channel not found: {}", channel_name)))
Err(ChannelError::Other(format!(
"Channel not found: {}",
channel_name
)))
}
}
}

View File

@ -1,11 +1,11 @@
pub mod base;
pub mod feishu;
pub mod cli_chat;
pub mod feishu;
pub mod manager;
pub mod slash_command;
pub use base::{Channel, ChannelError};
pub use manager::ChannelManager;
pub use feishu::FeishuChannel;
pub use cli_chat::CliChatChannel;
pub use slash_command::{parse_slash_command, command_matches};
pub use feishu::FeishuChannel;
pub use manager::ChannelManager;
pub use slash_command::{command_matches, parse_slash_command};

View File

@ -16,7 +16,9 @@ pub fn parse_slash_command(content: &str) -> Option<(&str, &str)> {
/// 检查内容是否匹配指定命令
pub fn command_matches(content: &str, aliases: &[&str]) -> bool {
let trimmed = content.trim();
aliases.iter().any(|&alias| trimmed == alias || trimmed.starts_with(&format!("{} ", alias)))
aliases
.iter()
.any(|&alias| trimmed == alias || trimmed.starts_with(&format!("{} ", alias)))
}
#[cfg(test)]
@ -27,7 +29,10 @@ mod tests {
fn test_parse_slash_command() {
assert_eq!(parse_slash_command("/reset"), Some(("reset", "")));
assert_eq!(parse_slash_command("/reset arg"), Some(("reset", "arg")));
assert_eq!(parse_slash_command("/new hello world"), Some(("new", "hello world")));
assert_eq!(
parse_slash_command("/new hello world"),
Some(("new", "hello world"))
);
assert_eq!(parse_slash_command("/??"), Some(("??", "")));
assert_eq!(parse_slash_command("/? arg"), Some(("?", "arg")));
assert_eq!(parse_slash_command("/?"), Some(("?", "")));

View File

@ -8,10 +8,10 @@ use crate::client::tui::ui::render_ui;
use crossterm::{
event::{self, Event},
execute,
terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen},
terminal::{EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, enable_raw_mode},
};
use futures_util::{SinkExt, StreamExt};
use ratatui::{prelude::CrosstermBackend, Terminal};
use ratatui::{Terminal, prelude::CrosstermBackend};
use std::io;
use tokio_tungstenite::{connect_async, tungstenite::Message};
@ -104,7 +104,10 @@ async fn handle_ws_message(app: &mut App, outbound: WsOutbound) {
WsOutbound::SessionCreated { session_id, .. } => {
app.set_current_session(Some(session_id));
}
WsOutbound::SessionList { sessions, current_session_id } => {
WsOutbound::SessionList {
sessions,
current_session_id,
} => {
app.set_sessions(sessions);
if let Some(id) = current_session_id {
app.set_current_session(Some(id));

View File

@ -1,10 +1,10 @@
use crate::client::tui::app::{App, MessageRole};
use ratatui::{
Frame,
layout::Rect,
style::{Color, Modifier, Style},
text::Line,
widgets::{Block, Borders, List, ListItem},
Frame,
};
pub fn render(f: &mut Frame, area: Rect, app: &App) {

View File

@ -1,10 +1,10 @@
use crate::client::tui::app::App;
use ratatui::{
Frame,
layout::Rect,
style::{Color, Modifier, Style},
text::{Line, Span},
widgets::{Block, Borders, List, ListItem},
Frame,
};
pub fn render(f: &mut Frame, area: Rect, app: &App) {

View File

@ -1,8 +1,8 @@
use ratatui::{
Frame,
layout::Rect,
style::{Color, Modifier, Style},
widgets::{Block, Borders, Clear, List, ListItem},
Frame,
};
pub fn render(f: &mut Frame, area: Rect) {

View File

@ -1,9 +1,9 @@
use crate::client::tui::app::App;
use ratatui::{
Frame,
layout::Rect,
style::{Color, Style},
widgets::{Block, Borders, Paragraph},
Frame,
};
pub fn render(f: &mut Frame, area: Rect, app: &App) {

View File

@ -1,9 +1,9 @@
use crate::client::tui::app::App;
use ratatui::{
Frame,
layout::Rect,
style::{Color, Modifier, Style},
widgets::{Block, Borders, List, ListItem},
Frame,
};
pub fn render(f: &mut Frame, area: Rect, app: &App) {
@ -11,9 +11,7 @@ pub fn render(f: &mut Frame, area: Rect, app: &App) {
.sessions
.iter()
.map(|session| {
let is_current = app
.current_session_id
.as_ref() == Some(&session.session_id);
let is_current = app.current_session_id.as_ref() == Some(&session.session_id);
let archived = session.archived_at.is_some();
let mut content = if is_current {

View File

@ -1,15 +1,18 @@
use crate::client::tui::app::App;
use ratatui::{
Frame,
layout::Rect,
style::{Color, Modifier, Style},
widgets::{Block, Borders, Paragraph},
Frame,
};
pub fn render(f: &mut Frame, area: Rect, app: &App) {
let (title, style) = if app.pending_quit {
let msg = if let Some(session_id) = &app.current_session_id {
format!("PicoBot | Session: {} | Press Ctrl+C again to quit", session_id)
format!(
"PicoBot | Session: {} | Press Ctrl+C again to quit",
session_id
)
} else {
"PicoBot | Press Ctrl+C again to quit".to_string()
};

View File

@ -1,6 +1,6 @@
use crate::client::tui::app::{App, MessageRole};
use crate::protocol::serialize_inbound;
use crate::protocol::WsInbound;
use crate::protocol::serialize_inbound;
use crossterm::event::{KeyCode, KeyEvent};
use futures_util::SinkExt;
@ -48,7 +48,10 @@ pub async fn handle_key_event(app: &mut App, key: KeyEvent) {
async fn handle_normal_input(app: &mut App, key: KeyEvent) {
// Handle Ctrl+C for quit (double press to exit)
let is_ctrl_c = key.code == KeyCode::Char('c') && key.modifiers.contains(crossterm::event::KeyModifiers::CONTROL);
let is_ctrl_c = key.code == KeyCode::Char('c')
&& key
.modifiers
.contains(crossterm::event::KeyModifiers::CONTROL);
if is_ctrl_c {
if app.handle_ctrl_c_for_quit() {
return;
@ -63,9 +66,11 @@ async fn handle_normal_input(app: &mut App, key: KeyEvent) {
}
KeyCode::Char(c) => {
app.input_insert_char(c);
// Show command menu when input starts with /
if !app.show_command_menu && (app.input == "/" || (app.input.len() > 1 && app.input.starts_with('/'))) {
if !app.show_command_menu
&& (app.input == "/" || (app.input.len() > 1 && app.input.starts_with('/')))
{
app.show_command_menu = true;
app.selected_command_idx = 0;
} else if app.show_command_menu && !app.input.starts_with('/') {
@ -74,7 +79,7 @@ async fn handle_normal_input(app: &mut App, key: KeyEvent) {
}
KeyCode::Backspace => {
app.input_delete_char();
// Hide menu if input no longer starts with /
if app.show_command_menu && !app.input.starts_with('/') {
app.show_command_menu = false;
@ -121,7 +126,9 @@ async fn process_input(app: &mut App, input: String) {
sender_id: None,
};
if let Ok(text) = serialize_inbound(&inbound) {
let _ = sender.send(tokio_tungstenite::tungstenite::Message::Text(text.into())).await;
let _ = sender
.send(tokio_tungstenite::tungstenite::Message::Text(text.into()))
.await;
}
}
}

View File

@ -1,8 +1,8 @@
use crate::client::tui::app::App;
use crate::client::tui::components::*;
use ratatui::{
layout::{Constraint, Direction, Layout, Rect},
Frame,
layout::{Constraint, Direction, Layout, Rect},
};
pub fn render_ui(f: &mut Frame, app: &App) {

View File

@ -273,12 +273,16 @@ impl Default for MemoryConfig {
impl MemoryConfig {
/// Resolve consolidation provider name, falling back to the main agent's provider.
pub fn resolve_consolidation_provider(&self, default: &str) -> String {
self.consolidation_provider.clone().unwrap_or_else(|| default.to_string())
self.consolidation_provider
.clone()
.unwrap_or_else(|| default.to_string())
}
/// Resolve consolidation model name, falling back to the main agent's model.
pub fn resolve_consolidation_model(&self, default: &str) -> String {
self.consolidation_model.clone().unwrap_or_else(|| default.to_string())
self.consolidation_model
.clone()
.unwrap_or_else(|| default.to_string())
}
}
@ -366,10 +370,18 @@ impl Default for BrowserConfig {
}
}
fn default_recall_limit() -> usize { 5 }
fn default_idle_consolidation_minutes() -> u64 { 10 }
fn default_timeline_retention_days() -> u64 { 90 }
fn default_max_failures_before_degrade() -> usize { 3 }
fn default_recall_limit() -> usize {
5
}
fn default_idle_consolidation_minutes() -> u64 {
10
}
fn default_timeline_retention_days() -> u64 {
90
}
fn default_max_failures_before_degrade() -> usize {
3
}
#[derive(Debug, Clone)]
pub struct LLMProviderConfig {
@ -469,7 +481,11 @@ pub enum ConfigError {
impl std::fmt::Display for ConfigError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConfigError::ConfigNotFound(path) => write!(f, "Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json", path),
ConfigError::ConfigNotFound(path) => write!(
f,
"Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json",
path
),
ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name),
ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name),
ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name),

View File

@ -1,12 +1,12 @@
use std::sync::Arc;
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage};
use super::GatewayState;
use crate::protocol::WsOutbound;
use crate::protocol::serialize_outbound;
use axum::extract::State;
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
use axum::response::Response;
use futures_util::{SinkExt, StreamExt};
use std::sync::Arc;
use tokio::sync::mpsc;
use crate::protocol::serialize_outbound;
use crate::protocol::WsOutbound;
use super::GatewayState;
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
ws.on_upgrade(|socket| async move {
@ -25,9 +25,11 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
let (session_id, client) = cli_chat_channel.register_client(sender.clone()).await;
// Send session established message
let _ = sender.send(WsOutbound::SessionEstablished {
session_id: session_id.clone(),
}).await;
let _ = sender
.send(WsOutbound::SessionEstablished {
session_id: session_id.clone(),
})
.await;
tracing::info!(session_id = %session_id, "CLI session established");
@ -37,9 +39,10 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
tokio::spawn(async move {
while let Some(msg) = receiver.recv().await {
if let Ok(text) = serialize_outbound(&msg)
&& ws_sender.send(WsMessage::Text(text.into())).await.is_err() {
break;
}
&& ws_sender.send(WsMessage::Text(text.into())).await.is_err()
{
break;
}
}
});

View File

@ -1,17 +1,17 @@
pub mod config;
pub mod providers;
pub mod bus;
pub mod agent;
pub mod gateway;
pub mod session;
pub mod client;
pub mod protocol;
pub mod bus;
pub mod channels;
pub mod client;
pub mod config;
pub mod gateway;
pub mod logging;
pub mod mcp;
pub mod memory;
pub mod observability;
pub mod protocol;
pub mod providers;
pub mod scheduler;
pub mod session;
pub mod skills;
pub mod storage;
pub mod tools;

View File

@ -1,11 +1,7 @@
use std::path::PathBuf;
use tracing_appender::rolling::{RollingFileAppender, Rotation};
use tracing_subscriber::{
fmt,
layer::SubscriberExt,
util::SubscriberInitExt,
fmt::time::LocalTime,
EnvFilter,
EnvFilter, fmt, fmt::time::LocalTime, layer::SubscriberExt, util::SubscriberInitExt,
};
/// Get the default log directory path: ~/.picobot/logs
@ -27,20 +23,20 @@ pub fn init_logging() {
// Create log directory if it doesn't exist
if !log_dir.exists()
&& let Err(e) = std::fs::create_dir_all(&log_dir) {
eprintln!("Warning: Failed to create log directory {}: {}", log_dir.display(), e);
}
&& let Err(e) = std::fs::create_dir_all(&log_dir)
{
eprintln!(
"Warning: Failed to create log directory {}: {}",
log_dir.display(),
e
);
}
// Create file appender with daily rotation
let file_appender = RollingFileAppender::new(
Rotation::DAILY,
&log_dir,
"picobot.log",
);
let file_appender = RollingFileAppender::new(Rotation::DAILY, &log_dir, "picobot.log");
// Build subscriber with both console and file output
let env_filter = EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new("info"));
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
let file_layer = fmt::layer()
.with_writer(file_appender)
@ -66,8 +62,7 @@ pub fn init_logging() {
/// Initialize logging without file output (console only)
pub fn init_logging_console_only() {
let env_filter = EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new("info"));
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
let console_layer = fmt::layer()
.with_timer(LocalTime::rfc_3339())

View File

@ -1,4 +1,4 @@
use clap::{Parser, CommandFactory};
use clap::{CommandFactory, Parser};
#[derive(Parser)]
#[command(name = "picobot")]

View File

@ -92,24 +92,19 @@ fn extract_text(result: &rmcp::model::CallToolResult) -> String {
parts.push(text.text.clone());
}
RawContent::Image(image) => {
parts.push(format!(
"[image: {}]",
image.mime_type,
));
parts.push(format!("[image: {}]", image.mime_type,));
}
RawContent::Resource(resource) => {
match &resource.resource {
rmcp::model::ResourceContents::TextResourceContents { text, .. } => {
parts.push(format!(
"[resource text: {}]",
text.chars().take(200).collect::<String>(),
));
}
rmcp::model::ResourceContents::BlobResourceContents { uri, .. } => {
parts.push(format!("[resource blob: {}]", uri));
}
RawContent::Resource(resource) => match &resource.resource {
rmcp::model::ResourceContents::TextResourceContents { text, .. } => {
parts.push(format!(
"[resource text: {}]",
text.chars().take(200).collect::<String>(),
));
}
}
rmcp::model::ResourceContents::BlobResourceContents { uri, .. } => {
parts.push(format!("[resource blob: {}]", uri));
}
},
_ => {
parts.push("[unsupported content]".to_string());
}
@ -225,8 +220,8 @@ async fn connect_server(config: &McpServerConfig) -> anyhow::Result<McpConnectio
cmd.env(k, v);
}
let service = ()
.serve(
let service =
().serve(
TokioChildProcess::new(cmd).context("failed to create stdio MCP transport")?,
)
.await
@ -261,14 +256,14 @@ async fn connect_server(config: &McpServerConfig) -> anyhow::Result<McpConnectio
} else {
StreamableHttpClientTransport::from_config(
StreamableHttpClientTransportConfig::with_uri(url.to_string())
.custom_headers(headers_map)
.custom_headers(headers_map),
)
};
let service = ()
.serve(transport)
.await
.context("failed to connect to HTTP/SSE MCP server")?;
let service =
().serve(transport)
.await
.context("failed to connect to HTTP/SSE MCP server")?;
let peer = service.peer().clone();

View File

@ -102,7 +102,11 @@ mod tests {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let storage = Arc::new(Storage::new(&db_path).await.unwrap());
let mm = Arc::new(MemoryManager::new(storage, "default".into(), "default".into()));
let mm = Arc::new(MemoryManager::new(
storage,
"default".into(),
"default".into(),
));
(mm, dir)
}
@ -131,15 +135,9 @@ mod tests {
async fn test_upsert_overwrites() {
let (mm, _dir) = setup_memory_manager().await;
mm.store(
"dup_key",
"original",
MemoryCategory::Knowledge,
None,
None,
)
.await
.unwrap();
mm.store("dup_key", "original", MemoryCategory::Knowledge, None, None)
.await
.unwrap();
mm.store(
"dup_key",
"updated",
@ -247,7 +245,12 @@ mod tests {
// Recall scoped to session A — should get only tl_a
let scoped = mm
.recall("summary", 10, Some(MemoryCategory::Timeline), Some("chan:chat:dialog_a"))
.recall(
"summary",
10,
Some(MemoryCategory::Timeline),
Some("chan:chat:dialog_a"),
)
.await
.unwrap();
assert_eq!(scoped.len(), 1);

View File

@ -20,10 +20,7 @@ pub enum ObserverEvent {
success: bool,
},
/// Emitted when the agent starts processing.
AgentStart {
provider: String,
model: String,
},
AgentStart { provider: String, model: String },
/// Emitted when the agent finishes processing.
AgentEnd {
provider: String,
@ -94,7 +91,11 @@ impl ToolExecutionOutcome {
}
/// Create a failed outcome with duration.
pub fn failure_with_duration(output: String, error_reason: Option<String>, duration: Duration) -> Self {
pub fn failure_with_duration(
output: String,
error_reason: Option<String>,
duration: Duration,
) -> Self {
Self {
output,
success: false,

View File

@ -4,23 +4,24 @@ use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
use crate::bus::message::ContentBlock;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
use super::traits::Usage;
use std::sync::Arc;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
use crate::bus::message::ContentBlock;
use crate::storage::Storage;
use std::sync::Arc;
const LLM_REQUEST_TIMEOUT_SECS: u64 = 300;
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> {
blocks.iter().map(|b| match b {
ContentBlock::Text { text } => {
serde_json::json!({ "type": "text", "text": text })
}
ContentBlock::ImageUrl { image_url } => {
convert_image_url_to_anthropic(&image_url.url)
}
}).collect()
blocks
.iter()
.map(|b| match b {
ContentBlock::Text { text } => {
serde_json::json!({ "type": "text", "text": text })
}
ContentBlock::ImageUrl { image_url } => convert_image_url_to_anthropic(&image_url.url),
})
.collect()
}
fn convert_image_url_to_anthropic(url: &str) -> serde_json::Value {
@ -197,8 +198,13 @@ impl LLMProvider for AnthropicProvider {
};
let content = if let Some(ref tc_id) = m.tool_call_id {
// Tool result: wrap as tool_result content block
let output = m.content.iter()
.filter_map(|b| match b { ContentBlock::Text { text } => Some(text.as_str()), _ => None })
let output = m
.content
.iter()
.filter_map(|b| match b {
ContentBlock::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("");
vec![serde_json::json!({
@ -244,19 +250,18 @@ impl LLMProvider for AnthropicProvider {
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
tracing::debug!(req_body = %req_body_str, "LLM request");
let resp = req_builder.json(&body).send().await
.inspect_err(|e| {
let is_timeout = e.is_timeout();
tracing::error!(
provider = %self.name,
model = %self.model_id,
url = %url,
timeout = is_timeout,
error = %e,
elapsed_ms = %start.elapsed().as_millis(),
"LLM API request failed"
);
})?;
let resp = req_builder.json(&body).send().await.inspect_err(|e| {
let is_timeout = e.is_timeout();
tracing::error!(
provider = %self.name,
model = %self.model_id,
url = %url,
timeout = is_timeout,
error = %e,
elapsed_ms = %start.elapsed().as_millis(),
"LLM API request failed"
);
})?;
let status = resp.status();
let body_text = resp.text().await?;
@ -281,32 +286,38 @@ impl LLMProvider for AnthropicProvider {
"LLM API returned error"
);
if let Some(ref storage) = self.storage {
let _ = storage.append_llm_call(
&self.name, &self.model_id, &req_body_str,
Some(&body_text), Some(&error_msg),
start.elapsed().as_millis() as u64,
).await;
let _ = storage
.append_llm_call(
&self.name,
&self.model_id,
&req_body_str,
Some(&body_text),
Some(&error_msg),
start.elapsed().as_millis() as u64,
)
.await;
}
return Err(format!("API error ({}): {}", status.as_u16(), error_msg).into());
}
let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text)
.map_err(|e| {
let err_msg = format!("decode error: {} | body: {}", e, &body_text);
if let Some(ref storage) = self.storage {
let name = self.name.clone();
let model = self.model_id.clone();
let req = req_body_str.clone();
let resp_body = body_text.clone();
let dur = start.elapsed().as_millis() as u64;
let err = err_msg.clone();
let s = storage.clone();
tokio::spawn(async move {
let _ = s.append_llm_call(&name, &model, &req, Some(&resp_body), Some(&err), dur).await;
});
}
err_msg
})?;
let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text).map_err(|e| {
let err_msg = format!("decode error: {} | body: {}", e, &body_text);
if let Some(ref storage) = self.storage {
let name = self.name.clone();
let model = self.model_id.clone();
let req = req_body_str.clone();
let resp_body = body_text.clone();
let dur = start.elapsed().as_millis() as u64;
let err = err_msg.clone();
let s = storage.clone();
tokio::spawn(async move {
let _ = s
.append_llm_call(&name, &model, &req, Some(&resp_body), Some(&err), dur)
.await;
});
}
err_msg
})?;
let mut content = String::new();
let mut reasoning = None;
@ -343,21 +354,35 @@ impl LLMProvider for AnthropicProvider {
reasoning_content: reasoning,
tool_calls,
usage: Usage {
prompt_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens).unwrap_or(0),
completion_tokens: anthropic_resp.usage.as_ref().map(|u| u.output_tokens).unwrap_or(0),
total_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens + u.output_tokens).unwrap_or(0),
prompt_tokens: anthropic_resp
.usage
.as_ref()
.map(|u| u.input_tokens)
.unwrap_or(0),
completion_tokens: anthropic_resp
.usage
.as_ref()
.map(|u| u.output_tokens)
.unwrap_or(0),
total_tokens: anthropic_resp
.usage
.as_ref()
.map(|u| u.input_tokens + u.output_tokens)
.unwrap_or(0),
},
};
if let Some(ref storage) = self.storage {
let _ = storage.append_llm_call(
&self.name,
&self.model_id,
&req_body_str,
Some(&body_text),
None,
start.elapsed().as_millis() as u64,
).await;
let _ = storage
.append_llm_call(
&self.name,
&self.model_id,
&req_body_str,
Some(&body_text),
None,
start.elapsed().as_millis() as u64,
)
.await;
}
Ok(response)

View File

@ -1,12 +1,15 @@
pub mod traits;
pub mod openai;
pub mod anthropic;
pub mod openai;
pub mod traits;
pub use self::openai::OpenAIProvider;
pub use self::anthropic::AnthropicProvider;
pub use self::openai::OpenAIProvider;
use crate::config::LLMProviderConfig;
pub use traits::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall, ToolFunction, Usage};
pub use traits::{
ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall,
ToolFunction, Usage,
};
pub fn create_provider(config: LLMProviderConfig) -> Result<Box<dyn LLMProvider>, ProviderError> {
match config.provider_type.as_str() {

View File

@ -1,29 +1,35 @@
use async_trait::async_trait;
use reqwest::Client;
use serde::Deserialize;
use serde_json::{json, Value};
use serde_json::{Value, json};
use std::collections::HashMap;
use std::time::Duration;
use crate::bus::message::ContentBlock;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
use super::traits::Usage;
use std::sync::Arc;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
use crate::bus::message::ContentBlock;
use crate::storage::Storage;
use std::sync::Arc;
const LLM_REQUEST_TIMEOUT_SECS: u64 = 300;
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
if blocks.len() == 1
&& let ContentBlock::Text { text } = &blocks[0] {
return Value::String(text.clone());
}
Value::Array(blocks.iter().map(|b| match b {
ContentBlock::Text { text } => json!({ "type": "text", "text": text }),
ContentBlock::ImageUrl { image_url } => {
json!({ "type": "image_url", "image_url": { "url": image_url.url } })
}
}).collect())
&& let ContentBlock::Text { text } = &blocks[0]
{
return Value::String(text.clone());
}
Value::Array(
blocks
.iter()
.map(|b| match b {
ContentBlock::Text { text } => json!({ "type": "text", "text": text }),
ContentBlock::ImageUrl { image_url } => {
json!({ "type": "image_url", "image_url": { "url": image_url.url } })
}
})
.collect(),
)
}
pub struct OpenAIProvider {
@ -201,10 +207,14 @@ impl LLMProvider for OpenAIProvider {
if let Some(content) = msg.get("content").and_then(|c| c.as_array()) {
for (j, item) in content.iter().enumerate() {
if item.get("type").and_then(|t| t.as_str()) == Some("image_url")
&& let Some(url_str) = item.get("image_url").and_then(|u| u.get("url")).and_then(|v| v.as_str()) {
let prefix: String = url_str.chars().take(20).collect();
tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)");
}
&& let Some(url_str) = item
.get("image_url")
.and_then(|u| u.get("url"))
.and_then(|v| v.as_str())
{
let prefix: String = url_str.chars().take(20).collect();
tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)");
}
}
}
}
@ -224,19 +234,18 @@ impl LLMProvider for OpenAIProvider {
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
tracing::debug!(req_body = %req_body_str, "LLM request");
let resp = req_builder.json(&body).send().await
.inspect_err(|e| {
let is_timeout = e.is_timeout();
tracing::error!(
provider = %self.name,
model = %self.model_id,
url = %url,
timeout = is_timeout,
error = %e,
elapsed_ms = %start.elapsed().as_millis(),
"LLM API request failed"
);
})?;
let resp = req_builder.json(&body).send().await.inspect_err(|e| {
let is_timeout = e.is_timeout();
tracing::error!(
provider = %self.name,
model = %self.model_id,
url = %url,
timeout = is_timeout,
error = %e,
elapsed_ms = %start.elapsed().as_millis(),
"LLM API request failed"
);
})?;
let status = resp.status();
let text = resp.text().await?;
@ -253,37 +262,48 @@ impl LLMProvider for OpenAIProvider {
"LLM API returned error"
);
if let Some(ref storage) = self.storage
&& let Err(e) = storage.append_llm_call(
&self.name, &self.model_id, &req_body_str,
Some(&text), Some(&error),
start.elapsed().as_millis() as u64,
).await {
tracing::warn!("failed to persist LLM call: {}", e);
}
&& let Err(e) = storage
.append_llm_call(
&self.name,
&self.model_id,
&req_body_str,
Some(&text),
Some(&error),
start.elapsed().as_millis() as u64,
)
.await
{
tracing::warn!("failed to persist LLM call: {}", e);
}
return Err(error.into());
}
let openai_resp: OpenAIResponse = serde_json::from_str(&text)
.map_err(|e| {
let err_msg = format!("decode error: {} | body: {}", e, &text);
if let Some(ref storage) = self.storage {
let name = self.name.clone();
let model = self.model_id.clone();
let req = req_body_str.clone();
let resp = text.clone();
let dur = start.elapsed().as_millis() as u64;
let err = err_msg.clone();
let s = storage.clone();
tokio::spawn(async move {
if let Err(e) = s.append_llm_call(&name, &model, &req, Some(&resp), Some(&err), dur).await {
tracing::warn!("failed to persist LLM call (decode error): {}", e);
}
});
}
err_msg
})?;
let openai_resp: OpenAIResponse = serde_json::from_str(&text).map_err(|e| {
let err_msg = format!("decode error: {} | body: {}", e, &text);
if let Some(ref storage) = self.storage {
let name = self.name.clone();
let model = self.model_id.clone();
let req = req_body_str.clone();
let resp = text.clone();
let dur = start.elapsed().as_millis() as u64;
let err = err_msg.clone();
let s = storage.clone();
tokio::spawn(async move {
if let Err(e) = s
.append_llm_call(&name, &model, &req, Some(&resp), Some(&err), dur)
.await
{
tracing::warn!("failed to persist LLM call (decode error): {}", e);
}
});
}
err_msg
})?;
let first_choice = openai_resp.choices.into_iter().next()
let first_choice = openai_resp
.choices
.into_iter()
.next()
.ok_or("no choices in response")?;
let content = first_choice
@ -300,7 +320,8 @@ impl LLMProvider for OpenAIProvider {
.map(|tc| ToolCall {
id: tc.id.clone(),
name: tc.function.name.clone(),
arguments: serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null),
arguments: serde_json::from_str(&tc.function.arguments)
.unwrap_or(serde_json::Value::Null),
})
.collect();
@ -318,13 +339,19 @@ impl LLMProvider for OpenAIProvider {
};
if let Some(ref storage) = self.storage
&& let Err(e) = storage.append_llm_call(
&self.name, &self.model_id, &req_body_str,
Some(&text), None,
start.elapsed().as_millis() as u64,
).await {
tracing::warn!("failed to persist LLM call: {}", e);
}
&& let Err(e) = storage
.append_llm_call(
&self.name,
&self.model_id,
&req_body_str,
Some(&text),
None,
start.elapsed().as_millis() as u64,
)
.await
{
tracing::warn!("failed to persist LLM call: {}", e);
}
Ok(response)
}
@ -386,6 +413,9 @@ mod tests {
assert_eq!(tool_calls[0]["id"], "call_1");
assert_eq!(tool_calls[0]["type"], "function");
assert_eq!(tool_calls[0]["function"]["name"], "calculator");
assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}");
assert_eq!(
tool_calls[0]["function"]["arguments"],
"{\"expression\":\"1+1\"}"
);
}
}

View File

@ -1,6 +1,6 @@
use crate::bus::message::ContentBlock;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::bus::message::ContentBlock;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
@ -61,7 +61,11 @@ impl Message {
}
}
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self {
pub fn tool(
tool_call_id: impl Into<String>,
tool_name: impl Into<String>,
content: impl Into<String>,
) -> Self {
Self {
role: "tool".to_string(),
content: vec![ContentBlock::text(content)],

View File

@ -5,11 +5,11 @@ use std::time::Instant;
use tokio::time;
use crate::config::SchedulerConfig;
use crate::session::session::HandleResult;
use crate::session::SessionManager;
use crate::session::session::HandleResult;
use crate::storage::JobRun;
use crate::storage::ScheduledJob;
use crate::storage::Storage;
use crate::storage::JobRun;
pub use types::Schedule;
@ -89,7 +89,11 @@ impl Scheduler {
let now = now_ms();
let due = match self.storage.due_scheduled_jobs(now, self.config.max_concurrent).await {
let due = match self
.storage
.due_scheduled_jobs(now, self.config.max_concurrent)
.await
{
Ok(jobs) => jobs,
Err(e) => {
tracing::error!("scheduler: failed to query due jobs: {}", e);
@ -107,7 +111,11 @@ impl Scheduler {
let start = Instant::now();
let started_at = now_ms();
if let Err(e) = self.storage.touch_scheduled_job_last_run(&job.id, started_at).await {
if let Err(e) = self
.storage
.touch_scheduled_job_last_run(&job.id, started_at)
.await
{
tracing::error!(job_id = %job.id, "scheduler: failed to touch last_run_at: {}", e);
continue;
}
@ -135,7 +143,10 @@ impl Scheduler {
match result {
Ok(HandleResult::AgentResponse(output)) => {
let output_truncated = if output.len() > 8000 {
format!("{}...[truncated]", &output[..output.ceil_char_boundary(8000)])
format!(
"{}...[truncated]",
&output[..output.ceil_char_boundary(8000)]
)
} else {
output.clone()
};
@ -155,7 +166,11 @@ impl Scheduler {
tracing::error!(job_id = %job.id, "scheduler: failed to record run: {}", e);
}
if let Err(e) = self.storage.set_scheduled_job_last_status(&job.id, "ok", None).await {
if let Err(e) = self
.storage
.set_scheduled_job_last_status(&job.id, "ok", None)
.await
{
tracing::error!(job_id = %job.id, "scheduler: failed to set last_status: {}", e);
}
@ -199,9 +214,11 @@ impl Scheduler {
tracing::error!(job_id = %job.id, "scheduler: failed to record error run: {}", e2);
}
if let Err(e2) = self.storage.set_scheduled_job_last_status(
&job.id, "error", Some(&error_str),
).await {
if let Err(e2) = self
.storage
.set_scheduled_job_last_status(&job.id, "error", Some(&error_str))
.await
{
tracing::error!(job_id = %job.id, "scheduler: failed to set error status: {}", e2);
}
@ -231,17 +248,23 @@ impl Scheduler {
self.storage.remove_scheduled_job(&job.id).await?;
tracing::info!(job_id = %job.id, "scheduler: one-shot job deleted after run");
} else {
self.storage.set_scheduled_job_enabled(&job.id, false).await?;
self.storage
.set_scheduled_job_enabled(&job.id, false)
.await?;
tracing::info!(job_id = %job.id, "scheduler: one-shot job disabled after run");
}
}
Schedule::Every { .. } | Schedule::Cron { .. } => {
if let Some(next) = next_run_for_schedule(&job.schedule, now) {
self.storage.set_scheduled_job_next_run(&job.id, next).await?;
self.storage
.set_scheduled_job_next_run(&job.id, next)
.await?;
tracing::info!(job_id = %job.id, next_run_at = %next, "scheduler: job rescheduled");
} else {
tracing::error!(job_id = %job.id, "scheduler: could not compute next run -- disabling job");
self.storage.set_scheduled_job_enabled(&job.id, false).await?;
self.storage
.set_scheduled_job_enabled(&job.id, false)
.await?;
}
}
}

View File

@ -22,32 +22,20 @@ pub enum SessionCommand {
dialog_id: String,
},
/// Get the current dialog for a chat
GetCurrentDialog {
channel: String,
chat_id: String,
},
GetCurrentDialog { channel: String, chat_id: String },
/// Rename a dialog
RenameDialog {
session_id: UnifiedSessionId,
title: String,
},
/// Archive a dialog
ArchiveDialog {
session_id: UnifiedSessionId,
},
ArchiveDialog { session_id: UnifiedSessionId },
/// Delete a dialog
DeleteDialog {
session_id: UnifiedSessionId,
},
DeleteDialog { session_id: UnifiedSessionId },
/// Clear dialog history
ClearHistory {
session_id: UnifiedSessionId,
},
ClearHistory { session_id: UnifiedSessionId },
/// Get list of available slash commands
GetSlashCommands {
channel: String,
chat_id: String,
},
GetSlashCommands { channel: String, chat_id: String },
/// Execute a slash command
ExecuteSlashCommand {
command: String,
@ -60,7 +48,11 @@ pub enum SessionCommand {
impl SessionCommand {
/// Create a CreateDialog command
pub fn create_dialog(channel: impl Into<String>, chat_id: impl Into<String>, title: Option<String>) -> Self {
pub fn create_dialog(
channel: impl Into<String>,
chat_id: impl Into<String>,
title: Option<String>,
) -> Self {
Self::CreateDialog {
channel: channel.into(),
chat_id: chat_id.into(),
@ -69,7 +61,11 @@ impl SessionCommand {
}
/// Create a ListDialogs command
pub fn list_dialogs(channel: impl Into<String>, chat_id: impl Into<String>, include_archived: bool) -> Self {
pub fn list_dialogs(
channel: impl Into<String>,
chat_id: impl Into<String>,
include_archived: bool,
) -> Self {
Self::ListDialogs {
channel: channel.into(),
chat_id: chat_id.into(),

View File

@ -1,5 +1,5 @@
use super::session_id::UnifiedSessionId;
use super::session::SlashCommand;
use super::session_id::UnifiedSessionId;
/// Dialog information returned by SessionManager
#[derive(Debug, Clone)]
@ -30,30 +30,20 @@ pub enum SessionEvent {
session_id: Option<UnifiedSessionId>,
},
/// Dialog switched successfully
DialogSwitched {
session_id: UnifiedSessionId,
},
DialogSwitched { session_id: UnifiedSessionId },
/// Dialog renamed
DialogRenamed {
session_id: UnifiedSessionId,
title: String,
},
/// Dialog archived
DialogArchived {
session_id: UnifiedSessionId,
},
DialogArchived { session_id: UnifiedSessionId },
/// Dialog deleted
DialogDeleted {
session_id: UnifiedSessionId,
},
DialogDeleted { session_id: UnifiedSessionId },
/// Dialog history cleared
HistoryCleared {
session_id: UnifiedSessionId,
},
HistoryCleared { session_id: UnifiedSessionId },
/// List of available slash commands
SlashCommandsList {
commands: Vec<SlashCommand>,
},
SlashCommandsList { commands: Vec<SlashCommand> },
/// Slash command executed successfully
SlashCommandExecuted {
new_session_id: Option<UnifiedSessionId>,
@ -70,8 +60,5 @@ pub enum SessionEvent {
message_count: usize,
},
/// Error occurred
Error {
code: String,
message: String,
},
Error { code: String, message: String },
}

View File

@ -1,11 +1,11 @@
pub mod error;
pub mod commands;
pub mod error;
pub mod events;
pub mod session;
pub mod session_id;
pub use error::SessionError;
pub use commands::SessionCommand;
pub use events::{SessionEvent, DialogInfo};
pub use session::{Session, SessionManager, SlashCommand, SLASH_COMMANDS};
pub use error::SessionError;
pub use events::{DialogInfo, SessionEvent};
pub use session::{SLASH_COMMANDS, Session, SessionManager, SlashCommand};
pub use session_id::UnifiedSessionId;

View File

@ -8,7 +8,6 @@
///
/// For simple cases where only one dialog exists per chat:
/// - `dialog_id` defaults to `"default"`
use serde::{Deserialize, Serialize};
pub const DEFAULT_DIALOG_ID: &str = "default";
@ -22,7 +21,11 @@ pub struct UnifiedSessionId {
impl UnifiedSessionId {
/// Create a new UnifiedSessionId
pub fn new(channel: impl Into<String>, chat_id: impl Into<String>, dialog_id: impl Into<String>) -> Self {
pub fn new(
channel: impl Into<String>,
chat_id: impl Into<String>,
dialog_id: impl Into<String>,
) -> Self {
Self {
channel: channel.into(),
chat_id: chat_id.into(),

View File

@ -1,6 +1,6 @@
use std::path::Path;
use super::embedded::{EmbeddedSkill, EMBEDDED_SKILLS};
use super::embedded::{EMBEDDED_SKILLS, EmbeddedSkill};
pub fn install_builtin_skills(target_dir: &Path) {
for skill in EMBEDDED_SKILLS {
@ -22,8 +22,7 @@ pub fn install_builtin_skills(target_dir: &Path) {
}
fn install_one(skill: &EmbeddedSkill, target_dir: &Path) -> Result<(), String> {
let decompressed = zstd::decode_all(skill.data)
.map_err(|e| format!("zstd decode: {}", e))?;
let decompressed = zstd::decode_all(skill.data).map_err(|e| format!("zstd decode: {}", e))?;
let mut archive = tar::Archive::new(decompressed.as_slice());
archive

View File

@ -120,7 +120,11 @@ impl SkillsLoader {
let count = loaded.len();
let mut replaced = 0usize;
for skill in loaded {
if let Some(existing) = state.loaded_skills.iter_mut().find(|s| s.name == skill.name) {
if let Some(existing) = state
.loaded_skills
.iter_mut()
.find(|s| s.name == skill.name)
{
*existing = skill;
replaced += 1;
} else {
@ -138,33 +142,42 @@ impl SkillsLoader {
// Load from workspace skills dir (highest priority) — replace same-name skills
if let Some(ref ws_dir) = self.workspace_skills_dir
&& ws_dir.exists() {
let loaded = self.load_skills_from_dir(ws_dir);
let count = loaded.len();
let mut replaced = 0usize;
for skill in loaded {
if let Some(existing) = state.loaded_skills.iter_mut().find(|s| s.name == skill.name) {
*existing = skill;
replaced += 1;
} else {
state.loaded_skills.push(skill);
}
&& ws_dir.exists()
{
let loaded = self.load_skills_from_dir(ws_dir);
let count = loaded.len();
let mut replaced = 0usize;
for skill in loaded {
if let Some(existing) = state
.loaded_skills
.iter_mut()
.find(|s| s.name == skill.name)
{
*existing = skill;
replaced += 1;
} else {
state.loaded_skills.push(skill);
}
tracing::debug!(
dir = %ws_dir.display(),
count = count,
replaced = replaced,
"Loaded skills from workspace directory"
);
state.last_workspace_mtime = Self::get_dir_mtime(ws_dir);
}
tracing::debug!(
dir = %ws_dir.display(),
count = count,
replaced = replaced,
"Loaded skills from workspace directory"
);
state.last_workspace_mtime = Self::get_dir_mtime(ws_dir);
}
state.last_load_time = SystemTime::now();
if state.loaded_skills.is_empty() {
tracing::debug!("No skills found in any skills directory");
} else {
tracing::info!(count = state.loaded_skills.len(), "Loaded {} skills total", state.loaded_skills.len());
tracing::info!(
count = state.loaded_skills.len(),
"Loaded {} skills total",
state.loaded_skills.len()
);
}
}
@ -215,18 +228,20 @@ impl SkillsLoader {
let mut max_mtime = None;
if let Ok(metadata) = std::fs::metadata(dir)
&& let Ok(mtime) = metadata.modified() {
max_mtime = Some(mtime);
}
&& let Ok(mtime) = metadata.modified()
{
max_mtime = Some(mtime);
}
if let Ok(entries) = std::fs::read_dir(dir) {
for entry in entries.flatten() {
let path = entry.path();
if let Ok(metadata) = std::fs::metadata(&path)
&& let Ok(mtime) = metadata.modified()
&& max_mtime.is_none_or(|current| mtime > current) {
max_mtime = Some(mtime);
}
&& max_mtime.is_none_or(|current| mtime > current)
{
max_mtime = Some(mtime);
}
}
}
@ -244,7 +259,12 @@ impl SkillsLoader {
pub fn get_always_skills(&self) -> Vec<Skill> {
self.reload_if_changed();
let state = self.state.lock().unwrap();
state.loaded_skills.iter().filter(|s| s.always).cloned().collect()
state
.loaded_skills
.iter()
.filter(|s| s.always)
.cloned()
.collect()
}
/// Get a specific skill by name (checks for changes first)
@ -258,7 +278,8 @@ impl SkillsLoader {
pub fn list_skills(&self) -> Vec<(String, String)> {
self.reload_if_changed();
let state = self.state.lock().unwrap();
state.loaded_skills
state
.loaded_skills
.iter()
.map(|s| (s.name.clone(), s.description.clone()))
.collect()
@ -279,15 +300,21 @@ impl SkillsLoader {
prompt.push_str("### 目录说明\n\n");
prompt.push_str("- `~/.agents/skills/` — 外部共享 skill 目录(第三方、系统级 skill\n");
prompt.push_str("- `~/.picobot/skills/` — 安装 skill 的默认目录\n");
prompt.push_str("- `{workspace}/skills/` — 工作目录下的 skillpicobot 自行创建的 skill 存放于此\n\n");
prompt.push_str("安装或创建 skill 时请按上述目录规范存放创建skill时不要和已有skill同名。\n\n");
prompt.push_str(
"- `{workspace}/skills/` — 工作目录下的 skillpicobot 自行创建的 skill 存放于此\n\n",
);
prompt.push_str(
"安装或创建 skill 时请按上述目录规范存放创建skill时不要和已有skill同名。\n\n",
);
// Always skills summary
let always_skills: Vec<_> = state.loaded_skills.iter().filter(|s| s.always).collect();
if !always_skills.is_empty() {
prompt.push_str("### 常用技能\n\n");
for skill in &always_skills {
let path_str = skill.path.as_ref()
let path_str = skill
.path
.as_ref()
.map(|p| p.to_string_lossy().to_string())
.unwrap_or_else(|| "".to_string());
prompt.push_str(&format!(
@ -300,8 +327,12 @@ impl SkillsLoader {
// Usage instructions
prompt.push_str("### 使用方法\n\n");
prompt.push_str("- 使用 `get_skill` 工具 action=\"list\" 列出所有可用 skill 及其名称、简介、路径\n");
prompt.push_str("- 使用 `get_skill` 工具 action=\"get\" 并提供 `skill_name` 获取指定 skill 完整内容\n");
prompt.push_str(
"- 使用 `get_skill` 工具 action=\"list\" 列出所有可用 skill 及其名称、简介、路径\n",
);
prompt.push_str(
"- 使用 `get_skill` 工具 action=\"get\" 并提供 `skill_name` 获取指定 skill 完整内容\n",
);
prompt.push_str("- 当用户询问已安装的 skill 相关信息时,需重新调用 get_skill 工具查询最新内容,避免 skill 已变更导致信息过时\n");
// Always skills full content
@ -338,25 +369,23 @@ impl SkillsLoader {
}
match std::fs::read_to_string(&skill_file) {
Ok(content) => {
match self.parse_skill(&path, &content) {
Some(skill) => {
tracing::debug!(
skill = %skill.name,
path = %skill_file.display(),
always = skill.always,
"Loaded skill"
);
skills.push(skill);
}
None => {
tracing::warn!(
path = %skill_file.display(),
"Failed to parse skill"
);
}
Ok(content) => match self.parse_skill(&path, &content) {
Some(skill) => {
tracing::debug!(
skill = %skill.name,
path = %skill_file.display(),
always = skill.always,
"Loaded skill"
);
skills.push(skill);
}
}
None => {
tracing::warn!(
path = %skill_file.display(),
"Failed to parse skill"
);
}
},
Err(e) => {
tracing::warn!(
path = %skill_file.display(),
@ -447,7 +476,6 @@ impl Default for SkillsLoader {
}
}
/// Extract first non-empty, non-heading line as description
fn extract_description(content: &str) -> String {
content

View File

@ -241,12 +241,11 @@ impl super::Storage {
let cutoff = chrono::Utc::now() - chrono::Duration::days(retention_days as i64);
let cutoff_str = cutoff.to_rfc3339();
let result = sqlx::query(
"DELETE FROM memories WHERE category = 'timeline' AND created_at < ?",
)
.bind(&cutoff_str)
.execute(self.pool())
.await?;
let result =
sqlx::query("DELETE FROM memories WHERE category = 'timeline' AND created_at < ?")
.bind(&cutoff_str)
.execute(self.pool())
.await?;
Ok(result.rows_affected())
}
@ -276,9 +275,7 @@ impl super::Storage {
}
}
fn parse_memory_rows(
rows: &[sqlx::sqlite::SqliteRow],
) -> Result<Vec<MemoryEntry>, StorageError> {
fn parse_memory_rows(rows: &[sqlx::sqlite::SqliteRow]) -> Result<Vec<MemoryEntry>, StorageError> {
rows.iter()
.map(|row| {
Ok(MemoryEntry {

View File

@ -165,7 +165,11 @@ impl crate::storage::Storage {
}
/// Update next_run_at and last_run_at for a job.
pub async fn set_scheduled_job_next_run(&self, id: &str, next_run_at: i64) -> anyhow::Result<()> {
pub async fn set_scheduled_job_next_run(
&self,
id: &str,
next_run_at: i64,
) -> anyhow::Result<()> {
let now = now_ms();
sqlx::query(
"UPDATE scheduled_jobs SET next_run_at = ?, last_run_at = ?, updated_at = ? WHERE id = ?",
@ -331,7 +335,9 @@ mod tests {
async fn setup_storage() -> Storage {
let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
let storage = Storage { pool };
Storage::init_scheduler_schema(storage.pool()).await.unwrap();
Storage::init_scheduler_schema(storage.pool())
.await
.unwrap();
storage
}
@ -450,7 +456,10 @@ mod tests {
updated_at: t,
};
storage.add_scheduled_job(&job).await.unwrap();
storage.set_scheduled_job_enabled("job-toggle", false).await.unwrap();
storage
.set_scheduled_job_enabled("job-toggle", false)
.await
.unwrap();
let got = storage.get_scheduled_job("job-toggle").await.unwrap();
assert!(!got.enabled);
}
@ -461,31 +470,55 @@ mod tests {
let t = now();
let jobs = vec![
ScheduledJob {
id: "due".into(), name: "due".into(),
schedule: Schedule::At { at: t }, prompt: "1".into(),
channel: "cli_chat".into(), chat_id: "c".into(),
model: None, enabled: true, delete_after_run: false,
next_run_at: t - 1000, last_run_at: None,
last_status: None, last_error: None,
created_at: t, updated_at: t,
id: "due".into(),
name: "due".into(),
schedule: Schedule::At { at: t },
prompt: "1".into(),
channel: "cli_chat".into(),
chat_id: "c".into(),
model: None,
enabled: true,
delete_after_run: false,
next_run_at: t - 1000,
last_run_at: None,
last_status: None,
last_error: None,
created_at: t,
updated_at: t,
},
ScheduledJob {
id: "future".into(), name: "future".into(),
schedule: Schedule::At { at: t + 99999999 }, prompt: "2".into(),
channel: "cli_chat".into(), chat_id: "c".into(),
model: None, enabled: true, delete_after_run: false,
next_run_at: t + 99999999, last_run_at: None,
last_status: None, last_error: None,
created_at: t, updated_at: t,
id: "future".into(),
name: "future".into(),
schedule: Schedule::At { at: t + 99999999 },
prompt: "2".into(),
channel: "cli_chat".into(),
chat_id: "c".into(),
model: None,
enabled: true,
delete_after_run: false,
next_run_at: t + 99999999,
last_run_at: None,
last_status: None,
last_error: None,
created_at: t,
updated_at: t,
},
ScheduledJob {
id: "disabled-due".into(), name: "disabled due".into(),
schedule: Schedule::At { at: t }, prompt: "3".into(),
channel: "cli_chat".into(), chat_id: "c".into(),
model: None, enabled: false, delete_after_run: false,
next_run_at: t - 1000, last_run_at: None,
last_status: None, last_error: None,
created_at: t, updated_at: t,
id: "disabled-due".into(),
name: "disabled due".into(),
schedule: Schedule::At { at: t },
prompt: "3".into(),
channel: "cli_chat".into(),
chat_id: "c".into(),
model: None,
enabled: false,
delete_after_run: false,
next_run_at: t - 1000,
last_run_at: None,
last_status: None,
last_error: None,
created_at: t,
updated_at: t,
},
];
for j in &jobs {
@ -501,24 +534,39 @@ mod tests {
let storage = setup_storage().await;
let t = now();
let job = ScheduledJob {
id: "job-run".into(), name: "run test".into(),
id: "job-run".into(),
name: "run test".into(),
schedule: Schedule::Every { every_ms: 1000 },
prompt: "hi".into(), channel: "cli_chat".into(), chat_id: "c".into(),
model: None, enabled: true, delete_after_run: false,
next_run_at: t, last_run_at: None,
last_status: None, last_error: None,
created_at: t, updated_at: t,
prompt: "hi".into(),
channel: "cli_chat".into(),
chat_id: "c".into(),
model: None,
enabled: true,
delete_after_run: false,
next_run_at: t,
last_run_at: None,
last_status: None,
last_error: None,
created_at: t,
updated_at: t,
};
storage.add_scheduled_job(&job).await.unwrap();
let run = super::JobRun {
id: 0, job_id: "job-run".into(),
started_at: t, finished_at: t + 500,
status: "ok".into(), output: Some("hello".into()),
error: None, duration_ms: 500,
id: 0,
job_id: "job-run".into(),
started_at: t,
finished_at: t + 500,
status: "ok".into(),
output: Some("hello".into()),
error: None,
duration_ms: 500,
};
storage.record_scheduled_job_run(&run).await.unwrap();
let runs = storage.list_scheduled_job_runs("job-run", 10).await.unwrap();
let runs = storage
.list_scheduled_job_runs("job-run", 10)
.await
.unwrap();
assert_eq!(runs.len(), 1);
assert_eq!(runs[0].status, "ok");
assert_eq!(runs[0].output.as_deref(), Some("hello"));
@ -529,22 +577,34 @@ mod tests {
let storage = setup_storage().await;
let t = now();
let job = ScheduledJob {
id: "job-update".into(), name: "old name".into(),
id: "job-update".into(),
name: "old name".into(),
schedule: Schedule::Every { every_ms: 1000 },
prompt: "old prompt".into(), channel: "feishu".into(),
chat_id: "oc_1".into(), model: None,
enabled: true, delete_after_run: false,
next_run_at: t, last_run_at: None,
last_status: None, last_error: None,
created_at: t, updated_at: t,
prompt: "old prompt".into(),
channel: "feishu".into(),
chat_id: "oc_1".into(),
model: None,
enabled: true,
delete_after_run: false,
next_run_at: t,
last_run_at: None,
last_status: None,
last_error: None,
created_at: t,
updated_at: t,
};
storage.add_scheduled_job(&job).await.unwrap();
storage.update_scheduled_job(
"job-update",
Some("new prompt".into()),
Some(Schedule::Every { every_ms: 60000 }),
None, None, None,
).await.unwrap();
storage
.update_scheduled_job(
"job-update",
Some("new prompt".into()),
Some(Schedule::Every { every_ms: 60000 }),
None,
None,
None,
)
.await
.unwrap();
let got = storage.get_scheduled_job("job-update").await.unwrap();
assert_eq!(got.prompt, "new prompt");
}

View File

@ -167,10 +167,7 @@ impl Tool for BashTool {
Err(_) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Command timed out after {} seconds",
timeout_secs
)),
error: Some(format!("Command timed out after {} seconds", timeout_secs)),
}),
}
}
@ -249,10 +246,7 @@ mod tests {
#[tokio::test]
async fn test_pwd_command() {
let tool = BashTool::new();
let result = tool
.execute(json!({ "command": "pwd" }))
.await
.unwrap();
let result = tool.execute(json!({ "command": "pwd" })).await.unwrap();
assert!(result.success);
}
@ -260,7 +254,10 @@ mod tests {
#[tokio::test]
async fn test_ls_command() {
let tool = BashTool::new();
let result = tool.execute(json!({ "command": "ls -la /tmp" })).await.unwrap();
let result = tool
.execute(json!({ "command": "ls -la /tmp" }))
.await
.unwrap();
assert!(result.success);
}

View File

@ -5,7 +5,7 @@ use std::time::Duration;
use anyhow::Context;
use async_trait::async_trait;
use base64::Engine;
use fantoccini::actions::{InputSource, MouseActions, PointerAction, MOUSE_BUTTON_LEFT};
use fantoccini::actions::{InputSource, MOUSE_BUTTON_LEFT, MouseActions, PointerAction};
use fantoccini::key::Key;
use fantoccini::{Client, ClientBuilder, Locator};
use serde::{Deserialize, Serialize};
@ -63,7 +63,9 @@ impl BrowserTool {
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum BrowserAction {
Open { url: String },
Open {
url: String,
},
Snapshot {
#[serde(default)]
interactive_only: bool,
@ -72,10 +74,20 @@ pub enum BrowserAction {
#[serde(default)]
depth: Option<i64>,
},
Click { selector: String },
Fill { selector: String, value: String },
Type { selector: Option<String>, text: String },
GetText { selector: String },
Click {
selector: String,
},
Fill {
selector: String,
value: String,
},
Type {
selector: Option<String>,
text: String,
},
GetText {
selector: String,
},
GetTitle,
GetUrl,
Screenshot {
@ -84,7 +96,9 @@ pub enum BrowserAction {
#[serde(default)]
return_base64: bool,
},
Focus { selector: String },
Focus {
selector: String,
},
Wait {
#[serde(default)]
selector: Option<String>,
@ -93,9 +107,16 @@ pub enum BrowserAction {
#[serde(default)]
text: Option<String>,
},
Press { key: String },
Hover { selector: String },
ClickAt { x: u32, y: u32 },
Press {
key: String,
},
Hover {
selector: String,
},
ClickAt {
x: u32,
y: u32,
},
Scroll {
direction: String,
#[serde(default)]
@ -120,13 +141,8 @@ fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result<Browse
.get("interactive_only")
.and_then(Value::as_bool)
.unwrap_or(true),
compact: args
.get("compact")
.and_then(Value::as_bool)
.unwrap_or(true),
depth: args
.get("depth")
.and_then(|v| v.as_i64()),
compact: args.get("compact").and_then(Value::as_bool).unwrap_or(true),
depth: args.get("depth").and_then(|v| v.as_i64()),
}),
"click" => {
let selector = args
@ -198,10 +214,7 @@ fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result<Browse
.and_then(|v| v.as_str())
.map(String::from),
ms: args.get("ms").and_then(|v| v.as_u64()),
text: args
.get("text")
.and_then(|v| v.as_str())
.map(String::from),
text: args.get("text").and_then(|v| v.as_str()).map(String::from),
}),
"press" => {
let key = args
@ -239,11 +252,13 @@ fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result<Browse
let x = args
.get("x")
.and_then(|v| v.as_u64())
.ok_or_else(|| anyhow::anyhow!("Missing 'x' for click_at"))? as u32;
.ok_or_else(|| anyhow::anyhow!("Missing 'x' for click_at"))?
as u32;
let y = args
.get("y")
.and_then(|v| v.as_u64())
.ok_or_else(|| anyhow::anyhow!("Missing 'y' for click_at"))? as u32;
.ok_or_else(|| anyhow::anyhow!("Missing 'y' for click_at"))?
as u32;
Ok(BrowserAction::ClickAt { x, y })
}
other => anyhow::bail!("Unsupported browser action: {}", other),
@ -488,7 +503,11 @@ impl BrowserState {
}
Err(e) => return Err(e.into()),
}
tracing::debug!(action = "fill", output_len = value.len(), "Browser action completed");
tracing::debug!(
action = "fill",
output_len = value.len(),
"Browser action completed"
);
Ok(ToolResult {
success: true,
output: format!("Filled {} with {}", selector, value),
@ -573,7 +592,10 @@ impl BrowserState {
error: None,
})
}
BrowserAction::Screenshot { path, return_base64 } => {
BrowserAction::Screenshot {
path,
return_base64,
} => {
let client = self.active_client()?;
let png = client.screenshot().await?;
let save_path = path.unwrap_or_else(|| {
@ -588,14 +610,25 @@ impl BrowserState {
tokio::fs::write(&save_path, &png).await?;
if return_base64 {
let b64 = base64::engine::general_purpose::STANDARD.encode(&png);
tracing::debug!(action = "screenshot", output_len = b64.len(), "Browser action completed");
tracing::debug!(
action = "screenshot",
output_len = b64.len(),
"Browser action completed"
);
return Ok(ToolResult {
success: true,
output: format!("Screenshot saved to {}. Base64: data:image/png;base64,{}", save_path, b64),
output: format!(
"Screenshot saved to {}. Base64: data:image/png;base64,{}",
save_path, b64
),
error: None,
});
}
tracing::debug!(action = "screenshot", output_len = save_path.len(), "Browser action completed");
tracing::debug!(
action = "screenshot",
output_len = save_path.len(),
"Browser action completed"
);
Ok(ToolResult {
success: true,
output: format!("Screenshot saved to {}", save_path),
@ -611,18 +644,18 @@ impl BrowserState {
vec![serde_json::to_value(el)?],
)
.await?;
tracing::debug!(action = "focus", output_len = selector.len(), "Browser action completed");
tracing::debug!(
action = "focus",
output_len = selector.len(),
"Browser action completed"
);
Ok(ToolResult {
success: true,
output: format!("Focused {}", selector),
error: None,
})
}
BrowserAction::Wait {
selector,
ms,
text,
} => {
BrowserAction::Wait { selector, ms, text } => {
if let Some(sel) = selector {
let client = self.active_client()?;
wait_for_selector(client, &sel).await?;
@ -719,9 +752,21 @@ impl BrowserState {
let id = info.get("id").and_then(|v| v.as_str()).unwrap_or("");
let el_type = info.get("type").and_then(|v| v.as_str()).unwrap_or("");
let text = info.get("text").and_then(|v| v.as_str()).unwrap_or("");
let id_str = if id.is_empty() { String::new() } else { format!("#{id}") };
let type_str = if el_type.is_empty() { String::new() } else { format!("[type={el_type}]") };
let text_str = if text.is_empty() { String::new() } else { format!(" ({text})") };
let id_str = if id.is_empty() {
String::new()
} else {
format!("#{id}")
};
let type_str = if el_type.is_empty() {
String::new()
} else {
format!("[type={el_type}]")
};
let text_str = if text.is_empty() {
String::new()
} else {
format!(" ({text})")
};
format!("Clicked at ({x},{y}) on <{tag}{id_str}{type_str}>{text_str}")
}
None => format!("Clicked at ({}, {})", x, y),
@ -1090,10 +1135,7 @@ fn css_attr_escape(input: &str) -> String {
}
fn xpath_contains_text(text: &str) -> String {
format!(
"//*[contains(normalize-space(.), {})]",
xpath_literal(text)
)
format!("//*[contains(normalize-space(.), {})]", xpath_literal(text))
}
fn xpath_literal(input: &str) -> String {
@ -1140,7 +1182,10 @@ fn webdriver_key(key: &str) -> String {
"pagedown" => Key::PageDown.to_string(),
"space" => " ".to_string(),
other => {
tracing::warn!("Unrecognized key '{}', this will have no effect (press only supports single named keys)", other);
tracing::warn!(
"Unrecognized key '{}', this will have no effect (press only supports single named keys)",
other
);
other.to_string()
}
}

View File

@ -659,10 +659,7 @@ mod tests {
#[tokio::test]
async fn test_evaluate_missing_expression() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "evaluate"}))
.await
.unwrap();
let result = tool.execute(json!({"function": "evaluate"})).await.unwrap();
assert!(!result.success);
}

View File

@ -31,10 +31,7 @@ impl ContentSearchTool {
for (i, line) in lines.iter().enumerate() {
if output.len() + line.len() + 1 > MAX_OUTPUT_CHARS {
let omitted = lines.len() - i;
output.push_str(&format!(
"\n... ({} matches omitted) ...",
omitted
));
output.push_str(&format!("\n... ({} matches omitted) ...", omitted));
break;
}
if !output.is_empty() {
@ -113,18 +110,40 @@ impl Tool for ContentSearchTool {
let dir = self.resolve_dir(args.get("dir").and_then(|v| v.as_str()));
let file_pattern = args.get("file_pattern").and_then(|v| v.as_str());
let case_sensitive = args.get("case_sensitive").and_then(|v| v.as_bool()).unwrap_or(false);
let context_lines = args.get("context_lines").and_then(|v| v.as_u64()).unwrap_or(0) as usize;
let max_results = args.get("max_results").and_then(|v| v.as_u64()).unwrap_or(MAX_RESULTS as u64) as usize;
let case_sensitive = args
.get("case_sensitive")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let context_lines = args
.get("context_lines")
.and_then(|v| v.as_u64())
.unwrap_or(0) as usize;
let max_results = args
.get("max_results")
.and_then(|v| v.as_u64())
.unwrap_or(MAX_RESULTS as u64) as usize;
let result = self.run_search(pattern, &dir, file_pattern, case_sensitive, context_lines, max_results).await;
let result = self
.run_search(
pattern,
&dir,
file_pattern,
case_sensitive,
context_lines,
max_results,
)
.await;
match result {
Ok(lines) => {
let count = lines.len();
let mut output = self.truncate_output(&lines);
output.push_str(&format!("\n\n---\n{} 条匹配", count));
Ok(ToolResult { success: true, output, error: None })
Ok(ToolResult {
success: true,
output,
error: None,
})
}
Err(e) => Ok(ToolResult {
success: false,
@ -146,22 +165,52 @@ impl ContentSearchTool {
max_results: usize,
) -> anyhow::Result<Vec<String>> {
if which::which("rg").is_ok() {
match self.search_with_rg(pattern, dir, file_pattern, case_sensitive, context_lines, max_results).await {
match self
.search_with_rg(
pattern,
dir,
file_pattern,
case_sensitive,
context_lines,
max_results,
)
.await
{
Ok(lines) => return Ok(lines),
Err(e) => tracing::warn!("rg failed: {}, falling back", e),
}
}
if which::which("grep").is_ok() {
match self.search_with_grep(pattern, dir, file_pattern, case_sensitive, context_lines, max_results).await {
match self
.search_with_grep(
pattern,
dir,
file_pattern,
case_sensitive,
context_lines,
max_results,
)
.await
{
Ok(lines) if !lines.is_empty() => return Ok(lines),
Ok(_) => {},
Ok(_) => {}
Err(e) => tracing::warn!("grep failed: {}, falling back", e),
}
}
tracing::warn!("No rg/grep available, using built-in content search (much slower). Install ripgrep for better performance.");
self.search_with_rust(pattern, dir, file_pattern, case_sensitive, context_lines, max_results).await
tracing::warn!(
"No rg/grep available, using built-in content search (much slower). Install ripgrep for better performance."
);
self.search_with_rust(
pattern,
dir,
file_pattern,
case_sensitive,
context_lines,
max_results,
)
.await
}
async fn search_with_rg(
@ -176,8 +225,10 @@ impl ContentSearchTool {
let mut cmd = Command::new("rg");
cmd.arg("-n")
.arg("--no-heading")
.arg("--color").arg("never")
.arg("--max-count").arg(max_results.to_string())
.arg("--color")
.arg("never")
.arg("--max-count")
.arg(max_results.to_string())
.arg(pattern)
.arg(dir)
.stdout(Stdio::piped())
@ -193,12 +244,9 @@ impl ContentSearchTool {
cmd.arg("--glob").arg(fp);
}
let output = timeout(
std::time::Duration::from_secs(TIMEOUT_SECS),
cmd.output(),
)
.await
.map_err(|_| anyhow::anyhow!("rg timed out after {}s", TIMEOUT_SECS))??;
let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
.await
.map_err(|_| anyhow::anyhow!("rg timed out after {}s", TIMEOUT_SECS))??;
if !output.status.success() && output.status.code() != Some(1) {
let stderr = String::from_utf8_lossy(&output.stderr);
@ -206,7 +254,8 @@ impl ContentSearchTool {
}
let text = String::from_utf8_lossy(&output.stdout);
let lines: Vec<String> = text.lines()
let lines: Vec<String> = text
.lines()
.take(max_results)
.map(|l| l.to_string())
.collect();
@ -242,15 +291,13 @@ impl ContentSearchTool {
cmd.arg("--include").arg(fp);
}
let output = timeout(
std::time::Duration::from_secs(TIMEOUT_SECS),
cmd.output(),
)
.await
.map_err(|_| anyhow::anyhow!("grep timed out after {}s", TIMEOUT_SECS))??;
let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
.await
.map_err(|_| anyhow::anyhow!("grep timed out after {}s", TIMEOUT_SECS))??;
let text = String::from_utf8_lossy(&output.stdout);
let lines: Vec<String> = text.lines()
let lines: Vec<String> = text
.lines()
.take(max_results)
.map(|l| l.to_string())
.collect();
@ -280,7 +327,9 @@ impl ContentSearchTool {
if case_sensitive {
regex::Regex::new(&re_str)
} else {
regex::RegexBuilder::new(&re_str).case_insensitive(true).build()
regex::RegexBuilder::new(&re_str)
.case_insensitive(true)
.build()
}
});
@ -291,7 +340,14 @@ impl ContentSearchTool {
};
let mut results = Vec::new();
grep_dir(Path::new(dir), Path::new(dir), &re, file_re.as_ref(), &mut results, max_results)?;
grep_dir(
Path::new(dir),
Path::new(dir),
&re,
file_re.as_ref(),
&mut results,
max_results,
)?;
Ok(results)
}
@ -350,16 +406,19 @@ fn grep_dir(
if path.is_dir() {
if let Some(name) = rel.file_name().and_then(|n| n.to_str())
&& name.starts_with('.') && name.len() > 1 {
continue;
}
&& name.starts_with('.')
&& name.len() > 1
{
continue;
}
grep_dir(base, &path, re, file_re, results, max)?;
} else if path.is_file() {
if let Some(file_re) = file_re
&& let Some(name) = rel.file_name().and_then(|n| n.to_str())
&& !file_re.is_match(name) {
continue;
}
&& !file_re.is_match(name)
{
continue;
}
if let Ok(content) = std::fs::read_to_string(&path) {
for (line_num, line) in content.lines().enumerate() {
@ -391,8 +450,16 @@ mod tests {
#[tokio::test]
async fn test_content_search_rust_fallback() {
let dir = TempDir::new().unwrap();
fs::write(dir.path().join("main.rs"), "fn main() {\n let x = 42;\n println!(\"hello\");\n}").unwrap();
fs::write(dir.path().join("lib.rs"), "pub fn foo() -> u32 {\n let y = 42;\n y\n}").unwrap();
fs::write(
dir.path().join("main.rs"),
"fn main() {\n let x = 42;\n println!(\"hello\");\n}",
)
.unwrap();
fs::write(
dir.path().join("lib.rs"),
"pub fn foo() -> u32 {\n let y = 42;\n y\n}",
)
.unwrap();
fs::write(dir.path().join("README.md"), "# Project\nHello world").unwrap();
let tool = ContentSearchTool::new();

View File

@ -1,10 +1,10 @@
use std::sync::Arc;
use async_trait::async_trait;
use serde_json::{json, Value};
use serde_json::{Value, json};
use uuid::Uuid;
use crate::scheduler::{next_run_for_schedule, Schedule};
use crate::scheduler::{Schedule, next_run_for_schedule};
use crate::storage::{ScheduledJob, Storage};
use crate::tools::traits::{Tool, ToolResult};
@ -229,10 +229,7 @@ impl Tool for CronListTool {
}
async fn execute(&self, args: Value) -> anyhow::Result<ToolResult> {
let filter = args
.get("status")
.and_then(|v| v.as_str())
.unwrap_or("all");
let filter = args.get("status").and_then(|v| v.as_str()).unwrap_or("all");
let jobs = self.storage.list_scheduled_jobs().await?;
let filtered: Vec<&ScheduledJob> = match filter {
@ -397,7 +394,9 @@ impl Tool for CronEnableTool {
.map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?;
let next = next_run_for_schedule(&job.schedule, now_ms());
self.storage.set_scheduled_job_enabled(&job_id, true).await?;
self.storage
.set_scheduled_job_enabled(&job_id, true)
.await?;
if let Some(n) = next {
self.storage.set_scheduled_job_next_run(&job_id, n).await?;
}
@ -464,7 +463,9 @@ impl Tool for CronDisableTool {
.get_scheduled_job(&job_id)
.await
.map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?;
self.storage.set_scheduled_job_enabled(&job_id, false).await?;
self.storage
.set_scheduled_job_enabled(&job_id, false)
.await?;
Ok(ToolResult {
success: true,
@ -580,7 +581,9 @@ impl Tool for CronUpdateTool {
if args.get("schedule").is_some() {
let job = self.storage.get_scheduled_job(&job_id).await?;
if let Some(next) = next_run_for_schedule(&job.schedule, now_ms()) {
self.storage.set_scheduled_job_next_run(&job_id, next).await?;
self.storage
.set_scheduled_job_next_run(&job_id, next)
.await?;
}
}
@ -765,9 +768,7 @@ mod tests {
let job = ScheduledJob {
id: "job-update-tool".into(),
name: "old".into(),
schedule: Schedule::Every {
every_ms: 3600000,
},
schedule: Schedule::Every { every_ms: 3600000 },
prompt: "old prompt".into(),
channel: "feishu".into(),
chat_id: "oc_1".into(),

View File

@ -102,7 +102,10 @@ impl Tool for DelegateTool {
_ => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Unknown action: {}. Supported: run, check_task, cancel_task, list_tasks", action)),
error: Some(format!(
"Unknown action: {}. Supported: run, check_task, cancel_task, list_tasks",
action
)),
}),
}
}
@ -115,9 +118,11 @@ impl DelegateTool {
.ok_or_else(|| anyhow::anyhow!("missing required parameter: prompt"))?
.to_string();
let allowed_tools: Option<Vec<String>> = args["allowed_tools"]
.as_array()
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect());
let allowed_tools: Option<Vec<String>> = args["allowed_tools"].as_array().map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
});
let max_iterations = args["max_iterations"].as_u64().map(|v| v as usize);
let timeout_secs = args["timeout_secs"].as_u64();
@ -141,15 +146,21 @@ impl DelegateTool {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("unknown mode: {}. Supported: inline, background, parallel", mode_str)),
})
error: Some(format!(
"unknown mode: {}. Supported: inline, background, parallel",
mode_str
)),
});
}
};
match mode {
ExecutionMode::Inline => {
let config = self.parse_config_from_args(args)?;
let result = self.sub_agent_manager.run_inline(config).await
let result = self
.sub_agent_manager
.run_inline(config)
.await
.map_err(|e| anyhow::anyhow!("{}", e))?;
match result.status {
@ -177,10 +188,14 @@ impl DelegateTool {
}
ExecutionMode::Background => {
let config = self.parse_config_from_args(args)?;
let ctx = crate::agent::sub_agent::get_delegate_context()
.map_err(|_| anyhow::anyhow!("delegate context not available: not in an agent worker"))?;
let ctx = crate::agent::sub_agent::get_delegate_context().map_err(|_| {
anyhow::anyhow!("delegate context not available: not in an agent worker")
})?;
let task_id = self.sub_agent_manager.run_background(config, ctx).await
let task_id = self
.sub_agent_manager
.run_background(config, ctx)
.await
.map_err(|e| anyhow::anyhow!("{}", e))?;
Ok(ToolResult {
@ -200,9 +215,12 @@ impl DelegateTool {
.as_str()
.ok_or_else(|| anyhow::anyhow!("each parallel task requires 'prompt'"))?
.to_string();
let allowed_tools: Option<Vec<String>> = task["allowed_tools"]
.as_array()
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect());
let allowed_tools: Option<Vec<String>> =
task["allowed_tools"].as_array().map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
});
configs.push(SubAgentConfig {
prompt,
@ -216,13 +234,18 @@ impl DelegateTool {
let has_args_allowed = args["allowed_tools"].as_array().is_some();
for c in &mut configs {
if c.allowed_tools.is_none() && has_args_allowed {
c.allowed_tools = args["allowed_tools"]
.as_array()
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect());
c.allowed_tools = args["allowed_tools"].as_array().map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
});
}
}
let results = self.sub_agent_manager.run_parallel(configs).await
let results = self
.sub_agent_manager
.run_parallel(configs)
.await
.map_err(|e| anyhow::anyhow!("{}", e))?;
let mut output = String::new();
@ -243,7 +266,9 @@ impl DelegateTool {
}
}
let all_success = results.iter().all(|r| matches!(r.status, TaskStatus::Completed));
let all_success = results
.iter()
.all(|r| matches!(r.status, TaskStatus::Completed));
Ok(ToolResult {
success: all_success,
output: output.trim().to_string(),

View File

@ -243,8 +243,8 @@ impl Tool for FileEditTool {
#[cfg(test)]
mod tests {
use super::*;
use tempfile::NamedTempFile;
use std::io::Write;
use tempfile::NamedTempFile;
#[tokio::test]
async fn test_edit_simple() {

View File

@ -181,10 +181,7 @@ impl Tool for FileReadTool {
}
result = lines[..end_idx].join("\n");
let truncated = original_len - result.len();
result.push_str(&format!(
"\n\n... ({} chars truncated) ...",
truncated
));
result.push_str(&format!("\n\n... ({} chars truncated) ...", truncated));
}
if end < total {
@ -196,10 +193,7 @@ impl Tool for FileReadTool {
end + 1
));
} else {
result.push_str(&format!(
"\n\n(End of file — {} lines total)",
total
));
result.push_str(&format!("\n\n(End of file — {} lines total)", total));
}
if let Some(label) = encoding_label {
@ -214,7 +208,7 @@ impl Tool for FileReadTool {
}
None => {
// Truly binary file — base64 encode
use base64::{engine::general_purpose::STANDARD, Engine};
use base64::{Engine, engine::general_purpose::STANDARD};
let encoded = STANDARD.encode(&bytes);
let mime = mime_guess::from_path(&resolved)
.first_or_octet_stream()
@ -278,8 +272,8 @@ fn decode_text(bytes: &[u8]) -> (Option<String>, Option<&'static str>) {
#[cfg(test)]
mod tests {
use super::*;
use tempfile::NamedTempFile;
use std::io::Write;
use tempfile::NamedTempFile;
#[tokio::test]
async fn test_read_simple_file() {
@ -338,10 +332,7 @@ mod tests {
#[tokio::test]
async fn test_is_directory() {
let tool = FileReadTool::new();
let result = tool
.execute(json!({ "path": "." }))
.await
.unwrap();
let result = tool.execute(json!({ "path": "." })).await.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("Not a file"));

View File

@ -101,17 +101,29 @@ impl Tool for FileSearchTool {
};
let dir = self.resolve_dir(args.get("dir").and_then(|v| v.as_str()));
let case_sensitive = args.get("case_sensitive").and_then(|v| v.as_bool()).unwrap_or(true);
let max_results = args.get("max_results").and_then(|v| v.as_u64()).unwrap_or(MAX_RESULTS as u64) as usize;
let case_sensitive = args
.get("case_sensitive")
.and_then(|v| v.as_bool())
.unwrap_or(true);
let max_results = args
.get("max_results")
.and_then(|v| v.as_u64())
.unwrap_or(MAX_RESULTS as u64) as usize;
let result = self.run_search(pattern, &dir, case_sensitive, max_results).await;
let result = self
.run_search(pattern, &dir, case_sensitive, max_results)
.await;
match result {
Ok(lines) => {
let count = lines.len();
let mut output = self.truncate_output(&lines);
output.push_str(&format!("\n\n---\n{} 个文件", count));
Ok(ToolResult { success: true, output, error: None })
Ok(ToolResult {
success: true,
output,
error: None,
})
}
Err(e) => Ok(ToolResult {
success: false,
@ -139,9 +151,12 @@ impl FileSearchTool {
};
if !fd_cmd.is_empty() {
match self.search_with_fd(pattern, dir, case_sensitive, max_results, fd_cmd).await {
match self
.search_with_fd(pattern, dir, case_sensitive, max_results, fd_cmd)
.await
{
Ok(lines) if !lines.is_empty() => return Ok(lines),
Ok(_) => {},
Ok(_) => {}
Err(e) => tracing::warn!("{} failed: {}, falling back", fd_cmd, e),
}
}
@ -149,13 +164,14 @@ impl FileSearchTool {
if which::which("find").is_ok() {
match self.search_with_find(pattern, dir, max_results).await {
Ok(lines) if !lines.is_empty() => return Ok(lines),
Ok(_) => {},
Ok(_) => {}
Err(e) => tracing::warn!("find failed: {}, falling back", e),
}
}
tracing::warn!("No fd/find available, using built-in file search (slower)");
self.search_with_rust(pattern, dir, case_sensitive, max_results).await
self.search_with_rust(pattern, dir, case_sensitive, max_results)
.await
}
async fn search_with_fd(
@ -167,11 +183,15 @@ impl FileSearchTool {
fd_cmd: &str,
) -> anyhow::Result<Vec<String>> {
let mut cmd = Command::new(fd_cmd);
cmd.arg("--search-path").arg(dir)
.arg("--glob").arg(pattern)
.arg("--color").arg("never")
cmd.arg("--search-path")
.arg(dir)
.arg("--glob")
.arg(pattern)
.arg("--color")
.arg("never")
.arg("--strip-cwd-prefix")
.arg("--max-results").arg(max_results.to_string())
.arg("--max-results")
.arg(max_results.to_string())
.stdout(Stdio::piped())
.stderr(Stdio::piped());
@ -179,12 +199,9 @@ impl FileSearchTool {
cmd.arg("--ignore-case");
}
let output = timeout(
std::time::Duration::from_secs(TIMEOUT_SECS),
cmd.output(),
)
.await
.map_err(|_| anyhow::anyhow!("fd timed out after {}s", TIMEOUT_SECS))??;
let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
.await
.map_err(|_| anyhow::anyhow!("fd timed out after {}s", TIMEOUT_SECS))??;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
@ -192,7 +209,8 @@ impl FileSearchTool {
}
let text = String::from_utf8_lossy(&output.stdout);
let lines: Vec<String> = text.lines()
let lines: Vec<String> = text
.lines()
.filter(|l| !l.is_empty())
.map(|l| l.to_string())
.collect();
@ -215,15 +233,13 @@ impl FileSearchTool {
.stdout(Stdio::piped())
.stderr(Stdio::null());
let output = timeout(
std::time::Duration::from_secs(TIMEOUT_SECS),
cmd.output(),
)
.await
.map_err(|_| anyhow::anyhow!("find timed out after {}s", TIMEOUT_SECS))??;
let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
.await
.map_err(|_| anyhow::anyhow!("find timed out after {}s", TIMEOUT_SECS))??;
let text = String::from_utf8_lossy(&output.stdout);
let mut lines: Vec<String> = text.lines()
let mut lines: Vec<String> = text
.lines()
.filter(|l| !l.is_empty())
.map(|l| {
let p = Path::new(l);
@ -254,7 +270,13 @@ impl FileSearchTool {
.map_err(|e| anyhow::anyhow!("Invalid glob pattern '{}': {}", pattern, e))?;
let mut results = Vec::new();
walk_dir(Path::new(dir), Path::new(dir), &re, &mut results, max_results)?;
walk_dir(
Path::new(dir),
Path::new(dir),
&re,
&mut results,
max_results,
)?;
Ok(results)
}
}
@ -311,15 +333,18 @@ fn walk_dir(
if path.is_dir() {
if let Some(name) = rel.file_name().and_then(|n| n.to_str())
&& name.starts_with('.') && name.len() > 1 {
continue;
}
&& name.starts_with('.')
&& name.len() > 1
{
continue;
}
walk_dir(base, &path, re, results, max)?;
} else if path.is_file() {
if let Some(name) = rel.file_name().and_then(|n| n.to_str())
&& re.is_match(name) {
results.push(rel.to_string_lossy().to_string());
}
&& re.is_match(name)
{
results.push(rel.to_string_lossy().to_string());
}
if results.len() >= max {
return Ok(());
}

View File

@ -90,13 +90,14 @@ impl Tool for FileWriteTool {
// Create parent directories if needed
if let Some(parent) = resolved.parent()
&& !parent.exists()
&& let Err(e) = std::fs::create_dir_all(parent) {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Failed to create parent directory: {}", e)),
});
}
&& let Err(e) = std::fs::create_dir_all(parent)
{
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Failed to create parent directory: {}", e)),
});
}
match std::fs::write(&resolved, content) {
Ok(_) => Ok(ToolResult {
@ -168,10 +169,7 @@ mod tests {
#[tokio::test]
async fn test_write_missing_path() {
let tool = FileWriteTool::new();
let result = tool
.execute(json!({ "content": "Hello" }))
.await
.unwrap();
let result = tool.execute(json!({ "content": "Hello" })).await.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("path"));

View File

@ -129,7 +129,9 @@ impl GetSkillTool {
let mut output = format!("可用 skill (共 {} 个):\n", skills.len());
for s in &skills {
let always_mark = if s.always { " [常驻]" } else { "" };
let path_str = s.path.as_ref()
let path_str = s
.path
.as_ref()
.map(|p| p.to_string_lossy().to_string())
.unwrap_or_else(|| "".to_string());
output.push_str(&format!(
@ -148,10 +150,10 @@ impl GetSkillTool {
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
use std::fs::File;
use std::io::Write;
use std::path::PathBuf;
use tempfile::tempdir;
#[tokio::test]
async fn test_get_existing_skill() {

View File

@ -50,10 +50,7 @@ impl HttpRequestTool {
}
if !host_matches_allowlist(&host, &self.allowed_domains) {
return Err(format!(
"Host '{}' is not in allowed_domains",
host
));
return Err(format!("Host '{}' is not in allowed_domains", host));
}
Ok(url.to_string())
@ -80,11 +77,10 @@ impl HttpRequestTool {
for (key, value) in obj {
if let Some(str_val) = value.as_str()
&& let Ok(name) = reqwest::header::HeaderName::from_bytes(key.as_bytes())
&& let Ok(val) =
reqwest::header::HeaderValue::from_str(str_val)
{
header_map.insert(name, val);
}
&& let Ok(val) = reqwest::header::HeaderValue::from_str(str_val)
{
header_map.insert(name, val);
}
}
}
@ -191,7 +187,9 @@ fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool {
allowed_domains.iter().any(|domain| {
host == domain
|| host.strip_suffix(domain).is_some_and(|prefix| prefix.ends_with('.'))
|| host
.strip_suffix(domain)
.is_some_and(|prefix| prefix.ends_with('.'))
})
}
@ -202,7 +200,11 @@ fn is_private_host(host: &str) -> bool {
}
// Check .local TLD
if host.rsplit('.').next().is_some_and(|label| label == "local") {
if host
.rsplit('.')
.next()
.is_some_and(|label| label == "local")
{
return true;
}
@ -224,9 +226,7 @@ fn is_private_ip(ip: &std::net::IpAddr) -> bool {
|| v4.is_broadcast()
|| v4.is_multicast()
}
std::net::IpAddr::V6(v6) => {
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
}
std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(),
}
}
@ -278,10 +278,7 @@ impl Tool for HttpRequestTool {
}
};
let method_str = args
.get("method")
.and_then(|v| v.as_str())
.unwrap_or("GET");
let method_str = args.get("method").and_then(|v| v.as_str()).unwrap_or("GET");
let headers_val = args.get("headers").cloned().unwrap_or(json!({}));
let body = args.get("body").and_then(|v| v.as_str());

View File

@ -151,10 +151,19 @@ impl Tool for MemoryRecallTool {
.and_then(|v| v.as_i64())
.unwrap_or(chrono::Utc::now().timestamp_millis());
self.memory
.recall_by_time(since, until, Some(query), limit, Some(MemoryCategory::Knowledge), None)
.recall_by_time(
since,
until,
Some(query),
limit,
Some(MemoryCategory::Knowledge),
None,
)
.await?
} else {
self.memory.recall(query, limit, Some(MemoryCategory::Knowledge), None).await?
self.memory
.recall(query, limit, Some(MemoryCategory::Knowledge), None)
.await?
};
if entries.is_empty() {
@ -168,7 +177,11 @@ impl Tool for MemoryRecallTool {
let formatted = entries
.iter()
.map(|e| {
let session = e.session_id.as_deref().map(|s| format!(" [session: {}]", s)).unwrap_or_default();
let session = e
.session_id
.as_deref()
.map(|s| format!(" [session: {}]", s))
.unwrap_or_default();
format!(
"- {} [{}]{} [importance: {:.1}]: {}",
e.key,
@ -264,10 +277,19 @@ impl Tool for TimelineRecallTool {
.and_then(|v| v.as_i64())
.unwrap_or(chrono::Utc::now().timestamp_millis());
self.memory
.recall_by_time(since, until, Some(query), limit, Some(MemoryCategory::Timeline), session_id)
.recall_by_time(
since,
until,
Some(query),
limit,
Some(MemoryCategory::Timeline),
session_id,
)
.await?
} else {
self.memory.recall(query, limit, Some(MemoryCategory::Timeline), session_id).await?
self.memory
.recall(query, limit, Some(MemoryCategory::Timeline), session_id)
.await?
};
if entries.is_empty() {
@ -281,7 +303,11 @@ impl Tool for TimelineRecallTool {
let formatted = entries
.iter()
.map(|e| {
let session = e.session_id.as_deref().map(|s| format!(" [session: {}]", s)).unwrap_or_default();
let session = e
.session_id
.as_deref()
.map(|s| format!(" [session: {}]", s))
.unwrap_or_default();
format!(
"- {} [{}]{} [importance: {:.1}]: {}",
e.key,

View File

@ -37,11 +37,11 @@ pub use send_message::SendMessageTool;
pub use traits::{OutboundMessenger, Tool, ToolResult};
pub use web_fetch::WebFetchTool;
use std::sync::Arc;
use crate::agent::SubAgentManager;
use crate::config::BrowserConfig;
use crate::memory::MemoryManager;
use crate::skills::SkillsLoader;
use std::sync::Arc;
/// Create the base tool registry (without send_message).
/// `send_message` tool is registered later via `SessionManager::register_outbound_tool()`

View File

@ -17,7 +17,10 @@ impl ToolRegistry {
}
pub fn register<T: ToolTrait + 'static>(&self, tool: T) {
self.tools.lock().unwrap().insert(tool.name().to_string(), Arc::new(tool));
self.tools
.lock()
.unwrap()
.insert(tool.name().to_string(), Arc::new(tool));
}
/// Register an existing Arc-wrapped tool by name

View File

@ -115,9 +115,11 @@ impl SchemaCleanr {
}
if let Some(Value::String(t)) = obj.get("type")
&& t == "object" && !obj.contains_key("properties") {
tracing::warn!("Object schema without 'properties' field may cause issues");
}
&& t == "object"
&& !obj.contains_key("properties")
{
tracing::warn!("Object schema without 'properties' field may cause issues");
}
Ok(())
}
@ -173,9 +175,10 @@ impl SchemaCleanr {
// Handle anyOf/oneOf simplification
if (obj.contains_key("anyOf") || obj.contains_key("oneOf"))
&& let Some(simplified) = Self::try_simplify_union(&obj, defs, strategy, ref_stack) {
return simplified;
}
&& let Some(simplified) = Self::try_simplify_union(&obj, defs, strategy, ref_stack)
{
return simplified;
}
// Build cleaned object
let mut cleaned = Map::new();
@ -243,12 +246,13 @@ impl SchemaCleanr {
}
if let Some(def_name) = Self::parse_local_ref(ref_value)
&& let Some(definition) = defs.get(def_name.as_str()) {
ref_stack.insert(ref_value.to_string());
let cleaned = Self::clean_with_defs(definition.clone(), defs, strategy, ref_stack);
ref_stack.remove(ref_value);
return Self::preserve_meta(obj, cleaned);
}
&& let Some(definition) = defs.get(def_name.as_str())
{
ref_stack.insert(ref_value.to_string());
let cleaned = Self::clean_with_defs(definition.clone(), defs, strategy, ref_stack);
ref_stack.remove(ref_value);
return Self::preserve_meta(obj, cleaned);
}
tracing::warn!("Cannot resolve $ref: {}", ref_value);
Self::preserve_meta(obj, Value::Object(Map::new()))
@ -340,13 +344,16 @@ impl SchemaCleanr {
return true;
}
if let Some(Value::Array(arr)) = obj.get("enum")
&& arr.len() == 1 && matches!(arr[0], Value::Null) {
return true;
}
&& arr.len() == 1
&& matches!(arr[0], Value::Null)
{
return true;
}
if let Some(Value::String(t)) = obj.get("type")
&& t == "null" {
return true;
}
&& t == "null"
{
return true;
}
}
false
}
@ -403,7 +410,10 @@ impl SchemaCleanr {
match non_null.len() {
0 => Value::String("null".to_string()),
1 => non_null.into_iter().next().unwrap_or(Value::String("null".to_string())),
1 => non_null
.into_iter()
.next()
.unwrap_or(Value::String("null".to_string())),
_ => Value::Array(non_null),
}
} else {

View File

@ -1,5 +1,5 @@
use std::sync::Arc;
use std::collections::HashSet;
use std::sync::Arc;
use async_trait::async_trait;
use mime_guess::mime;
@ -31,14 +31,20 @@ fn parse_target_chat_id(raw: &str) -> Result<(&str, &str, Option<&str>), String>
match parts.len() {
2 => {
if parts[0].is_empty() || parts[1].is_empty() {
Err(format!("Invalid target_chat_id format '{}': channel and chat_id must not be empty", raw))
Err(format!(
"Invalid target_chat_id format '{}': channel and chat_id must not be empty",
raw
))
} else {
Ok((parts[0], parts[1], None))
}
}
3 => {
if parts[0].is_empty() || parts[1].is_empty() || parts[2].is_empty() {
Err(format!("Invalid target_chat_id format '{}': all three parts must not be empty", raw))
Err(format!(
"Invalid target_chat_id format '{}': all three parts must not be empty",
raw
))
} else {
Ok((parts[0], parts[1], Some(parts[2])))
}
@ -98,8 +104,8 @@ target_chat_id 支持两种格式:<channel>:<chat_id>(发送到该聊天下
.ok_or_else(|| anyhow::anyhow!("missing content"))?;
// 1. Parse target_chat_id
let (channel, chat_id, dialog_id) = parse_target_chat_id(raw_id)
.map_err(|e| anyhow::anyhow!(e))?;
let (channel, chat_id, dialog_id) =
parse_target_chat_id(raw_id).map_err(|e| anyhow::anyhow!(e))?;
// 2. Validate channel
if !self.available_channels.contains(channel) {
@ -109,7 +115,11 @@ target_chat_id 支持两种格式:<channel>:<chat_id>(发送到该聊天下
error: Some(format!(
"Channel '{}' is not available. Available channels: {}",
channel,
self.available_channels.iter().cloned().collect::<Vec<_>>().join(", ")
self.available_channels
.iter()
.cloned()
.collect::<Vec<_>>()
.join(", ")
)),
});
}
@ -129,7 +139,8 @@ target_chat_id 支持两种格式:<channel>:<chat_id>(发送到该聊天下
let media = parse_files_arg(&args);
// 4. Send via messenger
match self.messenger
match self
.messenger
.send_message(channel, chat_id, dialog_id, content, source, media)
.await
{

View File

@ -1,5 +1,5 @@
use async_trait::async_trait;
use crate::bus::{MediaItem, MessageSource};
use async_trait::async_trait;
#[derive(Debug, Clone)]
pub struct ToolResult {

View File

@ -239,7 +239,11 @@ fn is_private_host(host: &str) -> bool {
return true;
}
if host.rsplit('.').next().is_some_and(|label| label == "local") {
if host
.rsplit('.')
.next()
.is_some_and(|label| label == "local")
{
return true;
}
@ -248,7 +252,9 @@ fn is_private_host(host: &str) -> bool {
std::net::IpAddr::V4(v4) => {
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified()
}
std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(),
std::net::IpAddr::V6(v6) => {
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
}
};
}

View File

@ -1,6 +1,6 @@
use std::collections::HashMap;
use picobot::providers::{create_provider, ChatCompletionRequest, Message};
use picobot::config::{Config, LLMProviderConfig};
use picobot::providers::{ChatCompletionRequest, Message, create_provider};
use std::collections::HashMap;
fn load_config() -> Option<LLMProviderConfig> {
dotenv::from_filename("tests/test.env").ok()?;
@ -42,8 +42,7 @@ fn create_request(content: &str) -> ChatCompletionRequest {
#[tokio::test]
#[ignore]
async fn test_openai_simple_completion() {
let config = load_config()
.expect("Please configure tests/test.env with valid API keys");
let config = load_config().expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider");
let response = provider.chat(create_request("Say 'ok'")).await.unwrap();
@ -57,8 +56,7 @@ async fn test_openai_simple_completion() {
#[tokio::test]
#[ignore]
async fn test_openai_conversation() {
let config = load_config()
.expect("Please configure tests/test.env with valid API keys");
let config = load_config().expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider");
@ -82,7 +80,9 @@ async fn test_openai_conversation() {
async fn test_config_load() {
// Test that config.json can be loaded and provider config created
let config = Config::load("config.json").expect("Failed to load config.json");
let provider_config = config.get_provider_config("default").expect("Failed to get provider config");
let provider_config = config
.get_provider_config("default")
.expect("Failed to get provider config");
assert_eq!(provider_config.provider_type, "openai");
assert_eq!(provider_config.name, "aliyun");

View File

@ -41,7 +41,7 @@ async fn test_scheduler_types_roundtrip() {
/// Verify that next_run_for_schedule produces valid future timestamps.
#[test]
fn test_next_run_always_future() {
use picobot::scheduler::{next_run_for_schedule, Schedule};
use picobot::scheduler::{Schedule, next_run_for_schedule};
let now = 1700000000000_i64;
@ -56,6 +56,10 @@ fn test_next_run_always_future() {
for s in &schedules {
let next = next_run_for_schedule(s, now);
assert!(next.is_some(), "expected next run for {:?}", s);
assert!(next.unwrap() > now, "next run should be after now for {:?}", s);
assert!(
next.unwrap() > now,
"next run should be after now for {:?}",
s
);
}
}

View File

@ -1,6 +1,6 @@
use std::collections::HashMap;
use picobot::providers::{create_provider, ChatCompletionRequest, Message, Tool, ToolFunction};
use picobot::config::LLMProviderConfig;
use picobot::providers::{ChatCompletionRequest, Message, Tool, ToolFunction, create_provider};
use std::collections::HashMap;
fn load_openai_config() -> Option<LLMProviderConfig> {
dotenv::from_filename("tests/test.env").ok()?;
@ -53,8 +53,7 @@ fn make_weather_tool() -> Tool {
#[tokio::test]
#[ignore]
async fn test_openai_tool_call() {
let config = load_openai_config()
.expect("Please configure tests/test.env with valid API keys");
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider");
@ -68,7 +67,11 @@ async fn test_openai_tool_call() {
let response = provider.chat(request).await.unwrap();
// Should have tool calls
assert!(!response.tool_calls.is_empty(), "Expected tool call, got: {}", response.content);
assert!(
!response.tool_calls.is_empty(),
"Expected tool call, got: {}",
response.content
);
let tool_call = &response.tool_calls[0];
assert_eq!(tool_call.name, "get_weather");
@ -78,8 +81,7 @@ async fn test_openai_tool_call() {
#[tokio::test]
#[ignore]
async fn test_openai_tool_call_with_manual_execution() {
let config = load_openai_config()
.expect("Please configure tests/test.env with valid API keys");
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider");
@ -92,8 +94,7 @@ async fn test_openai_tool_call_with_manual_execution() {
};
let response1 = provider.chat(request1).await.unwrap();
let tool_call = response1.tool_calls.first()
.expect("Expected tool call");
let tool_call = response1.tool_calls.first().expect("Expected tool call");
assert_eq!(tool_call.name, "get_weather");
// Second request with tool result
@ -116,8 +117,7 @@ async fn test_openai_tool_call_with_manual_execution() {
#[tokio::test]
#[ignore]
async fn test_openai_no_tool_when_not_provided() {
let config = load_openai_config()
.expect("Please configure tests/test.env with valid API keys");
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider");