Format codebase with rustfmt
This commit is contained in:
parent
c6f4392e63
commit
8f4ee79d8d
@ -4,10 +4,8 @@ use crate::agent::system_prompt::build_system_prompt;
|
||||
use crate::bus::message::ContentBlock;
|
||||
use crate::bus::{ChatMessage, MediaRef};
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::observability::{
|
||||
truncate_args, Observer, ObserverEvent, ToolExecutionOutcome,
|
||||
};
|
||||
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall};
|
||||
use crate::observability::{Observer, ObserverEvent, ToolExecutionOutcome, truncate_args};
|
||||
use crate::providers::{ChatCompletionRequest, LLMProvider, Message, ToolCall, create_provider};
|
||||
use crate::tools::ToolRegistry;
|
||||
use std::collections::VecDeque;
|
||||
use std::hash::{Hash, Hasher};
|
||||
@ -256,7 +254,10 @@ impl AgentLoop {
|
||||
}
|
||||
|
||||
/// Create a new AgentLoop with provider created from config and given tools.
|
||||
pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>) -> Result<Self, AgentError> {
|
||||
pub fn with_tools(
|
||||
provider_config: LLMProviderConfig,
|
||||
tools: Arc<ToolRegistry>,
|
||||
) -> Result<Self, AgentError> {
|
||||
let max_iterations = provider_config.max_tool_iterations;
|
||||
let model_name = provider_config.model_id.clone();
|
||||
let workspace_dir = provider_config.workspace_dir.clone();
|
||||
@ -279,7 +280,13 @@ impl AgentLoop {
|
||||
}
|
||||
|
||||
/// Create a new AgentLoop with an existing shared provider.
|
||||
pub fn with_provider(provider: Arc<dyn LLMProvider>, max_iterations: usize, model_name: String, workspace_dir: PathBuf, input_types: Vec<String>) -> Self {
|
||||
pub fn with_provider(
|
||||
provider: Arc<dyn LLMProvider>,
|
||||
max_iterations: usize,
|
||||
model_name: String,
|
||||
workspace_dir: PathBuf,
|
||||
input_types: Vec<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
provider,
|
||||
tools: Arc::new(ToolRegistry::new()),
|
||||
@ -379,7 +386,12 @@ impl AgentLoop {
|
||||
let content = if m.media_refs.is_empty() {
|
||||
vec![ContentBlock::text(&m.content)]
|
||||
} else {
|
||||
build_content_blocks(&m.content, &m.media_refs, &self.input_types, &self.media_registry)
|
||||
build_content_blocks(
|
||||
&m.content,
|
||||
&m.media_refs,
|
||||
&self.input_types,
|
||||
&self.media_registry,
|
||||
)
|
||||
};
|
||||
|
||||
Message {
|
||||
@ -399,14 +411,28 @@ impl AgentLoop {
|
||||
/// it loops back to the LLM with the tool results until either:
|
||||
/// - The LLM returns no more tool calls (final response)
|
||||
/// - Maximum iterations are reached
|
||||
pub async fn process(&self, mut messages: Vec<ChatMessage>) -> Result<AgentProcessResult, AgentError> {
|
||||
pub async fn process(
|
||||
&self,
|
||||
mut messages: Vec<ChatMessage>,
|
||||
) -> Result<AgentProcessResult, AgentError> {
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process");
|
||||
tracing::debug!(
|
||||
history_len = messages.len(),
|
||||
max_iterations = self.max_iterations,
|
||||
"Starting agent process"
|
||||
);
|
||||
|
||||
// Build and inject system prompt if not present
|
||||
let has_system = messages.first().is_some_and(|m| m.role == "system");
|
||||
if !has_system {
|
||||
let system_prompt = build_system_prompt(&self.workspace_dir, &self.model_name, &self.tools, None, None, false);
|
||||
let system_prompt = build_system_prompt(
|
||||
&self.workspace_dir,
|
||||
&self.model_name,
|
||||
&self.tools,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
);
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!("System prompt injected:\n{}", system_prompt);
|
||||
messages.insert(0, ChatMessage::system(system_prompt));
|
||||
@ -427,9 +453,7 @@ impl AgentLoop {
|
||||
let estimated = estimate_tokens(&messages);
|
||||
let danger = (self.context_window as f64 * 0.8) as usize;
|
||||
if estimated > danger {
|
||||
let trimmed = self.preemptive_trim_old_tool_results(
|
||||
&mut messages, 2000, 4,
|
||||
);
|
||||
let trimmed = self.preemptive_trim_old_tool_results(&mut messages, 2000, 4);
|
||||
if trimmed > 0 {
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(
|
||||
@ -463,8 +487,7 @@ impl AgentLoop {
|
||||
};
|
||||
|
||||
// Call LLM
|
||||
let response = (*self.provider).chat(request).await
|
||||
.map_err(|e| {
|
||||
let response = (*self.provider).chat(request).await.map_err(|e| {
|
||||
tracing::error!(error = %e, "LLM request failed");
|
||||
AgentError::LlmError(e.to_string())
|
||||
})?;
|
||||
@ -493,7 +516,9 @@ impl AgentLoop {
|
||||
|
||||
// Execute tool calls — log and notify immediately
|
||||
{
|
||||
let tools_info: Vec<String> = response.tool_calls.iter()
|
||||
let tools_info: Vec<String> = response
|
||||
.tool_calls
|
||||
.iter()
|
||||
.map(|tc| {
|
||||
let args = serde_json::to_string(&tc.arguments).unwrap_or_default();
|
||||
let s = format!("{}:{}", tc.name, args);
|
||||
@ -522,7 +547,9 @@ impl AgentLoop {
|
||||
// Log function call with name and arguments
|
||||
let args_str = match &tool_call.arguments {
|
||||
serde_json::Value::Object(obj) if obj.is_empty() => "{}".to_string(),
|
||||
other => serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()),
|
||||
other => {
|
||||
serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string())
|
||||
}
|
||||
};
|
||||
tracing::info!(tool = %tool_call.name, args = %args_str, "Calling tool");
|
||||
|
||||
@ -562,7 +589,11 @@ impl AgentLoop {
|
||||
|
||||
// Loop continues to next iteration with updated messages
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(iteration, message_count = messages.len(), "Tool execution complete, continuing to next iteration");
|
||||
tracing::debug!(
|
||||
iteration,
|
||||
message_count = messages.len(),
|
||||
"Tool execution complete, continuing to next iteration"
|
||||
);
|
||||
}
|
||||
|
||||
// Max iterations reached - ask LLM for a summary based on completed work
|
||||
@ -571,7 +602,7 @@ impl AgentLoop {
|
||||
// Add a message asking for summary
|
||||
let summary_request = ChatMessage::user(
|
||||
"You have reached the maximum number of tool call iterations. \
|
||||
Please provide your best answer based on the work completed so far."
|
||||
Please provide your best answer based on the work completed so far.",
|
||||
);
|
||||
messages.push(summary_request);
|
||||
|
||||
@ -603,14 +634,19 @@ impl AgentLoop {
|
||||
Err(e) => {
|
||||
// Fallback if summary call fails
|
||||
tracing::error!(error = %e, "Failed to get summary from LLM");
|
||||
let final_message = ChatMessage::assistant(
|
||||
format!("I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.", self.max_iterations)
|
||||
);
|
||||
let final_message = ChatMessage::assistant(format!(
|
||||
"I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.",
|
||||
self.max_iterations
|
||||
));
|
||||
emitted_messages.push(final_message.clone());
|
||||
Ok(AgentProcessResult {
|
||||
final_response: final_message,
|
||||
emitted_messages,
|
||||
total_tokens: if accumulated_tokens > 0 { Some(accumulated_tokens) } else { None },
|
||||
total_tokens: if accumulated_tokens > 0 {
|
||||
Some(accumulated_tokens)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -698,10 +734,7 @@ impl AgentLoop {
|
||||
}
|
||||
|
||||
// Apply duration
|
||||
ToolExecutionOutcome {
|
||||
duration,
|
||||
..result
|
||||
}
|
||||
ToolExecutionOutcome { duration, ..result }
|
||||
}
|
||||
|
||||
/// Internal tool execution without event tracking.
|
||||
@ -723,18 +756,12 @@ impl AgentLoop {
|
||||
ToolExecutionOutcome::success(result.output)
|
||||
} else {
|
||||
let error = result.error.unwrap_or_default();
|
||||
ToolExecutionOutcome::failure(
|
||||
format!("Error: {}", error),
|
||||
Some(error),
|
||||
)
|
||||
ToolExecutionOutcome::failure(format!("Error: {}", error), Some(error))
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(tool = %tool_call.name, error = %e, "Tool execution failed");
|
||||
ToolExecutionOutcome::failure(
|
||||
format!("Error: {}", e),
|
||||
Some(e.to_string()),
|
||||
)
|
||||
ToolExecutionOutcome::failure(format!("Error: {}", e), Some(e.to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -822,8 +849,14 @@ mod tests {
|
||||
|
||||
assert_eq!(provider_message.role, "assistant");
|
||||
assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1);
|
||||
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].id, "call_1");
|
||||
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].name, "calculator");
|
||||
assert_eq!(
|
||||
provider_message.tool_calls.as_ref().unwrap()[0].id,
|
||||
"call_1"
|
||||
);
|
||||
assert_eq!(
|
||||
provider_message.tool_calls.as_ref().unwrap()[0].name,
|
||||
"calculator"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -234,27 +234,33 @@ impl ContextCompressor {
|
||||
}
|
||||
} else if messages[i].role == "tool"
|
||||
&& let Some(ref tid) = messages[i].tool_call_id
|
||||
&& !declared.contains(tid.as_str()) {
|
||||
&& !declared.contains(tid.as_str())
|
||||
{
|
||||
messages.remove(i);
|
||||
continue;
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
|
||||
let broken: Vec<usize> = messages.iter().enumerate()
|
||||
let broken: Vec<usize> = messages
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(idx, msg)| {
|
||||
if msg.role == "assistant"
|
||||
&& let Some(ref tcs) = msg.tool_calls
|
||||
&& !tcs.is_empty() {
|
||||
&& !tcs.is_empty()
|
||||
{
|
||||
let all_present = tcs.iter().all(|tc| {
|
||||
messages.iter().any(|m| {
|
||||
m.role == "tool"
|
||||
&& m.tool_call_id.as_deref() == Some(tc.id.as_str())
|
||||
m.role == "tool" && m.tool_call_id.as_deref() == Some(tc.id.as_str())
|
||||
})
|
||||
});
|
||||
if !all_present { Some(idx) } else { None }
|
||||
} else { None }
|
||||
}).collect();
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
for idx in broken {
|
||||
let msg = &mut messages[idx];
|
||||
@ -262,7 +268,8 @@ impl ContextCompressor {
|
||||
let names: Vec<&str> = tcs.iter().map(|tc| tc.name.as_str()).collect();
|
||||
msg.content = format!(
|
||||
"{}\n\n[Tool calls ({}) — results are no longer available]",
|
||||
msg.content, names.join(", ")
|
||||
msg.content,
|
||||
names.join(", ")
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -275,7 +282,10 @@ impl ContextCompressor {
|
||||
// Check if compression is needed
|
||||
let tokens = self.token_estimate_with_history(&history);
|
||||
if tokens <= self.threshold() {
|
||||
return Ok(CompressionResult { history, created_timelines: false });
|
||||
return Ok(CompressionResult {
|
||||
history,
|
||||
created_timelines: false,
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
@ -299,7 +309,10 @@ impl ContextCompressor {
|
||||
}
|
||||
if tokens_after <= self.threshold() {
|
||||
self.invalidate_token_cache();
|
||||
return Ok(CompressionResult { history, created_timelines: false });
|
||||
return Ok(CompressionResult {
|
||||
history,
|
||||
created_timelines: false,
|
||||
});
|
||||
}
|
||||
|
||||
// LLM summarization pass
|
||||
@ -312,11 +325,7 @@ impl ContextCompressor {
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(
|
||||
pass = pass + 1,
|
||||
tokens = tokens,
|
||||
"Compression pass"
|
||||
);
|
||||
tracing::debug!(pass = pass + 1, tokens = tokens, "Compression pass");
|
||||
|
||||
match self.compress_once(¤t_history).await {
|
||||
Ok(Some(compressed)) => {
|
||||
@ -352,18 +361,24 @@ impl ContextCompressor {
|
||||
let m = ¤t_history[scan];
|
||||
if m.role == "assistant" {
|
||||
if let Some(tcs) = &m.tool_calls
|
||||
&& !tcs.is_empty() {
|
||||
&& !tcs.is_empty()
|
||||
{
|
||||
let has_post = current_history[scan + 1..]
|
||||
.iter()
|
||||
.filter(|r| r.role == "tool")
|
||||
.any(|r| tcs.iter().any(|tc| r.tool_call_id.as_deref() == Some(tc.id.as_str())));
|
||||
.any(|r| {
|
||||
tcs.iter()
|
||||
.any(|tc| r.tool_call_id.as_deref() == Some(tc.id.as_str()))
|
||||
});
|
||||
if has_post {
|
||||
tail_start = scan;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
if scan == 0 { break; }
|
||||
if scan == 0 {
|
||||
break;
|
||||
}
|
||||
scan -= 1;
|
||||
}
|
||||
}
|
||||
@ -390,11 +405,13 @@ impl ContextCompressor {
|
||||
for msg in &mut truncated[..self.config.protect_first_n] {
|
||||
if msg.role == "assistant" {
|
||||
if let Some(ref tcs) = msg.tool_calls
|
||||
&& !tcs.is_empty() {
|
||||
&& !tcs.is_empty()
|
||||
{
|
||||
let names: Vec<&str> = tcs.iter().map(|tc| tc.name.as_str()).collect();
|
||||
msg.content = format!(
|
||||
"{}\n\n[Tool calls ({}) — results dropped during truncation]",
|
||||
msg.content, names.join(", ")
|
||||
msg.content,
|
||||
names.join(", ")
|
||||
);
|
||||
msg.tool_calls = None;
|
||||
}
|
||||
@ -424,7 +441,10 @@ impl ContextCompressor {
|
||||
"Context compression completed"
|
||||
);
|
||||
|
||||
Ok(CompressionResult { history: current_history, created_timelines })
|
||||
Ok(CompressionResult {
|
||||
history: current_history,
|
||||
created_timelines,
|
||||
})
|
||||
}
|
||||
|
||||
/// Try to extract the actual context token limit from an LLM error message.
|
||||
@ -447,7 +467,8 @@ impl ContextCompressor {
|
||||
// Look for a number in the vicinity (up to 10 chars after marker)
|
||||
if let Some(num_str) = find_number_nearby(after, 50)
|
||||
&& let Ok(n) = num_str.parse::<usize>()
|
||||
&& (1024..=10_000_000).contains(&n) {
|
||||
&& (1024..=10_000_000).contains(&n)
|
||||
{
|
||||
return Some(n);
|
||||
}
|
||||
}
|
||||
@ -509,19 +530,26 @@ impl ContextCompressor {
|
||||
|
||||
// Persist compressed summary as timeline memory entry
|
||||
let ts = chrono::Utc::now().format("%Y-%m-%d %H:%M").to_string();
|
||||
let timeline_content = format!("[{}] Compressed {} conversation segments:\n{}",
|
||||
ts, between.len(), summary);
|
||||
let timeline_content = format!(
|
||||
"[{}] Compressed {} conversation segments:\n{}",
|
||||
ts,
|
||||
between.len(),
|
||||
summary
|
||||
);
|
||||
let key = format!("ctx_compressed_{}", uuid::Uuid::new_v4());
|
||||
let mm = self.memory.clone();
|
||||
let sid = self.session_id.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = mm.store(
|
||||
if let Err(e) = mm
|
||||
.store(
|
||||
&key,
|
||||
&timeline_content,
|
||||
crate::memory::MemoryCategory::Timeline,
|
||||
sid.as_deref(),
|
||||
Some(0.3),
|
||||
).await {
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::warn!(error = %e, "Failed to store compressed context as timeline");
|
||||
}
|
||||
});
|
||||
@ -552,10 +580,7 @@ impl ContextCompressor {
|
||||
}
|
||||
|
||||
/// Summarize a segment of messages using LLM.
|
||||
async fn summarize_segment(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
) -> Result<String, AgentError> {
|
||||
async fn summarize_segment(&self, messages: &[ChatMessage]) -> Result<String, AgentError> {
|
||||
if messages.is_empty() {
|
||||
return Ok(String::new());
|
||||
}
|
||||
@ -569,7 +594,8 @@ impl ContextCompressor {
|
||||
"tool" => "Tool",
|
||||
_ => m.role.as_str(),
|
||||
};
|
||||
let name = m.tool_name
|
||||
let name = m
|
||||
.tool_name
|
||||
.as_ref()
|
||||
.map(|n| format!(" ({})", n))
|
||||
.unwrap_or_default();
|
||||
@ -614,7 +640,10 @@ Be concise, aim for {} characters or less.
|
||||
);
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
messages: vec![Message::system("You are a helpful assistant."), Message::user(&prompt)],
|
||||
messages: vec![
|
||||
Message::system("You are a helpful assistant."),
|
||||
Message::user(&prompt),
|
||||
],
|
||||
temperature: Some(0.3),
|
||||
max_tokens: Some(1000),
|
||||
tools: None,
|
||||
@ -686,13 +715,23 @@ mod tests {
|
||||
content: "[summarized]".into(),
|
||||
reasoning_content: None,
|
||||
tool_calls: vec![],
|
||||
usage: Usage { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 },
|
||||
usage: Usage {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn ptype(&self) -> &str { "mock" }
|
||||
fn name(&self) -> &str { "mock" }
|
||||
fn model_id(&self) -> &str { "mock" }
|
||||
fn ptype(&self) -> &str {
|
||||
"mock"
|
||||
}
|
||||
fn name(&self) -> &str {
|
||||
"mock"
|
||||
}
|
||||
fn model_id(&self) -> &str {
|
||||
"mock"
|
||||
}
|
||||
}
|
||||
|
||||
fn mock_summarizer() -> Arc<dyn LLMProvider> {
|
||||
@ -704,11 +743,13 @@ mod tests {
|
||||
MM.get_or_init(|| {
|
||||
let rt = tokio::runtime::Runtime::new().unwrap();
|
||||
rt.block_on(async {
|
||||
let tmp = std::env::temp_dir().join(format!("picobot_ctx_test_{}.db", std::process::id()));
|
||||
let tmp = std::env::temp_dir()
|
||||
.join(format!("picobot_ctx_test_{}.db", std::process::id()));
|
||||
let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap());
|
||||
Arc::new(MemoryManager::new(storage, "test".into(), "test".into()))
|
||||
})
|
||||
}).clone()
|
||||
})
|
||||
.clone()
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -724,7 +765,11 @@ mod tests {
|
||||
// "Hi there!" (8) -> ceil(8/4)+4 = 2+4 = 6
|
||||
// "How are you?" (11) -> ceil(11/4)+4 = 3+4 = 7
|
||||
// raw = 19, with 1.2x = ~23
|
||||
assert!(tokens > 18 && tokens < 30, "Expected ~23 tokens, got {}", tokens);
|
||||
assert!(
|
||||
tokens > 18 && tokens < 30,
|
||||
"Expected ~23 tokens, got {}",
|
||||
tokens
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -733,7 +778,8 @@ mod tests {
|
||||
tool_result_trim_chars: 50,
|
||||
..Default::default()
|
||||
};
|
||||
let compressor = ContextCompressor::with_config(mock_provider(), 100_000, config, test_memory_manager());
|
||||
let compressor =
|
||||
ContextCompressor::with_config(mock_provider(), 100_000, config, test_memory_manager());
|
||||
|
||||
let mut messages = vec![
|
||||
ChatMessage::user("Hello"),
|
||||
@ -774,7 +820,11 @@ mod tests {
|
||||
ChatMessage::tool("call1", "bash", &"x".repeat(3000)),
|
||||
];
|
||||
|
||||
let result = compressor.compress_if_needed(messages).await.unwrap().history;
|
||||
let result = compressor
|
||||
.compress_if_needed(messages)
|
||||
.await
|
||||
.unwrap()
|
||||
.history;
|
||||
|
||||
let tool_msg = result.iter().find(|m| m.role == "tool").unwrap();
|
||||
assert!(
|
||||
@ -798,7 +848,8 @@ mod tests {
|
||||
// - B2B (L275): last user message lost when it is the final history message
|
||||
//
|
||||
// context_window=200 → threshold=100. Large tool outputs force LLM summarization.
|
||||
let tmp = std::env::temp_dir().join(format!("picobot_ctx_boundary_{}.db", std::process::id()));
|
||||
let tmp =
|
||||
std::env::temp_dir().join(format!("picobot_ctx_boundary_{}.db", std::process::id()));
|
||||
let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap());
|
||||
let mm = Arc::new(MemoryManager::new(storage, "test".into(), "test".into()));
|
||||
|
||||
@ -828,15 +879,33 @@ mod tests {
|
||||
ChatMessage::user("Q4"), // 8: LAST, is user → B2B triggers
|
||||
];
|
||||
|
||||
let result = compressor.compress_if_needed(messages).await.unwrap().history;
|
||||
let result = compressor
|
||||
.compress_if_needed(messages)
|
||||
.await
|
||||
.unwrap()
|
||||
.history;
|
||||
|
||||
// B2A: "Q1" must appear exactly once
|
||||
let q1_count = result.iter().filter(|m| m.role == "user" && m.content == "Q1").count();
|
||||
assert_eq!(q1_count, 1, "Q1 should appear exactly once, got {}", q1_count);
|
||||
let q1_count = result
|
||||
.iter()
|
||||
.filter(|m| m.role == "user" && m.content == "Q1")
|
||||
.count();
|
||||
assert_eq!(
|
||||
q1_count, 1,
|
||||
"Q1 should appear exactly once, got {}",
|
||||
q1_count
|
||||
);
|
||||
|
||||
// B2B: "Q4" must NOT be lost
|
||||
let q4_count = result.iter().filter(|m| m.role == "user" && m.content == "Q4").count();
|
||||
assert_eq!(q4_count, 1, "Q4 should appear exactly once (not lost), got {}", q4_count);
|
||||
let q4_count = result
|
||||
.iter()
|
||||
.filter(|m| m.role == "user" && m.content == "Q4")
|
||||
.count();
|
||||
assert_eq!(
|
||||
q4_count, 1,
|
||||
"Q4 should appear exactly once (not lost), got {}",
|
||||
q4_count
|
||||
);
|
||||
|
||||
let _ = std::fs::remove_file(&tmp);
|
||||
}
|
||||
@ -872,13 +941,23 @@ mod tests {
|
||||
ChatMessage::tool("t3", "bash", &big),
|
||||
];
|
||||
|
||||
let result = compressor.compress_if_needed(messages).await.unwrap().history;
|
||||
let result = compressor
|
||||
.compress_if_needed(messages)
|
||||
.await
|
||||
.unwrap()
|
||||
.history;
|
||||
|
||||
// After hard truncation: head (1) + trunc_note (1) + tail (2) = 4 messages
|
||||
assert!(result.len() < 7, "expected truncation reduction, got {} messages", result.len());
|
||||
assert!(
|
||||
result.len() < 7,
|
||||
"expected truncation reduction, got {} messages",
|
||||
result.len()
|
||||
);
|
||||
|
||||
// Truncation notice should be present
|
||||
let has_notice = result.iter().any(|m| m.content.contains("Context truncation"));
|
||||
let has_notice = result
|
||||
.iter()
|
||||
.any(|m| m.content.contains("Context truncation"));
|
||||
assert!(has_notice, "hard truncation notice missing");
|
||||
|
||||
let _ = std::fs::remove_file(&tmp);
|
||||
@ -910,8 +989,16 @@ mod tests {
|
||||
|
||||
// orphan should be removed; legitimate should stay
|
||||
assert_eq!(messages.len(), 4);
|
||||
assert!(messages.iter().all(|m| m.tool_call_id != Some("tc1".into())));
|
||||
assert!(messages.iter().any(|m| m.tool_call_id == Some("tc2".into())));
|
||||
assert!(
|
||||
messages
|
||||
.iter()
|
||||
.all(|m| m.tool_call_id != Some("tc1".into()))
|
||||
);
|
||||
assert!(
|
||||
messages
|
||||
.iter()
|
||||
.any(|m| m.tool_call_id == Some("tc2".into()))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@ -49,7 +49,7 @@ impl MediaHandler for ImageHandler {
|
||||
}
|
||||
|
||||
fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> {
|
||||
use base64::{engine::general_purpose::STANDARD, Engine as _};
|
||||
use base64::{Engine as _, engine::general_purpose::STANDARD};
|
||||
|
||||
let mut file = std::fs::File::open(path)?;
|
||||
let mut buffer = Vec::new();
|
||||
|
||||
@ -4,10 +4,13 @@ pub mod media_handler;
|
||||
pub mod sub_agent;
|
||||
pub mod system_prompt;
|
||||
|
||||
pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult};
|
||||
pub use agent_loop::{AgentError, AgentLoop, AgentProcessResult};
|
||||
pub use context_compressor::{ContextCompressor, estimate_tokens};
|
||||
pub use sub_agent::{DelegateContext, ExecutionMode, SubAgentConfig, SubAgentError, SubAgentManager, SubAgentResult, TaskNotification, TaskStatus};
|
||||
pub use system_prompt::{
|
||||
build_system_prompt, build_sub_agent_system_prompt, PromptContext, PromptSection,
|
||||
SystemPromptBuilder,
|
||||
pub use sub_agent::{
|
||||
DelegateContext, ExecutionMode, SubAgentConfig, SubAgentError, SubAgentManager, SubAgentResult,
|
||||
TaskNotification, TaskStatus,
|
||||
};
|
||||
pub use system_prompt::{
|
||||
PromptContext, PromptSection, SystemPromptBuilder, build_sub_agent_system_prompt,
|
||||
build_system_prompt,
|
||||
};
|
||||
|
||||
@ -6,12 +6,12 @@ use dashmap::DashMap;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::agent::system_prompt::build_sub_agent_system_prompt;
|
||||
use crate::agent::AgentLoop;
|
||||
use crate::agent::AgentError;
|
||||
use crate::agent::AgentLoop;
|
||||
use crate::agent::system_prompt::build_sub_agent_system_prompt;
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::providers::{create_provider, LLMProvider};
|
||||
use crate::providers::{LLMProvider, create_provider};
|
||||
use crate::skills::SkillsLoader;
|
||||
use crate::tools::ToolRegistry;
|
||||
|
||||
@ -21,7 +21,8 @@ tokio::task_local! {
|
||||
|
||||
/// Read the delegate context from the current task. Returns an error if not set.
|
||||
pub fn get_delegate_context() -> Result<DelegateContext, String> {
|
||||
DELEGATE_CONTEXT.try_with(|ctx| ctx.clone())
|
||||
DELEGATE_CONTEXT
|
||||
.try_with(|ctx| ctx.clone())
|
||||
.map_err(|_| "DELEGATE_CONTEXT not set".to_string())
|
||||
}
|
||||
|
||||
@ -207,7 +208,10 @@ impl SubAgentManager {
|
||||
let timeout_secs = config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
|
||||
let timeout_human = format_duration(timeout_secs);
|
||||
let http_get_only = config.allowed_tools.is_none()
|
||||
|| config.allowed_tools.as_ref().is_some_and(|v| v.iter().any(|t| t == "http_request"));
|
||||
|| config
|
||||
.allowed_tools
|
||||
.as_ref()
|
||||
.is_some_and(|v| v.iter().any(|t| t == "http_request"));
|
||||
let skills_prompt = self.get_skills_prompt(&tools);
|
||||
let system_prompt = build_sub_agent_system_prompt(
|
||||
&config.prompt,
|
||||
@ -219,7 +223,8 @@ impl SubAgentManager {
|
||||
http_get_only,
|
||||
);
|
||||
|
||||
let agent = self.build_sub_agent(&config, tools)
|
||||
let agent = self
|
||||
.build_sub_agent(&config, tools)
|
||||
.map_err(|e| SubAgentError::ProviderCreation(e.to_string()))?;
|
||||
|
||||
let history = vec![
|
||||
@ -241,10 +246,14 @@ impl SubAgentManager {
|
||||
Ok(Ok(agent_result)) => {
|
||||
let (content, truncated) =
|
||||
truncate_sub_agent_result(&agent_result.final_response.content);
|
||||
let tool_calls_count = agent_result.emitted_messages.iter()
|
||||
let tool_calls_count = agent_result
|
||||
.emitted_messages
|
||||
.iter()
|
||||
.filter(|m| m.tool_calls.is_some())
|
||||
.count();
|
||||
let iterations = agent_result.emitted_messages.iter()
|
||||
let iterations = agent_result
|
||||
.emitted_messages
|
||||
.iter()
|
||||
.filter(|m| m.role == "assistant" && m.tool_calls.is_some())
|
||||
.count();
|
||||
Ok(SubAgentResult {
|
||||
@ -343,7 +352,10 @@ impl SubAgentManager {
|
||||
let timeout_secs = config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
|
||||
let timeout_human = format_duration(timeout_secs);
|
||||
let http_get_only = config.allowed_tools.is_none()
|
||||
|| config.allowed_tools.as_ref().is_some_and(|v| v.iter().any(|t| t == "http_request"));
|
||||
|| config
|
||||
.allowed_tools
|
||||
.as_ref()
|
||||
.is_some_and(|v| v.iter().any(|t| t == "http_request"));
|
||||
let skills_prompt = self.get_skills_prompt(&tools);
|
||||
let system_prompt = build_sub_agent_system_prompt(
|
||||
&config.prompt,
|
||||
@ -372,8 +384,12 @@ impl SubAgentManager {
|
||||
if let Some(ref s) = storage {
|
||||
let _ = s
|
||||
.update_background_task_status(
|
||||
&tid, "running", None, None,
|
||||
Some(started_at), None,
|
||||
&tid,
|
||||
"running",
|
||||
None,
|
||||
None,
|
||||
Some(started_at),
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
@ -384,8 +400,7 @@ impl SubAgentManager {
|
||||
p.set_storage(s.clone());
|
||||
}
|
||||
}
|
||||
let provider_result: Option<Arc<dyn LLMProvider>> =
|
||||
provider.map(|p| Arc::from(p));
|
||||
let provider_result: Option<Arc<dyn LLMProvider>> = provider.map(|p| Arc::from(p));
|
||||
|
||||
let result = match provider_result {
|
||||
Some(provider) => {
|
||||
@ -474,9 +489,12 @@ impl SubAgentManager {
|
||||
if let Some(ref s) = storage {
|
||||
let _ = s
|
||||
.update_background_task_status(
|
||||
&tid, &status_str,
|
||||
Some(&result.content), error_val.as_deref(),
|
||||
Some(started_at), Some(finished_at),
|
||||
&tid,
|
||||
&status_str,
|
||||
Some(&result.content),
|
||||
error_val.as_deref(),
|
||||
Some(started_at),
|
||||
Some(finished_at),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
@ -514,15 +532,13 @@ impl SubAgentManager {
|
||||
Ok(true)
|
||||
} else if let Some(ref s) = self.storage {
|
||||
match s.get_background_task(task_id).await {
|
||||
Ok(task) => {
|
||||
match task.status.as_str() {
|
||||
Ok(task) => match task.status.as_str() {
|
||||
"pending" | "running" => {
|
||||
tracing::warn!(task_id, "task in DB but not in active_tasks");
|
||||
Ok(false)
|
||||
}
|
||||
_ => Ok(false),
|
||||
}
|
||||
}
|
||||
},
|
||||
Err(_) => Ok(false),
|
||||
}
|
||||
} else {
|
||||
@ -530,10 +546,7 @@ impl SubAgentManager {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn check_task(
|
||||
&self,
|
||||
task_id: &str,
|
||||
) -> Option<crate::storage::BackgroundTask> {
|
||||
pub async fn check_task(&self, task_id: &str) -> Option<crate::storage::BackgroundTask> {
|
||||
if let Some(ref s) = self.storage {
|
||||
s.get_background_task(task_id).await.ok()
|
||||
} else {
|
||||
@ -541,12 +554,11 @@ impl SubAgentManager {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn list_tasks(
|
||||
&self,
|
||||
session_id: &str,
|
||||
) -> Vec<crate::storage::BackgroundTask> {
|
||||
pub async fn list_tasks(&self, session_id: &str) -> Vec<crate::storage::BackgroundTask> {
|
||||
if let Some(ref s) = self.storage {
|
||||
s.list_background_tasks(session_id).await.unwrap_or_default()
|
||||
s.list_background_tasks(session_id)
|
||||
.await
|
||||
.unwrap_or_default()
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
|
||||
@ -465,7 +465,9 @@ impl PromptSection for SubAgentToolsSection {
|
||||
let mut s = String::from("## 可用工具\n\n");
|
||||
s.push_str(&ctx.tools.describe_for_prompt());
|
||||
if self.http_get_only {
|
||||
s.push_str("\n\n**注意**:使用 http_request 时只允许 GET 方法,禁止 POST、PUT、DELETE 等。");
|
||||
s.push_str(
|
||||
"\n\n**注意**:使用 http_request 时只允许 GET 方法,禁止 POST、PUT、DELETE 等。",
|
||||
);
|
||||
}
|
||||
s
|
||||
}
|
||||
@ -560,12 +562,7 @@ pub fn build_sub_agent_system_prompt(
|
||||
memory_context: None,
|
||||
has_compressed_history: false,
|
||||
};
|
||||
SystemPromptBuilder::with_sub_agent_defaults(
|
||||
task,
|
||||
timeout_human,
|
||||
skills_prompt,
|
||||
http_get_only,
|
||||
)
|
||||
SystemPromptBuilder::with_sub_agent_defaults(task, timeout_human, skills_prompt, http_get_only)
|
||||
.build(&ctx)
|
||||
}
|
||||
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::bus::{MessageBus, OutboundMessage};
|
||||
use crate::channels::base::{Channel, ChannelError};
|
||||
use crate::channels::ChannelManager;
|
||||
use crate::channels::base::{Channel, ChannelError};
|
||||
|
||||
/// OutboundDispatcher consumes outbound messages from the MessageBus
|
||||
/// and dispatches them to the appropriate Channel
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
use std::collections::HashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::providers::ToolCall;
|
||||
|
||||
@ -23,7 +23,9 @@ pub struct ImageUrlBlock {
|
||||
|
||||
impl ContentBlock {
|
||||
pub fn text(content: impl Into<String>) -> Self {
|
||||
Self::Text { text: content.into() }
|
||||
Self::Text {
|
||||
text: content.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn image_url(url: impl Into<String>) -> Self {
|
||||
@ -161,7 +163,10 @@ impl ChatMessage {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn assistant_with_tool_calls(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
|
||||
pub fn assistant_with_tool_calls(
|
||||
content: impl Into<String>,
|
||||
tool_calls: Vec<ToolCall>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
role: "assistant".to_string(),
|
||||
@ -206,7 +211,11 @@ impl ChatMessage {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
pub fn tool(
|
||||
tool_call_id: impl Into<String>,
|
||||
tool_name: impl Into<String>,
|
||||
content: impl Into<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
role: "tool".to_string(),
|
||||
|
||||
@ -2,10 +2,13 @@ pub mod dispatcher;
|
||||
pub mod message;
|
||||
|
||||
pub use dispatcher::OutboundDispatcher;
|
||||
pub use message::{ChatMessage, ContentBlock, ControlMessage, InboundMessage, MediaItem, MediaRef, MessageSource, OutboundMessage, SourceKind};
|
||||
pub use message::{
|
||||
ChatMessage, ContentBlock, ControlMessage, InboundMessage, MediaItem, MediaRef, MessageSource,
|
||||
OutboundMessage, SourceKind,
|
||||
};
|
||||
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
|
||||
// ============================================================================
|
||||
// MessageBus - Async message queue for Channel <-> Agent communication
|
||||
@ -49,7 +52,8 @@ impl MessageBus {
|
||||
|
||||
/// Consume an inbound message (Agent -> Bus)
|
||||
pub async fn consume_inbound(&self) -> InboundMessage {
|
||||
let msg = self.inbound_rx
|
||||
let msg = self
|
||||
.inbound_rx
|
||||
.lock()
|
||||
.await
|
||||
.recv()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -24,7 +24,10 @@ impl ChannelManager {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_bus(cli_chat_channel: Arc<crate::channels::CliChatChannel>, bus: Arc<MessageBus>) -> Self {
|
||||
pub fn with_bus(
|
||||
cli_chat_channel: Arc<crate::channels::CliChatChannel>,
|
||||
bus: Arc<MessageBus>,
|
||||
) -> Self {
|
||||
Self {
|
||||
channels: Arc::new(RwLock::new(HashMap::new())),
|
||||
cli_chat_channel,
|
||||
@ -39,7 +42,10 @@ impl ChannelManager {
|
||||
|
||||
/// Register a channel with the manager
|
||||
pub async fn register_channel(&self, name: &str, channel: Arc<dyn Channel + Send + Sync>) {
|
||||
self.channels.write().await.insert(name.to_string(), channel);
|
||||
self.channels
|
||||
.write()
|
||||
.await
|
||||
.insert(name.to_string(), channel);
|
||||
}
|
||||
|
||||
/// Get CLI chat channel
|
||||
@ -56,14 +62,19 @@ impl ChannelManager {
|
||||
// Initialize Feishu channel if enabled
|
||||
if let Some(feishu_config) = config.channels.get("feishu") {
|
||||
if feishu_config.enabled {
|
||||
let channel = FeishuChannel::new(feishu_config.clone(), &workspace_dir)
|
||||
.map_err(|e| ChannelError::Other(format!("Failed to create Feishu channel: {}", e)))?;
|
||||
let channel =
|
||||
FeishuChannel::new(feishu_config.clone(), &workspace_dir).map_err(|e| {
|
||||
ChannelError::Other(format!("Failed to create Feishu channel: {}", e))
|
||||
})?;
|
||||
|
||||
self.channels
|
||||
.write()
|
||||
.await
|
||||
.insert("feishu".to_string(), Arc::new(channel));
|
||||
tracing::info!("Feishu channel registered (media_dir: {}/media/feishu)", workspace_dir.display());
|
||||
tracing::info!(
|
||||
"Feishu channel registered (media_dir: {}/media/feishu)",
|
||||
workspace_dir.display()
|
||||
);
|
||||
} else {
|
||||
tracing::info!("Feishu channel disabled in config");
|
||||
}
|
||||
@ -118,7 +129,10 @@ impl ChannelManager {
|
||||
if let Some(channel) = self.get_channel(channel_name).await {
|
||||
channel.send(msg).await
|
||||
} else {
|
||||
Err(ChannelError::Other(format!("Channel not found: {}", channel_name)))
|
||||
Err(ChannelError::Other(format!(
|
||||
"Channel not found: {}",
|
||||
channel_name
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
pub mod base;
|
||||
pub mod feishu;
|
||||
pub mod cli_chat;
|
||||
pub mod feishu;
|
||||
pub mod manager;
|
||||
pub mod slash_command;
|
||||
|
||||
pub use base::{Channel, ChannelError};
|
||||
pub use manager::ChannelManager;
|
||||
pub use feishu::FeishuChannel;
|
||||
pub use cli_chat::CliChatChannel;
|
||||
pub use slash_command::{parse_slash_command, command_matches};
|
||||
pub use feishu::FeishuChannel;
|
||||
pub use manager::ChannelManager;
|
||||
pub use slash_command::{command_matches, parse_slash_command};
|
||||
|
||||
@ -16,7 +16,9 @@ pub fn parse_slash_command(content: &str) -> Option<(&str, &str)> {
|
||||
/// 检查内容是否匹配指定命令
|
||||
pub fn command_matches(content: &str, aliases: &[&str]) -> bool {
|
||||
let trimmed = content.trim();
|
||||
aliases.iter().any(|&alias| trimmed == alias || trimmed.starts_with(&format!("{} ", alias)))
|
||||
aliases
|
||||
.iter()
|
||||
.any(|&alias| trimmed == alias || trimmed.starts_with(&format!("{} ", alias)))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@ -27,7 +29,10 @@ mod tests {
|
||||
fn test_parse_slash_command() {
|
||||
assert_eq!(parse_slash_command("/reset"), Some(("reset", "")));
|
||||
assert_eq!(parse_slash_command("/reset arg"), Some(("reset", "arg")));
|
||||
assert_eq!(parse_slash_command("/new hello world"), Some(("new", "hello world")));
|
||||
assert_eq!(
|
||||
parse_slash_command("/new hello world"),
|
||||
Some(("new", "hello world"))
|
||||
);
|
||||
assert_eq!(parse_slash_command("/??"), Some(("??", "")));
|
||||
assert_eq!(parse_slash_command("/? arg"), Some(("?", "arg")));
|
||||
assert_eq!(parse_slash_command("/?"), Some(("?", "")));
|
||||
|
||||
@ -8,10 +8,10 @@ use crate::client::tui::ui::render_ui;
|
||||
use crossterm::{
|
||||
event::{self, Event},
|
||||
execute,
|
||||
terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen},
|
||||
terminal::{EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, enable_raw_mode},
|
||||
};
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use ratatui::{prelude::CrosstermBackend, Terminal};
|
||||
use ratatui::{Terminal, prelude::CrosstermBackend};
|
||||
use std::io;
|
||||
use tokio_tungstenite::{connect_async, tungstenite::Message};
|
||||
|
||||
@ -104,7 +104,10 @@ async fn handle_ws_message(app: &mut App, outbound: WsOutbound) {
|
||||
WsOutbound::SessionCreated { session_id, .. } => {
|
||||
app.set_current_session(Some(session_id));
|
||||
}
|
||||
WsOutbound::SessionList { sessions, current_session_id } => {
|
||||
WsOutbound::SessionList {
|
||||
sessions,
|
||||
current_session_id,
|
||||
} => {
|
||||
app.set_sessions(sessions);
|
||||
if let Some(id) = current_session_id {
|
||||
app.set_current_session(Some(id));
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
use crate::client::tui::app::{App, MessageRole};
|
||||
use ratatui::{
|
||||
Frame,
|
||||
layout::Rect,
|
||||
style::{Color, Modifier, Style},
|
||||
text::Line,
|
||||
widgets::{Block, Borders, List, ListItem},
|
||||
Frame,
|
||||
};
|
||||
|
||||
pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
use crate::client::tui::app::App;
|
||||
use ratatui::{
|
||||
Frame,
|
||||
layout::Rect,
|
||||
style::{Color, Modifier, Style},
|
||||
text::{Line, Span},
|
||||
widgets::{Block, Borders, List, ListItem},
|
||||
Frame,
|
||||
};
|
||||
|
||||
pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
use ratatui::{
|
||||
Frame,
|
||||
layout::Rect,
|
||||
style::{Color, Modifier, Style},
|
||||
widgets::{Block, Borders, Clear, List, ListItem},
|
||||
Frame,
|
||||
};
|
||||
|
||||
pub fn render(f: &mut Frame, area: Rect) {
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
use crate::client::tui::app::App;
|
||||
use ratatui::{
|
||||
Frame,
|
||||
layout::Rect,
|
||||
style::{Color, Style},
|
||||
widgets::{Block, Borders, Paragraph},
|
||||
Frame,
|
||||
};
|
||||
|
||||
pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
use crate::client::tui::app::App;
|
||||
use ratatui::{
|
||||
Frame,
|
||||
layout::Rect,
|
||||
style::{Color, Modifier, Style},
|
||||
widgets::{Block, Borders, List, ListItem},
|
||||
Frame,
|
||||
};
|
||||
|
||||
pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
||||
@ -11,9 +11,7 @@ pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
||||
.sessions
|
||||
.iter()
|
||||
.map(|session| {
|
||||
let is_current = app
|
||||
.current_session_id
|
||||
.as_ref() == Some(&session.session_id);
|
||||
let is_current = app.current_session_id.as_ref() == Some(&session.session_id);
|
||||
let archived = session.archived_at.is_some();
|
||||
|
||||
let mut content = if is_current {
|
||||
|
||||
@ -1,15 +1,18 @@
|
||||
use crate::client::tui::app::App;
|
||||
use ratatui::{
|
||||
Frame,
|
||||
layout::Rect,
|
||||
style::{Color, Modifier, Style},
|
||||
widgets::{Block, Borders, Paragraph},
|
||||
Frame,
|
||||
};
|
||||
|
||||
pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
||||
let (title, style) = if app.pending_quit {
|
||||
let msg = if let Some(session_id) = &app.current_session_id {
|
||||
format!("PicoBot | Session: {} | Press Ctrl+C again to quit", session_id)
|
||||
format!(
|
||||
"PicoBot | Session: {} | Press Ctrl+C again to quit",
|
||||
session_id
|
||||
)
|
||||
} else {
|
||||
"PicoBot | Press Ctrl+C again to quit".to_string()
|
||||
};
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
use crate::client::tui::app::{App, MessageRole};
|
||||
use crate::protocol::serialize_inbound;
|
||||
use crate::protocol::WsInbound;
|
||||
use crate::protocol::serialize_inbound;
|
||||
use crossterm::event::{KeyCode, KeyEvent};
|
||||
use futures_util::SinkExt;
|
||||
|
||||
@ -48,7 +48,10 @@ pub async fn handle_key_event(app: &mut App, key: KeyEvent) {
|
||||
|
||||
async fn handle_normal_input(app: &mut App, key: KeyEvent) {
|
||||
// Handle Ctrl+C for quit (double press to exit)
|
||||
let is_ctrl_c = key.code == KeyCode::Char('c') && key.modifiers.contains(crossterm::event::KeyModifiers::CONTROL);
|
||||
let is_ctrl_c = key.code == KeyCode::Char('c')
|
||||
&& key
|
||||
.modifiers
|
||||
.contains(crossterm::event::KeyModifiers::CONTROL);
|
||||
if is_ctrl_c {
|
||||
if app.handle_ctrl_c_for_quit() {
|
||||
return;
|
||||
@ -65,7 +68,9 @@ async fn handle_normal_input(app: &mut App, key: KeyEvent) {
|
||||
app.input_insert_char(c);
|
||||
|
||||
// Show command menu when input starts with /
|
||||
if !app.show_command_menu && (app.input == "/" || (app.input.len() > 1 && app.input.starts_with('/'))) {
|
||||
if !app.show_command_menu
|
||||
&& (app.input == "/" || (app.input.len() > 1 && app.input.starts_with('/')))
|
||||
{
|
||||
app.show_command_menu = true;
|
||||
app.selected_command_idx = 0;
|
||||
} else if app.show_command_menu && !app.input.starts_with('/') {
|
||||
@ -121,7 +126,9 @@ async fn process_input(app: &mut App, input: String) {
|
||||
sender_id: None,
|
||||
};
|
||||
if let Ok(text) = serialize_inbound(&inbound) {
|
||||
let _ = sender.send(tokio_tungstenite::tungstenite::Message::Text(text.into())).await;
|
||||
let _ = sender
|
||||
.send(tokio_tungstenite::tungstenite::Message::Text(text.into()))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
use crate::client::tui::app::App;
|
||||
use crate::client::tui::components::*;
|
||||
use ratatui::{
|
||||
layout::{Constraint, Direction, Layout, Rect},
|
||||
Frame,
|
||||
layout::{Constraint, Direction, Layout, Rect},
|
||||
};
|
||||
|
||||
pub fn render_ui(f: &mut Frame, app: &App) {
|
||||
|
||||
@ -273,12 +273,16 @@ impl Default for MemoryConfig {
|
||||
impl MemoryConfig {
|
||||
/// Resolve consolidation provider name, falling back to the main agent's provider.
|
||||
pub fn resolve_consolidation_provider(&self, default: &str) -> String {
|
||||
self.consolidation_provider.clone().unwrap_or_else(|| default.to_string())
|
||||
self.consolidation_provider
|
||||
.clone()
|
||||
.unwrap_or_else(|| default.to_string())
|
||||
}
|
||||
|
||||
/// Resolve consolidation model name, falling back to the main agent's model.
|
||||
pub fn resolve_consolidation_model(&self, default: &str) -> String {
|
||||
self.consolidation_model.clone().unwrap_or_else(|| default.to_string())
|
||||
self.consolidation_model
|
||||
.clone()
|
||||
.unwrap_or_else(|| default.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
@ -366,10 +370,18 @@ impl Default for BrowserConfig {
|
||||
}
|
||||
}
|
||||
|
||||
fn default_recall_limit() -> usize { 5 }
|
||||
fn default_idle_consolidation_minutes() -> u64 { 10 }
|
||||
fn default_timeline_retention_days() -> u64 { 90 }
|
||||
fn default_max_failures_before_degrade() -> usize { 3 }
|
||||
fn default_recall_limit() -> usize {
|
||||
5
|
||||
}
|
||||
fn default_idle_consolidation_minutes() -> u64 {
|
||||
10
|
||||
}
|
||||
fn default_timeline_retention_days() -> u64 {
|
||||
90
|
||||
}
|
||||
fn default_max_failures_before_degrade() -> usize {
|
||||
3
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LLMProviderConfig {
|
||||
@ -469,7 +481,11 @@ pub enum ConfigError {
|
||||
impl std::fmt::Display for ConfigError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ConfigError::ConfigNotFound(path) => write!(f, "Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json", path),
|
||||
ConfigError::ConfigNotFound(path) => write!(
|
||||
f,
|
||||
"Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json",
|
||||
path
|
||||
),
|
||||
ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name),
|
||||
ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name),
|
||||
ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name),
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
use std::sync::Arc;
|
||||
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage};
|
||||
use super::GatewayState;
|
||||
use crate::protocol::WsOutbound;
|
||||
use crate::protocol::serialize_outbound;
|
||||
use axum::extract::State;
|
||||
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
|
||||
use axum::response::Response;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use crate::protocol::serialize_outbound;
|
||||
use crate::protocol::WsOutbound;
|
||||
use super::GatewayState;
|
||||
|
||||
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
|
||||
ws.on_upgrade(|socket| async move {
|
||||
@ -25,9 +25,11 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
||||
let (session_id, client) = cli_chat_channel.register_client(sender.clone()).await;
|
||||
|
||||
// Send session established message
|
||||
let _ = sender.send(WsOutbound::SessionEstablished {
|
||||
let _ = sender
|
||||
.send(WsOutbound::SessionEstablished {
|
||||
session_id: session_id.clone(),
|
||||
}).await;
|
||||
})
|
||||
.await;
|
||||
|
||||
tracing::info!(session_id = %session_id, "CLI session established");
|
||||
|
||||
@ -37,7 +39,8 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
||||
tokio::spawn(async move {
|
||||
while let Some(msg) = receiver.recv().await {
|
||||
if let Ok(text) = serialize_outbound(&msg)
|
||||
&& ws_sender.send(WsMessage::Text(text.into())).await.is_err() {
|
||||
&& ws_sender.send(WsMessage::Text(text.into())).await.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
14
src/lib.rs
14
src/lib.rs
@ -1,17 +1,17 @@
|
||||
pub mod config;
|
||||
pub mod providers;
|
||||
pub mod bus;
|
||||
pub mod agent;
|
||||
pub mod gateway;
|
||||
pub mod session;
|
||||
pub mod client;
|
||||
pub mod protocol;
|
||||
pub mod bus;
|
||||
pub mod channels;
|
||||
pub mod client;
|
||||
pub mod config;
|
||||
pub mod gateway;
|
||||
pub mod logging;
|
||||
pub mod mcp;
|
||||
pub mod memory;
|
||||
pub mod observability;
|
||||
pub mod protocol;
|
||||
pub mod providers;
|
||||
pub mod scheduler;
|
||||
pub mod session;
|
||||
pub mod skills;
|
||||
pub mod storage;
|
||||
pub mod tools;
|
||||
|
||||
@ -1,11 +1,7 @@
|
||||
use std::path::PathBuf;
|
||||
use tracing_appender::rolling::{RollingFileAppender, Rotation};
|
||||
use tracing_subscriber::{
|
||||
fmt,
|
||||
layer::SubscriberExt,
|
||||
util::SubscriberInitExt,
|
||||
fmt::time::LocalTime,
|
||||
EnvFilter,
|
||||
EnvFilter, fmt, fmt::time::LocalTime, layer::SubscriberExt, util::SubscriberInitExt,
|
||||
};
|
||||
|
||||
/// Get the default log directory path: ~/.picobot/logs
|
||||
@ -27,20 +23,20 @@ pub fn init_logging() {
|
||||
|
||||
// Create log directory if it doesn't exist
|
||||
if !log_dir.exists()
|
||||
&& let Err(e) = std::fs::create_dir_all(&log_dir) {
|
||||
eprintln!("Warning: Failed to create log directory {}: {}", log_dir.display(), e);
|
||||
&& let Err(e) = std::fs::create_dir_all(&log_dir)
|
||||
{
|
||||
eprintln!(
|
||||
"Warning: Failed to create log directory {}: {}",
|
||||
log_dir.display(),
|
||||
e
|
||||
);
|
||||
}
|
||||
|
||||
// Create file appender with daily rotation
|
||||
let file_appender = RollingFileAppender::new(
|
||||
Rotation::DAILY,
|
||||
&log_dir,
|
||||
"picobot.log",
|
||||
);
|
||||
let file_appender = RollingFileAppender::new(Rotation::DAILY, &log_dir, "picobot.log");
|
||||
|
||||
// Build subscriber with both console and file output
|
||||
let env_filter = EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| EnvFilter::new("info"));
|
||||
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
|
||||
|
||||
let file_layer = fmt::layer()
|
||||
.with_writer(file_appender)
|
||||
@ -66,8 +62,7 @@ pub fn init_logging() {
|
||||
|
||||
/// Initialize logging without file output (console only)
|
||||
pub fn init_logging_console_only() {
|
||||
let env_filter = EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| EnvFilter::new("info"));
|
||||
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
|
||||
|
||||
let console_layer = fmt::layer()
|
||||
.with_timer(LocalTime::rfc_3339())
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
use clap::{Parser, CommandFactory};
|
||||
use clap::{CommandFactory, Parser};
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(name = "picobot")]
|
||||
|
||||
@ -92,13 +92,9 @@ fn extract_text(result: &rmcp::model::CallToolResult) -> String {
|
||||
parts.push(text.text.clone());
|
||||
}
|
||||
RawContent::Image(image) => {
|
||||
parts.push(format!(
|
||||
"[image: {}]",
|
||||
image.mime_type,
|
||||
));
|
||||
parts.push(format!("[image: {}]", image.mime_type,));
|
||||
}
|
||||
RawContent::Resource(resource) => {
|
||||
match &resource.resource {
|
||||
RawContent::Resource(resource) => match &resource.resource {
|
||||
rmcp::model::ResourceContents::TextResourceContents { text, .. } => {
|
||||
parts.push(format!(
|
||||
"[resource text: {}]",
|
||||
@ -108,8 +104,7 @@ fn extract_text(result: &rmcp::model::CallToolResult) -> String {
|
||||
rmcp::model::ResourceContents::BlobResourceContents { uri, .. } => {
|
||||
parts.push(format!("[resource blob: {}]", uri));
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
parts.push("[unsupported content]".to_string());
|
||||
}
|
||||
@ -225,8 +220,8 @@ async fn connect_server(config: &McpServerConfig) -> anyhow::Result<McpConnectio
|
||||
cmd.env(k, v);
|
||||
}
|
||||
|
||||
let service = ()
|
||||
.serve(
|
||||
let service =
|
||||
().serve(
|
||||
TokioChildProcess::new(cmd).context("failed to create stdio MCP transport")?,
|
||||
)
|
||||
.await
|
||||
@ -261,12 +256,12 @@ async fn connect_server(config: &McpServerConfig) -> anyhow::Result<McpConnectio
|
||||
} else {
|
||||
StreamableHttpClientTransport::from_config(
|
||||
StreamableHttpClientTransportConfig::with_uri(url.to_string())
|
||||
.custom_headers(headers_map)
|
||||
.custom_headers(headers_map),
|
||||
)
|
||||
};
|
||||
|
||||
let service = ()
|
||||
.serve(transport)
|
||||
let service =
|
||||
().serve(transport)
|
||||
.await
|
||||
.context("failed to connect to HTTP/SSE MCP server")?;
|
||||
|
||||
|
||||
@ -102,7 +102,11 @@ mod tests {
|
||||
let dir = tempdir().unwrap();
|
||||
let db_path = dir.path().join("test.db");
|
||||
let storage = Arc::new(Storage::new(&db_path).await.unwrap());
|
||||
let mm = Arc::new(MemoryManager::new(storage, "default".into(), "default".into()));
|
||||
let mm = Arc::new(MemoryManager::new(
|
||||
storage,
|
||||
"default".into(),
|
||||
"default".into(),
|
||||
));
|
||||
(mm, dir)
|
||||
}
|
||||
|
||||
@ -131,13 +135,7 @@ mod tests {
|
||||
async fn test_upsert_overwrites() {
|
||||
let (mm, _dir) = setup_memory_manager().await;
|
||||
|
||||
mm.store(
|
||||
"dup_key",
|
||||
"original",
|
||||
MemoryCategory::Knowledge,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
mm.store("dup_key", "original", MemoryCategory::Knowledge, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mm.store(
|
||||
@ -247,7 +245,12 @@ mod tests {
|
||||
|
||||
// Recall scoped to session A — should get only tl_a
|
||||
let scoped = mm
|
||||
.recall("summary", 10, Some(MemoryCategory::Timeline), Some("chan:chat:dialog_a"))
|
||||
.recall(
|
||||
"summary",
|
||||
10,
|
||||
Some(MemoryCategory::Timeline),
|
||||
Some("chan:chat:dialog_a"),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(scoped.len(), 1);
|
||||
|
||||
@ -20,10 +20,7 @@ pub enum ObserverEvent {
|
||||
success: bool,
|
||||
},
|
||||
/// Emitted when the agent starts processing.
|
||||
AgentStart {
|
||||
provider: String,
|
||||
model: String,
|
||||
},
|
||||
AgentStart { provider: String, model: String },
|
||||
/// Emitted when the agent finishes processing.
|
||||
AgentEnd {
|
||||
provider: String,
|
||||
@ -94,7 +91,11 @@ impl ToolExecutionOutcome {
|
||||
}
|
||||
|
||||
/// Create a failed outcome with duration.
|
||||
pub fn failure_with_duration(output: String, error_reason: Option<String>, duration: Duration) -> Self {
|
||||
pub fn failure_with_duration(
|
||||
output: String,
|
||||
error_reason: Option<String>,
|
||||
duration: Duration,
|
||||
) -> Self {
|
||||
Self {
|
||||
output,
|
||||
success: false,
|
||||
|
||||
@ -4,23 +4,24 @@ use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::bus::message::ContentBlock;
|
||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
||||
use super::traits::Usage;
|
||||
use std::sync::Arc;
|
||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
||||
use crate::bus::message::ContentBlock;
|
||||
use crate::storage::Storage;
|
||||
use std::sync::Arc;
|
||||
|
||||
const LLM_REQUEST_TIMEOUT_SECS: u64 = 300;
|
||||
|
||||
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> {
|
||||
blocks.iter().map(|b| match b {
|
||||
blocks
|
||||
.iter()
|
||||
.map(|b| match b {
|
||||
ContentBlock::Text { text } => {
|
||||
serde_json::json!({ "type": "text", "text": text })
|
||||
}
|
||||
ContentBlock::ImageUrl { image_url } => {
|
||||
convert_image_url_to_anthropic(&image_url.url)
|
||||
}
|
||||
}).collect()
|
||||
ContentBlock::ImageUrl { image_url } => convert_image_url_to_anthropic(&image_url.url),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn convert_image_url_to_anthropic(url: &str) -> serde_json::Value {
|
||||
@ -197,8 +198,13 @@ impl LLMProvider for AnthropicProvider {
|
||||
};
|
||||
let content = if let Some(ref tc_id) = m.tool_call_id {
|
||||
// Tool result: wrap as tool_result content block
|
||||
let output = m.content.iter()
|
||||
.filter_map(|b| match b { ContentBlock::Text { text } => Some(text.as_str()), _ => None })
|
||||
let output = m
|
||||
.content
|
||||
.iter()
|
||||
.filter_map(|b| match b {
|
||||
ContentBlock::Text { text } => Some(text.as_str()),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("");
|
||||
vec![serde_json::json!({
|
||||
@ -244,8 +250,7 @@ impl LLMProvider for AnthropicProvider {
|
||||
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
|
||||
tracing::debug!(req_body = %req_body_str, "LLM request");
|
||||
|
||||
let resp = req_builder.json(&body).send().await
|
||||
.inspect_err(|e| {
|
||||
let resp = req_builder.json(&body).send().await.inspect_err(|e| {
|
||||
let is_timeout = e.is_timeout();
|
||||
tracing::error!(
|
||||
provider = %self.name,
|
||||
@ -281,17 +286,21 @@ impl LLMProvider for AnthropicProvider {
|
||||
"LLM API returned error"
|
||||
);
|
||||
if let Some(ref storage) = self.storage {
|
||||
let _ = storage.append_llm_call(
|
||||
&self.name, &self.model_id, &req_body_str,
|
||||
Some(&body_text), Some(&error_msg),
|
||||
let _ = storage
|
||||
.append_llm_call(
|
||||
&self.name,
|
||||
&self.model_id,
|
||||
&req_body_str,
|
||||
Some(&body_text),
|
||||
Some(&error_msg),
|
||||
start.elapsed().as_millis() as u64,
|
||||
).await;
|
||||
)
|
||||
.await;
|
||||
}
|
||||
return Err(format!("API error ({}): {}", status.as_u16(), error_msg).into());
|
||||
}
|
||||
|
||||
let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text)
|
||||
.map_err(|e| {
|
||||
let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text).map_err(|e| {
|
||||
let err_msg = format!("decode error: {} | body: {}", e, &body_text);
|
||||
if let Some(ref storage) = self.storage {
|
||||
let name = self.name.clone();
|
||||
@ -302,7 +311,9 @@ impl LLMProvider for AnthropicProvider {
|
||||
let err = err_msg.clone();
|
||||
let s = storage.clone();
|
||||
tokio::spawn(async move {
|
||||
let _ = s.append_llm_call(&name, &model, &req, Some(&resp_body), Some(&err), dur).await;
|
||||
let _ = s
|
||||
.append_llm_call(&name, &model, &req, Some(&resp_body), Some(&err), dur)
|
||||
.await;
|
||||
});
|
||||
}
|
||||
err_msg
|
||||
@ -343,21 +354,35 @@ impl LLMProvider for AnthropicProvider {
|
||||
reasoning_content: reasoning,
|
||||
tool_calls,
|
||||
usage: Usage {
|
||||
prompt_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens).unwrap_or(0),
|
||||
completion_tokens: anthropic_resp.usage.as_ref().map(|u| u.output_tokens).unwrap_or(0),
|
||||
total_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens + u.output_tokens).unwrap_or(0),
|
||||
prompt_tokens: anthropic_resp
|
||||
.usage
|
||||
.as_ref()
|
||||
.map(|u| u.input_tokens)
|
||||
.unwrap_or(0),
|
||||
completion_tokens: anthropic_resp
|
||||
.usage
|
||||
.as_ref()
|
||||
.map(|u| u.output_tokens)
|
||||
.unwrap_or(0),
|
||||
total_tokens: anthropic_resp
|
||||
.usage
|
||||
.as_ref()
|
||||
.map(|u| u.input_tokens + u.output_tokens)
|
||||
.unwrap_or(0),
|
||||
},
|
||||
};
|
||||
|
||||
if let Some(ref storage) = self.storage {
|
||||
let _ = storage.append_llm_call(
|
||||
let _ = storage
|
||||
.append_llm_call(
|
||||
&self.name,
|
||||
&self.model_id,
|
||||
&req_body_str,
|
||||
Some(&body_text),
|
||||
None,
|
||||
start.elapsed().as_millis() as u64,
|
||||
).await;
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
|
||||
@ -1,12 +1,15 @@
|
||||
pub mod traits;
|
||||
pub mod openai;
|
||||
pub mod anthropic;
|
||||
pub mod openai;
|
||||
pub mod traits;
|
||||
|
||||
pub use self::openai::OpenAIProvider;
|
||||
pub use self::anthropic::AnthropicProvider;
|
||||
pub use self::openai::OpenAIProvider;
|
||||
|
||||
use crate::config::LLMProviderConfig;
|
||||
pub use traits::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall, ToolFunction, Usage};
|
||||
pub use traits::{
|
||||
ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall,
|
||||
ToolFunction, Usage,
|
||||
};
|
||||
|
||||
pub fn create_provider(config: LLMProviderConfig) -> Result<Box<dyn LLMProvider>, ProviderError> {
|
||||
match config.provider_type.as_str() {
|
||||
|
||||
@ -1,29 +1,35 @@
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use serde_json::{Value, json};
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::bus::message::ContentBlock;
|
||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
||||
use super::traits::Usage;
|
||||
use std::sync::Arc;
|
||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
||||
use crate::bus::message::ContentBlock;
|
||||
use crate::storage::Storage;
|
||||
use std::sync::Arc;
|
||||
|
||||
const LLM_REQUEST_TIMEOUT_SECS: u64 = 300;
|
||||
|
||||
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
|
||||
if blocks.len() == 1
|
||||
&& let ContentBlock::Text { text } = &blocks[0] {
|
||||
&& let ContentBlock::Text { text } = &blocks[0]
|
||||
{
|
||||
return Value::String(text.clone());
|
||||
}
|
||||
Value::Array(blocks.iter().map(|b| match b {
|
||||
Value::Array(
|
||||
blocks
|
||||
.iter()
|
||||
.map(|b| match b {
|
||||
ContentBlock::Text { text } => json!({ "type": "text", "text": text }),
|
||||
ContentBlock::ImageUrl { image_url } => {
|
||||
json!({ "type": "image_url", "image_url": { "url": image_url.url } })
|
||||
}
|
||||
}).collect())
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
pub struct OpenAIProvider {
|
||||
@ -201,7 +207,11 @@ impl LLMProvider for OpenAIProvider {
|
||||
if let Some(content) = msg.get("content").and_then(|c| c.as_array()) {
|
||||
for (j, item) in content.iter().enumerate() {
|
||||
if item.get("type").and_then(|t| t.as_str()) == Some("image_url")
|
||||
&& let Some(url_str) = item.get("image_url").and_then(|u| u.get("url")).and_then(|v| v.as_str()) {
|
||||
&& let Some(url_str) = item
|
||||
.get("image_url")
|
||||
.and_then(|u| u.get("url"))
|
||||
.and_then(|v| v.as_str())
|
||||
{
|
||||
let prefix: String = url_str.chars().take(20).collect();
|
||||
tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)");
|
||||
}
|
||||
@ -224,8 +234,7 @@ impl LLMProvider for OpenAIProvider {
|
||||
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
|
||||
tracing::debug!(req_body = %req_body_str, "LLM request");
|
||||
|
||||
let resp = req_builder.json(&body).send().await
|
||||
.inspect_err(|e| {
|
||||
let resp = req_builder.json(&body).send().await.inspect_err(|e| {
|
||||
let is_timeout = e.is_timeout();
|
||||
tracing::error!(
|
||||
provider = %self.name,
|
||||
@ -253,18 +262,23 @@ impl LLMProvider for OpenAIProvider {
|
||||
"LLM API returned error"
|
||||
);
|
||||
if let Some(ref storage) = self.storage
|
||||
&& let Err(e) = storage.append_llm_call(
|
||||
&self.name, &self.model_id, &req_body_str,
|
||||
Some(&text), Some(&error),
|
||||
&& let Err(e) = storage
|
||||
.append_llm_call(
|
||||
&self.name,
|
||||
&self.model_id,
|
||||
&req_body_str,
|
||||
Some(&text),
|
||||
Some(&error),
|
||||
start.elapsed().as_millis() as u64,
|
||||
).await {
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::warn!("failed to persist LLM call: {}", e);
|
||||
}
|
||||
return Err(error.into());
|
||||
}
|
||||
|
||||
let openai_resp: OpenAIResponse = serde_json::from_str(&text)
|
||||
.map_err(|e| {
|
||||
let openai_resp: OpenAIResponse = serde_json::from_str(&text).map_err(|e| {
|
||||
let err_msg = format!("decode error: {} | body: {}", e, &text);
|
||||
if let Some(ref storage) = self.storage {
|
||||
let name = self.name.clone();
|
||||
@ -275,7 +289,10 @@ impl LLMProvider for OpenAIProvider {
|
||||
let err = err_msg.clone();
|
||||
let s = storage.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = s.append_llm_call(&name, &model, &req, Some(&resp), Some(&err), dur).await {
|
||||
if let Err(e) = s
|
||||
.append_llm_call(&name, &model, &req, Some(&resp), Some(&err), dur)
|
||||
.await
|
||||
{
|
||||
tracing::warn!("failed to persist LLM call (decode error): {}", e);
|
||||
}
|
||||
});
|
||||
@ -283,7 +300,10 @@ impl LLMProvider for OpenAIProvider {
|
||||
err_msg
|
||||
})?;
|
||||
|
||||
let first_choice = openai_resp.choices.into_iter().next()
|
||||
let first_choice = openai_resp
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or("no choices in response")?;
|
||||
|
||||
let content = first_choice
|
||||
@ -300,7 +320,8 @@ impl LLMProvider for OpenAIProvider {
|
||||
.map(|tc| ToolCall {
|
||||
id: tc.id.clone(),
|
||||
name: tc.function.name.clone(),
|
||||
arguments: serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null),
|
||||
arguments: serde_json::from_str(&tc.function.arguments)
|
||||
.unwrap_or(serde_json::Value::Null),
|
||||
})
|
||||
.collect();
|
||||
|
||||
@ -318,11 +339,17 @@ impl LLMProvider for OpenAIProvider {
|
||||
};
|
||||
|
||||
if let Some(ref storage) = self.storage
|
||||
&& let Err(e) = storage.append_llm_call(
|
||||
&self.name, &self.model_id, &req_body_str,
|
||||
Some(&text), None,
|
||||
&& let Err(e) = storage
|
||||
.append_llm_call(
|
||||
&self.name,
|
||||
&self.model_id,
|
||||
&req_body_str,
|
||||
Some(&text),
|
||||
None,
|
||||
start.elapsed().as_millis() as u64,
|
||||
).await {
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::warn!("failed to persist LLM call: {}", e);
|
||||
}
|
||||
|
||||
@ -386,6 +413,9 @@ mod tests {
|
||||
assert_eq!(tool_calls[0]["id"], "call_1");
|
||||
assert_eq!(tool_calls[0]["type"], "function");
|
||||
assert_eq!(tool_calls[0]["function"]["name"], "calculator");
|
||||
assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}");
|
||||
assert_eq!(
|
||||
tool_calls[0]["function"]["arguments"],
|
||||
"{\"expression\":\"1+1\"}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
use crate::bus::message::ContentBlock;
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::bus::message::ContentBlock;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
@ -61,7 +61,11 @@ impl Message {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
pub fn tool(
|
||||
tool_call_id: impl Into<String>,
|
||||
tool_name: impl Into<String>,
|
||||
content: impl Into<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
role: "tool".to_string(),
|
||||
content: vec![ContentBlock::text(content)],
|
||||
|
||||
@ -5,11 +5,11 @@ use std::time::Instant;
|
||||
use tokio::time;
|
||||
|
||||
use crate::config::SchedulerConfig;
|
||||
use crate::session::session::HandleResult;
|
||||
use crate::session::SessionManager;
|
||||
use crate::session::session::HandleResult;
|
||||
use crate::storage::JobRun;
|
||||
use crate::storage::ScheduledJob;
|
||||
use crate::storage::Storage;
|
||||
use crate::storage::JobRun;
|
||||
|
||||
pub use types::Schedule;
|
||||
|
||||
@ -89,7 +89,11 @@ impl Scheduler {
|
||||
|
||||
let now = now_ms();
|
||||
|
||||
let due = match self.storage.due_scheduled_jobs(now, self.config.max_concurrent).await {
|
||||
let due = match self
|
||||
.storage
|
||||
.due_scheduled_jobs(now, self.config.max_concurrent)
|
||||
.await
|
||||
{
|
||||
Ok(jobs) => jobs,
|
||||
Err(e) => {
|
||||
tracing::error!("scheduler: failed to query due jobs: {}", e);
|
||||
@ -107,7 +111,11 @@ impl Scheduler {
|
||||
let start = Instant::now();
|
||||
let started_at = now_ms();
|
||||
|
||||
if let Err(e) = self.storage.touch_scheduled_job_last_run(&job.id, started_at).await {
|
||||
if let Err(e) = self
|
||||
.storage
|
||||
.touch_scheduled_job_last_run(&job.id, started_at)
|
||||
.await
|
||||
{
|
||||
tracing::error!(job_id = %job.id, "scheduler: failed to touch last_run_at: {}", e);
|
||||
continue;
|
||||
}
|
||||
@ -135,7 +143,10 @@ impl Scheduler {
|
||||
match result {
|
||||
Ok(HandleResult::AgentResponse(output)) => {
|
||||
let output_truncated = if output.len() > 8000 {
|
||||
format!("{}...[truncated]", &output[..output.ceil_char_boundary(8000)])
|
||||
format!(
|
||||
"{}...[truncated]",
|
||||
&output[..output.ceil_char_boundary(8000)]
|
||||
)
|
||||
} else {
|
||||
output.clone()
|
||||
};
|
||||
@ -155,7 +166,11 @@ impl Scheduler {
|
||||
tracing::error!(job_id = %job.id, "scheduler: failed to record run: {}", e);
|
||||
}
|
||||
|
||||
if let Err(e) = self.storage.set_scheduled_job_last_status(&job.id, "ok", None).await {
|
||||
if let Err(e) = self
|
||||
.storage
|
||||
.set_scheduled_job_last_status(&job.id, "ok", None)
|
||||
.await
|
||||
{
|
||||
tracing::error!(job_id = %job.id, "scheduler: failed to set last_status: {}", e);
|
||||
}
|
||||
|
||||
@ -199,9 +214,11 @@ impl Scheduler {
|
||||
tracing::error!(job_id = %job.id, "scheduler: failed to record error run: {}", e2);
|
||||
}
|
||||
|
||||
if let Err(e2) = self.storage.set_scheduled_job_last_status(
|
||||
&job.id, "error", Some(&error_str),
|
||||
).await {
|
||||
if let Err(e2) = self
|
||||
.storage
|
||||
.set_scheduled_job_last_status(&job.id, "error", Some(&error_str))
|
||||
.await
|
||||
{
|
||||
tracing::error!(job_id = %job.id, "scheduler: failed to set error status: {}", e2);
|
||||
}
|
||||
|
||||
@ -231,17 +248,23 @@ impl Scheduler {
|
||||
self.storage.remove_scheduled_job(&job.id).await?;
|
||||
tracing::info!(job_id = %job.id, "scheduler: one-shot job deleted after run");
|
||||
} else {
|
||||
self.storage.set_scheduled_job_enabled(&job.id, false).await?;
|
||||
self.storage
|
||||
.set_scheduled_job_enabled(&job.id, false)
|
||||
.await?;
|
||||
tracing::info!(job_id = %job.id, "scheduler: one-shot job disabled after run");
|
||||
}
|
||||
}
|
||||
Schedule::Every { .. } | Schedule::Cron { .. } => {
|
||||
if let Some(next) = next_run_for_schedule(&job.schedule, now) {
|
||||
self.storage.set_scheduled_job_next_run(&job.id, next).await?;
|
||||
self.storage
|
||||
.set_scheduled_job_next_run(&job.id, next)
|
||||
.await?;
|
||||
tracing::info!(job_id = %job.id, next_run_at = %next, "scheduler: job rescheduled");
|
||||
} else {
|
||||
tracing::error!(job_id = %job.id, "scheduler: could not compute next run -- disabling job");
|
||||
self.storage.set_scheduled_job_enabled(&job.id, false).await?;
|
||||
self.storage
|
||||
.set_scheduled_job_enabled(&job.id, false)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -22,32 +22,20 @@ pub enum SessionCommand {
|
||||
dialog_id: String,
|
||||
},
|
||||
/// Get the current dialog for a chat
|
||||
GetCurrentDialog {
|
||||
channel: String,
|
||||
chat_id: String,
|
||||
},
|
||||
GetCurrentDialog { channel: String, chat_id: String },
|
||||
/// Rename a dialog
|
||||
RenameDialog {
|
||||
session_id: UnifiedSessionId,
|
||||
title: String,
|
||||
},
|
||||
/// Archive a dialog
|
||||
ArchiveDialog {
|
||||
session_id: UnifiedSessionId,
|
||||
},
|
||||
ArchiveDialog { session_id: UnifiedSessionId },
|
||||
/// Delete a dialog
|
||||
DeleteDialog {
|
||||
session_id: UnifiedSessionId,
|
||||
},
|
||||
DeleteDialog { session_id: UnifiedSessionId },
|
||||
/// Clear dialog history
|
||||
ClearHistory {
|
||||
session_id: UnifiedSessionId,
|
||||
},
|
||||
ClearHistory { session_id: UnifiedSessionId },
|
||||
/// Get list of available slash commands
|
||||
GetSlashCommands {
|
||||
channel: String,
|
||||
chat_id: String,
|
||||
},
|
||||
GetSlashCommands { channel: String, chat_id: String },
|
||||
/// Execute a slash command
|
||||
ExecuteSlashCommand {
|
||||
command: String,
|
||||
@ -60,7 +48,11 @@ pub enum SessionCommand {
|
||||
|
||||
impl SessionCommand {
|
||||
/// Create a CreateDialog command
|
||||
pub fn create_dialog(channel: impl Into<String>, chat_id: impl Into<String>, title: Option<String>) -> Self {
|
||||
pub fn create_dialog(
|
||||
channel: impl Into<String>,
|
||||
chat_id: impl Into<String>,
|
||||
title: Option<String>,
|
||||
) -> Self {
|
||||
Self::CreateDialog {
|
||||
channel: channel.into(),
|
||||
chat_id: chat_id.into(),
|
||||
@ -69,7 +61,11 @@ impl SessionCommand {
|
||||
}
|
||||
|
||||
/// Create a ListDialogs command
|
||||
pub fn list_dialogs(channel: impl Into<String>, chat_id: impl Into<String>, include_archived: bool) -> Self {
|
||||
pub fn list_dialogs(
|
||||
channel: impl Into<String>,
|
||||
chat_id: impl Into<String>,
|
||||
include_archived: bool,
|
||||
) -> Self {
|
||||
Self::ListDialogs {
|
||||
channel: channel.into(),
|
||||
chat_id: chat_id.into(),
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
use super::session_id::UnifiedSessionId;
|
||||
use super::session::SlashCommand;
|
||||
use super::session_id::UnifiedSessionId;
|
||||
|
||||
/// Dialog information returned by SessionManager
|
||||
#[derive(Debug, Clone)]
|
||||
@ -30,30 +30,20 @@ pub enum SessionEvent {
|
||||
session_id: Option<UnifiedSessionId>,
|
||||
},
|
||||
/// Dialog switched successfully
|
||||
DialogSwitched {
|
||||
session_id: UnifiedSessionId,
|
||||
},
|
||||
DialogSwitched { session_id: UnifiedSessionId },
|
||||
/// Dialog renamed
|
||||
DialogRenamed {
|
||||
session_id: UnifiedSessionId,
|
||||
title: String,
|
||||
},
|
||||
/// Dialog archived
|
||||
DialogArchived {
|
||||
session_id: UnifiedSessionId,
|
||||
},
|
||||
DialogArchived { session_id: UnifiedSessionId },
|
||||
/// Dialog deleted
|
||||
DialogDeleted {
|
||||
session_id: UnifiedSessionId,
|
||||
},
|
||||
DialogDeleted { session_id: UnifiedSessionId },
|
||||
/// Dialog history cleared
|
||||
HistoryCleared {
|
||||
session_id: UnifiedSessionId,
|
||||
},
|
||||
HistoryCleared { session_id: UnifiedSessionId },
|
||||
/// List of available slash commands
|
||||
SlashCommandsList {
|
||||
commands: Vec<SlashCommand>,
|
||||
},
|
||||
SlashCommandsList { commands: Vec<SlashCommand> },
|
||||
/// Slash command executed successfully
|
||||
SlashCommandExecuted {
|
||||
new_session_id: Option<UnifiedSessionId>,
|
||||
@ -70,8 +60,5 @@ pub enum SessionEvent {
|
||||
message_count: usize,
|
||||
},
|
||||
/// Error occurred
|
||||
Error {
|
||||
code: String,
|
||||
message: String,
|
||||
},
|
||||
Error { code: String, message: String },
|
||||
}
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
pub mod error;
|
||||
pub mod commands;
|
||||
pub mod error;
|
||||
pub mod events;
|
||||
pub mod session;
|
||||
pub mod session_id;
|
||||
|
||||
pub use error::SessionError;
|
||||
pub use commands::SessionCommand;
|
||||
pub use events::{SessionEvent, DialogInfo};
|
||||
pub use session::{Session, SessionManager, SlashCommand, SLASH_COMMANDS};
|
||||
pub use error::SessionError;
|
||||
pub use events::{DialogInfo, SessionEvent};
|
||||
pub use session::{SLASH_COMMANDS, Session, SessionManager, SlashCommand};
|
||||
pub use session_id::UnifiedSessionId;
|
||||
|
||||
@ -8,7 +8,6 @@
|
||||
///
|
||||
/// For simple cases where only one dialog exists per chat:
|
||||
/// - `dialog_id` defaults to `"default"`
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub const DEFAULT_DIALOG_ID: &str = "default";
|
||||
@ -22,7 +21,11 @@ pub struct UnifiedSessionId {
|
||||
|
||||
impl UnifiedSessionId {
|
||||
/// Create a new UnifiedSessionId
|
||||
pub fn new(channel: impl Into<String>, chat_id: impl Into<String>, dialog_id: impl Into<String>) -> Self {
|
||||
pub fn new(
|
||||
channel: impl Into<String>,
|
||||
chat_id: impl Into<String>,
|
||||
dialog_id: impl Into<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
channel: channel.into(),
|
||||
chat_id: chat_id.into(),
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
use std::path::Path;
|
||||
|
||||
use super::embedded::{EmbeddedSkill, EMBEDDED_SKILLS};
|
||||
use super::embedded::{EMBEDDED_SKILLS, EmbeddedSkill};
|
||||
|
||||
pub fn install_builtin_skills(target_dir: &Path) {
|
||||
for skill in EMBEDDED_SKILLS {
|
||||
@ -22,8 +22,7 @@ pub fn install_builtin_skills(target_dir: &Path) {
|
||||
}
|
||||
|
||||
fn install_one(skill: &EmbeddedSkill, target_dir: &Path) -> Result<(), String> {
|
||||
let decompressed = zstd::decode_all(skill.data)
|
||||
.map_err(|e| format!("zstd decode: {}", e))?;
|
||||
let decompressed = zstd::decode_all(skill.data).map_err(|e| format!("zstd decode: {}", e))?;
|
||||
|
||||
let mut archive = tar::Archive::new(decompressed.as_slice());
|
||||
archive
|
||||
|
||||
@ -120,7 +120,11 @@ impl SkillsLoader {
|
||||
let count = loaded.len();
|
||||
let mut replaced = 0usize;
|
||||
for skill in loaded {
|
||||
if let Some(existing) = state.loaded_skills.iter_mut().find(|s| s.name == skill.name) {
|
||||
if let Some(existing) = state
|
||||
.loaded_skills
|
||||
.iter_mut()
|
||||
.find(|s| s.name == skill.name)
|
||||
{
|
||||
*existing = skill;
|
||||
replaced += 1;
|
||||
} else {
|
||||
@ -138,12 +142,17 @@ impl SkillsLoader {
|
||||
|
||||
// Load from workspace skills dir (highest priority) — replace same-name skills
|
||||
if let Some(ref ws_dir) = self.workspace_skills_dir
|
||||
&& ws_dir.exists() {
|
||||
&& ws_dir.exists()
|
||||
{
|
||||
let loaded = self.load_skills_from_dir(ws_dir);
|
||||
let count = loaded.len();
|
||||
let mut replaced = 0usize;
|
||||
for skill in loaded {
|
||||
if let Some(existing) = state.loaded_skills.iter_mut().find(|s| s.name == skill.name) {
|
||||
if let Some(existing) = state
|
||||
.loaded_skills
|
||||
.iter_mut()
|
||||
.find(|s| s.name == skill.name)
|
||||
{
|
||||
*existing = skill;
|
||||
replaced += 1;
|
||||
} else {
|
||||
@ -164,7 +173,11 @@ impl SkillsLoader {
|
||||
if state.loaded_skills.is_empty() {
|
||||
tracing::debug!("No skills found in any skills directory");
|
||||
} else {
|
||||
tracing::info!(count = state.loaded_skills.len(), "Loaded {} skills total", state.loaded_skills.len());
|
||||
tracing::info!(
|
||||
count = state.loaded_skills.len(),
|
||||
"Loaded {} skills total",
|
||||
state.loaded_skills.len()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -215,7 +228,8 @@ impl SkillsLoader {
|
||||
let mut max_mtime = None;
|
||||
|
||||
if let Ok(metadata) = std::fs::metadata(dir)
|
||||
&& let Ok(mtime) = metadata.modified() {
|
||||
&& let Ok(mtime) = metadata.modified()
|
||||
{
|
||||
max_mtime = Some(mtime);
|
||||
}
|
||||
|
||||
@ -224,7 +238,8 @@ impl SkillsLoader {
|
||||
let path = entry.path();
|
||||
if let Ok(metadata) = std::fs::metadata(&path)
|
||||
&& let Ok(mtime) = metadata.modified()
|
||||
&& max_mtime.is_none_or(|current| mtime > current) {
|
||||
&& max_mtime.is_none_or(|current| mtime > current)
|
||||
{
|
||||
max_mtime = Some(mtime);
|
||||
}
|
||||
}
|
||||
@ -244,7 +259,12 @@ impl SkillsLoader {
|
||||
pub fn get_always_skills(&self) -> Vec<Skill> {
|
||||
self.reload_if_changed();
|
||||
let state = self.state.lock().unwrap();
|
||||
state.loaded_skills.iter().filter(|s| s.always).cloned().collect()
|
||||
state
|
||||
.loaded_skills
|
||||
.iter()
|
||||
.filter(|s| s.always)
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get a specific skill by name (checks for changes first)
|
||||
@ -258,7 +278,8 @@ impl SkillsLoader {
|
||||
pub fn list_skills(&self) -> Vec<(String, String)> {
|
||||
self.reload_if_changed();
|
||||
let state = self.state.lock().unwrap();
|
||||
state.loaded_skills
|
||||
state
|
||||
.loaded_skills
|
||||
.iter()
|
||||
.map(|s| (s.name.clone(), s.description.clone()))
|
||||
.collect()
|
||||
@ -279,15 +300,21 @@ impl SkillsLoader {
|
||||
prompt.push_str("### 目录说明\n\n");
|
||||
prompt.push_str("- `~/.agents/skills/` — 外部共享 skill 目录(第三方、系统级 skill)\n");
|
||||
prompt.push_str("- `~/.picobot/skills/` — 安装 skill 的默认目录\n");
|
||||
prompt.push_str("- `{workspace}/skills/` — 工作目录下的 skill,picobot 自行创建的 skill 存放于此\n\n");
|
||||
prompt.push_str("安装或创建 skill 时请按上述目录规范存放,创建skill时不要和已有skill同名。\n\n");
|
||||
prompt.push_str(
|
||||
"- `{workspace}/skills/` — 工作目录下的 skill,picobot 自行创建的 skill 存放于此\n\n",
|
||||
);
|
||||
prompt.push_str(
|
||||
"安装或创建 skill 时请按上述目录规范存放,创建skill时不要和已有skill同名。\n\n",
|
||||
);
|
||||
|
||||
// Always skills summary
|
||||
let always_skills: Vec<_> = state.loaded_skills.iter().filter(|s| s.always).collect();
|
||||
if !always_skills.is_empty() {
|
||||
prompt.push_str("### 常用技能\n\n");
|
||||
for skill in &always_skills {
|
||||
let path_str = skill.path.as_ref()
|
||||
let path_str = skill
|
||||
.path
|
||||
.as_ref()
|
||||
.map(|p| p.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|| "—".to_string());
|
||||
prompt.push_str(&format!(
|
||||
@ -300,8 +327,12 @@ impl SkillsLoader {
|
||||
|
||||
// Usage instructions
|
||||
prompt.push_str("### 使用方法\n\n");
|
||||
prompt.push_str("- 使用 `get_skill` 工具 action=\"list\" 列出所有可用 skill 及其名称、简介、路径\n");
|
||||
prompt.push_str("- 使用 `get_skill` 工具 action=\"get\" 并提供 `skill_name` 获取指定 skill 完整内容\n");
|
||||
prompt.push_str(
|
||||
"- 使用 `get_skill` 工具 action=\"list\" 列出所有可用 skill 及其名称、简介、路径\n",
|
||||
);
|
||||
prompt.push_str(
|
||||
"- 使用 `get_skill` 工具 action=\"get\" 并提供 `skill_name` 获取指定 skill 完整内容\n",
|
||||
);
|
||||
prompt.push_str("- 当用户询问已安装的 skill 相关信息时,需重新调用 get_skill 工具查询最新内容,避免 skill 已变更导致信息过时\n");
|
||||
|
||||
// Always skills full content
|
||||
@ -338,8 +369,7 @@ impl SkillsLoader {
|
||||
}
|
||||
|
||||
match std::fs::read_to_string(&skill_file) {
|
||||
Ok(content) => {
|
||||
match self.parse_skill(&path, &content) {
|
||||
Ok(content) => match self.parse_skill(&path, &content) {
|
||||
Some(skill) => {
|
||||
tracing::debug!(
|
||||
skill = %skill.name,
|
||||
@ -355,8 +385,7 @@ impl SkillsLoader {
|
||||
"Failed to parse skill"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
path = %skill_file.display(),
|
||||
@ -447,7 +476,6 @@ impl Default for SkillsLoader {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Extract first non-empty, non-heading line as description
|
||||
fn extract_description(content: &str) -> String {
|
||||
content
|
||||
|
||||
@ -241,9 +241,8 @@ impl super::Storage {
|
||||
let cutoff = chrono::Utc::now() - chrono::Duration::days(retention_days as i64);
|
||||
let cutoff_str = cutoff.to_rfc3339();
|
||||
|
||||
let result = sqlx::query(
|
||||
"DELETE FROM memories WHERE category = 'timeline' AND created_at < ?",
|
||||
)
|
||||
let result =
|
||||
sqlx::query("DELETE FROM memories WHERE category = 'timeline' AND created_at < ?")
|
||||
.bind(&cutoff_str)
|
||||
.execute(self.pool())
|
||||
.await?;
|
||||
@ -276,9 +275,7 @@ impl super::Storage {
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_memory_rows(
|
||||
rows: &[sqlx::sqlite::SqliteRow],
|
||||
) -> Result<Vec<MemoryEntry>, StorageError> {
|
||||
fn parse_memory_rows(rows: &[sqlx::sqlite::SqliteRow]) -> Result<Vec<MemoryEntry>, StorageError> {
|
||||
rows.iter()
|
||||
.map(|row| {
|
||||
Ok(MemoryEntry {
|
||||
|
||||
@ -165,7 +165,11 @@ impl crate::storage::Storage {
|
||||
}
|
||||
|
||||
/// Update next_run_at and last_run_at for a job.
|
||||
pub async fn set_scheduled_job_next_run(&self, id: &str, next_run_at: i64) -> anyhow::Result<()> {
|
||||
pub async fn set_scheduled_job_next_run(
|
||||
&self,
|
||||
id: &str,
|
||||
next_run_at: i64,
|
||||
) -> anyhow::Result<()> {
|
||||
let now = now_ms();
|
||||
sqlx::query(
|
||||
"UPDATE scheduled_jobs SET next_run_at = ?, last_run_at = ?, updated_at = ? WHERE id = ?",
|
||||
@ -331,7 +335,9 @@ mod tests {
|
||||
async fn setup_storage() -> Storage {
|
||||
let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
|
||||
let storage = Storage { pool };
|
||||
Storage::init_scheduler_schema(storage.pool()).await.unwrap();
|
||||
Storage::init_scheduler_schema(storage.pool())
|
||||
.await
|
||||
.unwrap();
|
||||
storage
|
||||
}
|
||||
|
||||
@ -450,7 +456,10 @@ mod tests {
|
||||
updated_at: t,
|
||||
};
|
||||
storage.add_scheduled_job(&job).await.unwrap();
|
||||
storage.set_scheduled_job_enabled("job-toggle", false).await.unwrap();
|
||||
storage
|
||||
.set_scheduled_job_enabled("job-toggle", false)
|
||||
.await
|
||||
.unwrap();
|
||||
let got = storage.get_scheduled_job("job-toggle").await.unwrap();
|
||||
assert!(!got.enabled);
|
||||
}
|
||||
@ -461,31 +470,55 @@ mod tests {
|
||||
let t = now();
|
||||
let jobs = vec![
|
||||
ScheduledJob {
|
||||
id: "due".into(), name: "due".into(),
|
||||
schedule: Schedule::At { at: t }, prompt: "1".into(),
|
||||
channel: "cli_chat".into(), chat_id: "c".into(),
|
||||
model: None, enabled: true, delete_after_run: false,
|
||||
next_run_at: t - 1000, last_run_at: None,
|
||||
last_status: None, last_error: None,
|
||||
created_at: t, updated_at: t,
|
||||
id: "due".into(),
|
||||
name: "due".into(),
|
||||
schedule: Schedule::At { at: t },
|
||||
prompt: "1".into(),
|
||||
channel: "cli_chat".into(),
|
||||
chat_id: "c".into(),
|
||||
model: None,
|
||||
enabled: true,
|
||||
delete_after_run: false,
|
||||
next_run_at: t - 1000,
|
||||
last_run_at: None,
|
||||
last_status: None,
|
||||
last_error: None,
|
||||
created_at: t,
|
||||
updated_at: t,
|
||||
},
|
||||
ScheduledJob {
|
||||
id: "future".into(), name: "future".into(),
|
||||
schedule: Schedule::At { at: t + 99999999 }, prompt: "2".into(),
|
||||
channel: "cli_chat".into(), chat_id: "c".into(),
|
||||
model: None, enabled: true, delete_after_run: false,
|
||||
next_run_at: t + 99999999, last_run_at: None,
|
||||
last_status: None, last_error: None,
|
||||
created_at: t, updated_at: t,
|
||||
id: "future".into(),
|
||||
name: "future".into(),
|
||||
schedule: Schedule::At { at: t + 99999999 },
|
||||
prompt: "2".into(),
|
||||
channel: "cli_chat".into(),
|
||||
chat_id: "c".into(),
|
||||
model: None,
|
||||
enabled: true,
|
||||
delete_after_run: false,
|
||||
next_run_at: t + 99999999,
|
||||
last_run_at: None,
|
||||
last_status: None,
|
||||
last_error: None,
|
||||
created_at: t,
|
||||
updated_at: t,
|
||||
},
|
||||
ScheduledJob {
|
||||
id: "disabled-due".into(), name: "disabled due".into(),
|
||||
schedule: Schedule::At { at: t }, prompt: "3".into(),
|
||||
channel: "cli_chat".into(), chat_id: "c".into(),
|
||||
model: None, enabled: false, delete_after_run: false,
|
||||
next_run_at: t - 1000, last_run_at: None,
|
||||
last_status: None, last_error: None,
|
||||
created_at: t, updated_at: t,
|
||||
id: "disabled-due".into(),
|
||||
name: "disabled due".into(),
|
||||
schedule: Schedule::At { at: t },
|
||||
prompt: "3".into(),
|
||||
channel: "cli_chat".into(),
|
||||
chat_id: "c".into(),
|
||||
model: None,
|
||||
enabled: false,
|
||||
delete_after_run: false,
|
||||
next_run_at: t - 1000,
|
||||
last_run_at: None,
|
||||
last_status: None,
|
||||
last_error: None,
|
||||
created_at: t,
|
||||
updated_at: t,
|
||||
},
|
||||
];
|
||||
for j in &jobs {
|
||||
@ -501,24 +534,39 @@ mod tests {
|
||||
let storage = setup_storage().await;
|
||||
let t = now();
|
||||
let job = ScheduledJob {
|
||||
id: "job-run".into(), name: "run test".into(),
|
||||
id: "job-run".into(),
|
||||
name: "run test".into(),
|
||||
schedule: Schedule::Every { every_ms: 1000 },
|
||||
prompt: "hi".into(), channel: "cli_chat".into(), chat_id: "c".into(),
|
||||
model: None, enabled: true, delete_after_run: false,
|
||||
next_run_at: t, last_run_at: None,
|
||||
last_status: None, last_error: None,
|
||||
created_at: t, updated_at: t,
|
||||
prompt: "hi".into(),
|
||||
channel: "cli_chat".into(),
|
||||
chat_id: "c".into(),
|
||||
model: None,
|
||||
enabled: true,
|
||||
delete_after_run: false,
|
||||
next_run_at: t,
|
||||
last_run_at: None,
|
||||
last_status: None,
|
||||
last_error: None,
|
||||
created_at: t,
|
||||
updated_at: t,
|
||||
};
|
||||
storage.add_scheduled_job(&job).await.unwrap();
|
||||
|
||||
let run = super::JobRun {
|
||||
id: 0, job_id: "job-run".into(),
|
||||
started_at: t, finished_at: t + 500,
|
||||
status: "ok".into(), output: Some("hello".into()),
|
||||
error: None, duration_ms: 500,
|
||||
id: 0,
|
||||
job_id: "job-run".into(),
|
||||
started_at: t,
|
||||
finished_at: t + 500,
|
||||
status: "ok".into(),
|
||||
output: Some("hello".into()),
|
||||
error: None,
|
||||
duration_ms: 500,
|
||||
};
|
||||
storage.record_scheduled_job_run(&run).await.unwrap();
|
||||
let runs = storage.list_scheduled_job_runs("job-run", 10).await.unwrap();
|
||||
let runs = storage
|
||||
.list_scheduled_job_runs("job-run", 10)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(runs.len(), 1);
|
||||
assert_eq!(runs[0].status, "ok");
|
||||
assert_eq!(runs[0].output.as_deref(), Some("hello"));
|
||||
@ -529,22 +577,34 @@ mod tests {
|
||||
let storage = setup_storage().await;
|
||||
let t = now();
|
||||
let job = ScheduledJob {
|
||||
id: "job-update".into(), name: "old name".into(),
|
||||
id: "job-update".into(),
|
||||
name: "old name".into(),
|
||||
schedule: Schedule::Every { every_ms: 1000 },
|
||||
prompt: "old prompt".into(), channel: "feishu".into(),
|
||||
chat_id: "oc_1".into(), model: None,
|
||||
enabled: true, delete_after_run: false,
|
||||
next_run_at: t, last_run_at: None,
|
||||
last_status: None, last_error: None,
|
||||
created_at: t, updated_at: t,
|
||||
prompt: "old prompt".into(),
|
||||
channel: "feishu".into(),
|
||||
chat_id: "oc_1".into(),
|
||||
model: None,
|
||||
enabled: true,
|
||||
delete_after_run: false,
|
||||
next_run_at: t,
|
||||
last_run_at: None,
|
||||
last_status: None,
|
||||
last_error: None,
|
||||
created_at: t,
|
||||
updated_at: t,
|
||||
};
|
||||
storage.add_scheduled_job(&job).await.unwrap();
|
||||
storage.update_scheduled_job(
|
||||
storage
|
||||
.update_scheduled_job(
|
||||
"job-update",
|
||||
Some("new prompt".into()),
|
||||
Some(Schedule::Every { every_ms: 60000 }),
|
||||
None, None, None,
|
||||
).await.unwrap();
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let got = storage.get_scheduled_job("job-update").await.unwrap();
|
||||
assert_eq!(got.prompt, "new prompt");
|
||||
}
|
||||
|
||||
@ -167,10 +167,7 @@ impl Tool for BashTool {
|
||||
Err(_) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Command timed out after {} seconds",
|
||||
timeout_secs
|
||||
)),
|
||||
error: Some(format!("Command timed out after {} seconds", timeout_secs)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
@ -249,10 +246,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_pwd_command() {
|
||||
let tool = BashTool::new();
|
||||
let result = tool
|
||||
.execute(json!({ "command": "pwd" }))
|
||||
.await
|
||||
.unwrap();
|
||||
let result = tool.execute(json!({ "command": "pwd" })).await.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
}
|
||||
@ -260,7 +254,10 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_ls_command() {
|
||||
let tool = BashTool::new();
|
||||
let result = tool.execute(json!({ "command": "ls -la /tmp" })).await.unwrap();
|
||||
let result = tool
|
||||
.execute(json!({ "command": "ls -la /tmp" }))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
}
|
||||
|
||||
@ -5,7 +5,7 @@ use std::time::Duration;
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use base64::Engine;
|
||||
use fantoccini::actions::{InputSource, MouseActions, PointerAction, MOUSE_BUTTON_LEFT};
|
||||
use fantoccini::actions::{InputSource, MOUSE_BUTTON_LEFT, MouseActions, PointerAction};
|
||||
use fantoccini::key::Key;
|
||||
use fantoccini::{Client, ClientBuilder, Locator};
|
||||
use serde::{Deserialize, Serialize};
|
||||
@ -63,7 +63,9 @@ impl BrowserTool {
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum BrowserAction {
|
||||
Open { url: String },
|
||||
Open {
|
||||
url: String,
|
||||
},
|
||||
Snapshot {
|
||||
#[serde(default)]
|
||||
interactive_only: bool,
|
||||
@ -72,10 +74,20 @@ pub enum BrowserAction {
|
||||
#[serde(default)]
|
||||
depth: Option<i64>,
|
||||
},
|
||||
Click { selector: String },
|
||||
Fill { selector: String, value: String },
|
||||
Type { selector: Option<String>, text: String },
|
||||
GetText { selector: String },
|
||||
Click {
|
||||
selector: String,
|
||||
},
|
||||
Fill {
|
||||
selector: String,
|
||||
value: String,
|
||||
},
|
||||
Type {
|
||||
selector: Option<String>,
|
||||
text: String,
|
||||
},
|
||||
GetText {
|
||||
selector: String,
|
||||
},
|
||||
GetTitle,
|
||||
GetUrl,
|
||||
Screenshot {
|
||||
@ -84,7 +96,9 @@ pub enum BrowserAction {
|
||||
#[serde(default)]
|
||||
return_base64: bool,
|
||||
},
|
||||
Focus { selector: String },
|
||||
Focus {
|
||||
selector: String,
|
||||
},
|
||||
Wait {
|
||||
#[serde(default)]
|
||||
selector: Option<String>,
|
||||
@ -93,9 +107,16 @@ pub enum BrowserAction {
|
||||
#[serde(default)]
|
||||
text: Option<String>,
|
||||
},
|
||||
Press { key: String },
|
||||
Hover { selector: String },
|
||||
ClickAt { x: u32, y: u32 },
|
||||
Press {
|
||||
key: String,
|
||||
},
|
||||
Hover {
|
||||
selector: String,
|
||||
},
|
||||
ClickAt {
|
||||
x: u32,
|
||||
y: u32,
|
||||
},
|
||||
Scroll {
|
||||
direction: String,
|
||||
#[serde(default)]
|
||||
@ -120,13 +141,8 @@ fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result<Browse
|
||||
.get("interactive_only")
|
||||
.and_then(Value::as_bool)
|
||||
.unwrap_or(true),
|
||||
compact: args
|
||||
.get("compact")
|
||||
.and_then(Value::as_bool)
|
||||
.unwrap_or(true),
|
||||
depth: args
|
||||
.get("depth")
|
||||
.and_then(|v| v.as_i64()),
|
||||
compact: args.get("compact").and_then(Value::as_bool).unwrap_or(true),
|
||||
depth: args.get("depth").and_then(|v| v.as_i64()),
|
||||
}),
|
||||
"click" => {
|
||||
let selector = args
|
||||
@ -198,10 +214,7 @@ fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result<Browse
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from),
|
||||
ms: args.get("ms").and_then(|v| v.as_u64()),
|
||||
text: args
|
||||
.get("text")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from),
|
||||
text: args.get("text").and_then(|v| v.as_str()).map(String::from),
|
||||
}),
|
||||
"press" => {
|
||||
let key = args
|
||||
@ -239,11 +252,13 @@ fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result<Browse
|
||||
let x = args
|
||||
.get("x")
|
||||
.and_then(|v| v.as_u64())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'x' for click_at"))? as u32;
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'x' for click_at"))?
|
||||
as u32;
|
||||
let y = args
|
||||
.get("y")
|
||||
.and_then(|v| v.as_u64())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'y' for click_at"))? as u32;
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'y' for click_at"))?
|
||||
as u32;
|
||||
Ok(BrowserAction::ClickAt { x, y })
|
||||
}
|
||||
other => anyhow::bail!("Unsupported browser action: {}", other),
|
||||
@ -488,7 +503,11 @@ impl BrowserState {
|
||||
}
|
||||
Err(e) => return Err(e.into()),
|
||||
}
|
||||
tracing::debug!(action = "fill", output_len = value.len(), "Browser action completed");
|
||||
tracing::debug!(
|
||||
action = "fill",
|
||||
output_len = value.len(),
|
||||
"Browser action completed"
|
||||
);
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Filled {} with {}", selector, value),
|
||||
@ -573,7 +592,10 @@ impl BrowserState {
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
BrowserAction::Screenshot { path, return_base64 } => {
|
||||
BrowserAction::Screenshot {
|
||||
path,
|
||||
return_base64,
|
||||
} => {
|
||||
let client = self.active_client()?;
|
||||
let png = client.screenshot().await?;
|
||||
let save_path = path.unwrap_or_else(|| {
|
||||
@ -588,14 +610,25 @@ impl BrowserState {
|
||||
tokio::fs::write(&save_path, &png).await?;
|
||||
if return_base64 {
|
||||
let b64 = base64::engine::general_purpose::STANDARD.encode(&png);
|
||||
tracing::debug!(action = "screenshot", output_len = b64.len(), "Browser action completed");
|
||||
tracing::debug!(
|
||||
action = "screenshot",
|
||||
output_len = b64.len(),
|
||||
"Browser action completed"
|
||||
);
|
||||
return Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Screenshot saved to {}. Base64: data:image/png;base64,{}", save_path, b64),
|
||||
output: format!(
|
||||
"Screenshot saved to {}. Base64: data:image/png;base64,{}",
|
||||
save_path, b64
|
||||
),
|
||||
error: None,
|
||||
});
|
||||
}
|
||||
tracing::debug!(action = "screenshot", output_len = save_path.len(), "Browser action completed");
|
||||
tracing::debug!(
|
||||
action = "screenshot",
|
||||
output_len = save_path.len(),
|
||||
"Browser action completed"
|
||||
);
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Screenshot saved to {}", save_path),
|
||||
@ -611,18 +644,18 @@ impl BrowserState {
|
||||
vec![serde_json::to_value(el)?],
|
||||
)
|
||||
.await?;
|
||||
tracing::debug!(action = "focus", output_len = selector.len(), "Browser action completed");
|
||||
tracing::debug!(
|
||||
action = "focus",
|
||||
output_len = selector.len(),
|
||||
"Browser action completed"
|
||||
);
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Focused {}", selector),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
BrowserAction::Wait {
|
||||
selector,
|
||||
ms,
|
||||
text,
|
||||
} => {
|
||||
BrowserAction::Wait { selector, ms, text } => {
|
||||
if let Some(sel) = selector {
|
||||
let client = self.active_client()?;
|
||||
wait_for_selector(client, &sel).await?;
|
||||
@ -719,9 +752,21 @@ impl BrowserState {
|
||||
let id = info.get("id").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let el_type = info.get("type").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let text = info.get("text").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let id_str = if id.is_empty() { String::new() } else { format!("#{id}") };
|
||||
let type_str = if el_type.is_empty() { String::new() } else { format!("[type={el_type}]") };
|
||||
let text_str = if text.is_empty() { String::new() } else { format!(" ({text})") };
|
||||
let id_str = if id.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!("#{id}")
|
||||
};
|
||||
let type_str = if el_type.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!("[type={el_type}]")
|
||||
};
|
||||
let text_str = if text.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!(" ({text})")
|
||||
};
|
||||
format!("Clicked at ({x},{y}) on <{tag}{id_str}{type_str}>{text_str}")
|
||||
}
|
||||
None => format!("Clicked at ({}, {})", x, y),
|
||||
@ -1090,10 +1135,7 @@ fn css_attr_escape(input: &str) -> String {
|
||||
}
|
||||
|
||||
fn xpath_contains_text(text: &str) -> String {
|
||||
format!(
|
||||
"//*[contains(normalize-space(.), {})]",
|
||||
xpath_literal(text)
|
||||
)
|
||||
format!("//*[contains(normalize-space(.), {})]", xpath_literal(text))
|
||||
}
|
||||
|
||||
fn xpath_literal(input: &str) -> String {
|
||||
@ -1140,7 +1182,10 @@ fn webdriver_key(key: &str) -> String {
|
||||
"pagedown" => Key::PageDown.to_string(),
|
||||
"space" => " ".to_string(),
|
||||
other => {
|
||||
tracing::warn!("Unrecognized key '{}', this will have no effect (press only supports single named keys)", other);
|
||||
tracing::warn!(
|
||||
"Unrecognized key '{}', this will have no effect (press only supports single named keys)",
|
||||
other
|
||||
);
|
||||
other.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
@ -659,10 +659,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_evaluate_missing_expression() {
|
||||
let tool = CalculatorTool::new();
|
||||
let result = tool
|
||||
.execute(json!({"function": "evaluate"}))
|
||||
.await
|
||||
.unwrap();
|
||||
let result = tool.execute(json!({"function": "evaluate"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
}
|
||||
|
||||
|
||||
@ -31,10 +31,7 @@ impl ContentSearchTool {
|
||||
for (i, line) in lines.iter().enumerate() {
|
||||
if output.len() + line.len() + 1 > MAX_OUTPUT_CHARS {
|
||||
let omitted = lines.len() - i;
|
||||
output.push_str(&format!(
|
||||
"\n... ({} matches omitted) ...",
|
||||
omitted
|
||||
));
|
||||
output.push_str(&format!("\n... ({} matches omitted) ...", omitted));
|
||||
break;
|
||||
}
|
||||
if !output.is_empty() {
|
||||
@ -113,18 +110,40 @@ impl Tool for ContentSearchTool {
|
||||
|
||||
let dir = self.resolve_dir(args.get("dir").and_then(|v| v.as_str()));
|
||||
let file_pattern = args.get("file_pattern").and_then(|v| v.as_str());
|
||||
let case_sensitive = args.get("case_sensitive").and_then(|v| v.as_bool()).unwrap_or(false);
|
||||
let context_lines = args.get("context_lines").and_then(|v| v.as_u64()).unwrap_or(0) as usize;
|
||||
let max_results = args.get("max_results").and_then(|v| v.as_u64()).unwrap_or(MAX_RESULTS as u64) as usize;
|
||||
let case_sensitive = args
|
||||
.get("case_sensitive")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
let context_lines = args
|
||||
.get("context_lines")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0) as usize;
|
||||
let max_results = args
|
||||
.get("max_results")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(MAX_RESULTS as u64) as usize;
|
||||
|
||||
let result = self.run_search(pattern, &dir, file_pattern, case_sensitive, context_lines, max_results).await;
|
||||
let result = self
|
||||
.run_search(
|
||||
pattern,
|
||||
&dir,
|
||||
file_pattern,
|
||||
case_sensitive,
|
||||
context_lines,
|
||||
max_results,
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(lines) => {
|
||||
let count = lines.len();
|
||||
let mut output = self.truncate_output(&lines);
|
||||
output.push_str(&format!("\n\n---\n共 {} 条匹配", count));
|
||||
Ok(ToolResult { success: true, output, error: None })
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
@ -146,22 +165,52 @@ impl ContentSearchTool {
|
||||
max_results: usize,
|
||||
) -> anyhow::Result<Vec<String>> {
|
||||
if which::which("rg").is_ok() {
|
||||
match self.search_with_rg(pattern, dir, file_pattern, case_sensitive, context_lines, max_results).await {
|
||||
match self
|
||||
.search_with_rg(
|
||||
pattern,
|
||||
dir,
|
||||
file_pattern,
|
||||
case_sensitive,
|
||||
context_lines,
|
||||
max_results,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(lines) => return Ok(lines),
|
||||
Err(e) => tracing::warn!("rg failed: {}, falling back", e),
|
||||
}
|
||||
}
|
||||
|
||||
if which::which("grep").is_ok() {
|
||||
match self.search_with_grep(pattern, dir, file_pattern, case_sensitive, context_lines, max_results).await {
|
||||
match self
|
||||
.search_with_grep(
|
||||
pattern,
|
||||
dir,
|
||||
file_pattern,
|
||||
case_sensitive,
|
||||
context_lines,
|
||||
max_results,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(lines) if !lines.is_empty() => return Ok(lines),
|
||||
Ok(_) => {},
|
||||
Ok(_) => {}
|
||||
Err(e) => tracing::warn!("grep failed: {}, falling back", e),
|
||||
}
|
||||
}
|
||||
|
||||
tracing::warn!("No rg/grep available, using built-in content search (much slower). Install ripgrep for better performance.");
|
||||
self.search_with_rust(pattern, dir, file_pattern, case_sensitive, context_lines, max_results).await
|
||||
tracing::warn!(
|
||||
"No rg/grep available, using built-in content search (much slower). Install ripgrep for better performance."
|
||||
);
|
||||
self.search_with_rust(
|
||||
pattern,
|
||||
dir,
|
||||
file_pattern,
|
||||
case_sensitive,
|
||||
context_lines,
|
||||
max_results,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn search_with_rg(
|
||||
@ -176,8 +225,10 @@ impl ContentSearchTool {
|
||||
let mut cmd = Command::new("rg");
|
||||
cmd.arg("-n")
|
||||
.arg("--no-heading")
|
||||
.arg("--color").arg("never")
|
||||
.arg("--max-count").arg(max_results.to_string())
|
||||
.arg("--color")
|
||||
.arg("never")
|
||||
.arg("--max-count")
|
||||
.arg(max_results.to_string())
|
||||
.arg(pattern)
|
||||
.arg(dir)
|
||||
.stdout(Stdio::piped())
|
||||
@ -193,10 +244,7 @@ impl ContentSearchTool {
|
||||
cmd.arg("--glob").arg(fp);
|
||||
}
|
||||
|
||||
let output = timeout(
|
||||
std::time::Duration::from_secs(TIMEOUT_SECS),
|
||||
cmd.output(),
|
||||
)
|
||||
let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("rg timed out after {}s", TIMEOUT_SECS))??;
|
||||
|
||||
@ -206,7 +254,8 @@ impl ContentSearchTool {
|
||||
}
|
||||
|
||||
let text = String::from_utf8_lossy(&output.stdout);
|
||||
let lines: Vec<String> = text.lines()
|
||||
let lines: Vec<String> = text
|
||||
.lines()
|
||||
.take(max_results)
|
||||
.map(|l| l.to_string())
|
||||
.collect();
|
||||
@ -242,15 +291,13 @@ impl ContentSearchTool {
|
||||
cmd.arg("--include").arg(fp);
|
||||
}
|
||||
|
||||
let output = timeout(
|
||||
std::time::Duration::from_secs(TIMEOUT_SECS),
|
||||
cmd.output(),
|
||||
)
|
||||
let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("grep timed out after {}s", TIMEOUT_SECS))??;
|
||||
|
||||
let text = String::from_utf8_lossy(&output.stdout);
|
||||
let lines: Vec<String> = text.lines()
|
||||
let lines: Vec<String> = text
|
||||
.lines()
|
||||
.take(max_results)
|
||||
.map(|l| l.to_string())
|
||||
.collect();
|
||||
@ -280,7 +327,9 @@ impl ContentSearchTool {
|
||||
if case_sensitive {
|
||||
regex::Regex::new(&re_str)
|
||||
} else {
|
||||
regex::RegexBuilder::new(&re_str).case_insensitive(true).build()
|
||||
regex::RegexBuilder::new(&re_str)
|
||||
.case_insensitive(true)
|
||||
.build()
|
||||
}
|
||||
});
|
||||
|
||||
@ -291,7 +340,14 @@ impl ContentSearchTool {
|
||||
};
|
||||
|
||||
let mut results = Vec::new();
|
||||
grep_dir(Path::new(dir), Path::new(dir), &re, file_re.as_ref(), &mut results, max_results)?;
|
||||
grep_dir(
|
||||
Path::new(dir),
|
||||
Path::new(dir),
|
||||
&re,
|
||||
file_re.as_ref(),
|
||||
&mut results,
|
||||
max_results,
|
||||
)?;
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
@ -350,14 +406,17 @@ fn grep_dir(
|
||||
|
||||
if path.is_dir() {
|
||||
if let Some(name) = rel.file_name().and_then(|n| n.to_str())
|
||||
&& name.starts_with('.') && name.len() > 1 {
|
||||
&& name.starts_with('.')
|
||||
&& name.len() > 1
|
||||
{
|
||||
continue;
|
||||
}
|
||||
grep_dir(base, &path, re, file_re, results, max)?;
|
||||
} else if path.is_file() {
|
||||
if let Some(file_re) = file_re
|
||||
&& let Some(name) = rel.file_name().and_then(|n| n.to_str())
|
||||
&& !file_re.is_match(name) {
|
||||
&& !file_re.is_match(name)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -391,8 +450,16 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_content_search_rust_fallback() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
fs::write(dir.path().join("main.rs"), "fn main() {\n let x = 42;\n println!(\"hello\");\n}").unwrap();
|
||||
fs::write(dir.path().join("lib.rs"), "pub fn foo() -> u32 {\n let y = 42;\n y\n}").unwrap();
|
||||
fs::write(
|
||||
dir.path().join("main.rs"),
|
||||
"fn main() {\n let x = 42;\n println!(\"hello\");\n}",
|
||||
)
|
||||
.unwrap();
|
||||
fs::write(
|
||||
dir.path().join("lib.rs"),
|
||||
"pub fn foo() -> u32 {\n let y = 42;\n y\n}",
|
||||
)
|
||||
.unwrap();
|
||||
fs::write(dir.path().join("README.md"), "# Project\nHello world").unwrap();
|
||||
|
||||
let tool = ContentSearchTool::new();
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::{json, Value};
|
||||
use serde_json::{Value, json};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::scheduler::{next_run_for_schedule, Schedule};
|
||||
use crate::scheduler::{Schedule, next_run_for_schedule};
|
||||
use crate::storage::{ScheduledJob, Storage};
|
||||
use crate::tools::traits::{Tool, ToolResult};
|
||||
|
||||
@ -229,10 +229,7 @@ impl Tool for CronListTool {
|
||||
}
|
||||
|
||||
async fn execute(&self, args: Value) -> anyhow::Result<ToolResult> {
|
||||
let filter = args
|
||||
.get("status")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("all");
|
||||
let filter = args.get("status").and_then(|v| v.as_str()).unwrap_or("all");
|
||||
let jobs = self.storage.list_scheduled_jobs().await?;
|
||||
|
||||
let filtered: Vec<&ScheduledJob> = match filter {
|
||||
@ -397,7 +394,9 @@ impl Tool for CronEnableTool {
|
||||
.map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?;
|
||||
|
||||
let next = next_run_for_schedule(&job.schedule, now_ms());
|
||||
self.storage.set_scheduled_job_enabled(&job_id, true).await?;
|
||||
self.storage
|
||||
.set_scheduled_job_enabled(&job_id, true)
|
||||
.await?;
|
||||
if let Some(n) = next {
|
||||
self.storage.set_scheduled_job_next_run(&job_id, n).await?;
|
||||
}
|
||||
@ -464,7 +463,9 @@ impl Tool for CronDisableTool {
|
||||
.get_scheduled_job(&job_id)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?;
|
||||
self.storage.set_scheduled_job_enabled(&job_id, false).await?;
|
||||
self.storage
|
||||
.set_scheduled_job_enabled(&job_id, false)
|
||||
.await?;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
@ -580,7 +581,9 @@ impl Tool for CronUpdateTool {
|
||||
if args.get("schedule").is_some() {
|
||||
let job = self.storage.get_scheduled_job(&job_id).await?;
|
||||
if let Some(next) = next_run_for_schedule(&job.schedule, now_ms()) {
|
||||
self.storage.set_scheduled_job_next_run(&job_id, next).await?;
|
||||
self.storage
|
||||
.set_scheduled_job_next_run(&job_id, next)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
@ -765,9 +768,7 @@ mod tests {
|
||||
let job = ScheduledJob {
|
||||
id: "job-update-tool".into(),
|
||||
name: "old".into(),
|
||||
schedule: Schedule::Every {
|
||||
every_ms: 3600000,
|
||||
},
|
||||
schedule: Schedule::Every { every_ms: 3600000 },
|
||||
prompt: "old prompt".into(),
|
||||
channel: "feishu".into(),
|
||||
chat_id: "oc_1".into(),
|
||||
|
||||
@ -102,7 +102,10 @@ impl Tool for DelegateTool {
|
||||
_ => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Unknown action: {}. Supported: run, check_task, cancel_task, list_tasks", action)),
|
||||
error: Some(format!(
|
||||
"Unknown action: {}. Supported: run, check_task, cancel_task, list_tasks",
|
||||
action
|
||||
)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
@ -115,9 +118,11 @@ impl DelegateTool {
|
||||
.ok_or_else(|| anyhow::anyhow!("missing required parameter: prompt"))?
|
||||
.to_string();
|
||||
|
||||
let allowed_tools: Option<Vec<String>> = args["allowed_tools"]
|
||||
.as_array()
|
||||
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect());
|
||||
let allowed_tools: Option<Vec<String>> = args["allowed_tools"].as_array().map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(|s| s.to_string()))
|
||||
.collect()
|
||||
});
|
||||
|
||||
let max_iterations = args["max_iterations"].as_u64().map(|v| v as usize);
|
||||
let timeout_secs = args["timeout_secs"].as_u64();
|
||||
@ -141,15 +146,21 @@ impl DelegateTool {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("unknown mode: {}. Supported: inline, background, parallel", mode_str)),
|
||||
})
|
||||
error: Some(format!(
|
||||
"unknown mode: {}. Supported: inline, background, parallel",
|
||||
mode_str
|
||||
)),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
match mode {
|
||||
ExecutionMode::Inline => {
|
||||
let config = self.parse_config_from_args(args)?;
|
||||
let result = self.sub_agent_manager.run_inline(config).await
|
||||
let result = self
|
||||
.sub_agent_manager
|
||||
.run_inline(config)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
|
||||
match result.status {
|
||||
@ -177,10 +188,14 @@ impl DelegateTool {
|
||||
}
|
||||
ExecutionMode::Background => {
|
||||
let config = self.parse_config_from_args(args)?;
|
||||
let ctx = crate::agent::sub_agent::get_delegate_context()
|
||||
.map_err(|_| anyhow::anyhow!("delegate context not available: not in an agent worker"))?;
|
||||
let ctx = crate::agent::sub_agent::get_delegate_context().map_err(|_| {
|
||||
anyhow::anyhow!("delegate context not available: not in an agent worker")
|
||||
})?;
|
||||
|
||||
let task_id = self.sub_agent_manager.run_background(config, ctx).await
|
||||
let task_id = self
|
||||
.sub_agent_manager
|
||||
.run_background(config, ctx)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
|
||||
Ok(ToolResult {
|
||||
@ -200,9 +215,12 @@ impl DelegateTool {
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("each parallel task requires 'prompt'"))?
|
||||
.to_string();
|
||||
let allowed_tools: Option<Vec<String>> = task["allowed_tools"]
|
||||
.as_array()
|
||||
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect());
|
||||
let allowed_tools: Option<Vec<String>> =
|
||||
task["allowed_tools"].as_array().map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(|s| s.to_string()))
|
||||
.collect()
|
||||
});
|
||||
|
||||
configs.push(SubAgentConfig {
|
||||
prompt,
|
||||
@ -216,13 +234,18 @@ impl DelegateTool {
|
||||
let has_args_allowed = args["allowed_tools"].as_array().is_some();
|
||||
for c in &mut configs {
|
||||
if c.allowed_tools.is_none() && has_args_allowed {
|
||||
c.allowed_tools = args["allowed_tools"]
|
||||
.as_array()
|
||||
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect());
|
||||
c.allowed_tools = args["allowed_tools"].as_array().map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(|s| s.to_string()))
|
||||
.collect()
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let results = self.sub_agent_manager.run_parallel(configs).await
|
||||
let results = self
|
||||
.sub_agent_manager
|
||||
.run_parallel(configs)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||
|
||||
let mut output = String::new();
|
||||
@ -243,7 +266,9 @@ impl DelegateTool {
|
||||
}
|
||||
}
|
||||
|
||||
let all_success = results.iter().all(|r| matches!(r.status, TaskStatus::Completed));
|
||||
let all_success = results
|
||||
.iter()
|
||||
.all(|r| matches!(r.status, TaskStatus::Completed));
|
||||
Ok(ToolResult {
|
||||
success: all_success,
|
||||
output: output.trim().to_string(),
|
||||
|
||||
@ -243,8 +243,8 @@ impl Tool for FileEditTool {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::NamedTempFile;
|
||||
use std::io::Write;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_edit_simple() {
|
||||
|
||||
@ -181,10 +181,7 @@ impl Tool for FileReadTool {
|
||||
}
|
||||
result = lines[..end_idx].join("\n");
|
||||
let truncated = original_len - result.len();
|
||||
result.push_str(&format!(
|
||||
"\n\n... ({} chars truncated) ...",
|
||||
truncated
|
||||
));
|
||||
result.push_str(&format!("\n\n... ({} chars truncated) ...", truncated));
|
||||
}
|
||||
|
||||
if end < total {
|
||||
@ -196,10 +193,7 @@ impl Tool for FileReadTool {
|
||||
end + 1
|
||||
));
|
||||
} else {
|
||||
result.push_str(&format!(
|
||||
"\n\n(End of file — {} lines total)",
|
||||
total
|
||||
));
|
||||
result.push_str(&format!("\n\n(End of file — {} lines total)", total));
|
||||
}
|
||||
|
||||
if let Some(label) = encoding_label {
|
||||
@ -214,7 +208,7 @@ impl Tool for FileReadTool {
|
||||
}
|
||||
None => {
|
||||
// Truly binary file — base64 encode
|
||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||
use base64::{Engine, engine::general_purpose::STANDARD};
|
||||
let encoded = STANDARD.encode(&bytes);
|
||||
let mime = mime_guess::from_path(&resolved)
|
||||
.first_or_octet_stream()
|
||||
@ -278,8 +272,8 @@ fn decode_text(bytes: &[u8]) -> (Option<String>, Option<&'static str>) {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::NamedTempFile;
|
||||
use std::io::Write;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_read_simple_file() {
|
||||
@ -338,10 +332,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_is_directory() {
|
||||
let tool = FileReadTool::new();
|
||||
let result = tool
|
||||
.execute(json!({ "path": "." }))
|
||||
.await
|
||||
.unwrap();
|
||||
let result = tool.execute(json!({ "path": "." })).await.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("Not a file"));
|
||||
|
||||
@ -101,17 +101,29 @@ impl Tool for FileSearchTool {
|
||||
};
|
||||
|
||||
let dir = self.resolve_dir(args.get("dir").and_then(|v| v.as_str()));
|
||||
let case_sensitive = args.get("case_sensitive").and_then(|v| v.as_bool()).unwrap_or(true);
|
||||
let max_results = args.get("max_results").and_then(|v| v.as_u64()).unwrap_or(MAX_RESULTS as u64) as usize;
|
||||
let case_sensitive = args
|
||||
.get("case_sensitive")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(true);
|
||||
let max_results = args
|
||||
.get("max_results")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(MAX_RESULTS as u64) as usize;
|
||||
|
||||
let result = self.run_search(pattern, &dir, case_sensitive, max_results).await;
|
||||
let result = self
|
||||
.run_search(pattern, &dir, case_sensitive, max_results)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(lines) => {
|
||||
let count = lines.len();
|
||||
let mut output = self.truncate_output(&lines);
|
||||
output.push_str(&format!("\n\n---\n共 {} 个文件", count));
|
||||
Ok(ToolResult { success: true, output, error: None })
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
@ -139,9 +151,12 @@ impl FileSearchTool {
|
||||
};
|
||||
|
||||
if !fd_cmd.is_empty() {
|
||||
match self.search_with_fd(pattern, dir, case_sensitive, max_results, fd_cmd).await {
|
||||
match self
|
||||
.search_with_fd(pattern, dir, case_sensitive, max_results, fd_cmd)
|
||||
.await
|
||||
{
|
||||
Ok(lines) if !lines.is_empty() => return Ok(lines),
|
||||
Ok(_) => {},
|
||||
Ok(_) => {}
|
||||
Err(e) => tracing::warn!("{} failed: {}, falling back", fd_cmd, e),
|
||||
}
|
||||
}
|
||||
@ -149,13 +164,14 @@ impl FileSearchTool {
|
||||
if which::which("find").is_ok() {
|
||||
match self.search_with_find(pattern, dir, max_results).await {
|
||||
Ok(lines) if !lines.is_empty() => return Ok(lines),
|
||||
Ok(_) => {},
|
||||
Ok(_) => {}
|
||||
Err(e) => tracing::warn!("find failed: {}, falling back", e),
|
||||
}
|
||||
}
|
||||
|
||||
tracing::warn!("No fd/find available, using built-in file search (slower)");
|
||||
self.search_with_rust(pattern, dir, case_sensitive, max_results).await
|
||||
self.search_with_rust(pattern, dir, case_sensitive, max_results)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn search_with_fd(
|
||||
@ -167,11 +183,15 @@ impl FileSearchTool {
|
||||
fd_cmd: &str,
|
||||
) -> anyhow::Result<Vec<String>> {
|
||||
let mut cmd = Command::new(fd_cmd);
|
||||
cmd.arg("--search-path").arg(dir)
|
||||
.arg("--glob").arg(pattern)
|
||||
.arg("--color").arg("never")
|
||||
cmd.arg("--search-path")
|
||||
.arg(dir)
|
||||
.arg("--glob")
|
||||
.arg(pattern)
|
||||
.arg("--color")
|
||||
.arg("never")
|
||||
.arg("--strip-cwd-prefix")
|
||||
.arg("--max-results").arg(max_results.to_string())
|
||||
.arg("--max-results")
|
||||
.arg(max_results.to_string())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped());
|
||||
|
||||
@ -179,10 +199,7 @@ impl FileSearchTool {
|
||||
cmd.arg("--ignore-case");
|
||||
}
|
||||
|
||||
let output = timeout(
|
||||
std::time::Duration::from_secs(TIMEOUT_SECS),
|
||||
cmd.output(),
|
||||
)
|
||||
let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("fd timed out after {}s", TIMEOUT_SECS))??;
|
||||
|
||||
@ -192,7 +209,8 @@ impl FileSearchTool {
|
||||
}
|
||||
|
||||
let text = String::from_utf8_lossy(&output.stdout);
|
||||
let lines: Vec<String> = text.lines()
|
||||
let lines: Vec<String> = text
|
||||
.lines()
|
||||
.filter(|l| !l.is_empty())
|
||||
.map(|l| l.to_string())
|
||||
.collect();
|
||||
@ -215,15 +233,13 @@ impl FileSearchTool {
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::null());
|
||||
|
||||
let output = timeout(
|
||||
std::time::Duration::from_secs(TIMEOUT_SECS),
|
||||
cmd.output(),
|
||||
)
|
||||
let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("find timed out after {}s", TIMEOUT_SECS))??;
|
||||
|
||||
let text = String::from_utf8_lossy(&output.stdout);
|
||||
let mut lines: Vec<String> = text.lines()
|
||||
let mut lines: Vec<String> = text
|
||||
.lines()
|
||||
.filter(|l| !l.is_empty())
|
||||
.map(|l| {
|
||||
let p = Path::new(l);
|
||||
@ -254,7 +270,13 @@ impl FileSearchTool {
|
||||
.map_err(|e| anyhow::anyhow!("Invalid glob pattern '{}': {}", pattern, e))?;
|
||||
|
||||
let mut results = Vec::new();
|
||||
walk_dir(Path::new(dir), Path::new(dir), &re, &mut results, max_results)?;
|
||||
walk_dir(
|
||||
Path::new(dir),
|
||||
Path::new(dir),
|
||||
&re,
|
||||
&mut results,
|
||||
max_results,
|
||||
)?;
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
@ -311,13 +333,16 @@ fn walk_dir(
|
||||
|
||||
if path.is_dir() {
|
||||
if let Some(name) = rel.file_name().and_then(|n| n.to_str())
|
||||
&& name.starts_with('.') && name.len() > 1 {
|
||||
&& name.starts_with('.')
|
||||
&& name.len() > 1
|
||||
{
|
||||
continue;
|
||||
}
|
||||
walk_dir(base, &path, re, results, max)?;
|
||||
} else if path.is_file() {
|
||||
if let Some(name) = rel.file_name().and_then(|n| n.to_str())
|
||||
&& re.is_match(name) {
|
||||
&& re.is_match(name)
|
||||
{
|
||||
results.push(rel.to_string_lossy().to_string());
|
||||
}
|
||||
if results.len() >= max {
|
||||
|
||||
@ -90,7 +90,8 @@ impl Tool for FileWriteTool {
|
||||
// Create parent directories if needed
|
||||
if let Some(parent) = resolved.parent()
|
||||
&& !parent.exists()
|
||||
&& let Err(e) = std::fs::create_dir_all(parent) {
|
||||
&& let Err(e) = std::fs::create_dir_all(parent)
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
@ -168,10 +169,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_write_missing_path() {
|
||||
let tool = FileWriteTool::new();
|
||||
let result = tool
|
||||
.execute(json!({ "content": "Hello" }))
|
||||
.await
|
||||
.unwrap();
|
||||
let result = tool.execute(json!({ "content": "Hello" })).await.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("path"));
|
||||
|
||||
@ -129,7 +129,9 @@ impl GetSkillTool {
|
||||
let mut output = format!("可用 skill (共 {} 个):\n", skills.len());
|
||||
for s in &skills {
|
||||
let always_mark = if s.always { " [常驻]" } else { "" };
|
||||
let path_str = s.path.as_ref()
|
||||
let path_str = s
|
||||
.path
|
||||
.as_ref()
|
||||
.map(|p| p.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|| "—".to_string());
|
||||
output.push_str(&format!(
|
||||
@ -148,10 +150,10 @@ impl GetSkillTool {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::tempdir;
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_existing_skill() {
|
||||
|
||||
@ -50,10 +50,7 @@ impl HttpRequestTool {
|
||||
}
|
||||
|
||||
if !host_matches_allowlist(&host, &self.allowed_domains) {
|
||||
return Err(format!(
|
||||
"Host '{}' is not in allowed_domains",
|
||||
host
|
||||
));
|
||||
return Err(format!("Host '{}' is not in allowed_domains", host));
|
||||
}
|
||||
|
||||
Ok(url.to_string())
|
||||
@ -80,8 +77,7 @@ impl HttpRequestTool {
|
||||
for (key, value) in obj {
|
||||
if let Some(str_val) = value.as_str()
|
||||
&& let Ok(name) = reqwest::header::HeaderName::from_bytes(key.as_bytes())
|
||||
&& let Ok(val) =
|
||||
reqwest::header::HeaderValue::from_str(str_val)
|
||||
&& let Ok(val) = reqwest::header::HeaderValue::from_str(str_val)
|
||||
{
|
||||
header_map.insert(name, val);
|
||||
}
|
||||
@ -191,7 +187,9 @@ fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool {
|
||||
|
||||
allowed_domains.iter().any(|domain| {
|
||||
host == domain
|
||||
|| host.strip_suffix(domain).is_some_and(|prefix| prefix.ends_with('.'))
|
||||
|| host
|
||||
.strip_suffix(domain)
|
||||
.is_some_and(|prefix| prefix.ends_with('.'))
|
||||
})
|
||||
}
|
||||
|
||||
@ -202,7 +200,11 @@ fn is_private_host(host: &str) -> bool {
|
||||
}
|
||||
|
||||
// Check .local TLD
|
||||
if host.rsplit('.').next().is_some_and(|label| label == "local") {
|
||||
if host
|
||||
.rsplit('.')
|
||||
.next()
|
||||
.is_some_and(|label| label == "local")
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -224,9 +226,7 @@ fn is_private_ip(ip: &std::net::IpAddr) -> bool {
|
||||
|| v4.is_broadcast()
|
||||
|| v4.is_multicast()
|
||||
}
|
||||
std::net::IpAddr::V6(v6) => {
|
||||
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
|
||||
}
|
||||
std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -278,10 +278,7 @@ impl Tool for HttpRequestTool {
|
||||
}
|
||||
};
|
||||
|
||||
let method_str = args
|
||||
.get("method")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("GET");
|
||||
let method_str = args.get("method").and_then(|v| v.as_str()).unwrap_or("GET");
|
||||
|
||||
let headers_val = args.get("headers").cloned().unwrap_or(json!({}));
|
||||
let body = args.get("body").and_then(|v| v.as_str());
|
||||
|
||||
@ -151,10 +151,19 @@ impl Tool for MemoryRecallTool {
|
||||
.and_then(|v| v.as_i64())
|
||||
.unwrap_or(chrono::Utc::now().timestamp_millis());
|
||||
self.memory
|
||||
.recall_by_time(since, until, Some(query), limit, Some(MemoryCategory::Knowledge), None)
|
||||
.recall_by_time(
|
||||
since,
|
||||
until,
|
||||
Some(query),
|
||||
limit,
|
||||
Some(MemoryCategory::Knowledge),
|
||||
None,
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
self.memory.recall(query, limit, Some(MemoryCategory::Knowledge), None).await?
|
||||
self.memory
|
||||
.recall(query, limit, Some(MemoryCategory::Knowledge), None)
|
||||
.await?
|
||||
};
|
||||
|
||||
if entries.is_empty() {
|
||||
@ -168,7 +177,11 @@ impl Tool for MemoryRecallTool {
|
||||
let formatted = entries
|
||||
.iter()
|
||||
.map(|e| {
|
||||
let session = e.session_id.as_deref().map(|s| format!(" [session: {}]", s)).unwrap_or_default();
|
||||
let session = e
|
||||
.session_id
|
||||
.as_deref()
|
||||
.map(|s| format!(" [session: {}]", s))
|
||||
.unwrap_or_default();
|
||||
format!(
|
||||
"- {} [{}]{} [importance: {:.1}]: {}",
|
||||
e.key,
|
||||
@ -264,10 +277,19 @@ impl Tool for TimelineRecallTool {
|
||||
.and_then(|v| v.as_i64())
|
||||
.unwrap_or(chrono::Utc::now().timestamp_millis());
|
||||
self.memory
|
||||
.recall_by_time(since, until, Some(query), limit, Some(MemoryCategory::Timeline), session_id)
|
||||
.recall_by_time(
|
||||
since,
|
||||
until,
|
||||
Some(query),
|
||||
limit,
|
||||
Some(MemoryCategory::Timeline),
|
||||
session_id,
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
self.memory.recall(query, limit, Some(MemoryCategory::Timeline), session_id).await?
|
||||
self.memory
|
||||
.recall(query, limit, Some(MemoryCategory::Timeline), session_id)
|
||||
.await?
|
||||
};
|
||||
|
||||
if entries.is_empty() {
|
||||
@ -281,7 +303,11 @@ impl Tool for TimelineRecallTool {
|
||||
let formatted = entries
|
||||
.iter()
|
||||
.map(|e| {
|
||||
let session = e.session_id.as_deref().map(|s| format!(" [session: {}]", s)).unwrap_or_default();
|
||||
let session = e
|
||||
.session_id
|
||||
.as_deref()
|
||||
.map(|s| format!(" [session: {}]", s))
|
||||
.unwrap_or_default();
|
||||
format!(
|
||||
"- {} [{}]{} [importance: {:.1}]: {}",
|
||||
e.key,
|
||||
|
||||
@ -37,11 +37,11 @@ pub use send_message::SendMessageTool;
|
||||
pub use traits::{OutboundMessenger, Tool, ToolResult};
|
||||
pub use web_fetch::WebFetchTool;
|
||||
|
||||
use std::sync::Arc;
|
||||
use crate::agent::SubAgentManager;
|
||||
use crate::config::BrowserConfig;
|
||||
use crate::memory::MemoryManager;
|
||||
use crate::skills::SkillsLoader;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Create the base tool registry (without send_message).
|
||||
/// `send_message` tool is registered later via `SessionManager::register_outbound_tool()`
|
||||
|
||||
@ -17,7 +17,10 @@ impl ToolRegistry {
|
||||
}
|
||||
|
||||
pub fn register<T: ToolTrait + 'static>(&self, tool: T) {
|
||||
self.tools.lock().unwrap().insert(tool.name().to_string(), Arc::new(tool));
|
||||
self.tools
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(tool.name().to_string(), Arc::new(tool));
|
||||
}
|
||||
|
||||
/// Register an existing Arc-wrapped tool by name
|
||||
|
||||
@ -115,7 +115,9 @@ impl SchemaCleanr {
|
||||
}
|
||||
|
||||
if let Some(Value::String(t)) = obj.get("type")
|
||||
&& t == "object" && !obj.contains_key("properties") {
|
||||
&& t == "object"
|
||||
&& !obj.contains_key("properties")
|
||||
{
|
||||
tracing::warn!("Object schema without 'properties' field may cause issues");
|
||||
}
|
||||
|
||||
@ -173,7 +175,8 @@ impl SchemaCleanr {
|
||||
|
||||
// Handle anyOf/oneOf simplification
|
||||
if (obj.contains_key("anyOf") || obj.contains_key("oneOf"))
|
||||
&& let Some(simplified) = Self::try_simplify_union(&obj, defs, strategy, ref_stack) {
|
||||
&& let Some(simplified) = Self::try_simplify_union(&obj, defs, strategy, ref_stack)
|
||||
{
|
||||
return simplified;
|
||||
}
|
||||
|
||||
@ -243,7 +246,8 @@ impl SchemaCleanr {
|
||||
}
|
||||
|
||||
if let Some(def_name) = Self::parse_local_ref(ref_value)
|
||||
&& let Some(definition) = defs.get(def_name.as_str()) {
|
||||
&& let Some(definition) = defs.get(def_name.as_str())
|
||||
{
|
||||
ref_stack.insert(ref_value.to_string());
|
||||
let cleaned = Self::clean_with_defs(definition.clone(), defs, strategy, ref_stack);
|
||||
ref_stack.remove(ref_value);
|
||||
@ -340,11 +344,14 @@ impl SchemaCleanr {
|
||||
return true;
|
||||
}
|
||||
if let Some(Value::Array(arr)) = obj.get("enum")
|
||||
&& arr.len() == 1 && matches!(arr[0], Value::Null) {
|
||||
&& arr.len() == 1
|
||||
&& matches!(arr[0], Value::Null)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
if let Some(Value::String(t)) = obj.get("type")
|
||||
&& t == "null" {
|
||||
&& t == "null"
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@ -403,7 +410,10 @@ impl SchemaCleanr {
|
||||
|
||||
match non_null.len() {
|
||||
0 => Value::String("null".to_string()),
|
||||
1 => non_null.into_iter().next().unwrap_or(Value::String("null".to_string())),
|
||||
1 => non_null
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap_or(Value::String("null".to_string())),
|
||||
_ => Value::Array(non_null),
|
||||
}
|
||||
} else {
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
use std::sync::Arc;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use mime_guess::mime;
|
||||
@ -31,14 +31,20 @@ fn parse_target_chat_id(raw: &str) -> Result<(&str, &str, Option<&str>), String>
|
||||
match parts.len() {
|
||||
2 => {
|
||||
if parts[0].is_empty() || parts[1].is_empty() {
|
||||
Err(format!("Invalid target_chat_id format '{}': channel and chat_id must not be empty", raw))
|
||||
Err(format!(
|
||||
"Invalid target_chat_id format '{}': channel and chat_id must not be empty",
|
||||
raw
|
||||
))
|
||||
} else {
|
||||
Ok((parts[0], parts[1], None))
|
||||
}
|
||||
}
|
||||
3 => {
|
||||
if parts[0].is_empty() || parts[1].is_empty() || parts[2].is_empty() {
|
||||
Err(format!("Invalid target_chat_id format '{}': all three parts must not be empty", raw))
|
||||
Err(format!(
|
||||
"Invalid target_chat_id format '{}': all three parts must not be empty",
|
||||
raw
|
||||
))
|
||||
} else {
|
||||
Ok((parts[0], parts[1], Some(parts[2])))
|
||||
}
|
||||
@ -98,8 +104,8 @@ target_chat_id 支持两种格式:<channel>:<chat_id>(发送到该聊天下
|
||||
.ok_or_else(|| anyhow::anyhow!("missing content"))?;
|
||||
|
||||
// 1. Parse target_chat_id
|
||||
let (channel, chat_id, dialog_id) = parse_target_chat_id(raw_id)
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
let (channel, chat_id, dialog_id) =
|
||||
parse_target_chat_id(raw_id).map_err(|e| anyhow::anyhow!(e))?;
|
||||
|
||||
// 2. Validate channel
|
||||
if !self.available_channels.contains(channel) {
|
||||
@ -109,7 +115,11 @@ target_chat_id 支持两种格式:<channel>:<chat_id>(发送到该聊天下
|
||||
error: Some(format!(
|
||||
"Channel '{}' is not available. Available channels: {}",
|
||||
channel,
|
||||
self.available_channels.iter().cloned().collect::<Vec<_>>().join(", ")
|
||||
self.available_channels
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
)),
|
||||
});
|
||||
}
|
||||
@ -129,7 +139,8 @@ target_chat_id 支持两种格式:<channel>:<chat_id>(发送到该聊天下
|
||||
let media = parse_files_arg(&args);
|
||||
|
||||
// 4. Send via messenger
|
||||
match self.messenger
|
||||
match self
|
||||
.messenger
|
||||
.send_message(channel, chat_id, dialog_id, content, source, media)
|
||||
.await
|
||||
{
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
use async_trait::async_trait;
|
||||
use crate::bus::{MediaItem, MessageSource};
|
||||
use async_trait::async_trait;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ToolResult {
|
||||
|
||||
@ -239,7 +239,11 @@ fn is_private_host(host: &str) -> bool {
|
||||
return true;
|
||||
}
|
||||
|
||||
if host.rsplit('.').next().is_some_and(|label| label == "local") {
|
||||
if host
|
||||
.rsplit('.')
|
||||
.next()
|
||||
.is_some_and(|label| label == "local")
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -248,7 +252,9 @@ fn is_private_host(host: &str) -> bool {
|
||||
std::net::IpAddr::V4(v4) => {
|
||||
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified()
|
||||
}
|
||||
std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(),
|
||||
std::net::IpAddr::V6(v6) => {
|
||||
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
use std::collections::HashMap;
|
||||
use picobot::providers::{create_provider, ChatCompletionRequest, Message};
|
||||
use picobot::config::{Config, LLMProviderConfig};
|
||||
use picobot::providers::{ChatCompletionRequest, Message, create_provider};
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn load_config() -> Option<LLMProviderConfig> {
|
||||
dotenv::from_filename("tests/test.env").ok()?;
|
||||
@ -42,8 +42,7 @@ fn create_request(content: &str) -> ChatCompletionRequest {
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_openai_simple_completion() {
|
||||
let config = load_config()
|
||||
.expect("Please configure tests/test.env with valid API keys");
|
||||
let config = load_config().expect("Please configure tests/test.env with valid API keys");
|
||||
|
||||
let provider = create_provider(config).expect("Failed to create provider");
|
||||
let response = provider.chat(create_request("Say 'ok'")).await.unwrap();
|
||||
@ -57,8 +56,7 @@ async fn test_openai_simple_completion() {
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_openai_conversation() {
|
||||
let config = load_config()
|
||||
.expect("Please configure tests/test.env with valid API keys");
|
||||
let config = load_config().expect("Please configure tests/test.env with valid API keys");
|
||||
|
||||
let provider = create_provider(config).expect("Failed to create provider");
|
||||
|
||||
@ -82,7 +80,9 @@ async fn test_openai_conversation() {
|
||||
async fn test_config_load() {
|
||||
// Test that config.json can be loaded and provider config created
|
||||
let config = Config::load("config.json").expect("Failed to load config.json");
|
||||
let provider_config = config.get_provider_config("default").expect("Failed to get provider config");
|
||||
let provider_config = config
|
||||
.get_provider_config("default")
|
||||
.expect("Failed to get provider config");
|
||||
|
||||
assert_eq!(provider_config.provider_type, "openai");
|
||||
assert_eq!(provider_config.name, "aliyun");
|
||||
|
||||
@ -41,7 +41,7 @@ async fn test_scheduler_types_roundtrip() {
|
||||
/// Verify that next_run_for_schedule produces valid future timestamps.
|
||||
#[test]
|
||||
fn test_next_run_always_future() {
|
||||
use picobot::scheduler::{next_run_for_schedule, Schedule};
|
||||
use picobot::scheduler::{Schedule, next_run_for_schedule};
|
||||
|
||||
let now = 1700000000000_i64;
|
||||
|
||||
@ -56,6 +56,10 @@ fn test_next_run_always_future() {
|
||||
for s in &schedules {
|
||||
let next = next_run_for_schedule(s, now);
|
||||
assert!(next.is_some(), "expected next run for {:?}", s);
|
||||
assert!(next.unwrap() > now, "next run should be after now for {:?}", s);
|
||||
assert!(
|
||||
next.unwrap() > now,
|
||||
"next run should be after now for {:?}",
|
||||
s
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
use std::collections::HashMap;
|
||||
use picobot::providers::{create_provider, ChatCompletionRequest, Message, Tool, ToolFunction};
|
||||
use picobot::config::LLMProviderConfig;
|
||||
use picobot::providers::{ChatCompletionRequest, Message, Tool, ToolFunction, create_provider};
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn load_openai_config() -> Option<LLMProviderConfig> {
|
||||
dotenv::from_filename("tests/test.env").ok()?;
|
||||
@ -53,8 +53,7 @@ fn make_weather_tool() -> Tool {
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_openai_tool_call() {
|
||||
let config = load_openai_config()
|
||||
.expect("Please configure tests/test.env with valid API keys");
|
||||
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
|
||||
|
||||
let provider = create_provider(config).expect("Failed to create provider");
|
||||
|
||||
@ -68,7 +67,11 @@ async fn test_openai_tool_call() {
|
||||
let response = provider.chat(request).await.unwrap();
|
||||
|
||||
// Should have tool calls
|
||||
assert!(!response.tool_calls.is_empty(), "Expected tool call, got: {}", response.content);
|
||||
assert!(
|
||||
!response.tool_calls.is_empty(),
|
||||
"Expected tool call, got: {}",
|
||||
response.content
|
||||
);
|
||||
|
||||
let tool_call = &response.tool_calls[0];
|
||||
assert_eq!(tool_call.name, "get_weather");
|
||||
@ -78,8 +81,7 @@ async fn test_openai_tool_call() {
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_openai_tool_call_with_manual_execution() {
|
||||
let config = load_openai_config()
|
||||
.expect("Please configure tests/test.env with valid API keys");
|
||||
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
|
||||
|
||||
let provider = create_provider(config).expect("Failed to create provider");
|
||||
|
||||
@ -92,8 +94,7 @@ async fn test_openai_tool_call_with_manual_execution() {
|
||||
};
|
||||
|
||||
let response1 = provider.chat(request1).await.unwrap();
|
||||
let tool_call = response1.tool_calls.first()
|
||||
.expect("Expected tool call");
|
||||
let tool_call = response1.tool_calls.first().expect("Expected tool call");
|
||||
assert_eq!(tool_call.name, "get_weather");
|
||||
|
||||
// Second request with tool result
|
||||
@ -116,8 +117,7 @@ async fn test_openai_tool_call_with_manual_execution() {
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_openai_no_tool_when_not_provided() {
|
||||
let config = load_openai_config()
|
||||
.expect("Please configure tests/test.env with valid API keys");
|
||||
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
|
||||
|
||||
let provider = create_provider(config).expect("Failed to create provider");
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user