Format codebase with rustfmt

This commit is contained in:
xiaoski 2026-06-15 23:47:24 +08:00
parent c6f4392e63
commit 8f4ee79d8d
65 changed files with 1807 additions and 1061 deletions

View File

@ -4,10 +4,8 @@ use crate::agent::system_prompt::build_system_prompt;
use crate::bus::message::ContentBlock; use crate::bus::message::ContentBlock;
use crate::bus::{ChatMessage, MediaRef}; use crate::bus::{ChatMessage, MediaRef};
use crate::config::LLMProviderConfig; use crate::config::LLMProviderConfig;
use crate::observability::{ use crate::observability::{Observer, ObserverEvent, ToolExecutionOutcome, truncate_args};
truncate_args, Observer, ObserverEvent, ToolExecutionOutcome, use crate::providers::{ChatCompletionRequest, LLMProvider, Message, ToolCall, create_provider};
};
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall};
use crate::tools::ToolRegistry; use crate::tools::ToolRegistry;
use std::collections::VecDeque; use std::collections::VecDeque;
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
@ -256,7 +254,10 @@ impl AgentLoop {
} }
/// Create a new AgentLoop with provider created from config and given tools. /// Create a new AgentLoop with provider created from config and given tools.
pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>) -> Result<Self, AgentError> { pub fn with_tools(
provider_config: LLMProviderConfig,
tools: Arc<ToolRegistry>,
) -> Result<Self, AgentError> {
let max_iterations = provider_config.max_tool_iterations; let max_iterations = provider_config.max_tool_iterations;
let model_name = provider_config.model_id.clone(); let model_name = provider_config.model_id.clone();
let workspace_dir = provider_config.workspace_dir.clone(); let workspace_dir = provider_config.workspace_dir.clone();
@ -279,7 +280,13 @@ impl AgentLoop {
} }
/// Create a new AgentLoop with an existing shared provider. /// Create a new AgentLoop with an existing shared provider.
pub fn with_provider(provider: Arc<dyn LLMProvider>, max_iterations: usize, model_name: String, workspace_dir: PathBuf, input_types: Vec<String>) -> Self { pub fn with_provider(
provider: Arc<dyn LLMProvider>,
max_iterations: usize,
model_name: String,
workspace_dir: PathBuf,
input_types: Vec<String>,
) -> Self {
Self { Self {
provider, provider,
tools: Arc::new(ToolRegistry::new()), tools: Arc::new(ToolRegistry::new()),
@ -379,7 +386,12 @@ impl AgentLoop {
let content = if m.media_refs.is_empty() { let content = if m.media_refs.is_empty() {
vec![ContentBlock::text(&m.content)] vec![ContentBlock::text(&m.content)]
} else { } else {
build_content_blocks(&m.content, &m.media_refs, &self.input_types, &self.media_registry) build_content_blocks(
&m.content,
&m.media_refs,
&self.input_types,
&self.media_registry,
)
}; };
Message { Message {
@ -399,14 +411,28 @@ impl AgentLoop {
/// it loops back to the LLM with the tool results until either: /// it loops back to the LLM with the tool results until either:
/// - The LLM returns no more tool calls (final response) /// - The LLM returns no more tool calls (final response)
/// - Maximum iterations are reached /// - Maximum iterations are reached
pub async fn process(&self, mut messages: Vec<ChatMessage>) -> Result<AgentProcessResult, AgentError> { pub async fn process(
&self,
mut messages: Vec<ChatMessage>,
) -> Result<AgentProcessResult, AgentError> {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
tracing::debug!(history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process"); tracing::debug!(
history_len = messages.len(),
max_iterations = self.max_iterations,
"Starting agent process"
);
// Build and inject system prompt if not present // Build and inject system prompt if not present
let has_system = messages.first().is_some_and(|m| m.role == "system"); let has_system = messages.first().is_some_and(|m| m.role == "system");
if !has_system { if !has_system {
let system_prompt = build_system_prompt(&self.workspace_dir, &self.model_name, &self.tools, None, None, false); let system_prompt = build_system_prompt(
&self.workspace_dir,
&self.model_name,
&self.tools,
None,
None,
false,
);
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
tracing::debug!("System prompt injected:\n{}", system_prompt); tracing::debug!("System prompt injected:\n{}", system_prompt);
messages.insert(0, ChatMessage::system(system_prompt)); messages.insert(0, ChatMessage::system(system_prompt));
@ -427,9 +453,7 @@ impl AgentLoop {
let estimated = estimate_tokens(&messages); let estimated = estimate_tokens(&messages);
let danger = (self.context_window as f64 * 0.8) as usize; let danger = (self.context_window as f64 * 0.8) as usize;
if estimated > danger { if estimated > danger {
let trimmed = self.preemptive_trim_old_tool_results( let trimmed = self.preemptive_trim_old_tool_results(&mut messages, 2000, 4);
&mut messages, 2000, 4,
);
if trimmed > 0 { if trimmed > 0 {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
tracing::debug!( tracing::debug!(
@ -463,11 +487,10 @@ impl AgentLoop {
}; };
// Call LLM // Call LLM
let response = (*self.provider).chat(request).await let response = (*self.provider).chat(request).await.map_err(|e| {
.map_err(|e| { tracing::error!(error = %e, "LLM request failed");
tracing::error!(error = %e, "LLM request failed"); AgentError::LlmError(e.to_string())
AgentError::LlmError(e.to_string()) })?;
})?;
accumulated_tokens += response.usage.total_tokens; accumulated_tokens += response.usage.total_tokens;
@ -493,7 +516,9 @@ impl AgentLoop {
// Execute tool calls — log and notify immediately // Execute tool calls — log and notify immediately
{ {
let tools_info: Vec<String> = response.tool_calls.iter() let tools_info: Vec<String> = response
.tool_calls
.iter()
.map(|tc| { .map(|tc| {
let args = serde_json::to_string(&tc.arguments).unwrap_or_default(); let args = serde_json::to_string(&tc.arguments).unwrap_or_default();
let s = format!("{}:{}", tc.name, args); let s = format!("{}:{}", tc.name, args);
@ -522,7 +547,9 @@ impl AgentLoop {
// Log function call with name and arguments // Log function call with name and arguments
let args_str = match &tool_call.arguments { let args_str = match &tool_call.arguments {
serde_json::Value::Object(obj) if obj.is_empty() => "{}".to_string(), serde_json::Value::Object(obj) if obj.is_empty() => "{}".to_string(),
other => serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()), other => {
serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string())
}
}; };
tracing::info!(tool = %tool_call.name, args = %args_str, "Calling tool"); tracing::info!(tool = %tool_call.name, args = %args_str, "Calling tool");
@ -562,7 +589,11 @@ impl AgentLoop {
// Loop continues to next iteration with updated messages // Loop continues to next iteration with updated messages
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
tracing::debug!(iteration, message_count = messages.len(), "Tool execution complete, continuing to next iteration"); tracing::debug!(
iteration,
message_count = messages.len(),
"Tool execution complete, continuing to next iteration"
);
} }
// Max iterations reached - ask LLM for a summary based on completed work // Max iterations reached - ask LLM for a summary based on completed work
@ -571,7 +602,7 @@ impl AgentLoop {
// Add a message asking for summary // Add a message asking for summary
let summary_request = ChatMessage::user( let summary_request = ChatMessage::user(
"You have reached the maximum number of tool call iterations. \ "You have reached the maximum number of tool call iterations. \
Please provide your best answer based on the work completed so far." Please provide your best answer based on the work completed so far.",
); );
messages.push(summary_request); messages.push(summary_request);
@ -603,14 +634,19 @@ impl AgentLoop {
Err(e) => { Err(e) => {
// Fallback if summary call fails // Fallback if summary call fails
tracing::error!(error = %e, "Failed to get summary from LLM"); tracing::error!(error = %e, "Failed to get summary from LLM");
let final_message = ChatMessage::assistant( let final_message = ChatMessage::assistant(format!(
format!("I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.", self.max_iterations) "I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.",
); self.max_iterations
));
emitted_messages.push(final_message.clone()); emitted_messages.push(final_message.clone());
Ok(AgentProcessResult { Ok(AgentProcessResult {
final_response: final_message, final_response: final_message,
emitted_messages, emitted_messages,
total_tokens: if accumulated_tokens > 0 { Some(accumulated_tokens) } else { None }, total_tokens: if accumulated_tokens > 0 {
Some(accumulated_tokens)
} else {
None
},
}) })
} }
} }
@ -698,10 +734,7 @@ impl AgentLoop {
} }
// Apply duration // Apply duration
ToolExecutionOutcome { ToolExecutionOutcome { duration, ..result }
duration,
..result
}
} }
/// Internal tool execution without event tracking. /// Internal tool execution without event tracking.
@ -723,18 +756,12 @@ impl AgentLoop {
ToolExecutionOutcome::success(result.output) ToolExecutionOutcome::success(result.output)
} else { } else {
let error = result.error.unwrap_or_default(); let error = result.error.unwrap_or_default();
ToolExecutionOutcome::failure( ToolExecutionOutcome::failure(format!("Error: {}", error), Some(error))
format!("Error: {}", error),
Some(error),
)
} }
} }
Err(e) => { Err(e) => {
tracing::error!(tool = %tool_call.name, error = %e, "Tool execution failed"); tracing::error!(tool = %tool_call.name, error = %e, "Tool execution failed");
ToolExecutionOutcome::failure( ToolExecutionOutcome::failure(format!("Error: {}", e), Some(e.to_string()))
format!("Error: {}", e),
Some(e.to_string()),
)
} }
} }
} }
@ -822,8 +849,14 @@ mod tests {
assert_eq!(provider_message.role, "assistant"); assert_eq!(provider_message.role, "assistant");
assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1); assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1);
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].id, "call_1"); assert_eq!(
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].name, "calculator"); provider_message.tool_calls.as_ref().unwrap()[0].id,
"call_1"
);
assert_eq!(
provider_message.tool_calls.as_ref().unwrap()[0].name,
"calculator"
);
} }
} }

View File

@ -234,27 +234,33 @@ impl ContextCompressor {
} }
} else if messages[i].role == "tool" } else if messages[i].role == "tool"
&& let Some(ref tid) = messages[i].tool_call_id && let Some(ref tid) = messages[i].tool_call_id
&& !declared.contains(tid.as_str()) { && !declared.contains(tid.as_str())
messages.remove(i); {
continue; messages.remove(i);
} continue;
}
i += 1; i += 1;
} }
let broken: Vec<usize> = messages.iter().enumerate() let broken: Vec<usize> = messages
.iter()
.enumerate()
.filter_map(|(idx, msg)| { .filter_map(|(idx, msg)| {
if msg.role == "assistant" if msg.role == "assistant"
&& let Some(ref tcs) = msg.tool_calls && let Some(ref tcs) = msg.tool_calls
&& !tcs.is_empty() { && !tcs.is_empty()
let all_present = tcs.iter().all(|tc| { {
messages.iter().any(|m| { let all_present = tcs.iter().all(|tc| {
m.role == "tool" messages.iter().any(|m| {
&& m.tool_call_id.as_deref() == Some(tc.id.as_str()) m.role == "tool" && m.tool_call_id.as_deref() == Some(tc.id.as_str())
}) })
}); });
if !all_present { Some(idx) } else { None } if !all_present { Some(idx) } else { None }
} else { None } } else {
}).collect(); None
}
})
.collect();
for idx in broken { for idx in broken {
let msg = &mut messages[idx]; let msg = &mut messages[idx];
@ -262,7 +268,8 @@ impl ContextCompressor {
let names: Vec<&str> = tcs.iter().map(|tc| tc.name.as_str()).collect(); let names: Vec<&str> = tcs.iter().map(|tc| tc.name.as_str()).collect();
msg.content = format!( msg.content = format!(
"{}\n\n[Tool calls ({}) — results are no longer available]", "{}\n\n[Tool calls ({}) — results are no longer available]",
msg.content, names.join(", ") msg.content,
names.join(", ")
); );
} }
} }
@ -275,7 +282,10 @@ impl ContextCompressor {
// Check if compression is needed // Check if compression is needed
let tokens = self.token_estimate_with_history(&history); let tokens = self.token_estimate_with_history(&history);
if tokens <= self.threshold() { if tokens <= self.threshold() {
return Ok(CompressionResult { history, created_timelines: false }); return Ok(CompressionResult {
history,
created_timelines: false,
});
} }
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
@ -299,7 +309,10 @@ impl ContextCompressor {
} }
if tokens_after <= self.threshold() { if tokens_after <= self.threshold() {
self.invalidate_token_cache(); self.invalidate_token_cache();
return Ok(CompressionResult { history, created_timelines: false }); return Ok(CompressionResult {
history,
created_timelines: false,
});
} }
// LLM summarization pass // LLM summarization pass
@ -312,11 +325,7 @@ impl ContextCompressor {
} }
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
tracing::debug!( tracing::debug!(pass = pass + 1, tokens = tokens, "Compression pass");
pass = pass + 1,
tokens = tokens,
"Compression pass"
);
match self.compress_once(&current_history).await { match self.compress_once(&current_history).await {
Ok(Some(compressed)) => { Ok(Some(compressed)) => {
@ -352,18 +361,24 @@ impl ContextCompressor {
let m = &current_history[scan]; let m = &current_history[scan];
if m.role == "assistant" { if m.role == "assistant" {
if let Some(tcs) = &m.tool_calls if let Some(tcs) = &m.tool_calls
&& !tcs.is_empty() { && !tcs.is_empty()
let has_post = current_history[scan + 1..] {
.iter() let has_post = current_history[scan + 1..]
.filter(|r| r.role == "tool") .iter()
.any(|r| tcs.iter().any(|tc| r.tool_call_id.as_deref() == Some(tc.id.as_str()))); .filter(|r| r.role == "tool")
if has_post { .any(|r| {
tail_start = scan; tcs.iter()
} .any(|tc| r.tool_call_id.as_deref() == Some(tc.id.as_str()))
});
if has_post {
tail_start = scan;
} }
}
break;
}
if scan == 0 {
break; break;
} }
if scan == 0 { break; }
scan -= 1; scan -= 1;
} }
} }
@ -390,14 +405,16 @@ impl ContextCompressor {
for msg in &mut truncated[..self.config.protect_first_n] { for msg in &mut truncated[..self.config.protect_first_n] {
if msg.role == "assistant" { if msg.role == "assistant" {
if let Some(ref tcs) = msg.tool_calls if let Some(ref tcs) = msg.tool_calls
&& !tcs.is_empty() { && !tcs.is_empty()
let names: Vec<&str> = tcs.iter().map(|tc| tc.name.as_str()).collect(); {
msg.content = format!( let names: Vec<&str> = tcs.iter().map(|tc| tc.name.as_str()).collect();
"{}\n\n[Tool calls ({}) — results dropped during truncation]", msg.content = format!(
msg.content, names.join(", ") "{}\n\n[Tool calls ({}) — results dropped during truncation]",
); msg.content,
msg.tool_calls = None; names.join(", ")
} );
msg.tool_calls = None;
}
} }
} }
@ -424,7 +441,10 @@ impl ContextCompressor {
"Context compression completed" "Context compression completed"
); );
Ok(CompressionResult { history: current_history, created_timelines }) Ok(CompressionResult {
history: current_history,
created_timelines,
})
} }
/// Try to extract the actual context token limit from an LLM error message. /// Try to extract the actual context token limit from an LLM error message.
@ -447,20 +467,21 @@ impl ContextCompressor {
// Look for a number in the vicinity (up to 10 chars after marker) // Look for a number in the vicinity (up to 10 chars after marker)
if let Some(num_str) = find_number_nearby(after, 50) if let Some(num_str) = find_number_nearby(after, 50)
&& let Ok(n) = num_str.parse::<usize>() && let Ok(n) = num_str.parse::<usize>()
&& (1024..=10_000_000).contains(&n) { && (1024..=10_000_000).contains(&n)
return Some(n); {
} return Some(n);
}
} }
} }
// Also try: "XXXX token context" or "XXXX limit" // Also try: "XXXX token context" or "XXXX limit"
if let Some(num_str) = find_number_nearby(&lower, lower.len()) if let Some(num_str) = find_number_nearby(&lower, lower.len())
&& let Ok(n) = num_str.parse::<usize>() && let Ok(n) = num_str.parse::<usize>()
&& (1024..=10_000_000).contains(&n) && (1024..=10_000_000).contains(&n)
&& (lower.contains("token") || lower.contains("context") || lower.contains("limit")) && (lower.contains("token") || lower.contains("context") || lower.contains("limit"))
{ {
return Some(n); return Some(n);
} }
None None
} }
@ -509,19 +530,26 @@ impl ContextCompressor {
// Persist compressed summary as timeline memory entry // Persist compressed summary as timeline memory entry
let ts = chrono::Utc::now().format("%Y-%m-%d %H:%M").to_string(); let ts = chrono::Utc::now().format("%Y-%m-%d %H:%M").to_string();
let timeline_content = format!("[{}] Compressed {} conversation segments:\n{}", let timeline_content = format!(
ts, between.len(), summary); "[{}] Compressed {} conversation segments:\n{}",
ts,
between.len(),
summary
);
let key = format!("ctx_compressed_{}", uuid::Uuid::new_v4()); let key = format!("ctx_compressed_{}", uuid::Uuid::new_v4());
let mm = self.memory.clone(); let mm = self.memory.clone();
let sid = self.session_id.clone(); let sid = self.session_id.clone();
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = mm.store( if let Err(e) = mm
&key, .store(
&timeline_content, &key,
crate::memory::MemoryCategory::Timeline, &timeline_content,
sid.as_deref(), crate::memory::MemoryCategory::Timeline,
Some(0.3), sid.as_deref(),
).await { Some(0.3),
)
.await
{
tracing::warn!(error = %e, "Failed to store compressed context as timeline"); tracing::warn!(error = %e, "Failed to store compressed context as timeline");
} }
}); });
@ -552,10 +580,7 @@ impl ContextCompressor {
} }
/// Summarize a segment of messages using LLM. /// Summarize a segment of messages using LLM.
async fn summarize_segment( async fn summarize_segment(&self, messages: &[ChatMessage]) -> Result<String, AgentError> {
&self,
messages: &[ChatMessage],
) -> Result<String, AgentError> {
if messages.is_empty() { if messages.is_empty() {
return Ok(String::new()); return Ok(String::new());
} }
@ -569,7 +594,8 @@ impl ContextCompressor {
"tool" => "Tool", "tool" => "Tool",
_ => m.role.as_str(), _ => m.role.as_str(),
}; };
let name = m.tool_name let name = m
.tool_name
.as_ref() .as_ref()
.map(|n| format!(" ({})", n)) .map(|n| format!(" ({})", n))
.unwrap_or_default(); .unwrap_or_default();
@ -614,7 +640,10 @@ Be concise, aim for {} characters or less.
); );
let request = ChatCompletionRequest { let request = ChatCompletionRequest {
messages: vec![Message::system("You are a helpful assistant."), Message::user(&prompt)], messages: vec![
Message::system("You are a helpful assistant."),
Message::user(&prompt),
],
temperature: Some(0.3), temperature: Some(0.3),
max_tokens: Some(1000), max_tokens: Some(1000),
tools: None, tools: None,
@ -686,13 +715,23 @@ mod tests {
content: "[summarized]".into(), content: "[summarized]".into(),
reasoning_content: None, reasoning_content: None,
tool_calls: vec![], tool_calls: vec![],
usage: Usage { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }, usage: Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
},
}) })
} }
fn ptype(&self) -> &str { "mock" } fn ptype(&self) -> &str {
fn name(&self) -> &str { "mock" } "mock"
fn model_id(&self) -> &str { "mock" } }
fn name(&self) -> &str {
"mock"
}
fn model_id(&self) -> &str {
"mock"
}
} }
fn mock_summarizer() -> Arc<dyn LLMProvider> { fn mock_summarizer() -> Arc<dyn LLMProvider> {
@ -704,11 +743,13 @@ mod tests {
MM.get_or_init(|| { MM.get_or_init(|| {
let rt = tokio::runtime::Runtime::new().unwrap(); let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async { rt.block_on(async {
let tmp = std::env::temp_dir().join(format!("picobot_ctx_test_{}.db", std::process::id())); let tmp = std::env::temp_dir()
.join(format!("picobot_ctx_test_{}.db", std::process::id()));
let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap()); let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap());
Arc::new(MemoryManager::new(storage, "test".into(), "test".into())) Arc::new(MemoryManager::new(storage, "test".into(), "test".into()))
}) })
}).clone() })
.clone()
} }
#[test] #[test]
@ -724,7 +765,11 @@ mod tests {
// "Hi there!" (8) -> ceil(8/4)+4 = 2+4 = 6 // "Hi there!" (8) -> ceil(8/4)+4 = 2+4 = 6
// "How are you?" (11) -> ceil(11/4)+4 = 3+4 = 7 // "How are you?" (11) -> ceil(11/4)+4 = 3+4 = 7
// raw = 19, with 1.2x = ~23 // raw = 19, with 1.2x = ~23
assert!(tokens > 18 && tokens < 30, "Expected ~23 tokens, got {}", tokens); assert!(
tokens > 18 && tokens < 30,
"Expected ~23 tokens, got {}",
tokens
);
} }
#[test] #[test]
@ -733,7 +778,8 @@ mod tests {
tool_result_trim_chars: 50, tool_result_trim_chars: 50,
..Default::default() ..Default::default()
}; };
let compressor = ContextCompressor::with_config(mock_provider(), 100_000, config, test_memory_manager()); let compressor =
ContextCompressor::with_config(mock_provider(), 100_000, config, test_memory_manager());
let mut messages = vec![ let mut messages = vec![
ChatMessage::user("Hello"), ChatMessage::user("Hello"),
@ -774,7 +820,11 @@ mod tests {
ChatMessage::tool("call1", "bash", &"x".repeat(3000)), ChatMessage::tool("call1", "bash", &"x".repeat(3000)),
]; ];
let result = compressor.compress_if_needed(messages).await.unwrap().history; let result = compressor
.compress_if_needed(messages)
.await
.unwrap()
.history;
let tool_msg = result.iter().find(|m| m.role == "tool").unwrap(); let tool_msg = result.iter().find(|m| m.role == "tool").unwrap();
assert!( assert!(
@ -798,13 +848,14 @@ mod tests {
// - B2B (L275): last user message lost when it is the final history message // - B2B (L275): last user message lost when it is the final history message
// //
// context_window=200 → threshold=100. Large tool outputs force LLM summarization. // context_window=200 → threshold=100. Large tool outputs force LLM summarization.
let tmp = std::env::temp_dir().join(format!("picobot_ctx_boundary_{}.db", std::process::id())); let tmp =
std::env::temp_dir().join(format!("picobot_ctx_boundary_{}.db", std::process::id()));
let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap()); let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap());
let mm = Arc::new(MemoryManager::new(storage, "test".into(), "test".into())); let mm = Arc::new(MemoryManager::new(storage, "test".into(), "test".into()));
let config = ContextCompressionConfig { let config = ContextCompressionConfig {
tool_result_trim_chars: 2000, tool_result_trim_chars: 2000,
protect_first_n: 1, // system/protected → B2A: first user (after skip) duplicated protect_first_n: 1, // system/protected → B2A: first user (after skip) duplicated
protect_last_n: 2, protect_last_n: 2,
max_passes: 1, max_passes: 1,
..Default::default() ..Default::default()
@ -818,25 +869,43 @@ mod tests {
let big = "x".repeat(3000); let big = "x".repeat(3000);
let messages = vec![ let messages = vec![
ChatMessage::system("You are a helper."), // 0: protected ChatMessage::system("You are a helper."), // 0: protected
ChatMessage::user("Q1"), // 1: first user ChatMessage::user("Q1"), // 1: first user
ChatMessage::tool("t1", "bash", &big), // 2 ChatMessage::tool("t1", "bash", &big), // 2
ChatMessage::user("Q2"), // 3 ChatMessage::user("Q2"), // 3
ChatMessage::assistant("thinking"), // 4 ChatMessage::assistant("thinking"), // 4
ChatMessage::tool("t2", "bash", &big), // 5 ChatMessage::tool("t2", "bash", &big), // 5
ChatMessage::user("Q3"), // 6 ChatMessage::user("Q3"), // 6
ChatMessage::assistant("thinking"), // 7 ChatMessage::assistant("thinking"), // 7
ChatMessage::user("Q4"), // 8: LAST, is user → B2B triggers ChatMessage::user("Q4"), // 8: LAST, is user → B2B triggers
]; ];
let result = compressor.compress_if_needed(messages).await.unwrap().history; let result = compressor
.compress_if_needed(messages)
.await
.unwrap()
.history;
// B2A: "Q1" must appear exactly once // B2A: "Q1" must appear exactly once
let q1_count = result.iter().filter(|m| m.role == "user" && m.content == "Q1").count(); let q1_count = result
assert_eq!(q1_count, 1, "Q1 should appear exactly once, got {}", q1_count); .iter()
.filter(|m| m.role == "user" && m.content == "Q1")
.count();
assert_eq!(
q1_count, 1,
"Q1 should appear exactly once, got {}",
q1_count
);
// B2B: "Q4" must NOT be lost // B2B: "Q4" must NOT be lost
let q4_count = result.iter().filter(|m| m.role == "user" && m.content == "Q4").count(); let q4_count = result
assert_eq!(q4_count, 1, "Q4 should appear exactly once (not lost), got {}", q4_count); .iter()
.filter(|m| m.role == "user" && m.content == "Q4")
.count();
assert_eq!(
q4_count, 1,
"Q4 should appear exactly once (not lost), got {}",
q4_count
);
let _ = std::fs::remove_file(&tmp); let _ = std::fs::remove_file(&tmp);
} }
@ -850,10 +919,10 @@ mod tests {
let mm = Arc::new(MemoryManager::new(storage, "test".into(), "test".into())); let mm = Arc::new(MemoryManager::new(storage, "test".into(), "test".into()));
let config = ContextCompressionConfig { let config = ContextCompressionConfig {
tool_result_trim_chars: 500, // trim reduces but not enough tool_result_trim_chars: 500, // trim reduces but not enough
protect_first_n: 1, protect_first_n: 1,
protect_last_n: 2, protect_last_n: 2,
max_passes: 0, // no LLM summarization → will exceed danger max_passes: 0, // no LLM summarization → will exceed danger
..Default::default() ..Default::default()
}; };
// context_window=100, danger_threshold=90. // context_window=100, danger_threshold=90.
@ -872,13 +941,23 @@ mod tests {
ChatMessage::tool("t3", "bash", &big), ChatMessage::tool("t3", "bash", &big),
]; ];
let result = compressor.compress_if_needed(messages).await.unwrap().history; let result = compressor
.compress_if_needed(messages)
.await
.unwrap()
.history;
// After hard truncation: head (1) + trunc_note (1) + tail (2) = 4 messages // After hard truncation: head (1) + trunc_note (1) + tail (2) = 4 messages
assert!(result.len() < 7, "expected truncation reduction, got {} messages", result.len()); assert!(
result.len() < 7,
"expected truncation reduction, got {} messages",
result.len()
);
// Truncation notice should be present // Truncation notice should be present
let has_notice = result.iter().any(|m| m.content.contains("Context truncation")); let has_notice = result
.iter()
.any(|m| m.content.contains("Context truncation"));
assert!(has_notice, "hard truncation notice missing"); assert!(has_notice, "hard truncation notice missing");
let _ = std::fs::remove_file(&tmp); let _ = std::fs::remove_file(&tmp);
@ -893,9 +972,9 @@ mod tests {
let mut messages = vec![ let mut messages = vec![
ChatMessage::user("Q1"), ChatMessage::user("Q1"),
ChatMessage::user("[Context Summary]\n\nsummary of previous turn"), ChatMessage::user("[Context Summary]\n\nsummary of previous turn"),
ChatMessage::tool("tc1", "bash", "orphan result"), // orphan — tc1 never declared ChatMessage::tool("tc1", "bash", "orphan result"), // orphan — tc1 never declared
ChatMessage::assistant("done"), // declares tc2 ChatMessage::assistant("done"), // declares tc2
ChatMessage::tool("tc2", "bash", "legitimate result"), // legit ChatMessage::tool("tc2", "bash", "legitimate result"), // legit
]; ];
// Set tool_call_id on tool messages and tool_calls on assistant // Set tool_call_id on tool messages and tool_calls on assistant
messages[2].tool_call_id = Some("tc1".into()); messages[2].tool_call_id = Some("tc1".into());
@ -910,8 +989,16 @@ mod tests {
// orphan should be removed; legitimate should stay // orphan should be removed; legitimate should stay
assert_eq!(messages.len(), 4); assert_eq!(messages.len(), 4);
assert!(messages.iter().all(|m| m.tool_call_id != Some("tc1".into()))); assert!(
assert!(messages.iter().any(|m| m.tool_call_id == Some("tc2".into()))); messages
.iter()
.all(|m| m.tool_call_id != Some("tc1".into()))
);
assert!(
messages
.iter()
.any(|m| m.tool_call_id == Some("tc2".into()))
);
} }
#[test] #[test]

View File

@ -49,7 +49,7 @@ impl MediaHandler for ImageHandler {
} }
fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> { fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> {
use base64::{engine::general_purpose::STANDARD, Engine as _}; use base64::{Engine as _, engine::general_purpose::STANDARD};
let mut file = std::fs::File::open(path)?; let mut file = std::fs::File::open(path)?;
let mut buffer = Vec::new(); let mut buffer = Vec::new();

View File

@ -4,10 +4,13 @@ pub mod media_handler;
pub mod sub_agent; pub mod sub_agent;
pub mod system_prompt; pub mod system_prompt;
pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult}; pub use agent_loop::{AgentError, AgentLoop, AgentProcessResult};
pub use context_compressor::{ContextCompressor, estimate_tokens}; pub use context_compressor::{ContextCompressor, estimate_tokens};
pub use sub_agent::{DelegateContext, ExecutionMode, SubAgentConfig, SubAgentError, SubAgentManager, SubAgentResult, TaskNotification, TaskStatus}; pub use sub_agent::{
pub use system_prompt::{ DelegateContext, ExecutionMode, SubAgentConfig, SubAgentError, SubAgentManager, SubAgentResult,
build_system_prompt, build_sub_agent_system_prompt, PromptContext, PromptSection, TaskNotification, TaskStatus,
SystemPromptBuilder, };
pub use system_prompt::{
PromptContext, PromptSection, SystemPromptBuilder, build_sub_agent_system_prompt,
build_system_prompt,
}; };

View File

@ -6,12 +6,12 @@ use dashmap::DashMap;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use uuid::Uuid; use uuid::Uuid;
use crate::agent::system_prompt::build_sub_agent_system_prompt;
use crate::agent::AgentLoop;
use crate::agent::AgentError; use crate::agent::AgentError;
use crate::agent::AgentLoop;
use crate::agent::system_prompt::build_sub_agent_system_prompt;
use crate::bus::ChatMessage; use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig; use crate::config::LLMProviderConfig;
use crate::providers::{create_provider, LLMProvider}; use crate::providers::{LLMProvider, create_provider};
use crate::skills::SkillsLoader; use crate::skills::SkillsLoader;
use crate::tools::ToolRegistry; use crate::tools::ToolRegistry;
@ -21,7 +21,8 @@ tokio::task_local! {
/// Read the delegate context from the current task. Returns an error if not set. /// Read the delegate context from the current task. Returns an error if not set.
pub fn get_delegate_context() -> Result<DelegateContext, String> { pub fn get_delegate_context() -> Result<DelegateContext, String> {
DELEGATE_CONTEXT.try_with(|ctx| ctx.clone()) DELEGATE_CONTEXT
.try_with(|ctx| ctx.clone())
.map_err(|_| "DELEGATE_CONTEXT not set".to_string()) .map_err(|_| "DELEGATE_CONTEXT not set".to_string())
} }
@ -207,7 +208,10 @@ impl SubAgentManager {
let timeout_secs = config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS); let timeout_secs = config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
let timeout_human = format_duration(timeout_secs); let timeout_human = format_duration(timeout_secs);
let http_get_only = config.allowed_tools.is_none() let http_get_only = config.allowed_tools.is_none()
|| config.allowed_tools.as_ref().is_some_and(|v| v.iter().any(|t| t == "http_request")); || config
.allowed_tools
.as_ref()
.is_some_and(|v| v.iter().any(|t| t == "http_request"));
let skills_prompt = self.get_skills_prompt(&tools); let skills_prompt = self.get_skills_prompt(&tools);
let system_prompt = build_sub_agent_system_prompt( let system_prompt = build_sub_agent_system_prompt(
&config.prompt, &config.prompt,
@ -219,7 +223,8 @@ impl SubAgentManager {
http_get_only, http_get_only,
); );
let agent = self.build_sub_agent(&config, tools) let agent = self
.build_sub_agent(&config, tools)
.map_err(|e| SubAgentError::ProviderCreation(e.to_string()))?; .map_err(|e| SubAgentError::ProviderCreation(e.to_string()))?;
let history = vec![ let history = vec![
@ -241,10 +246,14 @@ impl SubAgentManager {
Ok(Ok(agent_result)) => { Ok(Ok(agent_result)) => {
let (content, truncated) = let (content, truncated) =
truncate_sub_agent_result(&agent_result.final_response.content); truncate_sub_agent_result(&agent_result.final_response.content);
let tool_calls_count = agent_result.emitted_messages.iter() let tool_calls_count = agent_result
.emitted_messages
.iter()
.filter(|m| m.tool_calls.is_some()) .filter(|m| m.tool_calls.is_some())
.count(); .count();
let iterations = agent_result.emitted_messages.iter() let iterations = agent_result
.emitted_messages
.iter()
.filter(|m| m.role == "assistant" && m.tool_calls.is_some()) .filter(|m| m.role == "assistant" && m.tool_calls.is_some())
.count(); .count();
Ok(SubAgentResult { Ok(SubAgentResult {
@ -343,7 +352,10 @@ impl SubAgentManager {
let timeout_secs = config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS); let timeout_secs = config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
let timeout_human = format_duration(timeout_secs); let timeout_human = format_duration(timeout_secs);
let http_get_only = config.allowed_tools.is_none() let http_get_only = config.allowed_tools.is_none()
|| config.allowed_tools.as_ref().is_some_and(|v| v.iter().any(|t| t == "http_request")); || config
.allowed_tools
.as_ref()
.is_some_and(|v| v.iter().any(|t| t == "http_request"));
let skills_prompt = self.get_skills_prompt(&tools); let skills_prompt = self.get_skills_prompt(&tools);
let system_prompt = build_sub_agent_system_prompt( let system_prompt = build_sub_agent_system_prompt(
&config.prompt, &config.prompt,
@ -372,8 +384,12 @@ impl SubAgentManager {
if let Some(ref s) = storage { if let Some(ref s) = storage {
let _ = s let _ = s
.update_background_task_status( .update_background_task_status(
&tid, "running", None, None, &tid,
Some(started_at), None, "running",
None,
None,
Some(started_at),
None,
) )
.await; .await;
} }
@ -384,8 +400,7 @@ impl SubAgentManager {
p.set_storage(s.clone()); p.set_storage(s.clone());
} }
} }
let provider_result: Option<Arc<dyn LLMProvider>> = let provider_result: Option<Arc<dyn LLMProvider>> = provider.map(|p| Arc::from(p));
provider.map(|p| Arc::from(p));
let result = match provider_result { let result = match provider_result {
Some(provider) => { Some(provider) => {
@ -474,9 +489,12 @@ impl SubAgentManager {
if let Some(ref s) = storage { if let Some(ref s) = storage {
let _ = s let _ = s
.update_background_task_status( .update_background_task_status(
&tid, &status_str, &tid,
Some(&result.content), error_val.as_deref(), &status_str,
Some(started_at), Some(finished_at), Some(&result.content),
error_val.as_deref(),
Some(started_at),
Some(finished_at),
) )
.await; .await;
} }
@ -514,15 +532,13 @@ impl SubAgentManager {
Ok(true) Ok(true)
} else if let Some(ref s) = self.storage { } else if let Some(ref s) = self.storage {
match s.get_background_task(task_id).await { match s.get_background_task(task_id).await {
Ok(task) => { Ok(task) => match task.status.as_str() {
match task.status.as_str() { "pending" | "running" => {
"pending" | "running" => { tracing::warn!(task_id, "task in DB but not in active_tasks");
tracing::warn!(task_id, "task in DB but not in active_tasks"); Ok(false)
Ok(false)
}
_ => Ok(false),
} }
} _ => Ok(false),
},
Err(_) => Ok(false), Err(_) => Ok(false),
} }
} else { } else {
@ -530,10 +546,7 @@ impl SubAgentManager {
} }
} }
pub async fn check_task( pub async fn check_task(&self, task_id: &str) -> Option<crate::storage::BackgroundTask> {
&self,
task_id: &str,
) -> Option<crate::storage::BackgroundTask> {
if let Some(ref s) = self.storage { if let Some(ref s) = self.storage {
s.get_background_task(task_id).await.ok() s.get_background_task(task_id).await.ok()
} else { } else {
@ -541,12 +554,11 @@ impl SubAgentManager {
} }
} }
pub async fn list_tasks( pub async fn list_tasks(&self, session_id: &str) -> Vec<crate::storage::BackgroundTask> {
&self,
session_id: &str,
) -> Vec<crate::storage::BackgroundTask> {
if let Some(ref s) = self.storage { if let Some(ref s) = self.storage {
s.list_background_tasks(session_id).await.unwrap_or_default() s.list_background_tasks(session_id)
.await
.unwrap_or_default()
} else { } else {
vec![] vec![]
} }

View File

@ -196,10 +196,10 @@ impl PromptSection for UserProfileSection {
if let Some(user_config_dir) = get_user_config_dir() if let Some(user_config_dir) = get_user_config_dir()
&& let Some(content) = && let Some(content) =
load_file_from_dir(&user_config_dir, "USER.md", BOOTSTRAP_MAX_CHARS) load_file_from_dir(&user_config_dir, "USER.md", BOOTSTRAP_MAX_CHARS)
{ {
output.push_str(&content); output.push_str(&content);
return output; return output;
} }
// No USER.md found, return empty // No USER.md found, return empty
String::new() String::new()
@ -220,10 +220,10 @@ impl PromptSection for AgentProfileSection {
if let Some(user_config_dir) = get_user_config_dir() if let Some(user_config_dir) = get_user_config_dir()
&& let Some(content) = && let Some(content) =
load_file_from_dir(&user_config_dir, "AGENTS.md", BOOTSTRAP_MAX_CHARS) load_file_from_dir(&user_config_dir, "AGENTS.md", BOOTSTRAP_MAX_CHARS)
{ {
output.push_str(&content); output.push_str(&content);
return output; return output;
} }
String::new() String::new()
} }
@ -465,7 +465,9 @@ impl PromptSection for SubAgentToolsSection {
let mut s = String::from("## 可用工具\n\n"); let mut s = String::from("## 可用工具\n\n");
s.push_str(&ctx.tools.describe_for_prompt()); s.push_str(&ctx.tools.describe_for_prompt());
if self.http_get_only { if self.http_get_only {
s.push_str("\n\n**注意**:使用 http_request 时只允许 GET 方法,禁止 POST、PUT、DELETE 等。"); s.push_str(
"\n\n**注意**:使用 http_request 时只允许 GET 方法,禁止 POST、PUT、DELETE 等。",
);
} }
s s
} }
@ -560,13 +562,8 @@ pub fn build_sub_agent_system_prompt(
memory_context: None, memory_context: None,
has_compressed_history: false, has_compressed_history: false,
}; };
SystemPromptBuilder::with_sub_agent_defaults( SystemPromptBuilder::with_sub_agent_defaults(task, timeout_human, skills_prompt, http_get_only)
task, .build(&ctx)
timeout_human,
skills_prompt,
http_get_only,
)
.build(&ctx)
} }
#[cfg(test)] #[cfg(test)]

View File

@ -1,8 +1,8 @@
use std::sync::Arc; use std::sync::Arc;
use crate::bus::{MessageBus, OutboundMessage}; use crate::bus::{MessageBus, OutboundMessage};
use crate::channels::base::{Channel, ChannelError};
use crate::channels::ChannelManager; use crate::channels::ChannelManager;
use crate::channels::base::{Channel, ChannelError};
/// OutboundDispatcher consumes outbound messages from the MessageBus /// OutboundDispatcher consumes outbound messages from the MessageBus
/// and dispatches them to the appropriate Channel /// and dispatches them to the appropriate Channel

View File

@ -1,5 +1,5 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::providers::ToolCall; use crate::providers::ToolCall;
@ -23,7 +23,9 @@ pub struct ImageUrlBlock {
impl ContentBlock { impl ContentBlock {
pub fn text(content: impl Into<String>) -> Self { pub fn text(content: impl Into<String>) -> Self {
Self::Text { text: content.into() } Self::Text {
text: content.into(),
}
} }
pub fn image_url(url: impl Into<String>) -> Self { pub fn image_url(url: impl Into<String>) -> Self {
@ -49,10 +51,10 @@ pub struct MediaRef {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct MediaItem { pub struct MediaItem {
pub path: String, // Local file path pub path: String, // Local file path
pub media_type: String, // "image", "audio", "file", "video" pub media_type: String, // "image", "audio", "file", "video"
pub mime_type: Option<String>, pub mime_type: Option<String>,
pub original_key: Option<String>, // Feishu file_key for download pub original_key: Option<String>, // Feishu file_key for download
} }
impl MediaItem { impl MediaItem {
@ -161,7 +163,10 @@ impl ChatMessage {
} }
} }
pub fn assistant_with_tool_calls(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self { pub fn assistant_with_tool_calls(
content: impl Into<String>,
tool_calls: Vec<ToolCall>,
) -> Self {
Self { Self {
id: uuid::Uuid::new_v4().to_string(), id: uuid::Uuid::new_v4().to_string(),
role: "assistant".to_string(), role: "assistant".to_string(),
@ -206,7 +211,11 @@ impl ChatMessage {
} }
} }
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self { pub fn tool(
tool_call_id: impl Into<String>,
tool_name: impl Into<String>,
content: impl Into<String>,
) -> Self {
Self { Self {
id: uuid::Uuid::new_v4().to_string(), id: uuid::Uuid::new_v4().to_string(),
role: "tool".to_string(), role: "tool".to_string(),

View File

@ -2,10 +2,13 @@ pub mod dispatcher;
pub mod message; pub mod message;
pub use dispatcher::OutboundDispatcher; pub use dispatcher::OutboundDispatcher;
pub use message::{ChatMessage, ContentBlock, ControlMessage, InboundMessage, MediaItem, MediaRef, MessageSource, OutboundMessage, SourceKind}; pub use message::{
ChatMessage, ContentBlock, ControlMessage, InboundMessage, MediaItem, MediaRef, MessageSource,
OutboundMessage, SourceKind,
};
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::{mpsc, Mutex}; use tokio::sync::{Mutex, mpsc};
// ============================================================================ // ============================================================================
// MessageBus - Async message queue for Channel <-> Agent communication // MessageBus - Async message queue for Channel <-> Agent communication
@ -49,7 +52,8 @@ impl MessageBus {
/// Consume an inbound message (Agent -> Bus) /// Consume an inbound message (Agent -> Bus)
pub async fn consume_inbound(&self) -> InboundMessage { pub async fn consume_inbound(&self) -> InboundMessage {
let msg = self.inbound_rx let msg = self
.inbound_rx
.lock() .lock()
.await .await
.recv() .recv()

File diff suppressed because it is too large Load Diff

View File

@ -24,7 +24,10 @@ impl ChannelManager {
} }
} }
pub fn with_bus(cli_chat_channel: Arc<crate::channels::CliChatChannel>, bus: Arc<MessageBus>) -> Self { pub fn with_bus(
cli_chat_channel: Arc<crate::channels::CliChatChannel>,
bus: Arc<MessageBus>,
) -> Self {
Self { Self {
channels: Arc::new(RwLock::new(HashMap::new())), channels: Arc::new(RwLock::new(HashMap::new())),
cli_chat_channel, cli_chat_channel,
@ -39,7 +42,10 @@ impl ChannelManager {
/// Register a channel with the manager /// Register a channel with the manager
pub async fn register_channel(&self, name: &str, channel: Arc<dyn Channel + Send + Sync>) { pub async fn register_channel(&self, name: &str, channel: Arc<dyn Channel + Send + Sync>) {
self.channels.write().await.insert(name.to_string(), channel); self.channels
.write()
.await
.insert(name.to_string(), channel);
} }
/// Get CLI chat channel /// Get CLI chat channel
@ -56,14 +62,19 @@ impl ChannelManager {
// Initialize Feishu channel if enabled // Initialize Feishu channel if enabled
if let Some(feishu_config) = config.channels.get("feishu") { if let Some(feishu_config) = config.channels.get("feishu") {
if feishu_config.enabled { if feishu_config.enabled {
let channel = FeishuChannel::new(feishu_config.clone(), &workspace_dir) let channel =
.map_err(|e| ChannelError::Other(format!("Failed to create Feishu channel: {}", e)))?; FeishuChannel::new(feishu_config.clone(), &workspace_dir).map_err(|e| {
ChannelError::Other(format!("Failed to create Feishu channel: {}", e))
})?;
self.channels self.channels
.write() .write()
.await .await
.insert("feishu".to_string(), Arc::new(channel)); .insert("feishu".to_string(), Arc::new(channel));
tracing::info!("Feishu channel registered (media_dir: {}/media/feishu)", workspace_dir.display()); tracing::info!(
"Feishu channel registered (media_dir: {}/media/feishu)",
workspace_dir.display()
);
} else { } else {
tracing::info!("Feishu channel disabled in config"); tracing::info!("Feishu channel disabled in config");
} }
@ -118,7 +129,10 @@ impl ChannelManager {
if let Some(channel) = self.get_channel(channel_name).await { if let Some(channel) = self.get_channel(channel_name).await {
channel.send(msg).await channel.send(msg).await
} else { } else {
Err(ChannelError::Other(format!("Channel not found: {}", channel_name))) Err(ChannelError::Other(format!(
"Channel not found: {}",
channel_name
)))
} }
} }
} }

View File

@ -1,11 +1,11 @@
pub mod base; pub mod base;
pub mod feishu;
pub mod cli_chat; pub mod cli_chat;
pub mod feishu;
pub mod manager; pub mod manager;
pub mod slash_command; pub mod slash_command;
pub use base::{Channel, ChannelError}; pub use base::{Channel, ChannelError};
pub use manager::ChannelManager;
pub use feishu::FeishuChannel;
pub use cli_chat::CliChatChannel; pub use cli_chat::CliChatChannel;
pub use slash_command::{parse_slash_command, command_matches}; pub use feishu::FeishuChannel;
pub use manager::ChannelManager;
pub use slash_command::{command_matches, parse_slash_command};

View File

@ -16,7 +16,9 @@ pub fn parse_slash_command(content: &str) -> Option<(&str, &str)> {
/// 检查内容是否匹配指定命令 /// 检查内容是否匹配指定命令
pub fn command_matches(content: &str, aliases: &[&str]) -> bool { pub fn command_matches(content: &str, aliases: &[&str]) -> bool {
let trimmed = content.trim(); let trimmed = content.trim();
aliases.iter().any(|&alias| trimmed == alias || trimmed.starts_with(&format!("{} ", alias))) aliases
.iter()
.any(|&alias| trimmed == alias || trimmed.starts_with(&format!("{} ", alias)))
} }
#[cfg(test)] #[cfg(test)]
@ -27,7 +29,10 @@ mod tests {
fn test_parse_slash_command() { fn test_parse_slash_command() {
assert_eq!(parse_slash_command("/reset"), Some(("reset", ""))); assert_eq!(parse_slash_command("/reset"), Some(("reset", "")));
assert_eq!(parse_slash_command("/reset arg"), Some(("reset", "arg"))); assert_eq!(parse_slash_command("/reset arg"), Some(("reset", "arg")));
assert_eq!(parse_slash_command("/new hello world"), Some(("new", "hello world"))); assert_eq!(
parse_slash_command("/new hello world"),
Some(("new", "hello world"))
);
assert_eq!(parse_slash_command("/??"), Some(("??", ""))); assert_eq!(parse_slash_command("/??"), Some(("??", "")));
assert_eq!(parse_slash_command("/? arg"), Some(("?", "arg"))); assert_eq!(parse_slash_command("/? arg"), Some(("?", "arg")));
assert_eq!(parse_slash_command("/?"), Some(("?", ""))); assert_eq!(parse_slash_command("/?"), Some(("?", "")));

View File

@ -8,10 +8,10 @@ use crate::client::tui::ui::render_ui;
use crossterm::{ use crossterm::{
event::{self, Event}, event::{self, Event},
execute, execute,
terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen}, terminal::{EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, enable_raw_mode},
}; };
use futures_util::{SinkExt, StreamExt}; use futures_util::{SinkExt, StreamExt};
use ratatui::{prelude::CrosstermBackend, Terminal}; use ratatui::{Terminal, prelude::CrosstermBackend};
use std::io; use std::io;
use tokio_tungstenite::{connect_async, tungstenite::Message}; use tokio_tungstenite::{connect_async, tungstenite::Message};
@ -104,7 +104,10 @@ async fn handle_ws_message(app: &mut App, outbound: WsOutbound) {
WsOutbound::SessionCreated { session_id, .. } => { WsOutbound::SessionCreated { session_id, .. } => {
app.set_current_session(Some(session_id)); app.set_current_session(Some(session_id));
} }
WsOutbound::SessionList { sessions, current_session_id } => { WsOutbound::SessionList {
sessions,
current_session_id,
} => {
app.set_sessions(sessions); app.set_sessions(sessions);
if let Some(id) = current_session_id { if let Some(id) = current_session_id {
app.set_current_session(Some(id)); app.set_current_session(Some(id));

View File

@ -1,10 +1,10 @@
use crate::client::tui::app::{App, MessageRole}; use crate::client::tui::app::{App, MessageRole};
use ratatui::{ use ratatui::{
Frame,
layout::Rect, layout::Rect,
style::{Color, Modifier, Style}, style::{Color, Modifier, Style},
text::Line, text::Line,
widgets::{Block, Borders, List, ListItem}, widgets::{Block, Borders, List, ListItem},
Frame,
}; };
pub fn render(f: &mut Frame, area: Rect, app: &App) { pub fn render(f: &mut Frame, area: Rect, app: &App) {

View File

@ -1,10 +1,10 @@
use crate::client::tui::app::App; use crate::client::tui::app::App;
use ratatui::{ use ratatui::{
Frame,
layout::Rect, layout::Rect,
style::{Color, Modifier, Style}, style::{Color, Modifier, Style},
text::{Line, Span}, text::{Line, Span},
widgets::{Block, Borders, List, ListItem}, widgets::{Block, Borders, List, ListItem},
Frame,
}; };
pub fn render(f: &mut Frame, area: Rect, app: &App) { pub fn render(f: &mut Frame, area: Rect, app: &App) {

View File

@ -1,8 +1,8 @@
use ratatui::{ use ratatui::{
Frame,
layout::Rect, layout::Rect,
style::{Color, Modifier, Style}, style::{Color, Modifier, Style},
widgets::{Block, Borders, Clear, List, ListItem}, widgets::{Block, Borders, Clear, List, ListItem},
Frame,
}; };
pub fn render(f: &mut Frame, area: Rect) { pub fn render(f: &mut Frame, area: Rect) {

View File

@ -1,9 +1,9 @@
use crate::client::tui::app::App; use crate::client::tui::app::App;
use ratatui::{ use ratatui::{
Frame,
layout::Rect, layout::Rect,
style::{Color, Style}, style::{Color, Style},
widgets::{Block, Borders, Paragraph}, widgets::{Block, Borders, Paragraph},
Frame,
}; };
pub fn render(f: &mut Frame, area: Rect, app: &App) { pub fn render(f: &mut Frame, area: Rect, app: &App) {

View File

@ -1,9 +1,9 @@
use crate::client::tui::app::App; use crate::client::tui::app::App;
use ratatui::{ use ratatui::{
Frame,
layout::Rect, layout::Rect,
style::{Color, Modifier, Style}, style::{Color, Modifier, Style},
widgets::{Block, Borders, List, ListItem}, widgets::{Block, Borders, List, ListItem},
Frame,
}; };
pub fn render(f: &mut Frame, area: Rect, app: &App) { pub fn render(f: &mut Frame, area: Rect, app: &App) {
@ -11,9 +11,7 @@ pub fn render(f: &mut Frame, area: Rect, app: &App) {
.sessions .sessions
.iter() .iter()
.map(|session| { .map(|session| {
let is_current = app let is_current = app.current_session_id.as_ref() == Some(&session.session_id);
.current_session_id
.as_ref() == Some(&session.session_id);
let archived = session.archived_at.is_some(); let archived = session.archived_at.is_some();
let mut content = if is_current { let mut content = if is_current {

View File

@ -1,15 +1,18 @@
use crate::client::tui::app::App; use crate::client::tui::app::App;
use ratatui::{ use ratatui::{
Frame,
layout::Rect, layout::Rect,
style::{Color, Modifier, Style}, style::{Color, Modifier, Style},
widgets::{Block, Borders, Paragraph}, widgets::{Block, Borders, Paragraph},
Frame,
}; };
pub fn render(f: &mut Frame, area: Rect, app: &App) { pub fn render(f: &mut Frame, area: Rect, app: &App) {
let (title, style) = if app.pending_quit { let (title, style) = if app.pending_quit {
let msg = if let Some(session_id) = &app.current_session_id { let msg = if let Some(session_id) = &app.current_session_id {
format!("PicoBot | Session: {} | Press Ctrl+C again to quit", session_id) format!(
"PicoBot | Session: {} | Press Ctrl+C again to quit",
session_id
)
} else { } else {
"PicoBot | Press Ctrl+C again to quit".to_string() "PicoBot | Press Ctrl+C again to quit".to_string()
}; };

View File

@ -1,6 +1,6 @@
use crate::client::tui::app::{App, MessageRole}; use crate::client::tui::app::{App, MessageRole};
use crate::protocol::serialize_inbound;
use crate::protocol::WsInbound; use crate::protocol::WsInbound;
use crate::protocol::serialize_inbound;
use crossterm::event::{KeyCode, KeyEvent}; use crossterm::event::{KeyCode, KeyEvent};
use futures_util::SinkExt; use futures_util::SinkExt;
@ -48,7 +48,10 @@ pub async fn handle_key_event(app: &mut App, key: KeyEvent) {
async fn handle_normal_input(app: &mut App, key: KeyEvent) { async fn handle_normal_input(app: &mut App, key: KeyEvent) {
// Handle Ctrl+C for quit (double press to exit) // Handle Ctrl+C for quit (double press to exit)
let is_ctrl_c = key.code == KeyCode::Char('c') && key.modifiers.contains(crossterm::event::KeyModifiers::CONTROL); let is_ctrl_c = key.code == KeyCode::Char('c')
&& key
.modifiers
.contains(crossterm::event::KeyModifiers::CONTROL);
if is_ctrl_c { if is_ctrl_c {
if app.handle_ctrl_c_for_quit() { if app.handle_ctrl_c_for_quit() {
return; return;
@ -65,7 +68,9 @@ async fn handle_normal_input(app: &mut App, key: KeyEvent) {
app.input_insert_char(c); app.input_insert_char(c);
// Show command menu when input starts with / // Show command menu when input starts with /
if !app.show_command_menu && (app.input == "/" || (app.input.len() > 1 && app.input.starts_with('/'))) { if !app.show_command_menu
&& (app.input == "/" || (app.input.len() > 1 && app.input.starts_with('/')))
{
app.show_command_menu = true; app.show_command_menu = true;
app.selected_command_idx = 0; app.selected_command_idx = 0;
} else if app.show_command_menu && !app.input.starts_with('/') { } else if app.show_command_menu && !app.input.starts_with('/') {
@ -121,7 +126,9 @@ async fn process_input(app: &mut App, input: String) {
sender_id: None, sender_id: None,
}; };
if let Ok(text) = serialize_inbound(&inbound) { if let Ok(text) = serialize_inbound(&inbound) {
let _ = sender.send(tokio_tungstenite::tungstenite::Message::Text(text.into())).await; let _ = sender
.send(tokio_tungstenite::tungstenite::Message::Text(text.into()))
.await;
} }
} }
} }

View File

@ -1,8 +1,8 @@
use crate::client::tui::app::App; use crate::client::tui::app::App;
use crate::client::tui::components::*; use crate::client::tui::components::*;
use ratatui::{ use ratatui::{
layout::{Constraint, Direction, Layout, Rect},
Frame, Frame,
layout::{Constraint, Direction, Layout, Rect},
}; };
pub fn render_ui(f: &mut Frame, app: &App) { pub fn render_ui(f: &mut Frame, app: &App) {

View File

@ -273,12 +273,16 @@ impl Default for MemoryConfig {
impl MemoryConfig { impl MemoryConfig {
/// Resolve consolidation provider name, falling back to the main agent's provider. /// Resolve consolidation provider name, falling back to the main agent's provider.
pub fn resolve_consolidation_provider(&self, default: &str) -> String { pub fn resolve_consolidation_provider(&self, default: &str) -> String {
self.consolidation_provider.clone().unwrap_or_else(|| default.to_string()) self.consolidation_provider
.clone()
.unwrap_or_else(|| default.to_string())
} }
/// Resolve consolidation model name, falling back to the main agent's model. /// Resolve consolidation model name, falling back to the main agent's model.
pub fn resolve_consolidation_model(&self, default: &str) -> String { pub fn resolve_consolidation_model(&self, default: &str) -> String {
self.consolidation_model.clone().unwrap_or_else(|| default.to_string()) self.consolidation_model
.clone()
.unwrap_or_else(|| default.to_string())
} }
} }
@ -366,10 +370,18 @@ impl Default for BrowserConfig {
} }
} }
fn default_recall_limit() -> usize { 5 } fn default_recall_limit() -> usize {
fn default_idle_consolidation_minutes() -> u64 { 10 } 5
fn default_timeline_retention_days() -> u64 { 90 } }
fn default_max_failures_before_degrade() -> usize { 3 } fn default_idle_consolidation_minutes() -> u64 {
10
}
fn default_timeline_retention_days() -> u64 {
90
}
fn default_max_failures_before_degrade() -> usize {
3
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct LLMProviderConfig { pub struct LLMProviderConfig {
@ -469,7 +481,11 @@ pub enum ConfigError {
impl std::fmt::Display for ConfigError { impl std::fmt::Display for ConfigError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
ConfigError::ConfigNotFound(path) => write!(f, "Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json", path), ConfigError::ConfigNotFound(path) => write!(
f,
"Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json",
path
),
ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name), ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name),
ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name), ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name),
ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name), ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name),

View File

@ -1,12 +1,12 @@
use std::sync::Arc; use super::GatewayState;
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage}; use crate::protocol::WsOutbound;
use crate::protocol::serialize_outbound;
use axum::extract::State; use axum::extract::State;
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
use axum::response::Response; use axum::response::Response;
use futures_util::{SinkExt, StreamExt}; use futures_util::{SinkExt, StreamExt};
use std::sync::Arc;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use crate::protocol::serialize_outbound;
use crate::protocol::WsOutbound;
use super::GatewayState;
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response { pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
ws.on_upgrade(|socket| async move { ws.on_upgrade(|socket| async move {
@ -25,9 +25,11 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
let (session_id, client) = cli_chat_channel.register_client(sender.clone()).await; let (session_id, client) = cli_chat_channel.register_client(sender.clone()).await;
// Send session established message // Send session established message
let _ = sender.send(WsOutbound::SessionEstablished { let _ = sender
session_id: session_id.clone(), .send(WsOutbound::SessionEstablished {
}).await; session_id: session_id.clone(),
})
.await;
tracing::info!(session_id = %session_id, "CLI session established"); tracing::info!(session_id = %session_id, "CLI session established");
@ -37,9 +39,10 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
tokio::spawn(async move { tokio::spawn(async move {
while let Some(msg) = receiver.recv().await { while let Some(msg) = receiver.recv().await {
if let Ok(text) = serialize_outbound(&msg) if let Ok(text) = serialize_outbound(&msg)
&& ws_sender.send(WsMessage::Text(text.into())).await.is_err() { && ws_sender.send(WsMessage::Text(text.into())).await.is_err()
break; {
} break;
}
} }
}); });

View File

@ -1,17 +1,17 @@
pub mod config;
pub mod providers;
pub mod bus;
pub mod agent; pub mod agent;
pub mod gateway; pub mod bus;
pub mod session;
pub mod client;
pub mod protocol;
pub mod channels; pub mod channels;
pub mod client;
pub mod config;
pub mod gateway;
pub mod logging; pub mod logging;
pub mod mcp; pub mod mcp;
pub mod memory; pub mod memory;
pub mod observability; pub mod observability;
pub mod protocol;
pub mod providers;
pub mod scheduler; pub mod scheduler;
pub mod session;
pub mod skills; pub mod skills;
pub mod storage; pub mod storage;
pub mod tools; pub mod tools;

View File

@ -1,11 +1,7 @@
use std::path::PathBuf; use std::path::PathBuf;
use tracing_appender::rolling::{RollingFileAppender, Rotation}; use tracing_appender::rolling::{RollingFileAppender, Rotation};
use tracing_subscriber::{ use tracing_subscriber::{
fmt, EnvFilter, fmt, fmt::time::LocalTime, layer::SubscriberExt, util::SubscriberInitExt,
layer::SubscriberExt,
util::SubscriberInitExt,
fmt::time::LocalTime,
EnvFilter,
}; };
/// Get the default log directory path: ~/.picobot/logs /// Get the default log directory path: ~/.picobot/logs
@ -27,20 +23,20 @@ pub fn init_logging() {
// Create log directory if it doesn't exist // Create log directory if it doesn't exist
if !log_dir.exists() if !log_dir.exists()
&& let Err(e) = std::fs::create_dir_all(&log_dir) { && let Err(e) = std::fs::create_dir_all(&log_dir)
eprintln!("Warning: Failed to create log directory {}: {}", log_dir.display(), e); {
} eprintln!(
"Warning: Failed to create log directory {}: {}",
log_dir.display(),
e
);
}
// Create file appender with daily rotation // Create file appender with daily rotation
let file_appender = RollingFileAppender::new( let file_appender = RollingFileAppender::new(Rotation::DAILY, &log_dir, "picobot.log");
Rotation::DAILY,
&log_dir,
"picobot.log",
);
// Build subscriber with both console and file output // Build subscriber with both console and file output
let env_filter = EnvFilter::try_from_default_env() let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
.unwrap_or_else(|_| EnvFilter::new("info"));
let file_layer = fmt::layer() let file_layer = fmt::layer()
.with_writer(file_appender) .with_writer(file_appender)
@ -66,8 +62,7 @@ pub fn init_logging() {
/// Initialize logging without file output (console only) /// Initialize logging without file output (console only)
pub fn init_logging_console_only() { pub fn init_logging_console_only() {
let env_filter = EnvFilter::try_from_default_env() let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
.unwrap_or_else(|_| EnvFilter::new("info"));
let console_layer = fmt::layer() let console_layer = fmt::layer()
.with_timer(LocalTime::rfc_3339()) .with_timer(LocalTime::rfc_3339())

View File

@ -1,4 +1,4 @@
use clap::{Parser, CommandFactory}; use clap::{CommandFactory, Parser};
#[derive(Parser)] #[derive(Parser)]
#[command(name = "picobot")] #[command(name = "picobot")]

View File

@ -92,24 +92,19 @@ fn extract_text(result: &rmcp::model::CallToolResult) -> String {
parts.push(text.text.clone()); parts.push(text.text.clone());
} }
RawContent::Image(image) => { RawContent::Image(image) => {
parts.push(format!( parts.push(format!("[image: {}]", image.mime_type,));
"[image: {}]",
image.mime_type,
));
} }
RawContent::Resource(resource) => { RawContent::Resource(resource) => match &resource.resource {
match &resource.resource { rmcp::model::ResourceContents::TextResourceContents { text, .. } => {
rmcp::model::ResourceContents::TextResourceContents { text, .. } => { parts.push(format!(
parts.push(format!( "[resource text: {}]",
"[resource text: {}]", text.chars().take(200).collect::<String>(),
text.chars().take(200).collect::<String>(), ));
));
}
rmcp::model::ResourceContents::BlobResourceContents { uri, .. } => {
parts.push(format!("[resource blob: {}]", uri));
}
} }
} rmcp::model::ResourceContents::BlobResourceContents { uri, .. } => {
parts.push(format!("[resource blob: {}]", uri));
}
},
_ => { _ => {
parts.push("[unsupported content]".to_string()); parts.push("[unsupported content]".to_string());
} }
@ -225,8 +220,8 @@ async fn connect_server(config: &McpServerConfig) -> anyhow::Result<McpConnectio
cmd.env(k, v); cmd.env(k, v);
} }
let service = () let service =
.serve( ().serve(
TokioChildProcess::new(cmd).context("failed to create stdio MCP transport")?, TokioChildProcess::new(cmd).context("failed to create stdio MCP transport")?,
) )
.await .await
@ -261,14 +256,14 @@ async fn connect_server(config: &McpServerConfig) -> anyhow::Result<McpConnectio
} else { } else {
StreamableHttpClientTransport::from_config( StreamableHttpClientTransport::from_config(
StreamableHttpClientTransportConfig::with_uri(url.to_string()) StreamableHttpClientTransportConfig::with_uri(url.to_string())
.custom_headers(headers_map) .custom_headers(headers_map),
) )
}; };
let service = () let service =
.serve(transport) ().serve(transport)
.await .await
.context("failed to connect to HTTP/SSE MCP server")?; .context("failed to connect to HTTP/SSE MCP server")?;
let peer = service.peer().clone(); let peer = service.peer().clone();

View File

@ -102,7 +102,11 @@ mod tests {
let dir = tempdir().unwrap(); let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db"); let db_path = dir.path().join("test.db");
let storage = Arc::new(Storage::new(&db_path).await.unwrap()); let storage = Arc::new(Storage::new(&db_path).await.unwrap());
let mm = Arc::new(MemoryManager::new(storage, "default".into(), "default".into())); let mm = Arc::new(MemoryManager::new(
storage,
"default".into(),
"default".into(),
));
(mm, dir) (mm, dir)
} }
@ -131,15 +135,9 @@ mod tests {
async fn test_upsert_overwrites() { async fn test_upsert_overwrites() {
let (mm, _dir) = setup_memory_manager().await; let (mm, _dir) = setup_memory_manager().await;
mm.store( mm.store("dup_key", "original", MemoryCategory::Knowledge, None, None)
"dup_key", .await
"original", .unwrap();
MemoryCategory::Knowledge,
None,
None,
)
.await
.unwrap();
mm.store( mm.store(
"dup_key", "dup_key",
"updated", "updated",
@ -247,7 +245,12 @@ mod tests {
// Recall scoped to session A — should get only tl_a // Recall scoped to session A — should get only tl_a
let scoped = mm let scoped = mm
.recall("summary", 10, Some(MemoryCategory::Timeline), Some("chan:chat:dialog_a")) .recall(
"summary",
10,
Some(MemoryCategory::Timeline),
Some("chan:chat:dialog_a"),
)
.await .await
.unwrap(); .unwrap();
assert_eq!(scoped.len(), 1); assert_eq!(scoped.len(), 1);

View File

@ -20,10 +20,7 @@ pub enum ObserverEvent {
success: bool, success: bool,
}, },
/// Emitted when the agent starts processing. /// Emitted when the agent starts processing.
AgentStart { AgentStart { provider: String, model: String },
provider: String,
model: String,
},
/// Emitted when the agent finishes processing. /// Emitted when the agent finishes processing.
AgentEnd { AgentEnd {
provider: String, provider: String,
@ -94,7 +91,11 @@ impl ToolExecutionOutcome {
} }
/// Create a failed outcome with duration. /// Create a failed outcome with duration.
pub fn failure_with_duration(output: String, error_reason: Option<String>, duration: Duration) -> Self { pub fn failure_with_duration(
output: String,
error_reason: Option<String>,
duration: Duration,
) -> Self {
Self { Self {
output, output,
success: false, success: false,

View File

@ -4,23 +4,24 @@ use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::time::Duration; use std::time::Duration;
use crate::bus::message::ContentBlock;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
use super::traits::Usage; use super::traits::Usage;
use std::sync::Arc; use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
use crate::bus::message::ContentBlock;
use crate::storage::Storage; use crate::storage::Storage;
use std::sync::Arc;
const LLM_REQUEST_TIMEOUT_SECS: u64 = 300; const LLM_REQUEST_TIMEOUT_SECS: u64 = 300;
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> { fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> {
blocks.iter().map(|b| match b { blocks
ContentBlock::Text { text } => { .iter()
serde_json::json!({ "type": "text", "text": text }) .map(|b| match b {
} ContentBlock::Text { text } => {
ContentBlock::ImageUrl { image_url } => { serde_json::json!({ "type": "text", "text": text })
convert_image_url_to_anthropic(&image_url.url) }
} ContentBlock::ImageUrl { image_url } => convert_image_url_to_anthropic(&image_url.url),
}).collect() })
.collect()
} }
fn convert_image_url_to_anthropic(url: &str) -> serde_json::Value { fn convert_image_url_to_anthropic(url: &str) -> serde_json::Value {
@ -197,8 +198,13 @@ impl LLMProvider for AnthropicProvider {
}; };
let content = if let Some(ref tc_id) = m.tool_call_id { let content = if let Some(ref tc_id) = m.tool_call_id {
// Tool result: wrap as tool_result content block // Tool result: wrap as tool_result content block
let output = m.content.iter() let output = m
.filter_map(|b| match b { ContentBlock::Text { text } => Some(text.as_str()), _ => None }) .content
.iter()
.filter_map(|b| match b {
ContentBlock::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(""); .join("");
vec![serde_json::json!({ vec![serde_json::json!({
@ -244,19 +250,18 @@ impl LLMProvider for AnthropicProvider {
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default(); let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
tracing::debug!(req_body = %req_body_str, "LLM request"); tracing::debug!(req_body = %req_body_str, "LLM request");
let resp = req_builder.json(&body).send().await let resp = req_builder.json(&body).send().await.inspect_err(|e| {
.inspect_err(|e| { let is_timeout = e.is_timeout();
let is_timeout = e.is_timeout(); tracing::error!(
tracing::error!( provider = %self.name,
provider = %self.name, model = %self.model_id,
model = %self.model_id, url = %url,
url = %url, timeout = is_timeout,
timeout = is_timeout, error = %e,
error = %e, elapsed_ms = %start.elapsed().as_millis(),
elapsed_ms = %start.elapsed().as_millis(), "LLM API request failed"
"LLM API request failed" );
); })?;
})?;
let status = resp.status(); let status = resp.status();
let body_text = resp.text().await?; let body_text = resp.text().await?;
@ -281,32 +286,38 @@ impl LLMProvider for AnthropicProvider {
"LLM API returned error" "LLM API returned error"
); );
if let Some(ref storage) = self.storage { if let Some(ref storage) = self.storage {
let _ = storage.append_llm_call( let _ = storage
&self.name, &self.model_id, &req_body_str, .append_llm_call(
Some(&body_text), Some(&error_msg), &self.name,
start.elapsed().as_millis() as u64, &self.model_id,
).await; &req_body_str,
Some(&body_text),
Some(&error_msg),
start.elapsed().as_millis() as u64,
)
.await;
} }
return Err(format!("API error ({}): {}", status.as_u16(), error_msg).into()); return Err(format!("API error ({}): {}", status.as_u16(), error_msg).into());
} }
let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text) let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text).map_err(|e| {
.map_err(|e| { let err_msg = format!("decode error: {} | body: {}", e, &body_text);
let err_msg = format!("decode error: {} | body: {}", e, &body_text); if let Some(ref storage) = self.storage {
if let Some(ref storage) = self.storage { let name = self.name.clone();
let name = self.name.clone(); let model = self.model_id.clone();
let model = self.model_id.clone(); let req = req_body_str.clone();
let req = req_body_str.clone(); let resp_body = body_text.clone();
let resp_body = body_text.clone(); let dur = start.elapsed().as_millis() as u64;
let dur = start.elapsed().as_millis() as u64; let err = err_msg.clone();
let err = err_msg.clone(); let s = storage.clone();
let s = storage.clone(); tokio::spawn(async move {
tokio::spawn(async move { let _ = s
let _ = s.append_llm_call(&name, &model, &req, Some(&resp_body), Some(&err), dur).await; .append_llm_call(&name, &model, &req, Some(&resp_body), Some(&err), dur)
}); .await;
} });
err_msg }
})?; err_msg
})?;
let mut content = String::new(); let mut content = String::new();
let mut reasoning = None; let mut reasoning = None;
@ -343,21 +354,35 @@ impl LLMProvider for AnthropicProvider {
reasoning_content: reasoning, reasoning_content: reasoning,
tool_calls, tool_calls,
usage: Usage { usage: Usage {
prompt_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens).unwrap_or(0), prompt_tokens: anthropic_resp
completion_tokens: anthropic_resp.usage.as_ref().map(|u| u.output_tokens).unwrap_or(0), .usage
total_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens + u.output_tokens).unwrap_or(0), .as_ref()
.map(|u| u.input_tokens)
.unwrap_or(0),
completion_tokens: anthropic_resp
.usage
.as_ref()
.map(|u| u.output_tokens)
.unwrap_or(0),
total_tokens: anthropic_resp
.usage
.as_ref()
.map(|u| u.input_tokens + u.output_tokens)
.unwrap_or(0),
}, },
}; };
if let Some(ref storage) = self.storage { if let Some(ref storage) = self.storage {
let _ = storage.append_llm_call( let _ = storage
&self.name, .append_llm_call(
&self.model_id, &self.name,
&req_body_str, &self.model_id,
Some(&body_text), &req_body_str,
None, Some(&body_text),
start.elapsed().as_millis() as u64, None,
).await; start.elapsed().as_millis() as u64,
)
.await;
} }
Ok(response) Ok(response)

View File

@ -1,12 +1,15 @@
pub mod traits;
pub mod openai;
pub mod anthropic; pub mod anthropic;
pub mod openai;
pub mod traits;
pub use self::openai::OpenAIProvider;
pub use self::anthropic::AnthropicProvider; pub use self::anthropic::AnthropicProvider;
pub use self::openai::OpenAIProvider;
use crate::config::LLMProviderConfig; use crate::config::LLMProviderConfig;
pub use traits::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall, ToolFunction, Usage}; pub use traits::{
ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall,
ToolFunction, Usage,
};
pub fn create_provider(config: LLMProviderConfig) -> Result<Box<dyn LLMProvider>, ProviderError> { pub fn create_provider(config: LLMProviderConfig) -> Result<Box<dyn LLMProvider>, ProviderError> {
match config.provider_type.as_str() { match config.provider_type.as_str() {

View File

@ -1,29 +1,35 @@
use async_trait::async_trait; use async_trait::async_trait;
use reqwest::Client; use reqwest::Client;
use serde::Deserialize; use serde::Deserialize;
use serde_json::{json, Value}; use serde_json::{Value, json};
use std::collections::HashMap; use std::collections::HashMap;
use std::time::Duration; use std::time::Duration;
use crate::bus::message::ContentBlock;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
use super::traits::Usage; use super::traits::Usage;
use std::sync::Arc; use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
use crate::bus::message::ContentBlock;
use crate::storage::Storage; use crate::storage::Storage;
use std::sync::Arc;
const LLM_REQUEST_TIMEOUT_SECS: u64 = 300; const LLM_REQUEST_TIMEOUT_SECS: u64 = 300;
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value { fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
if blocks.len() == 1 if blocks.len() == 1
&& let ContentBlock::Text { text } = &blocks[0] { && let ContentBlock::Text { text } = &blocks[0]
return Value::String(text.clone()); {
} return Value::String(text.clone());
Value::Array(blocks.iter().map(|b| match b { }
ContentBlock::Text { text } => json!({ "type": "text", "text": text }), Value::Array(
ContentBlock::ImageUrl { image_url } => { blocks
json!({ "type": "image_url", "image_url": { "url": image_url.url } }) .iter()
} .map(|b| match b {
}).collect()) ContentBlock::Text { text } => json!({ "type": "text", "text": text }),
ContentBlock::ImageUrl { image_url } => {
json!({ "type": "image_url", "image_url": { "url": image_url.url } })
}
})
.collect(),
)
} }
pub struct OpenAIProvider { pub struct OpenAIProvider {
@ -201,10 +207,14 @@ impl LLMProvider for OpenAIProvider {
if let Some(content) = msg.get("content").and_then(|c| c.as_array()) { if let Some(content) = msg.get("content").and_then(|c| c.as_array()) {
for (j, item) in content.iter().enumerate() { for (j, item) in content.iter().enumerate() {
if item.get("type").and_then(|t| t.as_str()) == Some("image_url") if item.get("type").and_then(|t| t.as_str()) == Some("image_url")
&& let Some(url_str) = item.get("image_url").and_then(|u| u.get("url")).and_then(|v| v.as_str()) { && let Some(url_str) = item
let prefix: String = url_str.chars().take(20).collect(); .get("image_url")
tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)"); .and_then(|u| u.get("url"))
} .and_then(|v| v.as_str())
{
let prefix: String = url_str.chars().take(20).collect();
tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)");
}
} }
} }
} }
@ -224,19 +234,18 @@ impl LLMProvider for OpenAIProvider {
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default(); let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
tracing::debug!(req_body = %req_body_str, "LLM request"); tracing::debug!(req_body = %req_body_str, "LLM request");
let resp = req_builder.json(&body).send().await let resp = req_builder.json(&body).send().await.inspect_err(|e| {
.inspect_err(|e| { let is_timeout = e.is_timeout();
let is_timeout = e.is_timeout(); tracing::error!(
tracing::error!( provider = %self.name,
provider = %self.name, model = %self.model_id,
model = %self.model_id, url = %url,
url = %url, timeout = is_timeout,
timeout = is_timeout, error = %e,
error = %e, elapsed_ms = %start.elapsed().as_millis(),
elapsed_ms = %start.elapsed().as_millis(), "LLM API request failed"
"LLM API request failed" );
); })?;
})?;
let status = resp.status(); let status = resp.status();
let text = resp.text().await?; let text = resp.text().await?;
@ -253,37 +262,48 @@ impl LLMProvider for OpenAIProvider {
"LLM API returned error" "LLM API returned error"
); );
if let Some(ref storage) = self.storage if let Some(ref storage) = self.storage
&& let Err(e) = storage.append_llm_call( && let Err(e) = storage
&self.name, &self.model_id, &req_body_str, .append_llm_call(
Some(&text), Some(&error), &self.name,
start.elapsed().as_millis() as u64, &self.model_id,
).await { &req_body_str,
tracing::warn!("failed to persist LLM call: {}", e); Some(&text),
} Some(&error),
start.elapsed().as_millis() as u64,
)
.await
{
tracing::warn!("failed to persist LLM call: {}", e);
}
return Err(error.into()); return Err(error.into());
} }
let openai_resp: OpenAIResponse = serde_json::from_str(&text) let openai_resp: OpenAIResponse = serde_json::from_str(&text).map_err(|e| {
.map_err(|e| { let err_msg = format!("decode error: {} | body: {}", e, &text);
let err_msg = format!("decode error: {} | body: {}", e, &text); if let Some(ref storage) = self.storage {
if let Some(ref storage) = self.storage { let name = self.name.clone();
let name = self.name.clone(); let model = self.model_id.clone();
let model = self.model_id.clone(); let req = req_body_str.clone();
let req = req_body_str.clone(); let resp = text.clone();
let resp = text.clone(); let dur = start.elapsed().as_millis() as u64;
let dur = start.elapsed().as_millis() as u64; let err = err_msg.clone();
let err = err_msg.clone(); let s = storage.clone();
let s = storage.clone(); tokio::spawn(async move {
tokio::spawn(async move { if let Err(e) = s
if let Err(e) = s.append_llm_call(&name, &model, &req, Some(&resp), Some(&err), dur).await { .append_llm_call(&name, &model, &req, Some(&resp), Some(&err), dur)
tracing::warn!("failed to persist LLM call (decode error): {}", e); .await
} {
}); tracing::warn!("failed to persist LLM call (decode error): {}", e);
} }
err_msg });
})?; }
err_msg
})?;
let first_choice = openai_resp.choices.into_iter().next() let first_choice = openai_resp
.choices
.into_iter()
.next()
.ok_or("no choices in response")?; .ok_or("no choices in response")?;
let content = first_choice let content = first_choice
@ -300,7 +320,8 @@ impl LLMProvider for OpenAIProvider {
.map(|tc| ToolCall { .map(|tc| ToolCall {
id: tc.id.clone(), id: tc.id.clone(),
name: tc.function.name.clone(), name: tc.function.name.clone(),
arguments: serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null), arguments: serde_json::from_str(&tc.function.arguments)
.unwrap_or(serde_json::Value::Null),
}) })
.collect(); .collect();
@ -318,13 +339,19 @@ impl LLMProvider for OpenAIProvider {
}; };
if let Some(ref storage) = self.storage if let Some(ref storage) = self.storage
&& let Err(e) = storage.append_llm_call( && let Err(e) = storage
&self.name, &self.model_id, &req_body_str, .append_llm_call(
Some(&text), None, &self.name,
start.elapsed().as_millis() as u64, &self.model_id,
).await { &req_body_str,
tracing::warn!("failed to persist LLM call: {}", e); Some(&text),
} None,
start.elapsed().as_millis() as u64,
)
.await
{
tracing::warn!("failed to persist LLM call: {}", e);
}
Ok(response) Ok(response)
} }
@ -386,6 +413,9 @@ mod tests {
assert_eq!(tool_calls[0]["id"], "call_1"); assert_eq!(tool_calls[0]["id"], "call_1");
assert_eq!(tool_calls[0]["type"], "function"); assert_eq!(tool_calls[0]["type"], "function");
assert_eq!(tool_calls[0]["function"]["name"], "calculator"); assert_eq!(tool_calls[0]["function"]["name"], "calculator");
assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}"); assert_eq!(
tool_calls[0]["function"]["arguments"],
"{\"expression\":\"1+1\"}"
);
} }
} }

View File

@ -1,6 +1,6 @@
use crate::bus::message::ContentBlock;
use async_trait::async_trait; use async_trait::async_trait;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::bus::message::ContentBlock;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message { pub struct Message {
@ -61,7 +61,11 @@ impl Message {
} }
} }
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self { pub fn tool(
tool_call_id: impl Into<String>,
tool_name: impl Into<String>,
content: impl Into<String>,
) -> Self {
Self { Self {
role: "tool".to_string(), role: "tool".to_string(),
content: vec![ContentBlock::text(content)], content: vec![ContentBlock::text(content)],

View File

@ -5,11 +5,11 @@ use std::time::Instant;
use tokio::time; use tokio::time;
use crate::config::SchedulerConfig; use crate::config::SchedulerConfig;
use crate::session::session::HandleResult;
use crate::session::SessionManager; use crate::session::SessionManager;
use crate::session::session::HandleResult;
use crate::storage::JobRun;
use crate::storage::ScheduledJob; use crate::storage::ScheduledJob;
use crate::storage::Storage; use crate::storage::Storage;
use crate::storage::JobRun;
pub use types::Schedule; pub use types::Schedule;
@ -89,7 +89,11 @@ impl Scheduler {
let now = now_ms(); let now = now_ms();
let due = match self.storage.due_scheduled_jobs(now, self.config.max_concurrent).await { let due = match self
.storage
.due_scheduled_jobs(now, self.config.max_concurrent)
.await
{
Ok(jobs) => jobs, Ok(jobs) => jobs,
Err(e) => { Err(e) => {
tracing::error!("scheduler: failed to query due jobs: {}", e); tracing::error!("scheduler: failed to query due jobs: {}", e);
@ -107,7 +111,11 @@ impl Scheduler {
let start = Instant::now(); let start = Instant::now();
let started_at = now_ms(); let started_at = now_ms();
if let Err(e) = self.storage.touch_scheduled_job_last_run(&job.id, started_at).await { if let Err(e) = self
.storage
.touch_scheduled_job_last_run(&job.id, started_at)
.await
{
tracing::error!(job_id = %job.id, "scheduler: failed to touch last_run_at: {}", e); tracing::error!(job_id = %job.id, "scheduler: failed to touch last_run_at: {}", e);
continue; continue;
} }
@ -135,7 +143,10 @@ impl Scheduler {
match result { match result {
Ok(HandleResult::AgentResponse(output)) => { Ok(HandleResult::AgentResponse(output)) => {
let output_truncated = if output.len() > 8000 { let output_truncated = if output.len() > 8000 {
format!("{}...[truncated]", &output[..output.ceil_char_boundary(8000)]) format!(
"{}...[truncated]",
&output[..output.ceil_char_boundary(8000)]
)
} else { } else {
output.clone() output.clone()
}; };
@ -155,7 +166,11 @@ impl Scheduler {
tracing::error!(job_id = %job.id, "scheduler: failed to record run: {}", e); tracing::error!(job_id = %job.id, "scheduler: failed to record run: {}", e);
} }
if let Err(e) = self.storage.set_scheduled_job_last_status(&job.id, "ok", None).await { if let Err(e) = self
.storage
.set_scheduled_job_last_status(&job.id, "ok", None)
.await
{
tracing::error!(job_id = %job.id, "scheduler: failed to set last_status: {}", e); tracing::error!(job_id = %job.id, "scheduler: failed to set last_status: {}", e);
} }
@ -199,9 +214,11 @@ impl Scheduler {
tracing::error!(job_id = %job.id, "scheduler: failed to record error run: {}", e2); tracing::error!(job_id = %job.id, "scheduler: failed to record error run: {}", e2);
} }
if let Err(e2) = self.storage.set_scheduled_job_last_status( if let Err(e2) = self
&job.id, "error", Some(&error_str), .storage
).await { .set_scheduled_job_last_status(&job.id, "error", Some(&error_str))
.await
{
tracing::error!(job_id = %job.id, "scheduler: failed to set error status: {}", e2); tracing::error!(job_id = %job.id, "scheduler: failed to set error status: {}", e2);
} }
@ -231,17 +248,23 @@ impl Scheduler {
self.storage.remove_scheduled_job(&job.id).await?; self.storage.remove_scheduled_job(&job.id).await?;
tracing::info!(job_id = %job.id, "scheduler: one-shot job deleted after run"); tracing::info!(job_id = %job.id, "scheduler: one-shot job deleted after run");
} else { } else {
self.storage.set_scheduled_job_enabled(&job.id, false).await?; self.storage
.set_scheduled_job_enabled(&job.id, false)
.await?;
tracing::info!(job_id = %job.id, "scheduler: one-shot job disabled after run"); tracing::info!(job_id = %job.id, "scheduler: one-shot job disabled after run");
} }
} }
Schedule::Every { .. } | Schedule::Cron { .. } => { Schedule::Every { .. } | Schedule::Cron { .. } => {
if let Some(next) = next_run_for_schedule(&job.schedule, now) { if let Some(next) = next_run_for_schedule(&job.schedule, now) {
self.storage.set_scheduled_job_next_run(&job.id, next).await?; self.storage
.set_scheduled_job_next_run(&job.id, next)
.await?;
tracing::info!(job_id = %job.id, next_run_at = %next, "scheduler: job rescheduled"); tracing::info!(job_id = %job.id, next_run_at = %next, "scheduler: job rescheduled");
} else { } else {
tracing::error!(job_id = %job.id, "scheduler: could not compute next run -- disabling job"); tracing::error!(job_id = %job.id, "scheduler: could not compute next run -- disabling job");
self.storage.set_scheduled_job_enabled(&job.id, false).await?; self.storage
.set_scheduled_job_enabled(&job.id, false)
.await?;
} }
} }
} }

View File

@ -22,32 +22,20 @@ pub enum SessionCommand {
dialog_id: String, dialog_id: String,
}, },
/// Get the current dialog for a chat /// Get the current dialog for a chat
GetCurrentDialog { GetCurrentDialog { channel: String, chat_id: String },
channel: String,
chat_id: String,
},
/// Rename a dialog /// Rename a dialog
RenameDialog { RenameDialog {
session_id: UnifiedSessionId, session_id: UnifiedSessionId,
title: String, title: String,
}, },
/// Archive a dialog /// Archive a dialog
ArchiveDialog { ArchiveDialog { session_id: UnifiedSessionId },
session_id: UnifiedSessionId,
},
/// Delete a dialog /// Delete a dialog
DeleteDialog { DeleteDialog { session_id: UnifiedSessionId },
session_id: UnifiedSessionId,
},
/// Clear dialog history /// Clear dialog history
ClearHistory { ClearHistory { session_id: UnifiedSessionId },
session_id: UnifiedSessionId,
},
/// Get list of available slash commands /// Get list of available slash commands
GetSlashCommands { GetSlashCommands { channel: String, chat_id: String },
channel: String,
chat_id: String,
},
/// Execute a slash command /// Execute a slash command
ExecuteSlashCommand { ExecuteSlashCommand {
command: String, command: String,
@ -60,7 +48,11 @@ pub enum SessionCommand {
impl SessionCommand { impl SessionCommand {
/// Create a CreateDialog command /// Create a CreateDialog command
pub fn create_dialog(channel: impl Into<String>, chat_id: impl Into<String>, title: Option<String>) -> Self { pub fn create_dialog(
channel: impl Into<String>,
chat_id: impl Into<String>,
title: Option<String>,
) -> Self {
Self::CreateDialog { Self::CreateDialog {
channel: channel.into(), channel: channel.into(),
chat_id: chat_id.into(), chat_id: chat_id.into(),
@ -69,7 +61,11 @@ impl SessionCommand {
} }
/// Create a ListDialogs command /// Create a ListDialogs command
pub fn list_dialogs(channel: impl Into<String>, chat_id: impl Into<String>, include_archived: bool) -> Self { pub fn list_dialogs(
channel: impl Into<String>,
chat_id: impl Into<String>,
include_archived: bool,
) -> Self {
Self::ListDialogs { Self::ListDialogs {
channel: channel.into(), channel: channel.into(),
chat_id: chat_id.into(), chat_id: chat_id.into(),

View File

@ -1,5 +1,5 @@
use super::session_id::UnifiedSessionId;
use super::session::SlashCommand; use super::session::SlashCommand;
use super::session_id::UnifiedSessionId;
/// Dialog information returned by SessionManager /// Dialog information returned by SessionManager
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -30,30 +30,20 @@ pub enum SessionEvent {
session_id: Option<UnifiedSessionId>, session_id: Option<UnifiedSessionId>,
}, },
/// Dialog switched successfully /// Dialog switched successfully
DialogSwitched { DialogSwitched { session_id: UnifiedSessionId },
session_id: UnifiedSessionId,
},
/// Dialog renamed /// Dialog renamed
DialogRenamed { DialogRenamed {
session_id: UnifiedSessionId, session_id: UnifiedSessionId,
title: String, title: String,
}, },
/// Dialog archived /// Dialog archived
DialogArchived { DialogArchived { session_id: UnifiedSessionId },
session_id: UnifiedSessionId,
},
/// Dialog deleted /// Dialog deleted
DialogDeleted { DialogDeleted { session_id: UnifiedSessionId },
session_id: UnifiedSessionId,
},
/// Dialog history cleared /// Dialog history cleared
HistoryCleared { HistoryCleared { session_id: UnifiedSessionId },
session_id: UnifiedSessionId,
},
/// List of available slash commands /// List of available slash commands
SlashCommandsList { SlashCommandsList { commands: Vec<SlashCommand> },
commands: Vec<SlashCommand>,
},
/// Slash command executed successfully /// Slash command executed successfully
SlashCommandExecuted { SlashCommandExecuted {
new_session_id: Option<UnifiedSessionId>, new_session_id: Option<UnifiedSessionId>,
@ -70,8 +60,5 @@ pub enum SessionEvent {
message_count: usize, message_count: usize,
}, },
/// Error occurred /// Error occurred
Error { Error { code: String, message: String },
code: String,
message: String,
},
} }

View File

@ -1,11 +1,11 @@
pub mod error;
pub mod commands; pub mod commands;
pub mod error;
pub mod events; pub mod events;
pub mod session; pub mod session;
pub mod session_id; pub mod session_id;
pub use error::SessionError;
pub use commands::SessionCommand; pub use commands::SessionCommand;
pub use events::{SessionEvent, DialogInfo}; pub use error::SessionError;
pub use session::{Session, SessionManager, SlashCommand, SLASH_COMMANDS}; pub use events::{DialogInfo, SessionEvent};
pub use session::{SLASH_COMMANDS, Session, SessionManager, SlashCommand};
pub use session_id::UnifiedSessionId; pub use session_id::UnifiedSessionId;

View File

@ -8,7 +8,6 @@
/// ///
/// For simple cases where only one dialog exists per chat: /// For simple cases where only one dialog exists per chat:
/// - `dialog_id` defaults to `"default"` /// - `dialog_id` defaults to `"default"`
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
pub const DEFAULT_DIALOG_ID: &str = "default"; pub const DEFAULT_DIALOG_ID: &str = "default";
@ -22,7 +21,11 @@ pub struct UnifiedSessionId {
impl UnifiedSessionId { impl UnifiedSessionId {
/// Create a new UnifiedSessionId /// Create a new UnifiedSessionId
pub fn new(channel: impl Into<String>, chat_id: impl Into<String>, dialog_id: impl Into<String>) -> Self { pub fn new(
channel: impl Into<String>,
chat_id: impl Into<String>,
dialog_id: impl Into<String>,
) -> Self {
Self { Self {
channel: channel.into(), channel: channel.into(),
chat_id: chat_id.into(), chat_id: chat_id.into(),

View File

@ -1,6 +1,6 @@
use std::path::Path; use std::path::Path;
use super::embedded::{EmbeddedSkill, EMBEDDED_SKILLS}; use super::embedded::{EMBEDDED_SKILLS, EmbeddedSkill};
pub fn install_builtin_skills(target_dir: &Path) { pub fn install_builtin_skills(target_dir: &Path) {
for skill in EMBEDDED_SKILLS { for skill in EMBEDDED_SKILLS {
@ -22,8 +22,7 @@ pub fn install_builtin_skills(target_dir: &Path) {
} }
fn install_one(skill: &EmbeddedSkill, target_dir: &Path) -> Result<(), String> { fn install_one(skill: &EmbeddedSkill, target_dir: &Path) -> Result<(), String> {
let decompressed = zstd::decode_all(skill.data) let decompressed = zstd::decode_all(skill.data).map_err(|e| format!("zstd decode: {}", e))?;
.map_err(|e| format!("zstd decode: {}", e))?;
let mut archive = tar::Archive::new(decompressed.as_slice()); let mut archive = tar::Archive::new(decompressed.as_slice());
archive archive

View File

@ -120,7 +120,11 @@ impl SkillsLoader {
let count = loaded.len(); let count = loaded.len();
let mut replaced = 0usize; let mut replaced = 0usize;
for skill in loaded { for skill in loaded {
if let Some(existing) = state.loaded_skills.iter_mut().find(|s| s.name == skill.name) { if let Some(existing) = state
.loaded_skills
.iter_mut()
.find(|s| s.name == skill.name)
{
*existing = skill; *existing = skill;
replaced += 1; replaced += 1;
} else { } else {
@ -138,33 +142,42 @@ impl SkillsLoader {
// Load from workspace skills dir (highest priority) — replace same-name skills // Load from workspace skills dir (highest priority) — replace same-name skills
if let Some(ref ws_dir) = self.workspace_skills_dir if let Some(ref ws_dir) = self.workspace_skills_dir
&& ws_dir.exists() { && ws_dir.exists()
let loaded = self.load_skills_from_dir(ws_dir); {
let count = loaded.len(); let loaded = self.load_skills_from_dir(ws_dir);
let mut replaced = 0usize; let count = loaded.len();
for skill in loaded { let mut replaced = 0usize;
if let Some(existing) = state.loaded_skills.iter_mut().find(|s| s.name == skill.name) { for skill in loaded {
*existing = skill; if let Some(existing) = state
replaced += 1; .loaded_skills
} else { .iter_mut()
state.loaded_skills.push(skill); .find(|s| s.name == skill.name)
} {
*existing = skill;
replaced += 1;
} else {
state.loaded_skills.push(skill);
} }
tracing::debug!(
dir = %ws_dir.display(),
count = count,
replaced = replaced,
"Loaded skills from workspace directory"
);
state.last_workspace_mtime = Self::get_dir_mtime(ws_dir);
} }
tracing::debug!(
dir = %ws_dir.display(),
count = count,
replaced = replaced,
"Loaded skills from workspace directory"
);
state.last_workspace_mtime = Self::get_dir_mtime(ws_dir);
}
state.last_load_time = SystemTime::now(); state.last_load_time = SystemTime::now();
if state.loaded_skills.is_empty() { if state.loaded_skills.is_empty() {
tracing::debug!("No skills found in any skills directory"); tracing::debug!("No skills found in any skills directory");
} else { } else {
tracing::info!(count = state.loaded_skills.len(), "Loaded {} skills total", state.loaded_skills.len()); tracing::info!(
count = state.loaded_skills.len(),
"Loaded {} skills total",
state.loaded_skills.len()
);
} }
} }
@ -215,18 +228,20 @@ impl SkillsLoader {
let mut max_mtime = None; let mut max_mtime = None;
if let Ok(metadata) = std::fs::metadata(dir) if let Ok(metadata) = std::fs::metadata(dir)
&& let Ok(mtime) = metadata.modified() { && let Ok(mtime) = metadata.modified()
max_mtime = Some(mtime); {
} max_mtime = Some(mtime);
}
if let Ok(entries) = std::fs::read_dir(dir) { if let Ok(entries) = std::fs::read_dir(dir) {
for entry in entries.flatten() { for entry in entries.flatten() {
let path = entry.path(); let path = entry.path();
if let Ok(metadata) = std::fs::metadata(&path) if let Ok(metadata) = std::fs::metadata(&path)
&& let Ok(mtime) = metadata.modified() && let Ok(mtime) = metadata.modified()
&& max_mtime.is_none_or(|current| mtime > current) { && max_mtime.is_none_or(|current| mtime > current)
max_mtime = Some(mtime); {
} max_mtime = Some(mtime);
}
} }
} }
@ -244,7 +259,12 @@ impl SkillsLoader {
pub fn get_always_skills(&self) -> Vec<Skill> { pub fn get_always_skills(&self) -> Vec<Skill> {
self.reload_if_changed(); self.reload_if_changed();
let state = self.state.lock().unwrap(); let state = self.state.lock().unwrap();
state.loaded_skills.iter().filter(|s| s.always).cloned().collect() state
.loaded_skills
.iter()
.filter(|s| s.always)
.cloned()
.collect()
} }
/// Get a specific skill by name (checks for changes first) /// Get a specific skill by name (checks for changes first)
@ -258,7 +278,8 @@ impl SkillsLoader {
pub fn list_skills(&self) -> Vec<(String, String)> { pub fn list_skills(&self) -> Vec<(String, String)> {
self.reload_if_changed(); self.reload_if_changed();
let state = self.state.lock().unwrap(); let state = self.state.lock().unwrap();
state.loaded_skills state
.loaded_skills
.iter() .iter()
.map(|s| (s.name.clone(), s.description.clone())) .map(|s| (s.name.clone(), s.description.clone()))
.collect() .collect()
@ -279,15 +300,21 @@ impl SkillsLoader {
prompt.push_str("### 目录说明\n\n"); prompt.push_str("### 目录说明\n\n");
prompt.push_str("- `~/.agents/skills/` — 外部共享 skill 目录(第三方、系统级 skill\n"); prompt.push_str("- `~/.agents/skills/` — 外部共享 skill 目录(第三方、系统级 skill\n");
prompt.push_str("- `~/.picobot/skills/` — 安装 skill 的默认目录\n"); prompt.push_str("- `~/.picobot/skills/` — 安装 skill 的默认目录\n");
prompt.push_str("- `{workspace}/skills/` — 工作目录下的 skillpicobot 自行创建的 skill 存放于此\n\n"); prompt.push_str(
prompt.push_str("安装或创建 skill 时请按上述目录规范存放创建skill时不要和已有skill同名。\n\n"); "- `{workspace}/skills/` — 工作目录下的 skillpicobot 自行创建的 skill 存放于此\n\n",
);
prompt.push_str(
"安装或创建 skill 时请按上述目录规范存放创建skill时不要和已有skill同名。\n\n",
);
// Always skills summary // Always skills summary
let always_skills: Vec<_> = state.loaded_skills.iter().filter(|s| s.always).collect(); let always_skills: Vec<_> = state.loaded_skills.iter().filter(|s| s.always).collect();
if !always_skills.is_empty() { if !always_skills.is_empty() {
prompt.push_str("### 常用技能\n\n"); prompt.push_str("### 常用技能\n\n");
for skill in &always_skills { for skill in &always_skills {
let path_str = skill.path.as_ref() let path_str = skill
.path
.as_ref()
.map(|p| p.to_string_lossy().to_string()) .map(|p| p.to_string_lossy().to_string())
.unwrap_or_else(|| "".to_string()); .unwrap_or_else(|| "".to_string());
prompt.push_str(&format!( prompt.push_str(&format!(
@ -300,8 +327,12 @@ impl SkillsLoader {
// Usage instructions // Usage instructions
prompt.push_str("### 使用方法\n\n"); prompt.push_str("### 使用方法\n\n");
prompt.push_str("- 使用 `get_skill` 工具 action=\"list\" 列出所有可用 skill 及其名称、简介、路径\n"); prompt.push_str(
prompt.push_str("- 使用 `get_skill` 工具 action=\"get\" 并提供 `skill_name` 获取指定 skill 完整内容\n"); "- 使用 `get_skill` 工具 action=\"list\" 列出所有可用 skill 及其名称、简介、路径\n",
);
prompt.push_str(
"- 使用 `get_skill` 工具 action=\"get\" 并提供 `skill_name` 获取指定 skill 完整内容\n",
);
prompt.push_str("- 当用户询问已安装的 skill 相关信息时,需重新调用 get_skill 工具查询最新内容,避免 skill 已变更导致信息过时\n"); prompt.push_str("- 当用户询问已安装的 skill 相关信息时,需重新调用 get_skill 工具查询最新内容,避免 skill 已变更导致信息过时\n");
// Always skills full content // Always skills full content
@ -338,25 +369,23 @@ impl SkillsLoader {
} }
match std::fs::read_to_string(&skill_file) { match std::fs::read_to_string(&skill_file) {
Ok(content) => { Ok(content) => match self.parse_skill(&path, &content) {
match self.parse_skill(&path, &content) { Some(skill) => {
Some(skill) => { tracing::debug!(
tracing::debug!( skill = %skill.name,
skill = %skill.name, path = %skill_file.display(),
path = %skill_file.display(), always = skill.always,
always = skill.always, "Loaded skill"
"Loaded skill" );
); skills.push(skill);
skills.push(skill);
}
None => {
tracing::warn!(
path = %skill_file.display(),
"Failed to parse skill"
);
}
} }
} None => {
tracing::warn!(
path = %skill_file.display(),
"Failed to parse skill"
);
}
},
Err(e) => { Err(e) => {
tracing::warn!( tracing::warn!(
path = %skill_file.display(), path = %skill_file.display(),
@ -447,7 +476,6 @@ impl Default for SkillsLoader {
} }
} }
/// Extract first non-empty, non-heading line as description /// Extract first non-empty, non-heading line as description
fn extract_description(content: &str) -> String { fn extract_description(content: &str) -> String {
content content

View File

@ -241,12 +241,11 @@ impl super::Storage {
let cutoff = chrono::Utc::now() - chrono::Duration::days(retention_days as i64); let cutoff = chrono::Utc::now() - chrono::Duration::days(retention_days as i64);
let cutoff_str = cutoff.to_rfc3339(); let cutoff_str = cutoff.to_rfc3339();
let result = sqlx::query( let result =
"DELETE FROM memories WHERE category = 'timeline' AND created_at < ?", sqlx::query("DELETE FROM memories WHERE category = 'timeline' AND created_at < ?")
) .bind(&cutoff_str)
.bind(&cutoff_str) .execute(self.pool())
.execute(self.pool()) .await?;
.await?;
Ok(result.rows_affected()) Ok(result.rows_affected())
} }
@ -276,9 +275,7 @@ impl super::Storage {
} }
} }
fn parse_memory_rows( fn parse_memory_rows(rows: &[sqlx::sqlite::SqliteRow]) -> Result<Vec<MemoryEntry>, StorageError> {
rows: &[sqlx::sqlite::SqliteRow],
) -> Result<Vec<MemoryEntry>, StorageError> {
rows.iter() rows.iter()
.map(|row| { .map(|row| {
Ok(MemoryEntry { Ok(MemoryEntry {

View File

@ -165,7 +165,11 @@ impl crate::storage::Storage {
} }
/// Update next_run_at and last_run_at for a job. /// Update next_run_at and last_run_at for a job.
pub async fn set_scheduled_job_next_run(&self, id: &str, next_run_at: i64) -> anyhow::Result<()> { pub async fn set_scheduled_job_next_run(
&self,
id: &str,
next_run_at: i64,
) -> anyhow::Result<()> {
let now = now_ms(); let now = now_ms();
sqlx::query( sqlx::query(
"UPDATE scheduled_jobs SET next_run_at = ?, last_run_at = ?, updated_at = ? WHERE id = ?", "UPDATE scheduled_jobs SET next_run_at = ?, last_run_at = ?, updated_at = ? WHERE id = ?",
@ -331,7 +335,9 @@ mod tests {
async fn setup_storage() -> Storage { async fn setup_storage() -> Storage {
let pool = SqlitePool::connect("sqlite::memory:").await.unwrap(); let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
let storage = Storage { pool }; let storage = Storage { pool };
Storage::init_scheduler_schema(storage.pool()).await.unwrap(); Storage::init_scheduler_schema(storage.pool())
.await
.unwrap();
storage storage
} }
@ -450,7 +456,10 @@ mod tests {
updated_at: t, updated_at: t,
}; };
storage.add_scheduled_job(&job).await.unwrap(); storage.add_scheduled_job(&job).await.unwrap();
storage.set_scheduled_job_enabled("job-toggle", false).await.unwrap(); storage
.set_scheduled_job_enabled("job-toggle", false)
.await
.unwrap();
let got = storage.get_scheduled_job("job-toggle").await.unwrap(); let got = storage.get_scheduled_job("job-toggle").await.unwrap();
assert!(!got.enabled); assert!(!got.enabled);
} }
@ -461,31 +470,55 @@ mod tests {
let t = now(); let t = now();
let jobs = vec![ let jobs = vec![
ScheduledJob { ScheduledJob {
id: "due".into(), name: "due".into(), id: "due".into(),
schedule: Schedule::At { at: t }, prompt: "1".into(), name: "due".into(),
channel: "cli_chat".into(), chat_id: "c".into(), schedule: Schedule::At { at: t },
model: None, enabled: true, delete_after_run: false, prompt: "1".into(),
next_run_at: t - 1000, last_run_at: None, channel: "cli_chat".into(),
last_status: None, last_error: None, chat_id: "c".into(),
created_at: t, updated_at: t, model: None,
enabled: true,
delete_after_run: false,
next_run_at: t - 1000,
last_run_at: None,
last_status: None,
last_error: None,
created_at: t,
updated_at: t,
}, },
ScheduledJob { ScheduledJob {
id: "future".into(), name: "future".into(), id: "future".into(),
schedule: Schedule::At { at: t + 99999999 }, prompt: "2".into(), name: "future".into(),
channel: "cli_chat".into(), chat_id: "c".into(), schedule: Schedule::At { at: t + 99999999 },
model: None, enabled: true, delete_after_run: false, prompt: "2".into(),
next_run_at: t + 99999999, last_run_at: None, channel: "cli_chat".into(),
last_status: None, last_error: None, chat_id: "c".into(),
created_at: t, updated_at: t, model: None,
enabled: true,
delete_after_run: false,
next_run_at: t + 99999999,
last_run_at: None,
last_status: None,
last_error: None,
created_at: t,
updated_at: t,
}, },
ScheduledJob { ScheduledJob {
id: "disabled-due".into(), name: "disabled due".into(), id: "disabled-due".into(),
schedule: Schedule::At { at: t }, prompt: "3".into(), name: "disabled due".into(),
channel: "cli_chat".into(), chat_id: "c".into(), schedule: Schedule::At { at: t },
model: None, enabled: false, delete_after_run: false, prompt: "3".into(),
next_run_at: t - 1000, last_run_at: None, channel: "cli_chat".into(),
last_status: None, last_error: None, chat_id: "c".into(),
created_at: t, updated_at: t, model: None,
enabled: false,
delete_after_run: false,
next_run_at: t - 1000,
last_run_at: None,
last_status: None,
last_error: None,
created_at: t,
updated_at: t,
}, },
]; ];
for j in &jobs { for j in &jobs {
@ -501,24 +534,39 @@ mod tests {
let storage = setup_storage().await; let storage = setup_storage().await;
let t = now(); let t = now();
let job = ScheduledJob { let job = ScheduledJob {
id: "job-run".into(), name: "run test".into(), id: "job-run".into(),
name: "run test".into(),
schedule: Schedule::Every { every_ms: 1000 }, schedule: Schedule::Every { every_ms: 1000 },
prompt: "hi".into(), channel: "cli_chat".into(), chat_id: "c".into(), prompt: "hi".into(),
model: None, enabled: true, delete_after_run: false, channel: "cli_chat".into(),
next_run_at: t, last_run_at: None, chat_id: "c".into(),
last_status: None, last_error: None, model: None,
created_at: t, updated_at: t, enabled: true,
delete_after_run: false,
next_run_at: t,
last_run_at: None,
last_status: None,
last_error: None,
created_at: t,
updated_at: t,
}; };
storage.add_scheduled_job(&job).await.unwrap(); storage.add_scheduled_job(&job).await.unwrap();
let run = super::JobRun { let run = super::JobRun {
id: 0, job_id: "job-run".into(), id: 0,
started_at: t, finished_at: t + 500, job_id: "job-run".into(),
status: "ok".into(), output: Some("hello".into()), started_at: t,
error: None, duration_ms: 500, finished_at: t + 500,
status: "ok".into(),
output: Some("hello".into()),
error: None,
duration_ms: 500,
}; };
storage.record_scheduled_job_run(&run).await.unwrap(); storage.record_scheduled_job_run(&run).await.unwrap();
let runs = storage.list_scheduled_job_runs("job-run", 10).await.unwrap(); let runs = storage
.list_scheduled_job_runs("job-run", 10)
.await
.unwrap();
assert_eq!(runs.len(), 1); assert_eq!(runs.len(), 1);
assert_eq!(runs[0].status, "ok"); assert_eq!(runs[0].status, "ok");
assert_eq!(runs[0].output.as_deref(), Some("hello")); assert_eq!(runs[0].output.as_deref(), Some("hello"));
@ -529,22 +577,34 @@ mod tests {
let storage = setup_storage().await; let storage = setup_storage().await;
let t = now(); let t = now();
let job = ScheduledJob { let job = ScheduledJob {
id: "job-update".into(), name: "old name".into(), id: "job-update".into(),
name: "old name".into(),
schedule: Schedule::Every { every_ms: 1000 }, schedule: Schedule::Every { every_ms: 1000 },
prompt: "old prompt".into(), channel: "feishu".into(), prompt: "old prompt".into(),
chat_id: "oc_1".into(), model: None, channel: "feishu".into(),
enabled: true, delete_after_run: false, chat_id: "oc_1".into(),
next_run_at: t, last_run_at: None, model: None,
last_status: None, last_error: None, enabled: true,
created_at: t, updated_at: t, delete_after_run: false,
next_run_at: t,
last_run_at: None,
last_status: None,
last_error: None,
created_at: t,
updated_at: t,
}; };
storage.add_scheduled_job(&job).await.unwrap(); storage.add_scheduled_job(&job).await.unwrap();
storage.update_scheduled_job( storage
"job-update", .update_scheduled_job(
Some("new prompt".into()), "job-update",
Some(Schedule::Every { every_ms: 60000 }), Some("new prompt".into()),
None, None, None, Some(Schedule::Every { every_ms: 60000 }),
).await.unwrap(); None,
None,
None,
)
.await
.unwrap();
let got = storage.get_scheduled_job("job-update").await.unwrap(); let got = storage.get_scheduled_job("job-update").await.unwrap();
assert_eq!(got.prompt, "new prompt"); assert_eq!(got.prompt, "new prompt");
} }

View File

@ -167,10 +167,7 @@ impl Tool for BashTool {
Err(_) => Ok(ToolResult { Err(_) => Ok(ToolResult {
success: false, success: false,
output: String::new(), output: String::new(),
error: Some(format!( error: Some(format!("Command timed out after {} seconds", timeout_secs)),
"Command timed out after {} seconds",
timeout_secs
)),
}), }),
} }
} }
@ -249,10 +246,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_pwd_command() { async fn test_pwd_command() {
let tool = BashTool::new(); let tool = BashTool::new();
let result = tool let result = tool.execute(json!({ "command": "pwd" })).await.unwrap();
.execute(json!({ "command": "pwd" }))
.await
.unwrap();
assert!(result.success); assert!(result.success);
} }
@ -260,7 +254,10 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_ls_command() { async fn test_ls_command() {
let tool = BashTool::new(); let tool = BashTool::new();
let result = tool.execute(json!({ "command": "ls -la /tmp" })).await.unwrap(); let result = tool
.execute(json!({ "command": "ls -la /tmp" }))
.await
.unwrap();
assert!(result.success); assert!(result.success);
} }

View File

@ -5,7 +5,7 @@ use std::time::Duration;
use anyhow::Context; use anyhow::Context;
use async_trait::async_trait; use async_trait::async_trait;
use base64::Engine; use base64::Engine;
use fantoccini::actions::{InputSource, MouseActions, PointerAction, MOUSE_BUTTON_LEFT}; use fantoccini::actions::{InputSource, MOUSE_BUTTON_LEFT, MouseActions, PointerAction};
use fantoccini::key::Key; use fantoccini::key::Key;
use fantoccini::{Client, ClientBuilder, Locator}; use fantoccini::{Client, ClientBuilder, Locator};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -63,7 +63,9 @@ impl BrowserTool {
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum BrowserAction { pub enum BrowserAction {
Open { url: String }, Open {
url: String,
},
Snapshot { Snapshot {
#[serde(default)] #[serde(default)]
interactive_only: bool, interactive_only: bool,
@ -72,10 +74,20 @@ pub enum BrowserAction {
#[serde(default)] #[serde(default)]
depth: Option<i64>, depth: Option<i64>,
}, },
Click { selector: String }, Click {
Fill { selector: String, value: String }, selector: String,
Type { selector: Option<String>, text: String }, },
GetText { selector: String }, Fill {
selector: String,
value: String,
},
Type {
selector: Option<String>,
text: String,
},
GetText {
selector: String,
},
GetTitle, GetTitle,
GetUrl, GetUrl,
Screenshot { Screenshot {
@ -84,7 +96,9 @@ pub enum BrowserAction {
#[serde(default)] #[serde(default)]
return_base64: bool, return_base64: bool,
}, },
Focus { selector: String }, Focus {
selector: String,
},
Wait { Wait {
#[serde(default)] #[serde(default)]
selector: Option<String>, selector: Option<String>,
@ -93,9 +107,16 @@ pub enum BrowserAction {
#[serde(default)] #[serde(default)]
text: Option<String>, text: Option<String>,
}, },
Press { key: String }, Press {
Hover { selector: String }, key: String,
ClickAt { x: u32, y: u32 }, },
Hover {
selector: String,
},
ClickAt {
x: u32,
y: u32,
},
Scroll { Scroll {
direction: String, direction: String,
#[serde(default)] #[serde(default)]
@ -120,13 +141,8 @@ fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result<Browse
.get("interactive_only") .get("interactive_only")
.and_then(Value::as_bool) .and_then(Value::as_bool)
.unwrap_or(true), .unwrap_or(true),
compact: args compact: args.get("compact").and_then(Value::as_bool).unwrap_or(true),
.get("compact") depth: args.get("depth").and_then(|v| v.as_i64()),
.and_then(Value::as_bool)
.unwrap_or(true),
depth: args
.get("depth")
.and_then(|v| v.as_i64()),
}), }),
"click" => { "click" => {
let selector = args let selector = args
@ -198,10 +214,7 @@ fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result<Browse
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.map(String::from), .map(String::from),
ms: args.get("ms").and_then(|v| v.as_u64()), ms: args.get("ms").and_then(|v| v.as_u64()),
text: args text: args.get("text").and_then(|v| v.as_str()).map(String::from),
.get("text")
.and_then(|v| v.as_str())
.map(String::from),
}), }),
"press" => { "press" => {
let key = args let key = args
@ -239,11 +252,13 @@ fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result<Browse
let x = args let x = args
.get("x") .get("x")
.and_then(|v| v.as_u64()) .and_then(|v| v.as_u64())
.ok_or_else(|| anyhow::anyhow!("Missing 'x' for click_at"))? as u32; .ok_or_else(|| anyhow::anyhow!("Missing 'x' for click_at"))?
as u32;
let y = args let y = args
.get("y") .get("y")
.and_then(|v| v.as_u64()) .and_then(|v| v.as_u64())
.ok_or_else(|| anyhow::anyhow!("Missing 'y' for click_at"))? as u32; .ok_or_else(|| anyhow::anyhow!("Missing 'y' for click_at"))?
as u32;
Ok(BrowserAction::ClickAt { x, y }) Ok(BrowserAction::ClickAt { x, y })
} }
other => anyhow::bail!("Unsupported browser action: {}", other), other => anyhow::bail!("Unsupported browser action: {}", other),
@ -488,7 +503,11 @@ impl BrowserState {
} }
Err(e) => return Err(e.into()), Err(e) => return Err(e.into()),
} }
tracing::debug!(action = "fill", output_len = value.len(), "Browser action completed"); tracing::debug!(
action = "fill",
output_len = value.len(),
"Browser action completed"
);
Ok(ToolResult { Ok(ToolResult {
success: true, success: true,
output: format!("Filled {} with {}", selector, value), output: format!("Filled {} with {}", selector, value),
@ -573,7 +592,10 @@ impl BrowserState {
error: None, error: None,
}) })
} }
BrowserAction::Screenshot { path, return_base64 } => { BrowserAction::Screenshot {
path,
return_base64,
} => {
let client = self.active_client()?; let client = self.active_client()?;
let png = client.screenshot().await?; let png = client.screenshot().await?;
let save_path = path.unwrap_or_else(|| { let save_path = path.unwrap_or_else(|| {
@ -588,14 +610,25 @@ impl BrowserState {
tokio::fs::write(&save_path, &png).await?; tokio::fs::write(&save_path, &png).await?;
if return_base64 { if return_base64 {
let b64 = base64::engine::general_purpose::STANDARD.encode(&png); let b64 = base64::engine::general_purpose::STANDARD.encode(&png);
tracing::debug!(action = "screenshot", output_len = b64.len(), "Browser action completed"); tracing::debug!(
action = "screenshot",
output_len = b64.len(),
"Browser action completed"
);
return Ok(ToolResult { return Ok(ToolResult {
success: true, success: true,
output: format!("Screenshot saved to {}. Base64: data:image/png;base64,{}", save_path, b64), output: format!(
"Screenshot saved to {}. Base64: data:image/png;base64,{}",
save_path, b64
),
error: None, error: None,
}); });
} }
tracing::debug!(action = "screenshot", output_len = save_path.len(), "Browser action completed"); tracing::debug!(
action = "screenshot",
output_len = save_path.len(),
"Browser action completed"
);
Ok(ToolResult { Ok(ToolResult {
success: true, success: true,
output: format!("Screenshot saved to {}", save_path), output: format!("Screenshot saved to {}", save_path),
@ -611,18 +644,18 @@ impl BrowserState {
vec![serde_json::to_value(el)?], vec![serde_json::to_value(el)?],
) )
.await?; .await?;
tracing::debug!(action = "focus", output_len = selector.len(), "Browser action completed"); tracing::debug!(
action = "focus",
output_len = selector.len(),
"Browser action completed"
);
Ok(ToolResult { Ok(ToolResult {
success: true, success: true,
output: format!("Focused {}", selector), output: format!("Focused {}", selector),
error: None, error: None,
}) })
} }
BrowserAction::Wait { BrowserAction::Wait { selector, ms, text } => {
selector,
ms,
text,
} => {
if let Some(sel) = selector { if let Some(sel) = selector {
let client = self.active_client()?; let client = self.active_client()?;
wait_for_selector(client, &sel).await?; wait_for_selector(client, &sel).await?;
@ -719,9 +752,21 @@ impl BrowserState {
let id = info.get("id").and_then(|v| v.as_str()).unwrap_or(""); let id = info.get("id").and_then(|v| v.as_str()).unwrap_or("");
let el_type = info.get("type").and_then(|v| v.as_str()).unwrap_or(""); let el_type = info.get("type").and_then(|v| v.as_str()).unwrap_or("");
let text = info.get("text").and_then(|v| v.as_str()).unwrap_or(""); let text = info.get("text").and_then(|v| v.as_str()).unwrap_or("");
let id_str = if id.is_empty() { String::new() } else { format!("#{id}") }; let id_str = if id.is_empty() {
let type_str = if el_type.is_empty() { String::new() } else { format!("[type={el_type}]") }; String::new()
let text_str = if text.is_empty() { String::new() } else { format!(" ({text})") }; } else {
format!("#{id}")
};
let type_str = if el_type.is_empty() {
String::new()
} else {
format!("[type={el_type}]")
};
let text_str = if text.is_empty() {
String::new()
} else {
format!(" ({text})")
};
format!("Clicked at ({x},{y}) on <{tag}{id_str}{type_str}>{text_str}") format!("Clicked at ({x},{y}) on <{tag}{id_str}{type_str}>{text_str}")
} }
None => format!("Clicked at ({}, {})", x, y), None => format!("Clicked at ({}, {})", x, y),
@ -1090,10 +1135,7 @@ fn css_attr_escape(input: &str) -> String {
} }
fn xpath_contains_text(text: &str) -> String { fn xpath_contains_text(text: &str) -> String {
format!( format!("//*[contains(normalize-space(.), {})]", xpath_literal(text))
"//*[contains(normalize-space(.), {})]",
xpath_literal(text)
)
} }
fn xpath_literal(input: &str) -> String { fn xpath_literal(input: &str) -> String {
@ -1140,7 +1182,10 @@ fn webdriver_key(key: &str) -> String {
"pagedown" => Key::PageDown.to_string(), "pagedown" => Key::PageDown.to_string(),
"space" => " ".to_string(), "space" => " ".to_string(),
other => { other => {
tracing::warn!("Unrecognized key '{}', this will have no effect (press only supports single named keys)", other); tracing::warn!(
"Unrecognized key '{}', this will have no effect (press only supports single named keys)",
other
);
other.to_string() other.to_string()
} }
} }

View File

@ -659,10 +659,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_evaluate_missing_expression() { async fn test_evaluate_missing_expression() {
let tool = CalculatorTool::new(); let tool = CalculatorTool::new();
let result = tool let result = tool.execute(json!({"function": "evaluate"})).await.unwrap();
.execute(json!({"function": "evaluate"}))
.await
.unwrap();
assert!(!result.success); assert!(!result.success);
} }

View File

@ -31,10 +31,7 @@ impl ContentSearchTool {
for (i, line) in lines.iter().enumerate() { for (i, line) in lines.iter().enumerate() {
if output.len() + line.len() + 1 > MAX_OUTPUT_CHARS { if output.len() + line.len() + 1 > MAX_OUTPUT_CHARS {
let omitted = lines.len() - i; let omitted = lines.len() - i;
output.push_str(&format!( output.push_str(&format!("\n... ({} matches omitted) ...", omitted));
"\n... ({} matches omitted) ...",
omitted
));
break; break;
} }
if !output.is_empty() { if !output.is_empty() {
@ -113,18 +110,40 @@ impl Tool for ContentSearchTool {
let dir = self.resolve_dir(args.get("dir").and_then(|v| v.as_str())); let dir = self.resolve_dir(args.get("dir").and_then(|v| v.as_str()));
let file_pattern = args.get("file_pattern").and_then(|v| v.as_str()); let file_pattern = args.get("file_pattern").and_then(|v| v.as_str());
let case_sensitive = args.get("case_sensitive").and_then(|v| v.as_bool()).unwrap_or(false); let case_sensitive = args
let context_lines = args.get("context_lines").and_then(|v| v.as_u64()).unwrap_or(0) as usize; .get("case_sensitive")
let max_results = args.get("max_results").and_then(|v| v.as_u64()).unwrap_or(MAX_RESULTS as u64) as usize; .and_then(|v| v.as_bool())
.unwrap_or(false);
let context_lines = args
.get("context_lines")
.and_then(|v| v.as_u64())
.unwrap_or(0) as usize;
let max_results = args
.get("max_results")
.and_then(|v| v.as_u64())
.unwrap_or(MAX_RESULTS as u64) as usize;
let result = self.run_search(pattern, &dir, file_pattern, case_sensitive, context_lines, max_results).await; let result = self
.run_search(
pattern,
&dir,
file_pattern,
case_sensitive,
context_lines,
max_results,
)
.await;
match result { match result {
Ok(lines) => { Ok(lines) => {
let count = lines.len(); let count = lines.len();
let mut output = self.truncate_output(&lines); let mut output = self.truncate_output(&lines);
output.push_str(&format!("\n\n---\n{} 条匹配", count)); output.push_str(&format!("\n\n---\n{} 条匹配", count));
Ok(ToolResult { success: true, output, error: None }) Ok(ToolResult {
success: true,
output,
error: None,
})
} }
Err(e) => Ok(ToolResult { Err(e) => Ok(ToolResult {
success: false, success: false,
@ -146,22 +165,52 @@ impl ContentSearchTool {
max_results: usize, max_results: usize,
) -> anyhow::Result<Vec<String>> { ) -> anyhow::Result<Vec<String>> {
if which::which("rg").is_ok() { if which::which("rg").is_ok() {
match self.search_with_rg(pattern, dir, file_pattern, case_sensitive, context_lines, max_results).await { match self
.search_with_rg(
pattern,
dir,
file_pattern,
case_sensitive,
context_lines,
max_results,
)
.await
{
Ok(lines) => return Ok(lines), Ok(lines) => return Ok(lines),
Err(e) => tracing::warn!("rg failed: {}, falling back", e), Err(e) => tracing::warn!("rg failed: {}, falling back", e),
} }
} }
if which::which("grep").is_ok() { if which::which("grep").is_ok() {
match self.search_with_grep(pattern, dir, file_pattern, case_sensitive, context_lines, max_results).await { match self
.search_with_grep(
pattern,
dir,
file_pattern,
case_sensitive,
context_lines,
max_results,
)
.await
{
Ok(lines) if !lines.is_empty() => return Ok(lines), Ok(lines) if !lines.is_empty() => return Ok(lines),
Ok(_) => {}, Ok(_) => {}
Err(e) => tracing::warn!("grep failed: {}, falling back", e), Err(e) => tracing::warn!("grep failed: {}, falling back", e),
} }
} }
tracing::warn!("No rg/grep available, using built-in content search (much slower). Install ripgrep for better performance."); tracing::warn!(
self.search_with_rust(pattern, dir, file_pattern, case_sensitive, context_lines, max_results).await "No rg/grep available, using built-in content search (much slower). Install ripgrep for better performance."
);
self.search_with_rust(
pattern,
dir,
file_pattern,
case_sensitive,
context_lines,
max_results,
)
.await
} }
async fn search_with_rg( async fn search_with_rg(
@ -176,8 +225,10 @@ impl ContentSearchTool {
let mut cmd = Command::new("rg"); let mut cmd = Command::new("rg");
cmd.arg("-n") cmd.arg("-n")
.arg("--no-heading") .arg("--no-heading")
.arg("--color").arg("never") .arg("--color")
.arg("--max-count").arg(max_results.to_string()) .arg("never")
.arg("--max-count")
.arg(max_results.to_string())
.arg(pattern) .arg(pattern)
.arg(dir) .arg(dir)
.stdout(Stdio::piped()) .stdout(Stdio::piped())
@ -193,12 +244,9 @@ impl ContentSearchTool {
cmd.arg("--glob").arg(fp); cmd.arg("--glob").arg(fp);
} }
let output = timeout( let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
std::time::Duration::from_secs(TIMEOUT_SECS), .await
cmd.output(), .map_err(|_| anyhow::anyhow!("rg timed out after {}s", TIMEOUT_SECS))??;
)
.await
.map_err(|_| anyhow::anyhow!("rg timed out after {}s", TIMEOUT_SECS))??;
if !output.status.success() && output.status.code() != Some(1) { if !output.status.success() && output.status.code() != Some(1) {
let stderr = String::from_utf8_lossy(&output.stderr); let stderr = String::from_utf8_lossy(&output.stderr);
@ -206,7 +254,8 @@ impl ContentSearchTool {
} }
let text = String::from_utf8_lossy(&output.stdout); let text = String::from_utf8_lossy(&output.stdout);
let lines: Vec<String> = text.lines() let lines: Vec<String> = text
.lines()
.take(max_results) .take(max_results)
.map(|l| l.to_string()) .map(|l| l.to_string())
.collect(); .collect();
@ -242,15 +291,13 @@ impl ContentSearchTool {
cmd.arg("--include").arg(fp); cmd.arg("--include").arg(fp);
} }
let output = timeout( let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
std::time::Duration::from_secs(TIMEOUT_SECS), .await
cmd.output(), .map_err(|_| anyhow::anyhow!("grep timed out after {}s", TIMEOUT_SECS))??;
)
.await
.map_err(|_| anyhow::anyhow!("grep timed out after {}s", TIMEOUT_SECS))??;
let text = String::from_utf8_lossy(&output.stdout); let text = String::from_utf8_lossy(&output.stdout);
let lines: Vec<String> = text.lines() let lines: Vec<String> = text
.lines()
.take(max_results) .take(max_results)
.map(|l| l.to_string()) .map(|l| l.to_string())
.collect(); .collect();
@ -280,7 +327,9 @@ impl ContentSearchTool {
if case_sensitive { if case_sensitive {
regex::Regex::new(&re_str) regex::Regex::new(&re_str)
} else { } else {
regex::RegexBuilder::new(&re_str).case_insensitive(true).build() regex::RegexBuilder::new(&re_str)
.case_insensitive(true)
.build()
} }
}); });
@ -291,7 +340,14 @@ impl ContentSearchTool {
}; };
let mut results = Vec::new(); let mut results = Vec::new();
grep_dir(Path::new(dir), Path::new(dir), &re, file_re.as_ref(), &mut results, max_results)?; grep_dir(
Path::new(dir),
Path::new(dir),
&re,
file_re.as_ref(),
&mut results,
max_results,
)?;
Ok(results) Ok(results)
} }
@ -350,16 +406,19 @@ fn grep_dir(
if path.is_dir() { if path.is_dir() {
if let Some(name) = rel.file_name().and_then(|n| n.to_str()) if let Some(name) = rel.file_name().and_then(|n| n.to_str())
&& name.starts_with('.') && name.len() > 1 { && name.starts_with('.')
continue; && name.len() > 1
} {
continue;
}
grep_dir(base, &path, re, file_re, results, max)?; grep_dir(base, &path, re, file_re, results, max)?;
} else if path.is_file() { } else if path.is_file() {
if let Some(file_re) = file_re if let Some(file_re) = file_re
&& let Some(name) = rel.file_name().and_then(|n| n.to_str()) && let Some(name) = rel.file_name().and_then(|n| n.to_str())
&& !file_re.is_match(name) { && !file_re.is_match(name)
continue; {
} continue;
}
if let Ok(content) = std::fs::read_to_string(&path) { if let Ok(content) = std::fs::read_to_string(&path) {
for (line_num, line) in content.lines().enumerate() { for (line_num, line) in content.lines().enumerate() {
@ -391,8 +450,16 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_content_search_rust_fallback() { async fn test_content_search_rust_fallback() {
let dir = TempDir::new().unwrap(); let dir = TempDir::new().unwrap();
fs::write(dir.path().join("main.rs"), "fn main() {\n let x = 42;\n println!(\"hello\");\n}").unwrap(); fs::write(
fs::write(dir.path().join("lib.rs"), "pub fn foo() -> u32 {\n let y = 42;\n y\n}").unwrap(); dir.path().join("main.rs"),
"fn main() {\n let x = 42;\n println!(\"hello\");\n}",
)
.unwrap();
fs::write(
dir.path().join("lib.rs"),
"pub fn foo() -> u32 {\n let y = 42;\n y\n}",
)
.unwrap();
fs::write(dir.path().join("README.md"), "# Project\nHello world").unwrap(); fs::write(dir.path().join("README.md"), "# Project\nHello world").unwrap();
let tool = ContentSearchTool::new(); let tool = ContentSearchTool::new();

View File

@ -1,10 +1,10 @@
use std::sync::Arc; use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use serde_json::{json, Value}; use serde_json::{Value, json};
use uuid::Uuid; use uuid::Uuid;
use crate::scheduler::{next_run_for_schedule, Schedule}; use crate::scheduler::{Schedule, next_run_for_schedule};
use crate::storage::{ScheduledJob, Storage}; use crate::storage::{ScheduledJob, Storage};
use crate::tools::traits::{Tool, ToolResult}; use crate::tools::traits::{Tool, ToolResult};
@ -229,10 +229,7 @@ impl Tool for CronListTool {
} }
async fn execute(&self, args: Value) -> anyhow::Result<ToolResult> { async fn execute(&self, args: Value) -> anyhow::Result<ToolResult> {
let filter = args let filter = args.get("status").and_then(|v| v.as_str()).unwrap_or("all");
.get("status")
.and_then(|v| v.as_str())
.unwrap_or("all");
let jobs = self.storage.list_scheduled_jobs().await?; let jobs = self.storage.list_scheduled_jobs().await?;
let filtered: Vec<&ScheduledJob> = match filter { let filtered: Vec<&ScheduledJob> = match filter {
@ -397,7 +394,9 @@ impl Tool for CronEnableTool {
.map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?; .map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?;
let next = next_run_for_schedule(&job.schedule, now_ms()); let next = next_run_for_schedule(&job.schedule, now_ms());
self.storage.set_scheduled_job_enabled(&job_id, true).await?; self.storage
.set_scheduled_job_enabled(&job_id, true)
.await?;
if let Some(n) = next { if let Some(n) = next {
self.storage.set_scheduled_job_next_run(&job_id, n).await?; self.storage.set_scheduled_job_next_run(&job_id, n).await?;
} }
@ -464,7 +463,9 @@ impl Tool for CronDisableTool {
.get_scheduled_job(&job_id) .get_scheduled_job(&job_id)
.await .await
.map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?; .map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?;
self.storage.set_scheduled_job_enabled(&job_id, false).await?; self.storage
.set_scheduled_job_enabled(&job_id, false)
.await?;
Ok(ToolResult { Ok(ToolResult {
success: true, success: true,
@ -580,7 +581,9 @@ impl Tool for CronUpdateTool {
if args.get("schedule").is_some() { if args.get("schedule").is_some() {
let job = self.storage.get_scheduled_job(&job_id).await?; let job = self.storage.get_scheduled_job(&job_id).await?;
if let Some(next) = next_run_for_schedule(&job.schedule, now_ms()) { if let Some(next) = next_run_for_schedule(&job.schedule, now_ms()) {
self.storage.set_scheduled_job_next_run(&job_id, next).await?; self.storage
.set_scheduled_job_next_run(&job_id, next)
.await?;
} }
} }
@ -765,9 +768,7 @@ mod tests {
let job = ScheduledJob { let job = ScheduledJob {
id: "job-update-tool".into(), id: "job-update-tool".into(),
name: "old".into(), name: "old".into(),
schedule: Schedule::Every { schedule: Schedule::Every { every_ms: 3600000 },
every_ms: 3600000,
},
prompt: "old prompt".into(), prompt: "old prompt".into(),
channel: "feishu".into(), channel: "feishu".into(),
chat_id: "oc_1".into(), chat_id: "oc_1".into(),

View File

@ -102,7 +102,10 @@ impl Tool for DelegateTool {
_ => Ok(ToolResult { _ => Ok(ToolResult {
success: false, success: false,
output: String::new(), output: String::new(),
error: Some(format!("Unknown action: {}. Supported: run, check_task, cancel_task, list_tasks", action)), error: Some(format!(
"Unknown action: {}. Supported: run, check_task, cancel_task, list_tasks",
action
)),
}), }),
} }
} }
@ -115,9 +118,11 @@ impl DelegateTool {
.ok_or_else(|| anyhow::anyhow!("missing required parameter: prompt"))? .ok_or_else(|| anyhow::anyhow!("missing required parameter: prompt"))?
.to_string(); .to_string();
let allowed_tools: Option<Vec<String>> = args["allowed_tools"] let allowed_tools: Option<Vec<String>> = args["allowed_tools"].as_array().map(|arr| {
.as_array() arr.iter()
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect()); .filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
});
let max_iterations = args["max_iterations"].as_u64().map(|v| v as usize); let max_iterations = args["max_iterations"].as_u64().map(|v| v as usize);
let timeout_secs = args["timeout_secs"].as_u64(); let timeout_secs = args["timeout_secs"].as_u64();
@ -141,15 +146,21 @@ impl DelegateTool {
return Ok(ToolResult { return Ok(ToolResult {
success: false, success: false,
output: String::new(), output: String::new(),
error: Some(format!("unknown mode: {}. Supported: inline, background, parallel", mode_str)), error: Some(format!(
}) "unknown mode: {}. Supported: inline, background, parallel",
mode_str
)),
});
} }
}; };
match mode { match mode {
ExecutionMode::Inline => { ExecutionMode::Inline => {
let config = self.parse_config_from_args(args)?; let config = self.parse_config_from_args(args)?;
let result = self.sub_agent_manager.run_inline(config).await let result = self
.sub_agent_manager
.run_inline(config)
.await
.map_err(|e| anyhow::anyhow!("{}", e))?; .map_err(|e| anyhow::anyhow!("{}", e))?;
match result.status { match result.status {
@ -177,10 +188,14 @@ impl DelegateTool {
} }
ExecutionMode::Background => { ExecutionMode::Background => {
let config = self.parse_config_from_args(args)?; let config = self.parse_config_from_args(args)?;
let ctx = crate::agent::sub_agent::get_delegate_context() let ctx = crate::agent::sub_agent::get_delegate_context().map_err(|_| {
.map_err(|_| anyhow::anyhow!("delegate context not available: not in an agent worker"))?; anyhow::anyhow!("delegate context not available: not in an agent worker")
})?;
let task_id = self.sub_agent_manager.run_background(config, ctx).await let task_id = self
.sub_agent_manager
.run_background(config, ctx)
.await
.map_err(|e| anyhow::anyhow!("{}", e))?; .map_err(|e| anyhow::anyhow!("{}", e))?;
Ok(ToolResult { Ok(ToolResult {
@ -200,9 +215,12 @@ impl DelegateTool {
.as_str() .as_str()
.ok_or_else(|| anyhow::anyhow!("each parallel task requires 'prompt'"))? .ok_or_else(|| anyhow::anyhow!("each parallel task requires 'prompt'"))?
.to_string(); .to_string();
let allowed_tools: Option<Vec<String>> = task["allowed_tools"] let allowed_tools: Option<Vec<String>> =
.as_array() task["allowed_tools"].as_array().map(|arr| {
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect()); arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
});
configs.push(SubAgentConfig { configs.push(SubAgentConfig {
prompt, prompt,
@ -216,13 +234,18 @@ impl DelegateTool {
let has_args_allowed = args["allowed_tools"].as_array().is_some(); let has_args_allowed = args["allowed_tools"].as_array().is_some();
for c in &mut configs { for c in &mut configs {
if c.allowed_tools.is_none() && has_args_allowed { if c.allowed_tools.is_none() && has_args_allowed {
c.allowed_tools = args["allowed_tools"] c.allowed_tools = args["allowed_tools"].as_array().map(|arr| {
.as_array() arr.iter()
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect()); .filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
});
} }
} }
let results = self.sub_agent_manager.run_parallel(configs).await let results = self
.sub_agent_manager
.run_parallel(configs)
.await
.map_err(|e| anyhow::anyhow!("{}", e))?; .map_err(|e| anyhow::anyhow!("{}", e))?;
let mut output = String::new(); let mut output = String::new();
@ -243,7 +266,9 @@ impl DelegateTool {
} }
} }
let all_success = results.iter().all(|r| matches!(r.status, TaskStatus::Completed)); let all_success = results
.iter()
.all(|r| matches!(r.status, TaskStatus::Completed));
Ok(ToolResult { Ok(ToolResult {
success: all_success, success: all_success,
output: output.trim().to_string(), output: output.trim().to_string(),

View File

@ -243,8 +243,8 @@ impl Tool for FileEditTool {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use tempfile::NamedTempFile;
use std::io::Write; use std::io::Write;
use tempfile::NamedTempFile;
#[tokio::test] #[tokio::test]
async fn test_edit_simple() { async fn test_edit_simple() {

View File

@ -181,10 +181,7 @@ impl Tool for FileReadTool {
} }
result = lines[..end_idx].join("\n"); result = lines[..end_idx].join("\n");
let truncated = original_len - result.len(); let truncated = original_len - result.len();
result.push_str(&format!( result.push_str(&format!("\n\n... ({} chars truncated) ...", truncated));
"\n\n... ({} chars truncated) ...",
truncated
));
} }
if end < total { if end < total {
@ -196,10 +193,7 @@ impl Tool for FileReadTool {
end + 1 end + 1
)); ));
} else { } else {
result.push_str(&format!( result.push_str(&format!("\n\n(End of file — {} lines total)", total));
"\n\n(End of file — {} lines total)",
total
));
} }
if let Some(label) = encoding_label { if let Some(label) = encoding_label {
@ -214,7 +208,7 @@ impl Tool for FileReadTool {
} }
None => { None => {
// Truly binary file — base64 encode // Truly binary file — base64 encode
use base64::{engine::general_purpose::STANDARD, Engine}; use base64::{Engine, engine::general_purpose::STANDARD};
let encoded = STANDARD.encode(&bytes); let encoded = STANDARD.encode(&bytes);
let mime = mime_guess::from_path(&resolved) let mime = mime_guess::from_path(&resolved)
.first_or_octet_stream() .first_or_octet_stream()
@ -278,8 +272,8 @@ fn decode_text(bytes: &[u8]) -> (Option<String>, Option<&'static str>) {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use tempfile::NamedTempFile;
use std::io::Write; use std::io::Write;
use tempfile::NamedTempFile;
#[tokio::test] #[tokio::test]
async fn test_read_simple_file() { async fn test_read_simple_file() {
@ -338,10 +332,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_is_directory() { async fn test_is_directory() {
let tool = FileReadTool::new(); let tool = FileReadTool::new();
let result = tool let result = tool.execute(json!({ "path": "." })).await.unwrap();
.execute(json!({ "path": "." }))
.await
.unwrap();
assert!(!result.success); assert!(!result.success);
assert!(result.error.unwrap().contains("Not a file")); assert!(result.error.unwrap().contains("Not a file"));

View File

@ -101,17 +101,29 @@ impl Tool for FileSearchTool {
}; };
let dir = self.resolve_dir(args.get("dir").and_then(|v| v.as_str())); let dir = self.resolve_dir(args.get("dir").and_then(|v| v.as_str()));
let case_sensitive = args.get("case_sensitive").and_then(|v| v.as_bool()).unwrap_or(true); let case_sensitive = args
let max_results = args.get("max_results").and_then(|v| v.as_u64()).unwrap_or(MAX_RESULTS as u64) as usize; .get("case_sensitive")
.and_then(|v| v.as_bool())
.unwrap_or(true);
let max_results = args
.get("max_results")
.and_then(|v| v.as_u64())
.unwrap_or(MAX_RESULTS as u64) as usize;
let result = self.run_search(pattern, &dir, case_sensitive, max_results).await; let result = self
.run_search(pattern, &dir, case_sensitive, max_results)
.await;
match result { match result {
Ok(lines) => { Ok(lines) => {
let count = lines.len(); let count = lines.len();
let mut output = self.truncate_output(&lines); let mut output = self.truncate_output(&lines);
output.push_str(&format!("\n\n---\n{} 个文件", count)); output.push_str(&format!("\n\n---\n{} 个文件", count));
Ok(ToolResult { success: true, output, error: None }) Ok(ToolResult {
success: true,
output,
error: None,
})
} }
Err(e) => Ok(ToolResult { Err(e) => Ok(ToolResult {
success: false, success: false,
@ -139,9 +151,12 @@ impl FileSearchTool {
}; };
if !fd_cmd.is_empty() { if !fd_cmd.is_empty() {
match self.search_with_fd(pattern, dir, case_sensitive, max_results, fd_cmd).await { match self
.search_with_fd(pattern, dir, case_sensitive, max_results, fd_cmd)
.await
{
Ok(lines) if !lines.is_empty() => return Ok(lines), Ok(lines) if !lines.is_empty() => return Ok(lines),
Ok(_) => {}, Ok(_) => {}
Err(e) => tracing::warn!("{} failed: {}, falling back", fd_cmd, e), Err(e) => tracing::warn!("{} failed: {}, falling back", fd_cmd, e),
} }
} }
@ -149,13 +164,14 @@ impl FileSearchTool {
if which::which("find").is_ok() { if which::which("find").is_ok() {
match self.search_with_find(pattern, dir, max_results).await { match self.search_with_find(pattern, dir, max_results).await {
Ok(lines) if !lines.is_empty() => return Ok(lines), Ok(lines) if !lines.is_empty() => return Ok(lines),
Ok(_) => {}, Ok(_) => {}
Err(e) => tracing::warn!("find failed: {}, falling back", e), Err(e) => tracing::warn!("find failed: {}, falling back", e),
} }
} }
tracing::warn!("No fd/find available, using built-in file search (slower)"); tracing::warn!("No fd/find available, using built-in file search (slower)");
self.search_with_rust(pattern, dir, case_sensitive, max_results).await self.search_with_rust(pattern, dir, case_sensitive, max_results)
.await
} }
async fn search_with_fd( async fn search_with_fd(
@ -167,11 +183,15 @@ impl FileSearchTool {
fd_cmd: &str, fd_cmd: &str,
) -> anyhow::Result<Vec<String>> { ) -> anyhow::Result<Vec<String>> {
let mut cmd = Command::new(fd_cmd); let mut cmd = Command::new(fd_cmd);
cmd.arg("--search-path").arg(dir) cmd.arg("--search-path")
.arg("--glob").arg(pattern) .arg(dir)
.arg("--color").arg("never") .arg("--glob")
.arg(pattern)
.arg("--color")
.arg("never")
.arg("--strip-cwd-prefix") .arg("--strip-cwd-prefix")
.arg("--max-results").arg(max_results.to_string()) .arg("--max-results")
.arg(max_results.to_string())
.stdout(Stdio::piped()) .stdout(Stdio::piped())
.stderr(Stdio::piped()); .stderr(Stdio::piped());
@ -179,12 +199,9 @@ impl FileSearchTool {
cmd.arg("--ignore-case"); cmd.arg("--ignore-case");
} }
let output = timeout( let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
std::time::Duration::from_secs(TIMEOUT_SECS), .await
cmd.output(), .map_err(|_| anyhow::anyhow!("fd timed out after {}s", TIMEOUT_SECS))??;
)
.await
.map_err(|_| anyhow::anyhow!("fd timed out after {}s", TIMEOUT_SECS))??;
if !output.status.success() { if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr); let stderr = String::from_utf8_lossy(&output.stderr);
@ -192,7 +209,8 @@ impl FileSearchTool {
} }
let text = String::from_utf8_lossy(&output.stdout); let text = String::from_utf8_lossy(&output.stdout);
let lines: Vec<String> = text.lines() let lines: Vec<String> = text
.lines()
.filter(|l| !l.is_empty()) .filter(|l| !l.is_empty())
.map(|l| l.to_string()) .map(|l| l.to_string())
.collect(); .collect();
@ -215,15 +233,13 @@ impl FileSearchTool {
.stdout(Stdio::piped()) .stdout(Stdio::piped())
.stderr(Stdio::null()); .stderr(Stdio::null());
let output = timeout( let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
std::time::Duration::from_secs(TIMEOUT_SECS), .await
cmd.output(), .map_err(|_| anyhow::anyhow!("find timed out after {}s", TIMEOUT_SECS))??;
)
.await
.map_err(|_| anyhow::anyhow!("find timed out after {}s", TIMEOUT_SECS))??;
let text = String::from_utf8_lossy(&output.stdout); let text = String::from_utf8_lossy(&output.stdout);
let mut lines: Vec<String> = text.lines() let mut lines: Vec<String> = text
.lines()
.filter(|l| !l.is_empty()) .filter(|l| !l.is_empty())
.map(|l| { .map(|l| {
let p = Path::new(l); let p = Path::new(l);
@ -254,7 +270,13 @@ impl FileSearchTool {
.map_err(|e| anyhow::anyhow!("Invalid glob pattern '{}': {}", pattern, e))?; .map_err(|e| anyhow::anyhow!("Invalid glob pattern '{}': {}", pattern, e))?;
let mut results = Vec::new(); let mut results = Vec::new();
walk_dir(Path::new(dir), Path::new(dir), &re, &mut results, max_results)?; walk_dir(
Path::new(dir),
Path::new(dir),
&re,
&mut results,
max_results,
)?;
Ok(results) Ok(results)
} }
} }
@ -311,15 +333,18 @@ fn walk_dir(
if path.is_dir() { if path.is_dir() {
if let Some(name) = rel.file_name().and_then(|n| n.to_str()) if let Some(name) = rel.file_name().and_then(|n| n.to_str())
&& name.starts_with('.') && name.len() > 1 { && name.starts_with('.')
continue; && name.len() > 1
} {
continue;
}
walk_dir(base, &path, re, results, max)?; walk_dir(base, &path, re, results, max)?;
} else if path.is_file() { } else if path.is_file() {
if let Some(name) = rel.file_name().and_then(|n| n.to_str()) if let Some(name) = rel.file_name().and_then(|n| n.to_str())
&& re.is_match(name) { && re.is_match(name)
results.push(rel.to_string_lossy().to_string()); {
} results.push(rel.to_string_lossy().to_string());
}
if results.len() >= max { if results.len() >= max {
return Ok(()); return Ok(());
} }

View File

@ -90,13 +90,14 @@ impl Tool for FileWriteTool {
// Create parent directories if needed // Create parent directories if needed
if let Some(parent) = resolved.parent() if let Some(parent) = resolved.parent()
&& !parent.exists() && !parent.exists()
&& let Err(e) = std::fs::create_dir_all(parent) { && let Err(e) = std::fs::create_dir_all(parent)
return Ok(ToolResult { {
success: false, return Ok(ToolResult {
output: String::new(), success: false,
error: Some(format!("Failed to create parent directory: {}", e)), output: String::new(),
}); error: Some(format!("Failed to create parent directory: {}", e)),
} });
}
match std::fs::write(&resolved, content) { match std::fs::write(&resolved, content) {
Ok(_) => Ok(ToolResult { Ok(_) => Ok(ToolResult {
@ -168,10 +169,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_write_missing_path() { async fn test_write_missing_path() {
let tool = FileWriteTool::new(); let tool = FileWriteTool::new();
let result = tool let result = tool.execute(json!({ "content": "Hello" })).await.unwrap();
.execute(json!({ "content": "Hello" }))
.await
.unwrap();
assert!(!result.success); assert!(!result.success);
assert!(result.error.unwrap().contains("path")); assert!(result.error.unwrap().contains("path"));

View File

@ -129,7 +129,9 @@ impl GetSkillTool {
let mut output = format!("可用 skill (共 {} 个):\n", skills.len()); let mut output = format!("可用 skill (共 {} 个):\n", skills.len());
for s in &skills { for s in &skills {
let always_mark = if s.always { " [常驻]" } else { "" }; let always_mark = if s.always { " [常驻]" } else { "" };
let path_str = s.path.as_ref() let path_str = s
.path
.as_ref()
.map(|p| p.to_string_lossy().to_string()) .map(|p| p.to_string_lossy().to_string())
.unwrap_or_else(|| "".to_string()); .unwrap_or_else(|| "".to_string());
output.push_str(&format!( output.push_str(&format!(
@ -148,10 +150,10 @@ impl GetSkillTool {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use tempfile::tempdir;
use std::fs::File; use std::fs::File;
use std::io::Write; use std::io::Write;
use std::path::PathBuf; use std::path::PathBuf;
use tempfile::tempdir;
#[tokio::test] #[tokio::test]
async fn test_get_existing_skill() { async fn test_get_existing_skill() {

View File

@ -50,10 +50,7 @@ impl HttpRequestTool {
} }
if !host_matches_allowlist(&host, &self.allowed_domains) { if !host_matches_allowlist(&host, &self.allowed_domains) {
return Err(format!( return Err(format!("Host '{}' is not in allowed_domains", host));
"Host '{}' is not in allowed_domains",
host
));
} }
Ok(url.to_string()) Ok(url.to_string())
@ -80,11 +77,10 @@ impl HttpRequestTool {
for (key, value) in obj { for (key, value) in obj {
if let Some(str_val) = value.as_str() if let Some(str_val) = value.as_str()
&& let Ok(name) = reqwest::header::HeaderName::from_bytes(key.as_bytes()) && let Ok(name) = reqwest::header::HeaderName::from_bytes(key.as_bytes())
&& let Ok(val) = && let Ok(val) = reqwest::header::HeaderValue::from_str(str_val)
reqwest::header::HeaderValue::from_str(str_val) {
{ header_map.insert(name, val);
header_map.insert(name, val); }
}
} }
} }
@ -191,7 +187,9 @@ fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool {
allowed_domains.iter().any(|domain| { allowed_domains.iter().any(|domain| {
host == domain host == domain
|| host.strip_suffix(domain).is_some_and(|prefix| prefix.ends_with('.')) || host
.strip_suffix(domain)
.is_some_and(|prefix| prefix.ends_with('.'))
}) })
} }
@ -202,7 +200,11 @@ fn is_private_host(host: &str) -> bool {
} }
// Check .local TLD // Check .local TLD
if host.rsplit('.').next().is_some_and(|label| label == "local") { if host
.rsplit('.')
.next()
.is_some_and(|label| label == "local")
{
return true; return true;
} }
@ -224,9 +226,7 @@ fn is_private_ip(ip: &std::net::IpAddr) -> bool {
|| v4.is_broadcast() || v4.is_broadcast()
|| v4.is_multicast() || v4.is_multicast()
} }
std::net::IpAddr::V6(v6) => { std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(),
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
}
} }
} }
@ -278,10 +278,7 @@ impl Tool for HttpRequestTool {
} }
}; };
let method_str = args let method_str = args.get("method").and_then(|v| v.as_str()).unwrap_or("GET");
.get("method")
.and_then(|v| v.as_str())
.unwrap_or("GET");
let headers_val = args.get("headers").cloned().unwrap_or(json!({})); let headers_val = args.get("headers").cloned().unwrap_or(json!({}));
let body = args.get("body").and_then(|v| v.as_str()); let body = args.get("body").and_then(|v| v.as_str());

View File

@ -151,10 +151,19 @@ impl Tool for MemoryRecallTool {
.and_then(|v| v.as_i64()) .and_then(|v| v.as_i64())
.unwrap_or(chrono::Utc::now().timestamp_millis()); .unwrap_or(chrono::Utc::now().timestamp_millis());
self.memory self.memory
.recall_by_time(since, until, Some(query), limit, Some(MemoryCategory::Knowledge), None) .recall_by_time(
since,
until,
Some(query),
limit,
Some(MemoryCategory::Knowledge),
None,
)
.await? .await?
} else { } else {
self.memory.recall(query, limit, Some(MemoryCategory::Knowledge), None).await? self.memory
.recall(query, limit, Some(MemoryCategory::Knowledge), None)
.await?
}; };
if entries.is_empty() { if entries.is_empty() {
@ -168,7 +177,11 @@ impl Tool for MemoryRecallTool {
let formatted = entries let formatted = entries
.iter() .iter()
.map(|e| { .map(|e| {
let session = e.session_id.as_deref().map(|s| format!(" [session: {}]", s)).unwrap_or_default(); let session = e
.session_id
.as_deref()
.map(|s| format!(" [session: {}]", s))
.unwrap_or_default();
format!( format!(
"- {} [{}]{} [importance: {:.1}]: {}", "- {} [{}]{} [importance: {:.1}]: {}",
e.key, e.key,
@ -264,10 +277,19 @@ impl Tool for TimelineRecallTool {
.and_then(|v| v.as_i64()) .and_then(|v| v.as_i64())
.unwrap_or(chrono::Utc::now().timestamp_millis()); .unwrap_or(chrono::Utc::now().timestamp_millis());
self.memory self.memory
.recall_by_time(since, until, Some(query), limit, Some(MemoryCategory::Timeline), session_id) .recall_by_time(
since,
until,
Some(query),
limit,
Some(MemoryCategory::Timeline),
session_id,
)
.await? .await?
} else { } else {
self.memory.recall(query, limit, Some(MemoryCategory::Timeline), session_id).await? self.memory
.recall(query, limit, Some(MemoryCategory::Timeline), session_id)
.await?
}; };
if entries.is_empty() { if entries.is_empty() {
@ -281,7 +303,11 @@ impl Tool for TimelineRecallTool {
let formatted = entries let formatted = entries
.iter() .iter()
.map(|e| { .map(|e| {
let session = e.session_id.as_deref().map(|s| format!(" [session: {}]", s)).unwrap_or_default(); let session = e
.session_id
.as_deref()
.map(|s| format!(" [session: {}]", s))
.unwrap_or_default();
format!( format!(
"- {} [{}]{} [importance: {:.1}]: {}", "- {} [{}]{} [importance: {:.1}]: {}",
e.key, e.key,

View File

@ -37,11 +37,11 @@ pub use send_message::SendMessageTool;
pub use traits::{OutboundMessenger, Tool, ToolResult}; pub use traits::{OutboundMessenger, Tool, ToolResult};
pub use web_fetch::WebFetchTool; pub use web_fetch::WebFetchTool;
use std::sync::Arc;
use crate::agent::SubAgentManager; use crate::agent::SubAgentManager;
use crate::config::BrowserConfig; use crate::config::BrowserConfig;
use crate::memory::MemoryManager; use crate::memory::MemoryManager;
use crate::skills::SkillsLoader; use crate::skills::SkillsLoader;
use std::sync::Arc;
/// Create the base tool registry (without send_message). /// Create the base tool registry (without send_message).
/// `send_message` tool is registered later via `SessionManager::register_outbound_tool()` /// `send_message` tool is registered later via `SessionManager::register_outbound_tool()`

View File

@ -17,7 +17,10 @@ impl ToolRegistry {
} }
pub fn register<T: ToolTrait + 'static>(&self, tool: T) { pub fn register<T: ToolTrait + 'static>(&self, tool: T) {
self.tools.lock().unwrap().insert(tool.name().to_string(), Arc::new(tool)); self.tools
.lock()
.unwrap()
.insert(tool.name().to_string(), Arc::new(tool));
} }
/// Register an existing Arc-wrapped tool by name /// Register an existing Arc-wrapped tool by name

View File

@ -115,9 +115,11 @@ impl SchemaCleanr {
} }
if let Some(Value::String(t)) = obj.get("type") if let Some(Value::String(t)) = obj.get("type")
&& t == "object" && !obj.contains_key("properties") { && t == "object"
tracing::warn!("Object schema without 'properties' field may cause issues"); && !obj.contains_key("properties")
} {
tracing::warn!("Object schema without 'properties' field may cause issues");
}
Ok(()) Ok(())
} }
@ -173,9 +175,10 @@ impl SchemaCleanr {
// Handle anyOf/oneOf simplification // Handle anyOf/oneOf simplification
if (obj.contains_key("anyOf") || obj.contains_key("oneOf")) if (obj.contains_key("anyOf") || obj.contains_key("oneOf"))
&& let Some(simplified) = Self::try_simplify_union(&obj, defs, strategy, ref_stack) { && let Some(simplified) = Self::try_simplify_union(&obj, defs, strategy, ref_stack)
return simplified; {
} return simplified;
}
// Build cleaned object // Build cleaned object
let mut cleaned = Map::new(); let mut cleaned = Map::new();
@ -243,12 +246,13 @@ impl SchemaCleanr {
} }
if let Some(def_name) = Self::parse_local_ref(ref_value) if let Some(def_name) = Self::parse_local_ref(ref_value)
&& let Some(definition) = defs.get(def_name.as_str()) { && let Some(definition) = defs.get(def_name.as_str())
ref_stack.insert(ref_value.to_string()); {
let cleaned = Self::clean_with_defs(definition.clone(), defs, strategy, ref_stack); ref_stack.insert(ref_value.to_string());
ref_stack.remove(ref_value); let cleaned = Self::clean_with_defs(definition.clone(), defs, strategy, ref_stack);
return Self::preserve_meta(obj, cleaned); ref_stack.remove(ref_value);
} return Self::preserve_meta(obj, cleaned);
}
tracing::warn!("Cannot resolve $ref: {}", ref_value); tracing::warn!("Cannot resolve $ref: {}", ref_value);
Self::preserve_meta(obj, Value::Object(Map::new())) Self::preserve_meta(obj, Value::Object(Map::new()))
@ -340,13 +344,16 @@ impl SchemaCleanr {
return true; return true;
} }
if let Some(Value::Array(arr)) = obj.get("enum") if let Some(Value::Array(arr)) = obj.get("enum")
&& arr.len() == 1 && matches!(arr[0], Value::Null) { && arr.len() == 1
return true; && matches!(arr[0], Value::Null)
} {
return true;
}
if let Some(Value::String(t)) = obj.get("type") if let Some(Value::String(t)) = obj.get("type")
&& t == "null" { && t == "null"
return true; {
} return true;
}
} }
false false
} }
@ -403,7 +410,10 @@ impl SchemaCleanr {
match non_null.len() { match non_null.len() {
0 => Value::String("null".to_string()), 0 => Value::String("null".to_string()),
1 => non_null.into_iter().next().unwrap_or(Value::String("null".to_string())), 1 => non_null
.into_iter()
.next()
.unwrap_or(Value::String("null".to_string())),
_ => Value::Array(non_null), _ => Value::Array(non_null),
} }
} else { } else {

View File

@ -1,5 +1,5 @@
use std::sync::Arc;
use std::collections::HashSet; use std::collections::HashSet;
use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use mime_guess::mime; use mime_guess::mime;
@ -31,14 +31,20 @@ fn parse_target_chat_id(raw: &str) -> Result<(&str, &str, Option<&str>), String>
match parts.len() { match parts.len() {
2 => { 2 => {
if parts[0].is_empty() || parts[1].is_empty() { if parts[0].is_empty() || parts[1].is_empty() {
Err(format!("Invalid target_chat_id format '{}': channel and chat_id must not be empty", raw)) Err(format!(
"Invalid target_chat_id format '{}': channel and chat_id must not be empty",
raw
))
} else { } else {
Ok((parts[0], parts[1], None)) Ok((parts[0], parts[1], None))
} }
} }
3 => { 3 => {
if parts[0].is_empty() || parts[1].is_empty() || parts[2].is_empty() { if parts[0].is_empty() || parts[1].is_empty() || parts[2].is_empty() {
Err(format!("Invalid target_chat_id format '{}': all three parts must not be empty", raw)) Err(format!(
"Invalid target_chat_id format '{}': all three parts must not be empty",
raw
))
} else { } else {
Ok((parts[0], parts[1], Some(parts[2]))) Ok((parts[0], parts[1], Some(parts[2])))
} }
@ -98,8 +104,8 @@ target_chat_id 支持两种格式:<channel>:<chat_id>(发送到该聊天下
.ok_or_else(|| anyhow::anyhow!("missing content"))?; .ok_or_else(|| anyhow::anyhow!("missing content"))?;
// 1. Parse target_chat_id // 1. Parse target_chat_id
let (channel, chat_id, dialog_id) = parse_target_chat_id(raw_id) let (channel, chat_id, dialog_id) =
.map_err(|e| anyhow::anyhow!(e))?; parse_target_chat_id(raw_id).map_err(|e| anyhow::anyhow!(e))?;
// 2. Validate channel // 2. Validate channel
if !self.available_channels.contains(channel) { if !self.available_channels.contains(channel) {
@ -109,7 +115,11 @@ target_chat_id 支持两种格式:<channel>:<chat_id>(发送到该聊天下
error: Some(format!( error: Some(format!(
"Channel '{}' is not available. Available channels: {}", "Channel '{}' is not available. Available channels: {}",
channel, channel,
self.available_channels.iter().cloned().collect::<Vec<_>>().join(", ") self.available_channels
.iter()
.cloned()
.collect::<Vec<_>>()
.join(", ")
)), )),
}); });
} }
@ -129,7 +139,8 @@ target_chat_id 支持两种格式:<channel>:<chat_id>(发送到该聊天下
let media = parse_files_arg(&args); let media = parse_files_arg(&args);
// 4. Send via messenger // 4. Send via messenger
match self.messenger match self
.messenger
.send_message(channel, chat_id, dialog_id, content, source, media) .send_message(channel, chat_id, dialog_id, content, source, media)
.await .await
{ {

View File

@ -1,5 +1,5 @@
use async_trait::async_trait;
use crate::bus::{MediaItem, MessageSource}; use crate::bus::{MediaItem, MessageSource};
use async_trait::async_trait;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ToolResult { pub struct ToolResult {

View File

@ -239,7 +239,11 @@ fn is_private_host(host: &str) -> bool {
return true; return true;
} }
if host.rsplit('.').next().is_some_and(|label| label == "local") { if host
.rsplit('.')
.next()
.is_some_and(|label| label == "local")
{
return true; return true;
} }
@ -248,7 +252,9 @@ fn is_private_host(host: &str) -> bool {
std::net::IpAddr::V4(v4) => { std::net::IpAddr::V4(v4) => {
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified() v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified()
} }
std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(), std::net::IpAddr::V6(v6) => {
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
}
}; };
} }

View File

@ -1,6 +1,6 @@
use std::collections::HashMap;
use picobot::providers::{create_provider, ChatCompletionRequest, Message};
use picobot::config::{Config, LLMProviderConfig}; use picobot::config::{Config, LLMProviderConfig};
use picobot::providers::{ChatCompletionRequest, Message, create_provider};
use std::collections::HashMap;
fn load_config() -> Option<LLMProviderConfig> { fn load_config() -> Option<LLMProviderConfig> {
dotenv::from_filename("tests/test.env").ok()?; dotenv::from_filename("tests/test.env").ok()?;
@ -42,8 +42,7 @@ fn create_request(content: &str) -> ChatCompletionRequest {
#[tokio::test] #[tokio::test]
#[ignore] #[ignore]
async fn test_openai_simple_completion() { async fn test_openai_simple_completion() {
let config = load_config() let config = load_config().expect("Please configure tests/test.env with valid API keys");
.expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider"); let provider = create_provider(config).expect("Failed to create provider");
let response = provider.chat(create_request("Say 'ok'")).await.unwrap(); let response = provider.chat(create_request("Say 'ok'")).await.unwrap();
@ -57,8 +56,7 @@ async fn test_openai_simple_completion() {
#[tokio::test] #[tokio::test]
#[ignore] #[ignore]
async fn test_openai_conversation() { async fn test_openai_conversation() {
let config = load_config() let config = load_config().expect("Please configure tests/test.env with valid API keys");
.expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider"); let provider = create_provider(config).expect("Failed to create provider");
@ -82,7 +80,9 @@ async fn test_openai_conversation() {
async fn test_config_load() { async fn test_config_load() {
// Test that config.json can be loaded and provider config created // Test that config.json can be loaded and provider config created
let config = Config::load("config.json").expect("Failed to load config.json"); let config = Config::load("config.json").expect("Failed to load config.json");
let provider_config = config.get_provider_config("default").expect("Failed to get provider config"); let provider_config = config
.get_provider_config("default")
.expect("Failed to get provider config");
assert_eq!(provider_config.provider_type, "openai"); assert_eq!(provider_config.provider_type, "openai");
assert_eq!(provider_config.name, "aliyun"); assert_eq!(provider_config.name, "aliyun");

View File

@ -41,7 +41,7 @@ async fn test_scheduler_types_roundtrip() {
/// Verify that next_run_for_schedule produces valid future timestamps. /// Verify that next_run_for_schedule produces valid future timestamps.
#[test] #[test]
fn test_next_run_always_future() { fn test_next_run_always_future() {
use picobot::scheduler::{next_run_for_schedule, Schedule}; use picobot::scheduler::{Schedule, next_run_for_schedule};
let now = 1700000000000_i64; let now = 1700000000000_i64;
@ -56,6 +56,10 @@ fn test_next_run_always_future() {
for s in &schedules { for s in &schedules {
let next = next_run_for_schedule(s, now); let next = next_run_for_schedule(s, now);
assert!(next.is_some(), "expected next run for {:?}", s); assert!(next.is_some(), "expected next run for {:?}", s);
assert!(next.unwrap() > now, "next run should be after now for {:?}", s); assert!(
next.unwrap() > now,
"next run should be after now for {:?}",
s
);
} }
} }

View File

@ -1,6 +1,6 @@
use std::collections::HashMap;
use picobot::providers::{create_provider, ChatCompletionRequest, Message, Tool, ToolFunction};
use picobot::config::LLMProviderConfig; use picobot::config::LLMProviderConfig;
use picobot::providers::{ChatCompletionRequest, Message, Tool, ToolFunction, create_provider};
use std::collections::HashMap;
fn load_openai_config() -> Option<LLMProviderConfig> { fn load_openai_config() -> Option<LLMProviderConfig> {
dotenv::from_filename("tests/test.env").ok()?; dotenv::from_filename("tests/test.env").ok()?;
@ -53,8 +53,7 @@ fn make_weather_tool() -> Tool {
#[tokio::test] #[tokio::test]
#[ignore] #[ignore]
async fn test_openai_tool_call() { async fn test_openai_tool_call() {
let config = load_openai_config() let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
.expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider"); let provider = create_provider(config).expect("Failed to create provider");
@ -68,7 +67,11 @@ async fn test_openai_tool_call() {
let response = provider.chat(request).await.unwrap(); let response = provider.chat(request).await.unwrap();
// Should have tool calls // Should have tool calls
assert!(!response.tool_calls.is_empty(), "Expected tool call, got: {}", response.content); assert!(
!response.tool_calls.is_empty(),
"Expected tool call, got: {}",
response.content
);
let tool_call = &response.tool_calls[0]; let tool_call = &response.tool_calls[0];
assert_eq!(tool_call.name, "get_weather"); assert_eq!(tool_call.name, "get_weather");
@ -78,8 +81,7 @@ async fn test_openai_tool_call() {
#[tokio::test] #[tokio::test]
#[ignore] #[ignore]
async fn test_openai_tool_call_with_manual_execution() { async fn test_openai_tool_call_with_manual_execution() {
let config = load_openai_config() let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
.expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider"); let provider = create_provider(config).expect("Failed to create provider");
@ -92,8 +94,7 @@ async fn test_openai_tool_call_with_manual_execution() {
}; };
let response1 = provider.chat(request1).await.unwrap(); let response1 = provider.chat(request1).await.unwrap();
let tool_call = response1.tool_calls.first() let tool_call = response1.tool_calls.first().expect("Expected tool call");
.expect("Expected tool call");
assert_eq!(tool_call.name, "get_weather"); assert_eq!(tool_call.name, "get_weather");
// Second request with tool result // Second request with tool result
@ -116,8 +117,7 @@ async fn test_openai_tool_call_with_manual_execution() {
#[tokio::test] #[tokio::test]
#[ignore] #[ignore]
async fn test_openai_no_tool_when_not_provided() { async fn test_openai_no_tool_when_not_provided() {
let config = load_openai_config() let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
.expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider"); let provider = create_provider(config).expect("Failed to create provider");