Format codebase with rustfmt
This commit is contained in:
parent
c6f4392e63
commit
8f4ee79d8d
@ -4,10 +4,8 @@ use crate::agent::system_prompt::build_system_prompt;
|
|||||||
use crate::bus::message::ContentBlock;
|
use crate::bus::message::ContentBlock;
|
||||||
use crate::bus::{ChatMessage, MediaRef};
|
use crate::bus::{ChatMessage, MediaRef};
|
||||||
use crate::config::LLMProviderConfig;
|
use crate::config::LLMProviderConfig;
|
||||||
use crate::observability::{
|
use crate::observability::{Observer, ObserverEvent, ToolExecutionOutcome, truncate_args};
|
||||||
truncate_args, Observer, ObserverEvent, ToolExecutionOutcome,
|
use crate::providers::{ChatCompletionRequest, LLMProvider, Message, ToolCall, create_provider};
|
||||||
};
|
|
||||||
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall};
|
|
||||||
use crate::tools::ToolRegistry;
|
use crate::tools::ToolRegistry;
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
use std::hash::{Hash, Hasher};
|
use std::hash::{Hash, Hasher};
|
||||||
@ -256,7 +254,10 @@ impl AgentLoop {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Create a new AgentLoop with provider created from config and given tools.
|
/// Create a new AgentLoop with provider created from config and given tools.
|
||||||
pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>) -> Result<Self, AgentError> {
|
pub fn with_tools(
|
||||||
|
provider_config: LLMProviderConfig,
|
||||||
|
tools: Arc<ToolRegistry>,
|
||||||
|
) -> Result<Self, AgentError> {
|
||||||
let max_iterations = provider_config.max_tool_iterations;
|
let max_iterations = provider_config.max_tool_iterations;
|
||||||
let model_name = provider_config.model_id.clone();
|
let model_name = provider_config.model_id.clone();
|
||||||
let workspace_dir = provider_config.workspace_dir.clone();
|
let workspace_dir = provider_config.workspace_dir.clone();
|
||||||
@ -279,7 +280,13 @@ impl AgentLoop {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Create a new AgentLoop with an existing shared provider.
|
/// Create a new AgentLoop with an existing shared provider.
|
||||||
pub fn with_provider(provider: Arc<dyn LLMProvider>, max_iterations: usize, model_name: String, workspace_dir: PathBuf, input_types: Vec<String>) -> Self {
|
pub fn with_provider(
|
||||||
|
provider: Arc<dyn LLMProvider>,
|
||||||
|
max_iterations: usize,
|
||||||
|
model_name: String,
|
||||||
|
workspace_dir: PathBuf,
|
||||||
|
input_types: Vec<String>,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
provider,
|
provider,
|
||||||
tools: Arc::new(ToolRegistry::new()),
|
tools: Arc::new(ToolRegistry::new()),
|
||||||
@ -379,7 +386,12 @@ impl AgentLoop {
|
|||||||
let content = if m.media_refs.is_empty() {
|
let content = if m.media_refs.is_empty() {
|
||||||
vec![ContentBlock::text(&m.content)]
|
vec![ContentBlock::text(&m.content)]
|
||||||
} else {
|
} else {
|
||||||
build_content_blocks(&m.content, &m.media_refs, &self.input_types, &self.media_registry)
|
build_content_blocks(
|
||||||
|
&m.content,
|
||||||
|
&m.media_refs,
|
||||||
|
&self.input_types,
|
||||||
|
&self.media_registry,
|
||||||
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
Message {
|
Message {
|
||||||
@ -399,14 +411,28 @@ impl AgentLoop {
|
|||||||
/// it loops back to the LLM with the tool results until either:
|
/// it loops back to the LLM with the tool results until either:
|
||||||
/// - The LLM returns no more tool calls (final response)
|
/// - The LLM returns no more tool calls (final response)
|
||||||
/// - Maximum iterations are reached
|
/// - Maximum iterations are reached
|
||||||
pub async fn process(&self, mut messages: Vec<ChatMessage>) -> Result<AgentProcessResult, AgentError> {
|
pub async fn process(
|
||||||
|
&self,
|
||||||
|
mut messages: Vec<ChatMessage>,
|
||||||
|
) -> Result<AgentProcessResult, AgentError> {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process");
|
tracing::debug!(
|
||||||
|
history_len = messages.len(),
|
||||||
|
max_iterations = self.max_iterations,
|
||||||
|
"Starting agent process"
|
||||||
|
);
|
||||||
|
|
||||||
// Build and inject system prompt if not present
|
// Build and inject system prompt if not present
|
||||||
let has_system = messages.first().is_some_and(|m| m.role == "system");
|
let has_system = messages.first().is_some_and(|m| m.role == "system");
|
||||||
if !has_system {
|
if !has_system {
|
||||||
let system_prompt = build_system_prompt(&self.workspace_dir, &self.model_name, &self.tools, None, None, false);
|
let system_prompt = build_system_prompt(
|
||||||
|
&self.workspace_dir,
|
||||||
|
&self.model_name,
|
||||||
|
&self.tools,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
false,
|
||||||
|
);
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!("System prompt injected:\n{}", system_prompt);
|
tracing::debug!("System prompt injected:\n{}", system_prompt);
|
||||||
messages.insert(0, ChatMessage::system(system_prompt));
|
messages.insert(0, ChatMessage::system(system_prompt));
|
||||||
@ -427,9 +453,7 @@ impl AgentLoop {
|
|||||||
let estimated = estimate_tokens(&messages);
|
let estimated = estimate_tokens(&messages);
|
||||||
let danger = (self.context_window as f64 * 0.8) as usize;
|
let danger = (self.context_window as f64 * 0.8) as usize;
|
||||||
if estimated > danger {
|
if estimated > danger {
|
||||||
let trimmed = self.preemptive_trim_old_tool_results(
|
let trimmed = self.preemptive_trim_old_tool_results(&mut messages, 2000, 4);
|
||||||
&mut messages, 2000, 4,
|
|
||||||
);
|
|
||||||
if trimmed > 0 {
|
if trimmed > 0 {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
@ -463,8 +487,7 @@ impl AgentLoop {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Call LLM
|
// Call LLM
|
||||||
let response = (*self.provider).chat(request).await
|
let response = (*self.provider).chat(request).await.map_err(|e| {
|
||||||
.map_err(|e| {
|
|
||||||
tracing::error!(error = %e, "LLM request failed");
|
tracing::error!(error = %e, "LLM request failed");
|
||||||
AgentError::LlmError(e.to_string())
|
AgentError::LlmError(e.to_string())
|
||||||
})?;
|
})?;
|
||||||
@ -493,7 +516,9 @@ impl AgentLoop {
|
|||||||
|
|
||||||
// Execute tool calls — log and notify immediately
|
// Execute tool calls — log and notify immediately
|
||||||
{
|
{
|
||||||
let tools_info: Vec<String> = response.tool_calls.iter()
|
let tools_info: Vec<String> = response
|
||||||
|
.tool_calls
|
||||||
|
.iter()
|
||||||
.map(|tc| {
|
.map(|tc| {
|
||||||
let args = serde_json::to_string(&tc.arguments).unwrap_or_default();
|
let args = serde_json::to_string(&tc.arguments).unwrap_or_default();
|
||||||
let s = format!("{}:{}", tc.name, args);
|
let s = format!("{}:{}", tc.name, args);
|
||||||
@ -522,7 +547,9 @@ impl AgentLoop {
|
|||||||
// Log function call with name and arguments
|
// Log function call with name and arguments
|
||||||
let args_str = match &tool_call.arguments {
|
let args_str = match &tool_call.arguments {
|
||||||
serde_json::Value::Object(obj) if obj.is_empty() => "{}".to_string(),
|
serde_json::Value::Object(obj) if obj.is_empty() => "{}".to_string(),
|
||||||
other => serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()),
|
other => {
|
||||||
|
serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string())
|
||||||
|
}
|
||||||
};
|
};
|
||||||
tracing::info!(tool = %tool_call.name, args = %args_str, "Calling tool");
|
tracing::info!(tool = %tool_call.name, args = %args_str, "Calling tool");
|
||||||
|
|
||||||
@ -562,7 +589,11 @@ impl AgentLoop {
|
|||||||
|
|
||||||
// Loop continues to next iteration with updated messages
|
// Loop continues to next iteration with updated messages
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(iteration, message_count = messages.len(), "Tool execution complete, continuing to next iteration");
|
tracing::debug!(
|
||||||
|
iteration,
|
||||||
|
message_count = messages.len(),
|
||||||
|
"Tool execution complete, continuing to next iteration"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Max iterations reached - ask LLM for a summary based on completed work
|
// Max iterations reached - ask LLM for a summary based on completed work
|
||||||
@ -571,7 +602,7 @@ impl AgentLoop {
|
|||||||
// Add a message asking for summary
|
// Add a message asking for summary
|
||||||
let summary_request = ChatMessage::user(
|
let summary_request = ChatMessage::user(
|
||||||
"You have reached the maximum number of tool call iterations. \
|
"You have reached the maximum number of tool call iterations. \
|
||||||
Please provide your best answer based on the work completed so far."
|
Please provide your best answer based on the work completed so far.",
|
||||||
);
|
);
|
||||||
messages.push(summary_request);
|
messages.push(summary_request);
|
||||||
|
|
||||||
@ -603,14 +634,19 @@ impl AgentLoop {
|
|||||||
Err(e) => {
|
Err(e) => {
|
||||||
// Fallback if summary call fails
|
// Fallback if summary call fails
|
||||||
tracing::error!(error = %e, "Failed to get summary from LLM");
|
tracing::error!(error = %e, "Failed to get summary from LLM");
|
||||||
let final_message = ChatMessage::assistant(
|
let final_message = ChatMessage::assistant(format!(
|
||||||
format!("I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.", self.max_iterations)
|
"I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.",
|
||||||
);
|
self.max_iterations
|
||||||
|
));
|
||||||
emitted_messages.push(final_message.clone());
|
emitted_messages.push(final_message.clone());
|
||||||
Ok(AgentProcessResult {
|
Ok(AgentProcessResult {
|
||||||
final_response: final_message,
|
final_response: final_message,
|
||||||
emitted_messages,
|
emitted_messages,
|
||||||
total_tokens: if accumulated_tokens > 0 { Some(accumulated_tokens) } else { None },
|
total_tokens: if accumulated_tokens > 0 {
|
||||||
|
Some(accumulated_tokens)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -698,10 +734,7 @@ impl AgentLoop {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Apply duration
|
// Apply duration
|
||||||
ToolExecutionOutcome {
|
ToolExecutionOutcome { duration, ..result }
|
||||||
duration,
|
|
||||||
..result
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Internal tool execution without event tracking.
|
/// Internal tool execution without event tracking.
|
||||||
@ -723,18 +756,12 @@ impl AgentLoop {
|
|||||||
ToolExecutionOutcome::success(result.output)
|
ToolExecutionOutcome::success(result.output)
|
||||||
} else {
|
} else {
|
||||||
let error = result.error.unwrap_or_default();
|
let error = result.error.unwrap_or_default();
|
||||||
ToolExecutionOutcome::failure(
|
ToolExecutionOutcome::failure(format!("Error: {}", error), Some(error))
|
||||||
format!("Error: {}", error),
|
|
||||||
Some(error),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!(tool = %tool_call.name, error = %e, "Tool execution failed");
|
tracing::error!(tool = %tool_call.name, error = %e, "Tool execution failed");
|
||||||
ToolExecutionOutcome::failure(
|
ToolExecutionOutcome::failure(format!("Error: {}", e), Some(e.to_string()))
|
||||||
format!("Error: {}", e),
|
|
||||||
Some(e.to_string()),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -822,8 +849,14 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(provider_message.role, "assistant");
|
assert_eq!(provider_message.role, "assistant");
|
||||||
assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1);
|
assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1);
|
||||||
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].id, "call_1");
|
assert_eq!(
|
||||||
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].name, "calculator");
|
provider_message.tool_calls.as_ref().unwrap()[0].id,
|
||||||
|
"call_1"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
provider_message.tool_calls.as_ref().unwrap()[0].name,
|
||||||
|
"calculator"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -234,27 +234,33 @@ impl ContextCompressor {
|
|||||||
}
|
}
|
||||||
} else if messages[i].role == "tool"
|
} else if messages[i].role == "tool"
|
||||||
&& let Some(ref tid) = messages[i].tool_call_id
|
&& let Some(ref tid) = messages[i].tool_call_id
|
||||||
&& !declared.contains(tid.as_str()) {
|
&& !declared.contains(tid.as_str())
|
||||||
|
{
|
||||||
messages.remove(i);
|
messages.remove(i);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
i += 1;
|
i += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
let broken: Vec<usize> = messages.iter().enumerate()
|
let broken: Vec<usize> = messages
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
.filter_map(|(idx, msg)| {
|
.filter_map(|(idx, msg)| {
|
||||||
if msg.role == "assistant"
|
if msg.role == "assistant"
|
||||||
&& let Some(ref tcs) = msg.tool_calls
|
&& let Some(ref tcs) = msg.tool_calls
|
||||||
&& !tcs.is_empty() {
|
&& !tcs.is_empty()
|
||||||
|
{
|
||||||
let all_present = tcs.iter().all(|tc| {
|
let all_present = tcs.iter().all(|tc| {
|
||||||
messages.iter().any(|m| {
|
messages.iter().any(|m| {
|
||||||
m.role == "tool"
|
m.role == "tool" && m.tool_call_id.as_deref() == Some(tc.id.as_str())
|
||||||
&& m.tool_call_id.as_deref() == Some(tc.id.as_str())
|
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
if !all_present { Some(idx) } else { None }
|
if !all_present { Some(idx) } else { None }
|
||||||
} else { None }
|
} else {
|
||||||
}).collect();
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
for idx in broken {
|
for idx in broken {
|
||||||
let msg = &mut messages[idx];
|
let msg = &mut messages[idx];
|
||||||
@ -262,7 +268,8 @@ impl ContextCompressor {
|
|||||||
let names: Vec<&str> = tcs.iter().map(|tc| tc.name.as_str()).collect();
|
let names: Vec<&str> = tcs.iter().map(|tc| tc.name.as_str()).collect();
|
||||||
msg.content = format!(
|
msg.content = format!(
|
||||||
"{}\n\n[Tool calls ({}) — results are no longer available]",
|
"{}\n\n[Tool calls ({}) — results are no longer available]",
|
||||||
msg.content, names.join(", ")
|
msg.content,
|
||||||
|
names.join(", ")
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -275,7 +282,10 @@ impl ContextCompressor {
|
|||||||
// Check if compression is needed
|
// Check if compression is needed
|
||||||
let tokens = self.token_estimate_with_history(&history);
|
let tokens = self.token_estimate_with_history(&history);
|
||||||
if tokens <= self.threshold() {
|
if tokens <= self.threshold() {
|
||||||
return Ok(CompressionResult { history, created_timelines: false });
|
return Ok(CompressionResult {
|
||||||
|
history,
|
||||||
|
created_timelines: false,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
@ -299,7 +309,10 @@ impl ContextCompressor {
|
|||||||
}
|
}
|
||||||
if tokens_after <= self.threshold() {
|
if tokens_after <= self.threshold() {
|
||||||
self.invalidate_token_cache();
|
self.invalidate_token_cache();
|
||||||
return Ok(CompressionResult { history, created_timelines: false });
|
return Ok(CompressionResult {
|
||||||
|
history,
|
||||||
|
created_timelines: false,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// LLM summarization pass
|
// LLM summarization pass
|
||||||
@ -312,11 +325,7 @@ impl ContextCompressor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(
|
tracing::debug!(pass = pass + 1, tokens = tokens, "Compression pass");
|
||||||
pass = pass + 1,
|
|
||||||
tokens = tokens,
|
|
||||||
"Compression pass"
|
|
||||||
);
|
|
||||||
|
|
||||||
match self.compress_once(¤t_history).await {
|
match self.compress_once(¤t_history).await {
|
||||||
Ok(Some(compressed)) => {
|
Ok(Some(compressed)) => {
|
||||||
@ -352,18 +361,24 @@ impl ContextCompressor {
|
|||||||
let m = ¤t_history[scan];
|
let m = ¤t_history[scan];
|
||||||
if m.role == "assistant" {
|
if m.role == "assistant" {
|
||||||
if let Some(tcs) = &m.tool_calls
|
if let Some(tcs) = &m.tool_calls
|
||||||
&& !tcs.is_empty() {
|
&& !tcs.is_empty()
|
||||||
|
{
|
||||||
let has_post = current_history[scan + 1..]
|
let has_post = current_history[scan + 1..]
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|r| r.role == "tool")
|
.filter(|r| r.role == "tool")
|
||||||
.any(|r| tcs.iter().any(|tc| r.tool_call_id.as_deref() == Some(tc.id.as_str())));
|
.any(|r| {
|
||||||
|
tcs.iter()
|
||||||
|
.any(|tc| r.tool_call_id.as_deref() == Some(tc.id.as_str()))
|
||||||
|
});
|
||||||
if has_post {
|
if has_post {
|
||||||
tail_start = scan;
|
tail_start = scan;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if scan == 0 { break; }
|
if scan == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
scan -= 1;
|
scan -= 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -390,11 +405,13 @@ impl ContextCompressor {
|
|||||||
for msg in &mut truncated[..self.config.protect_first_n] {
|
for msg in &mut truncated[..self.config.protect_first_n] {
|
||||||
if msg.role == "assistant" {
|
if msg.role == "assistant" {
|
||||||
if let Some(ref tcs) = msg.tool_calls
|
if let Some(ref tcs) = msg.tool_calls
|
||||||
&& !tcs.is_empty() {
|
&& !tcs.is_empty()
|
||||||
|
{
|
||||||
let names: Vec<&str> = tcs.iter().map(|tc| tc.name.as_str()).collect();
|
let names: Vec<&str> = tcs.iter().map(|tc| tc.name.as_str()).collect();
|
||||||
msg.content = format!(
|
msg.content = format!(
|
||||||
"{}\n\n[Tool calls ({}) — results dropped during truncation]",
|
"{}\n\n[Tool calls ({}) — results dropped during truncation]",
|
||||||
msg.content, names.join(", ")
|
msg.content,
|
||||||
|
names.join(", ")
|
||||||
);
|
);
|
||||||
msg.tool_calls = None;
|
msg.tool_calls = None;
|
||||||
}
|
}
|
||||||
@ -424,7 +441,10 @@ impl ContextCompressor {
|
|||||||
"Context compression completed"
|
"Context compression completed"
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok(CompressionResult { history: current_history, created_timelines })
|
Ok(CompressionResult {
|
||||||
|
history: current_history,
|
||||||
|
created_timelines,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Try to extract the actual context token limit from an LLM error message.
|
/// Try to extract the actual context token limit from an LLM error message.
|
||||||
@ -447,7 +467,8 @@ impl ContextCompressor {
|
|||||||
// Look for a number in the vicinity (up to 10 chars after marker)
|
// Look for a number in the vicinity (up to 10 chars after marker)
|
||||||
if let Some(num_str) = find_number_nearby(after, 50)
|
if let Some(num_str) = find_number_nearby(after, 50)
|
||||||
&& let Ok(n) = num_str.parse::<usize>()
|
&& let Ok(n) = num_str.parse::<usize>()
|
||||||
&& (1024..=10_000_000).contains(&n) {
|
&& (1024..=10_000_000).contains(&n)
|
||||||
|
{
|
||||||
return Some(n);
|
return Some(n);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -509,19 +530,26 @@ impl ContextCompressor {
|
|||||||
|
|
||||||
// Persist compressed summary as timeline memory entry
|
// Persist compressed summary as timeline memory entry
|
||||||
let ts = chrono::Utc::now().format("%Y-%m-%d %H:%M").to_string();
|
let ts = chrono::Utc::now().format("%Y-%m-%d %H:%M").to_string();
|
||||||
let timeline_content = format!("[{}] Compressed {} conversation segments:\n{}",
|
let timeline_content = format!(
|
||||||
ts, between.len(), summary);
|
"[{}] Compressed {} conversation segments:\n{}",
|
||||||
|
ts,
|
||||||
|
between.len(),
|
||||||
|
summary
|
||||||
|
);
|
||||||
let key = format!("ctx_compressed_{}", uuid::Uuid::new_v4());
|
let key = format!("ctx_compressed_{}", uuid::Uuid::new_v4());
|
||||||
let mm = self.memory.clone();
|
let mm = self.memory.clone();
|
||||||
let sid = self.session_id.clone();
|
let sid = self.session_id.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(e) = mm.store(
|
if let Err(e) = mm
|
||||||
|
.store(
|
||||||
&key,
|
&key,
|
||||||
&timeline_content,
|
&timeline_content,
|
||||||
crate::memory::MemoryCategory::Timeline,
|
crate::memory::MemoryCategory::Timeline,
|
||||||
sid.as_deref(),
|
sid.as_deref(),
|
||||||
Some(0.3),
|
Some(0.3),
|
||||||
).await {
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
tracing::warn!(error = %e, "Failed to store compressed context as timeline");
|
tracing::warn!(error = %e, "Failed to store compressed context as timeline");
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@ -552,10 +580,7 @@ impl ContextCompressor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Summarize a segment of messages using LLM.
|
/// Summarize a segment of messages using LLM.
|
||||||
async fn summarize_segment(
|
async fn summarize_segment(&self, messages: &[ChatMessage]) -> Result<String, AgentError> {
|
||||||
&self,
|
|
||||||
messages: &[ChatMessage],
|
|
||||||
) -> Result<String, AgentError> {
|
|
||||||
if messages.is_empty() {
|
if messages.is_empty() {
|
||||||
return Ok(String::new());
|
return Ok(String::new());
|
||||||
}
|
}
|
||||||
@ -569,7 +594,8 @@ impl ContextCompressor {
|
|||||||
"tool" => "Tool",
|
"tool" => "Tool",
|
||||||
_ => m.role.as_str(),
|
_ => m.role.as_str(),
|
||||||
};
|
};
|
||||||
let name = m.tool_name
|
let name = m
|
||||||
|
.tool_name
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map(|n| format!(" ({})", n))
|
.map(|n| format!(" ({})", n))
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
@ -614,7 +640,10 @@ Be concise, aim for {} characters or less.
|
|||||||
);
|
);
|
||||||
|
|
||||||
let request = ChatCompletionRequest {
|
let request = ChatCompletionRequest {
|
||||||
messages: vec![Message::system("You are a helpful assistant."), Message::user(&prompt)],
|
messages: vec![
|
||||||
|
Message::system("You are a helpful assistant."),
|
||||||
|
Message::user(&prompt),
|
||||||
|
],
|
||||||
temperature: Some(0.3),
|
temperature: Some(0.3),
|
||||||
max_tokens: Some(1000),
|
max_tokens: Some(1000),
|
||||||
tools: None,
|
tools: None,
|
||||||
@ -686,13 +715,23 @@ mod tests {
|
|||||||
content: "[summarized]".into(),
|
content: "[summarized]".into(),
|
||||||
reasoning_content: None,
|
reasoning_content: None,
|
||||||
tool_calls: vec![],
|
tool_calls: vec![],
|
||||||
usage: Usage { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 },
|
usage: Usage {
|
||||||
|
prompt_tokens: 0,
|
||||||
|
completion_tokens: 0,
|
||||||
|
total_tokens: 0,
|
||||||
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ptype(&self) -> &str { "mock" }
|
fn ptype(&self) -> &str {
|
||||||
fn name(&self) -> &str { "mock" }
|
"mock"
|
||||||
fn model_id(&self) -> &str { "mock" }
|
}
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"mock"
|
||||||
|
}
|
||||||
|
fn model_id(&self) -> &str {
|
||||||
|
"mock"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mock_summarizer() -> Arc<dyn LLMProvider> {
|
fn mock_summarizer() -> Arc<dyn LLMProvider> {
|
||||||
@ -704,11 +743,13 @@ mod tests {
|
|||||||
MM.get_or_init(|| {
|
MM.get_or_init(|| {
|
||||||
let rt = tokio::runtime::Runtime::new().unwrap();
|
let rt = tokio::runtime::Runtime::new().unwrap();
|
||||||
rt.block_on(async {
|
rt.block_on(async {
|
||||||
let tmp = std::env::temp_dir().join(format!("picobot_ctx_test_{}.db", std::process::id()));
|
let tmp = std::env::temp_dir()
|
||||||
|
.join(format!("picobot_ctx_test_{}.db", std::process::id()));
|
||||||
let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap());
|
let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap());
|
||||||
Arc::new(MemoryManager::new(storage, "test".into(), "test".into()))
|
Arc::new(MemoryManager::new(storage, "test".into(), "test".into()))
|
||||||
})
|
})
|
||||||
}).clone()
|
})
|
||||||
|
.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -724,7 +765,11 @@ mod tests {
|
|||||||
// "Hi there!" (8) -> ceil(8/4)+4 = 2+4 = 6
|
// "Hi there!" (8) -> ceil(8/4)+4 = 2+4 = 6
|
||||||
// "How are you?" (11) -> ceil(11/4)+4 = 3+4 = 7
|
// "How are you?" (11) -> ceil(11/4)+4 = 3+4 = 7
|
||||||
// raw = 19, with 1.2x = ~23
|
// raw = 19, with 1.2x = ~23
|
||||||
assert!(tokens > 18 && tokens < 30, "Expected ~23 tokens, got {}", tokens);
|
assert!(
|
||||||
|
tokens > 18 && tokens < 30,
|
||||||
|
"Expected ~23 tokens, got {}",
|
||||||
|
tokens
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -733,7 +778,8 @@ mod tests {
|
|||||||
tool_result_trim_chars: 50,
|
tool_result_trim_chars: 50,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let compressor = ContextCompressor::with_config(mock_provider(), 100_000, config, test_memory_manager());
|
let compressor =
|
||||||
|
ContextCompressor::with_config(mock_provider(), 100_000, config, test_memory_manager());
|
||||||
|
|
||||||
let mut messages = vec![
|
let mut messages = vec![
|
||||||
ChatMessage::user("Hello"),
|
ChatMessage::user("Hello"),
|
||||||
@ -774,7 +820,11 @@ mod tests {
|
|||||||
ChatMessage::tool("call1", "bash", &"x".repeat(3000)),
|
ChatMessage::tool("call1", "bash", &"x".repeat(3000)),
|
||||||
];
|
];
|
||||||
|
|
||||||
let result = compressor.compress_if_needed(messages).await.unwrap().history;
|
let result = compressor
|
||||||
|
.compress_if_needed(messages)
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.history;
|
||||||
|
|
||||||
let tool_msg = result.iter().find(|m| m.role == "tool").unwrap();
|
let tool_msg = result.iter().find(|m| m.role == "tool").unwrap();
|
||||||
assert!(
|
assert!(
|
||||||
@ -798,7 +848,8 @@ mod tests {
|
|||||||
// - B2B (L275): last user message lost when it is the final history message
|
// - B2B (L275): last user message lost when it is the final history message
|
||||||
//
|
//
|
||||||
// context_window=200 → threshold=100. Large tool outputs force LLM summarization.
|
// context_window=200 → threshold=100. Large tool outputs force LLM summarization.
|
||||||
let tmp = std::env::temp_dir().join(format!("picobot_ctx_boundary_{}.db", std::process::id()));
|
let tmp =
|
||||||
|
std::env::temp_dir().join(format!("picobot_ctx_boundary_{}.db", std::process::id()));
|
||||||
let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap());
|
let storage = Arc::new(crate::storage::Storage::new(&tmp).await.unwrap());
|
||||||
let mm = Arc::new(MemoryManager::new(storage, "test".into(), "test".into()));
|
let mm = Arc::new(MemoryManager::new(storage, "test".into(), "test".into()));
|
||||||
|
|
||||||
@ -828,15 +879,33 @@ mod tests {
|
|||||||
ChatMessage::user("Q4"), // 8: LAST, is user → B2B triggers
|
ChatMessage::user("Q4"), // 8: LAST, is user → B2B triggers
|
||||||
];
|
];
|
||||||
|
|
||||||
let result = compressor.compress_if_needed(messages).await.unwrap().history;
|
let result = compressor
|
||||||
|
.compress_if_needed(messages)
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.history;
|
||||||
|
|
||||||
// B2A: "Q1" must appear exactly once
|
// B2A: "Q1" must appear exactly once
|
||||||
let q1_count = result.iter().filter(|m| m.role == "user" && m.content == "Q1").count();
|
let q1_count = result
|
||||||
assert_eq!(q1_count, 1, "Q1 should appear exactly once, got {}", q1_count);
|
.iter()
|
||||||
|
.filter(|m| m.role == "user" && m.content == "Q1")
|
||||||
|
.count();
|
||||||
|
assert_eq!(
|
||||||
|
q1_count, 1,
|
||||||
|
"Q1 should appear exactly once, got {}",
|
||||||
|
q1_count
|
||||||
|
);
|
||||||
|
|
||||||
// B2B: "Q4" must NOT be lost
|
// B2B: "Q4" must NOT be lost
|
||||||
let q4_count = result.iter().filter(|m| m.role == "user" && m.content == "Q4").count();
|
let q4_count = result
|
||||||
assert_eq!(q4_count, 1, "Q4 should appear exactly once (not lost), got {}", q4_count);
|
.iter()
|
||||||
|
.filter(|m| m.role == "user" && m.content == "Q4")
|
||||||
|
.count();
|
||||||
|
assert_eq!(
|
||||||
|
q4_count, 1,
|
||||||
|
"Q4 should appear exactly once (not lost), got {}",
|
||||||
|
q4_count
|
||||||
|
);
|
||||||
|
|
||||||
let _ = std::fs::remove_file(&tmp);
|
let _ = std::fs::remove_file(&tmp);
|
||||||
}
|
}
|
||||||
@ -872,13 +941,23 @@ mod tests {
|
|||||||
ChatMessage::tool("t3", "bash", &big),
|
ChatMessage::tool("t3", "bash", &big),
|
||||||
];
|
];
|
||||||
|
|
||||||
let result = compressor.compress_if_needed(messages).await.unwrap().history;
|
let result = compressor
|
||||||
|
.compress_if_needed(messages)
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.history;
|
||||||
|
|
||||||
// After hard truncation: head (1) + trunc_note (1) + tail (2) = 4 messages
|
// After hard truncation: head (1) + trunc_note (1) + tail (2) = 4 messages
|
||||||
assert!(result.len() < 7, "expected truncation reduction, got {} messages", result.len());
|
assert!(
|
||||||
|
result.len() < 7,
|
||||||
|
"expected truncation reduction, got {} messages",
|
||||||
|
result.len()
|
||||||
|
);
|
||||||
|
|
||||||
// Truncation notice should be present
|
// Truncation notice should be present
|
||||||
let has_notice = result.iter().any(|m| m.content.contains("Context truncation"));
|
let has_notice = result
|
||||||
|
.iter()
|
||||||
|
.any(|m| m.content.contains("Context truncation"));
|
||||||
assert!(has_notice, "hard truncation notice missing");
|
assert!(has_notice, "hard truncation notice missing");
|
||||||
|
|
||||||
let _ = std::fs::remove_file(&tmp);
|
let _ = std::fs::remove_file(&tmp);
|
||||||
@ -910,8 +989,16 @@ mod tests {
|
|||||||
|
|
||||||
// orphan should be removed; legitimate should stay
|
// orphan should be removed; legitimate should stay
|
||||||
assert_eq!(messages.len(), 4);
|
assert_eq!(messages.len(), 4);
|
||||||
assert!(messages.iter().all(|m| m.tool_call_id != Some("tc1".into())));
|
assert!(
|
||||||
assert!(messages.iter().any(|m| m.tool_call_id == Some("tc2".into())));
|
messages
|
||||||
|
.iter()
|
||||||
|
.all(|m| m.tool_call_id != Some("tc1".into()))
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
messages
|
||||||
|
.iter()
|
||||||
|
.any(|m| m.tool_call_id == Some("tc2".into()))
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@ -49,7 +49,7 @@ impl MediaHandler for ImageHandler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> {
|
fn encode_image_to_base64(path: &str) -> Result<(String, String), std::io::Error> {
|
||||||
use base64::{engine::general_purpose::STANDARD, Engine as _};
|
use base64::{Engine as _, engine::general_purpose::STANDARD};
|
||||||
|
|
||||||
let mut file = std::fs::File::open(path)?;
|
let mut file = std::fs::File::open(path)?;
|
||||||
let mut buffer = Vec::new();
|
let mut buffer = Vec::new();
|
||||||
|
|||||||
@ -4,10 +4,13 @@ pub mod media_handler;
|
|||||||
pub mod sub_agent;
|
pub mod sub_agent;
|
||||||
pub mod system_prompt;
|
pub mod system_prompt;
|
||||||
|
|
||||||
pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult};
|
pub use agent_loop::{AgentError, AgentLoop, AgentProcessResult};
|
||||||
pub use context_compressor::{ContextCompressor, estimate_tokens};
|
pub use context_compressor::{ContextCompressor, estimate_tokens};
|
||||||
pub use sub_agent::{DelegateContext, ExecutionMode, SubAgentConfig, SubAgentError, SubAgentManager, SubAgentResult, TaskNotification, TaskStatus};
|
pub use sub_agent::{
|
||||||
pub use system_prompt::{
|
DelegateContext, ExecutionMode, SubAgentConfig, SubAgentError, SubAgentManager, SubAgentResult,
|
||||||
build_system_prompt, build_sub_agent_system_prompt, PromptContext, PromptSection,
|
TaskNotification, TaskStatus,
|
||||||
SystemPromptBuilder,
|
};
|
||||||
|
pub use system_prompt::{
|
||||||
|
PromptContext, PromptSection, SystemPromptBuilder, build_sub_agent_system_prompt,
|
||||||
|
build_system_prompt,
|
||||||
};
|
};
|
||||||
|
|||||||
@ -6,12 +6,12 @@ use dashmap::DashMap;
|
|||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::agent::system_prompt::build_sub_agent_system_prompt;
|
|
||||||
use crate::agent::AgentLoop;
|
|
||||||
use crate::agent::AgentError;
|
use crate::agent::AgentError;
|
||||||
|
use crate::agent::AgentLoop;
|
||||||
|
use crate::agent::system_prompt::build_sub_agent_system_prompt;
|
||||||
use crate::bus::ChatMessage;
|
use crate::bus::ChatMessage;
|
||||||
use crate::config::LLMProviderConfig;
|
use crate::config::LLMProviderConfig;
|
||||||
use crate::providers::{create_provider, LLMProvider};
|
use crate::providers::{LLMProvider, create_provider};
|
||||||
use crate::skills::SkillsLoader;
|
use crate::skills::SkillsLoader;
|
||||||
use crate::tools::ToolRegistry;
|
use crate::tools::ToolRegistry;
|
||||||
|
|
||||||
@ -21,7 +21,8 @@ tokio::task_local! {
|
|||||||
|
|
||||||
/// Read the delegate context from the current task. Returns an error if not set.
|
/// Read the delegate context from the current task. Returns an error if not set.
|
||||||
pub fn get_delegate_context() -> Result<DelegateContext, String> {
|
pub fn get_delegate_context() -> Result<DelegateContext, String> {
|
||||||
DELEGATE_CONTEXT.try_with(|ctx| ctx.clone())
|
DELEGATE_CONTEXT
|
||||||
|
.try_with(|ctx| ctx.clone())
|
||||||
.map_err(|_| "DELEGATE_CONTEXT not set".to_string())
|
.map_err(|_| "DELEGATE_CONTEXT not set".to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -207,7 +208,10 @@ impl SubAgentManager {
|
|||||||
let timeout_secs = config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
|
let timeout_secs = config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
|
||||||
let timeout_human = format_duration(timeout_secs);
|
let timeout_human = format_duration(timeout_secs);
|
||||||
let http_get_only = config.allowed_tools.is_none()
|
let http_get_only = config.allowed_tools.is_none()
|
||||||
|| config.allowed_tools.as_ref().is_some_and(|v| v.iter().any(|t| t == "http_request"));
|
|| config
|
||||||
|
.allowed_tools
|
||||||
|
.as_ref()
|
||||||
|
.is_some_and(|v| v.iter().any(|t| t == "http_request"));
|
||||||
let skills_prompt = self.get_skills_prompt(&tools);
|
let skills_prompt = self.get_skills_prompt(&tools);
|
||||||
let system_prompt = build_sub_agent_system_prompt(
|
let system_prompt = build_sub_agent_system_prompt(
|
||||||
&config.prompt,
|
&config.prompt,
|
||||||
@ -219,7 +223,8 @@ impl SubAgentManager {
|
|||||||
http_get_only,
|
http_get_only,
|
||||||
);
|
);
|
||||||
|
|
||||||
let agent = self.build_sub_agent(&config, tools)
|
let agent = self
|
||||||
|
.build_sub_agent(&config, tools)
|
||||||
.map_err(|e| SubAgentError::ProviderCreation(e.to_string()))?;
|
.map_err(|e| SubAgentError::ProviderCreation(e.to_string()))?;
|
||||||
|
|
||||||
let history = vec![
|
let history = vec![
|
||||||
@ -241,10 +246,14 @@ impl SubAgentManager {
|
|||||||
Ok(Ok(agent_result)) => {
|
Ok(Ok(agent_result)) => {
|
||||||
let (content, truncated) =
|
let (content, truncated) =
|
||||||
truncate_sub_agent_result(&agent_result.final_response.content);
|
truncate_sub_agent_result(&agent_result.final_response.content);
|
||||||
let tool_calls_count = agent_result.emitted_messages.iter()
|
let tool_calls_count = agent_result
|
||||||
|
.emitted_messages
|
||||||
|
.iter()
|
||||||
.filter(|m| m.tool_calls.is_some())
|
.filter(|m| m.tool_calls.is_some())
|
||||||
.count();
|
.count();
|
||||||
let iterations = agent_result.emitted_messages.iter()
|
let iterations = agent_result
|
||||||
|
.emitted_messages
|
||||||
|
.iter()
|
||||||
.filter(|m| m.role == "assistant" && m.tool_calls.is_some())
|
.filter(|m| m.role == "assistant" && m.tool_calls.is_some())
|
||||||
.count();
|
.count();
|
||||||
Ok(SubAgentResult {
|
Ok(SubAgentResult {
|
||||||
@ -343,7 +352,10 @@ impl SubAgentManager {
|
|||||||
let timeout_secs = config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
|
let timeout_secs = config.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
|
||||||
let timeout_human = format_duration(timeout_secs);
|
let timeout_human = format_duration(timeout_secs);
|
||||||
let http_get_only = config.allowed_tools.is_none()
|
let http_get_only = config.allowed_tools.is_none()
|
||||||
|| config.allowed_tools.as_ref().is_some_and(|v| v.iter().any(|t| t == "http_request"));
|
|| config
|
||||||
|
.allowed_tools
|
||||||
|
.as_ref()
|
||||||
|
.is_some_and(|v| v.iter().any(|t| t == "http_request"));
|
||||||
let skills_prompt = self.get_skills_prompt(&tools);
|
let skills_prompt = self.get_skills_prompt(&tools);
|
||||||
let system_prompt = build_sub_agent_system_prompt(
|
let system_prompt = build_sub_agent_system_prompt(
|
||||||
&config.prompt,
|
&config.prompt,
|
||||||
@ -372,8 +384,12 @@ impl SubAgentManager {
|
|||||||
if let Some(ref s) = storage {
|
if let Some(ref s) = storage {
|
||||||
let _ = s
|
let _ = s
|
||||||
.update_background_task_status(
|
.update_background_task_status(
|
||||||
&tid, "running", None, None,
|
&tid,
|
||||||
Some(started_at), None,
|
"running",
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
Some(started_at),
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
@ -384,8 +400,7 @@ impl SubAgentManager {
|
|||||||
p.set_storage(s.clone());
|
p.set_storage(s.clone());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let provider_result: Option<Arc<dyn LLMProvider>> =
|
let provider_result: Option<Arc<dyn LLMProvider>> = provider.map(|p| Arc::from(p));
|
||||||
provider.map(|p| Arc::from(p));
|
|
||||||
|
|
||||||
let result = match provider_result {
|
let result = match provider_result {
|
||||||
Some(provider) => {
|
Some(provider) => {
|
||||||
@ -474,9 +489,12 @@ impl SubAgentManager {
|
|||||||
if let Some(ref s) = storage {
|
if let Some(ref s) = storage {
|
||||||
let _ = s
|
let _ = s
|
||||||
.update_background_task_status(
|
.update_background_task_status(
|
||||||
&tid, &status_str,
|
&tid,
|
||||||
Some(&result.content), error_val.as_deref(),
|
&status_str,
|
||||||
Some(started_at), Some(finished_at),
|
Some(&result.content),
|
||||||
|
error_val.as_deref(),
|
||||||
|
Some(started_at),
|
||||||
|
Some(finished_at),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
@ -514,15 +532,13 @@ impl SubAgentManager {
|
|||||||
Ok(true)
|
Ok(true)
|
||||||
} else if let Some(ref s) = self.storage {
|
} else if let Some(ref s) = self.storage {
|
||||||
match s.get_background_task(task_id).await {
|
match s.get_background_task(task_id).await {
|
||||||
Ok(task) => {
|
Ok(task) => match task.status.as_str() {
|
||||||
match task.status.as_str() {
|
|
||||||
"pending" | "running" => {
|
"pending" | "running" => {
|
||||||
tracing::warn!(task_id, "task in DB but not in active_tasks");
|
tracing::warn!(task_id, "task in DB but not in active_tasks");
|
||||||
Ok(false)
|
Ok(false)
|
||||||
}
|
}
|
||||||
_ => Ok(false),
|
_ => Ok(false),
|
||||||
}
|
},
|
||||||
}
|
|
||||||
Err(_) => Ok(false),
|
Err(_) => Ok(false),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -530,10 +546,7 @@ impl SubAgentManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn check_task(
|
pub async fn check_task(&self, task_id: &str) -> Option<crate::storage::BackgroundTask> {
|
||||||
&self,
|
|
||||||
task_id: &str,
|
|
||||||
) -> Option<crate::storage::BackgroundTask> {
|
|
||||||
if let Some(ref s) = self.storage {
|
if let Some(ref s) = self.storage {
|
||||||
s.get_background_task(task_id).await.ok()
|
s.get_background_task(task_id).await.ok()
|
||||||
} else {
|
} else {
|
||||||
@ -541,12 +554,11 @@ impl SubAgentManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn list_tasks(
|
pub async fn list_tasks(&self, session_id: &str) -> Vec<crate::storage::BackgroundTask> {
|
||||||
&self,
|
|
||||||
session_id: &str,
|
|
||||||
) -> Vec<crate::storage::BackgroundTask> {
|
|
||||||
if let Some(ref s) = self.storage {
|
if let Some(ref s) = self.storage {
|
||||||
s.list_background_tasks(session_id).await.unwrap_or_default()
|
s.list_background_tasks(session_id)
|
||||||
|
.await
|
||||||
|
.unwrap_or_default()
|
||||||
} else {
|
} else {
|
||||||
vec![]
|
vec![]
|
||||||
}
|
}
|
||||||
|
|||||||
@ -465,7 +465,9 @@ impl PromptSection for SubAgentToolsSection {
|
|||||||
let mut s = String::from("## 可用工具\n\n");
|
let mut s = String::from("## 可用工具\n\n");
|
||||||
s.push_str(&ctx.tools.describe_for_prompt());
|
s.push_str(&ctx.tools.describe_for_prompt());
|
||||||
if self.http_get_only {
|
if self.http_get_only {
|
||||||
s.push_str("\n\n**注意**:使用 http_request 时只允许 GET 方法,禁止 POST、PUT、DELETE 等。");
|
s.push_str(
|
||||||
|
"\n\n**注意**:使用 http_request 时只允许 GET 方法,禁止 POST、PUT、DELETE 等。",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
s
|
s
|
||||||
}
|
}
|
||||||
@ -560,12 +562,7 @@ pub fn build_sub_agent_system_prompt(
|
|||||||
memory_context: None,
|
memory_context: None,
|
||||||
has_compressed_history: false,
|
has_compressed_history: false,
|
||||||
};
|
};
|
||||||
SystemPromptBuilder::with_sub_agent_defaults(
|
SystemPromptBuilder::with_sub_agent_defaults(task, timeout_human, skills_prompt, http_get_only)
|
||||||
task,
|
|
||||||
timeout_human,
|
|
||||||
skills_prompt,
|
|
||||||
http_get_only,
|
|
||||||
)
|
|
||||||
.build(&ctx)
|
.build(&ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::bus::{MessageBus, OutboundMessage};
|
use crate::bus::{MessageBus, OutboundMessage};
|
||||||
use crate::channels::base::{Channel, ChannelError};
|
|
||||||
use crate::channels::ChannelManager;
|
use crate::channels::ChannelManager;
|
||||||
|
use crate::channels::base::{Channel, ChannelError};
|
||||||
|
|
||||||
/// OutboundDispatcher consumes outbound messages from the MessageBus
|
/// OutboundDispatcher consumes outbound messages from the MessageBus
|
||||||
/// and dispatches them to the appropriate Channel
|
/// and dispatches them to the appropriate Channel
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use crate::providers::ToolCall;
|
use crate::providers::ToolCall;
|
||||||
|
|
||||||
@ -23,7 +23,9 @@ pub struct ImageUrlBlock {
|
|||||||
|
|
||||||
impl ContentBlock {
|
impl ContentBlock {
|
||||||
pub fn text(content: impl Into<String>) -> Self {
|
pub fn text(content: impl Into<String>) -> Self {
|
||||||
Self::Text { text: content.into() }
|
Self::Text {
|
||||||
|
text: content.into(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn image_url(url: impl Into<String>) -> Self {
|
pub fn image_url(url: impl Into<String>) -> Self {
|
||||||
@ -161,7 +163,10 @@ impl ChatMessage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn assistant_with_tool_calls(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
|
pub fn assistant_with_tool_calls(
|
||||||
|
content: impl Into<String>,
|
||||||
|
tool_calls: Vec<ToolCall>,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
id: uuid::Uuid::new_v4().to_string(),
|
id: uuid::Uuid::new_v4().to_string(),
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
@ -206,7 +211,11 @@ impl ChatMessage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self {
|
pub fn tool(
|
||||||
|
tool_call_id: impl Into<String>,
|
||||||
|
tool_name: impl Into<String>,
|
||||||
|
content: impl Into<String>,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
id: uuid::Uuid::new_v4().to_string(),
|
id: uuid::Uuid::new_v4().to_string(),
|
||||||
role: "tool".to_string(),
|
role: "tool".to_string(),
|
||||||
|
|||||||
@ -2,10 +2,13 @@ pub mod dispatcher;
|
|||||||
pub mod message;
|
pub mod message;
|
||||||
|
|
||||||
pub use dispatcher::OutboundDispatcher;
|
pub use dispatcher::OutboundDispatcher;
|
||||||
pub use message::{ChatMessage, ContentBlock, ControlMessage, InboundMessage, MediaItem, MediaRef, MessageSource, OutboundMessage, SourceKind};
|
pub use message::{
|
||||||
|
ChatMessage, ContentBlock, ControlMessage, InboundMessage, MediaItem, MediaRef, MessageSource,
|
||||||
|
OutboundMessage, SourceKind,
|
||||||
|
};
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::{mpsc, Mutex};
|
use tokio::sync::{Mutex, mpsc};
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// MessageBus - Async message queue for Channel <-> Agent communication
|
// MessageBus - Async message queue for Channel <-> Agent communication
|
||||||
@ -49,7 +52,8 @@ impl MessageBus {
|
|||||||
|
|
||||||
/// Consume an inbound message (Agent -> Bus)
|
/// Consume an inbound message (Agent -> Bus)
|
||||||
pub async fn consume_inbound(&self) -> InboundMessage {
|
pub async fn consume_inbound(&self) -> InboundMessage {
|
||||||
let msg = self.inbound_rx
|
let msg = self
|
||||||
|
.inbound_rx
|
||||||
.lock()
|
.lock()
|
||||||
.await
|
.await
|
||||||
.recv()
|
.recv()
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -24,7 +24,10 @@ impl ChannelManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_bus(cli_chat_channel: Arc<crate::channels::CliChatChannel>, bus: Arc<MessageBus>) -> Self {
|
pub fn with_bus(
|
||||||
|
cli_chat_channel: Arc<crate::channels::CliChatChannel>,
|
||||||
|
bus: Arc<MessageBus>,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
channels: Arc::new(RwLock::new(HashMap::new())),
|
channels: Arc::new(RwLock::new(HashMap::new())),
|
||||||
cli_chat_channel,
|
cli_chat_channel,
|
||||||
@ -39,7 +42,10 @@ impl ChannelManager {
|
|||||||
|
|
||||||
/// Register a channel with the manager
|
/// Register a channel with the manager
|
||||||
pub async fn register_channel(&self, name: &str, channel: Arc<dyn Channel + Send + Sync>) {
|
pub async fn register_channel(&self, name: &str, channel: Arc<dyn Channel + Send + Sync>) {
|
||||||
self.channels.write().await.insert(name.to_string(), channel);
|
self.channels
|
||||||
|
.write()
|
||||||
|
.await
|
||||||
|
.insert(name.to_string(), channel);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get CLI chat channel
|
/// Get CLI chat channel
|
||||||
@ -56,14 +62,19 @@ impl ChannelManager {
|
|||||||
// Initialize Feishu channel if enabled
|
// Initialize Feishu channel if enabled
|
||||||
if let Some(feishu_config) = config.channels.get("feishu") {
|
if let Some(feishu_config) = config.channels.get("feishu") {
|
||||||
if feishu_config.enabled {
|
if feishu_config.enabled {
|
||||||
let channel = FeishuChannel::new(feishu_config.clone(), &workspace_dir)
|
let channel =
|
||||||
.map_err(|e| ChannelError::Other(format!("Failed to create Feishu channel: {}", e)))?;
|
FeishuChannel::new(feishu_config.clone(), &workspace_dir).map_err(|e| {
|
||||||
|
ChannelError::Other(format!("Failed to create Feishu channel: {}", e))
|
||||||
|
})?;
|
||||||
|
|
||||||
self.channels
|
self.channels
|
||||||
.write()
|
.write()
|
||||||
.await
|
.await
|
||||||
.insert("feishu".to_string(), Arc::new(channel));
|
.insert("feishu".to_string(), Arc::new(channel));
|
||||||
tracing::info!("Feishu channel registered (media_dir: {}/media/feishu)", workspace_dir.display());
|
tracing::info!(
|
||||||
|
"Feishu channel registered (media_dir: {}/media/feishu)",
|
||||||
|
workspace_dir.display()
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
tracing::info!("Feishu channel disabled in config");
|
tracing::info!("Feishu channel disabled in config");
|
||||||
}
|
}
|
||||||
@ -118,7 +129,10 @@ impl ChannelManager {
|
|||||||
if let Some(channel) = self.get_channel(channel_name).await {
|
if let Some(channel) = self.get_channel(channel_name).await {
|
||||||
channel.send(msg).await
|
channel.send(msg).await
|
||||||
} else {
|
} else {
|
||||||
Err(ChannelError::Other(format!("Channel not found: {}", channel_name)))
|
Err(ChannelError::Other(format!(
|
||||||
|
"Channel not found: {}",
|
||||||
|
channel_name
|
||||||
|
)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
pub mod base;
|
pub mod base;
|
||||||
pub mod feishu;
|
|
||||||
pub mod cli_chat;
|
pub mod cli_chat;
|
||||||
|
pub mod feishu;
|
||||||
pub mod manager;
|
pub mod manager;
|
||||||
pub mod slash_command;
|
pub mod slash_command;
|
||||||
|
|
||||||
pub use base::{Channel, ChannelError};
|
pub use base::{Channel, ChannelError};
|
||||||
pub use manager::ChannelManager;
|
|
||||||
pub use feishu::FeishuChannel;
|
|
||||||
pub use cli_chat::CliChatChannel;
|
pub use cli_chat::CliChatChannel;
|
||||||
pub use slash_command::{parse_slash_command, command_matches};
|
pub use feishu::FeishuChannel;
|
||||||
|
pub use manager::ChannelManager;
|
||||||
|
pub use slash_command::{command_matches, parse_slash_command};
|
||||||
|
|||||||
@ -16,7 +16,9 @@ pub fn parse_slash_command(content: &str) -> Option<(&str, &str)> {
|
|||||||
/// 检查内容是否匹配指定命令
|
/// 检查内容是否匹配指定命令
|
||||||
pub fn command_matches(content: &str, aliases: &[&str]) -> bool {
|
pub fn command_matches(content: &str, aliases: &[&str]) -> bool {
|
||||||
let trimmed = content.trim();
|
let trimmed = content.trim();
|
||||||
aliases.iter().any(|&alias| trimmed == alias || trimmed.starts_with(&format!("{} ", alias)))
|
aliases
|
||||||
|
.iter()
|
||||||
|
.any(|&alias| trimmed == alias || trimmed.starts_with(&format!("{} ", alias)))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@ -27,7 +29,10 @@ mod tests {
|
|||||||
fn test_parse_slash_command() {
|
fn test_parse_slash_command() {
|
||||||
assert_eq!(parse_slash_command("/reset"), Some(("reset", "")));
|
assert_eq!(parse_slash_command("/reset"), Some(("reset", "")));
|
||||||
assert_eq!(parse_slash_command("/reset arg"), Some(("reset", "arg")));
|
assert_eq!(parse_slash_command("/reset arg"), Some(("reset", "arg")));
|
||||||
assert_eq!(parse_slash_command("/new hello world"), Some(("new", "hello world")));
|
assert_eq!(
|
||||||
|
parse_slash_command("/new hello world"),
|
||||||
|
Some(("new", "hello world"))
|
||||||
|
);
|
||||||
assert_eq!(parse_slash_command("/??"), Some(("??", "")));
|
assert_eq!(parse_slash_command("/??"), Some(("??", "")));
|
||||||
assert_eq!(parse_slash_command("/? arg"), Some(("?", "arg")));
|
assert_eq!(parse_slash_command("/? arg"), Some(("?", "arg")));
|
||||||
assert_eq!(parse_slash_command("/?"), Some(("?", "")));
|
assert_eq!(parse_slash_command("/?"), Some(("?", "")));
|
||||||
|
|||||||
@ -8,10 +8,10 @@ use crate::client::tui::ui::render_ui;
|
|||||||
use crossterm::{
|
use crossterm::{
|
||||||
event::{self, Event},
|
event::{self, Event},
|
||||||
execute,
|
execute,
|
||||||
terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen},
|
terminal::{EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, enable_raw_mode},
|
||||||
};
|
};
|
||||||
use futures_util::{SinkExt, StreamExt};
|
use futures_util::{SinkExt, StreamExt};
|
||||||
use ratatui::{prelude::CrosstermBackend, Terminal};
|
use ratatui::{Terminal, prelude::CrosstermBackend};
|
||||||
use std::io;
|
use std::io;
|
||||||
use tokio_tungstenite::{connect_async, tungstenite::Message};
|
use tokio_tungstenite::{connect_async, tungstenite::Message};
|
||||||
|
|
||||||
@ -104,7 +104,10 @@ async fn handle_ws_message(app: &mut App, outbound: WsOutbound) {
|
|||||||
WsOutbound::SessionCreated { session_id, .. } => {
|
WsOutbound::SessionCreated { session_id, .. } => {
|
||||||
app.set_current_session(Some(session_id));
|
app.set_current_session(Some(session_id));
|
||||||
}
|
}
|
||||||
WsOutbound::SessionList { sessions, current_session_id } => {
|
WsOutbound::SessionList {
|
||||||
|
sessions,
|
||||||
|
current_session_id,
|
||||||
|
} => {
|
||||||
app.set_sessions(sessions);
|
app.set_sessions(sessions);
|
||||||
if let Some(id) = current_session_id {
|
if let Some(id) = current_session_id {
|
||||||
app.set_current_session(Some(id));
|
app.set_current_session(Some(id));
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
use crate::client::tui::app::{App, MessageRole};
|
use crate::client::tui::app::{App, MessageRole};
|
||||||
use ratatui::{
|
use ratatui::{
|
||||||
|
Frame,
|
||||||
layout::Rect,
|
layout::Rect,
|
||||||
style::{Color, Modifier, Style},
|
style::{Color, Modifier, Style},
|
||||||
text::Line,
|
text::Line,
|
||||||
widgets::{Block, Borders, List, ListItem},
|
widgets::{Block, Borders, List, ListItem},
|
||||||
Frame,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
use crate::client::tui::app::App;
|
use crate::client::tui::app::App;
|
||||||
use ratatui::{
|
use ratatui::{
|
||||||
|
Frame,
|
||||||
layout::Rect,
|
layout::Rect,
|
||||||
style::{Color, Modifier, Style},
|
style::{Color, Modifier, Style},
|
||||||
text::{Line, Span},
|
text::{Line, Span},
|
||||||
widgets::{Block, Borders, List, ListItem},
|
widgets::{Block, Borders, List, ListItem},
|
||||||
Frame,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
use ratatui::{
|
use ratatui::{
|
||||||
|
Frame,
|
||||||
layout::Rect,
|
layout::Rect,
|
||||||
style::{Color, Modifier, Style},
|
style::{Color, Modifier, Style},
|
||||||
widgets::{Block, Borders, Clear, List, ListItem},
|
widgets::{Block, Borders, Clear, List, ListItem},
|
||||||
Frame,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn render(f: &mut Frame, area: Rect) {
|
pub fn render(f: &mut Frame, area: Rect) {
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
use crate::client::tui::app::App;
|
use crate::client::tui::app::App;
|
||||||
use ratatui::{
|
use ratatui::{
|
||||||
|
Frame,
|
||||||
layout::Rect,
|
layout::Rect,
|
||||||
style::{Color, Style},
|
style::{Color, Style},
|
||||||
widgets::{Block, Borders, Paragraph},
|
widgets::{Block, Borders, Paragraph},
|
||||||
Frame,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
use crate::client::tui::app::App;
|
use crate::client::tui::app::App;
|
||||||
use ratatui::{
|
use ratatui::{
|
||||||
|
Frame,
|
||||||
layout::Rect,
|
layout::Rect,
|
||||||
style::{Color, Modifier, Style},
|
style::{Color, Modifier, Style},
|
||||||
widgets::{Block, Borders, List, ListItem},
|
widgets::{Block, Borders, List, ListItem},
|
||||||
Frame,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
||||||
@ -11,9 +11,7 @@ pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
|||||||
.sessions
|
.sessions
|
||||||
.iter()
|
.iter()
|
||||||
.map(|session| {
|
.map(|session| {
|
||||||
let is_current = app
|
let is_current = app.current_session_id.as_ref() == Some(&session.session_id);
|
||||||
.current_session_id
|
|
||||||
.as_ref() == Some(&session.session_id);
|
|
||||||
let archived = session.archived_at.is_some();
|
let archived = session.archived_at.is_some();
|
||||||
|
|
||||||
let mut content = if is_current {
|
let mut content = if is_current {
|
||||||
|
|||||||
@ -1,15 +1,18 @@
|
|||||||
use crate::client::tui::app::App;
|
use crate::client::tui::app::App;
|
||||||
use ratatui::{
|
use ratatui::{
|
||||||
|
Frame,
|
||||||
layout::Rect,
|
layout::Rect,
|
||||||
style::{Color, Modifier, Style},
|
style::{Color, Modifier, Style},
|
||||||
widgets::{Block, Borders, Paragraph},
|
widgets::{Block, Borders, Paragraph},
|
||||||
Frame,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
pub fn render(f: &mut Frame, area: Rect, app: &App) {
|
||||||
let (title, style) = if app.pending_quit {
|
let (title, style) = if app.pending_quit {
|
||||||
let msg = if let Some(session_id) = &app.current_session_id {
|
let msg = if let Some(session_id) = &app.current_session_id {
|
||||||
format!("PicoBot | Session: {} | Press Ctrl+C again to quit", session_id)
|
format!(
|
||||||
|
"PicoBot | Session: {} | Press Ctrl+C again to quit",
|
||||||
|
session_id
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
"PicoBot | Press Ctrl+C again to quit".to_string()
|
"PicoBot | Press Ctrl+C again to quit".to_string()
|
||||||
};
|
};
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
use crate::client::tui::app::{App, MessageRole};
|
use crate::client::tui::app::{App, MessageRole};
|
||||||
use crate::protocol::serialize_inbound;
|
|
||||||
use crate::protocol::WsInbound;
|
use crate::protocol::WsInbound;
|
||||||
|
use crate::protocol::serialize_inbound;
|
||||||
use crossterm::event::{KeyCode, KeyEvent};
|
use crossterm::event::{KeyCode, KeyEvent};
|
||||||
use futures_util::SinkExt;
|
use futures_util::SinkExt;
|
||||||
|
|
||||||
@ -48,7 +48,10 @@ pub async fn handle_key_event(app: &mut App, key: KeyEvent) {
|
|||||||
|
|
||||||
async fn handle_normal_input(app: &mut App, key: KeyEvent) {
|
async fn handle_normal_input(app: &mut App, key: KeyEvent) {
|
||||||
// Handle Ctrl+C for quit (double press to exit)
|
// Handle Ctrl+C for quit (double press to exit)
|
||||||
let is_ctrl_c = key.code == KeyCode::Char('c') && key.modifiers.contains(crossterm::event::KeyModifiers::CONTROL);
|
let is_ctrl_c = key.code == KeyCode::Char('c')
|
||||||
|
&& key
|
||||||
|
.modifiers
|
||||||
|
.contains(crossterm::event::KeyModifiers::CONTROL);
|
||||||
if is_ctrl_c {
|
if is_ctrl_c {
|
||||||
if app.handle_ctrl_c_for_quit() {
|
if app.handle_ctrl_c_for_quit() {
|
||||||
return;
|
return;
|
||||||
@ -65,7 +68,9 @@ async fn handle_normal_input(app: &mut App, key: KeyEvent) {
|
|||||||
app.input_insert_char(c);
|
app.input_insert_char(c);
|
||||||
|
|
||||||
// Show command menu when input starts with /
|
// Show command menu when input starts with /
|
||||||
if !app.show_command_menu && (app.input == "/" || (app.input.len() > 1 && app.input.starts_with('/'))) {
|
if !app.show_command_menu
|
||||||
|
&& (app.input == "/" || (app.input.len() > 1 && app.input.starts_with('/')))
|
||||||
|
{
|
||||||
app.show_command_menu = true;
|
app.show_command_menu = true;
|
||||||
app.selected_command_idx = 0;
|
app.selected_command_idx = 0;
|
||||||
} else if app.show_command_menu && !app.input.starts_with('/') {
|
} else if app.show_command_menu && !app.input.starts_with('/') {
|
||||||
@ -121,7 +126,9 @@ async fn process_input(app: &mut App, input: String) {
|
|||||||
sender_id: None,
|
sender_id: None,
|
||||||
};
|
};
|
||||||
if let Ok(text) = serialize_inbound(&inbound) {
|
if let Ok(text) = serialize_inbound(&inbound) {
|
||||||
let _ = sender.send(tokio_tungstenite::tungstenite::Message::Text(text.into())).await;
|
let _ = sender
|
||||||
|
.send(tokio_tungstenite::tungstenite::Message::Text(text.into()))
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
use crate::client::tui::app::App;
|
use crate::client::tui::app::App;
|
||||||
use crate::client::tui::components::*;
|
use crate::client::tui::components::*;
|
||||||
use ratatui::{
|
use ratatui::{
|
||||||
layout::{Constraint, Direction, Layout, Rect},
|
|
||||||
Frame,
|
Frame,
|
||||||
|
layout::{Constraint, Direction, Layout, Rect},
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn render_ui(f: &mut Frame, app: &App) {
|
pub fn render_ui(f: &mut Frame, app: &App) {
|
||||||
|
|||||||
@ -273,12 +273,16 @@ impl Default for MemoryConfig {
|
|||||||
impl MemoryConfig {
|
impl MemoryConfig {
|
||||||
/// Resolve consolidation provider name, falling back to the main agent's provider.
|
/// Resolve consolidation provider name, falling back to the main agent's provider.
|
||||||
pub fn resolve_consolidation_provider(&self, default: &str) -> String {
|
pub fn resolve_consolidation_provider(&self, default: &str) -> String {
|
||||||
self.consolidation_provider.clone().unwrap_or_else(|| default.to_string())
|
self.consolidation_provider
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| default.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Resolve consolidation model name, falling back to the main agent's model.
|
/// Resolve consolidation model name, falling back to the main agent's model.
|
||||||
pub fn resolve_consolidation_model(&self, default: &str) -> String {
|
pub fn resolve_consolidation_model(&self, default: &str) -> String {
|
||||||
self.consolidation_model.clone().unwrap_or_else(|| default.to_string())
|
self.consolidation_model
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| default.to_string())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -366,10 +370,18 @@ impl Default for BrowserConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_recall_limit() -> usize { 5 }
|
fn default_recall_limit() -> usize {
|
||||||
fn default_idle_consolidation_minutes() -> u64 { 10 }
|
5
|
||||||
fn default_timeline_retention_days() -> u64 { 90 }
|
}
|
||||||
fn default_max_failures_before_degrade() -> usize { 3 }
|
fn default_idle_consolidation_minutes() -> u64 {
|
||||||
|
10
|
||||||
|
}
|
||||||
|
fn default_timeline_retention_days() -> u64 {
|
||||||
|
90
|
||||||
|
}
|
||||||
|
fn default_max_failures_before_degrade() -> usize {
|
||||||
|
3
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct LLMProviderConfig {
|
pub struct LLMProviderConfig {
|
||||||
@ -469,7 +481,11 @@ pub enum ConfigError {
|
|||||||
impl std::fmt::Display for ConfigError {
|
impl std::fmt::Display for ConfigError {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
ConfigError::ConfigNotFound(path) => write!(f, "Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json", path),
|
ConfigError::ConfigNotFound(path) => write!(
|
||||||
|
f,
|
||||||
|
"Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json",
|
||||||
|
path
|
||||||
|
),
|
||||||
ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name),
|
ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name),
|
||||||
ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name),
|
ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name),
|
||||||
ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name),
|
ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name),
|
||||||
|
|||||||
@ -1,12 +1,12 @@
|
|||||||
use std::sync::Arc;
|
use super::GatewayState;
|
||||||
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage};
|
use crate::protocol::WsOutbound;
|
||||||
|
use crate::protocol::serialize_outbound;
|
||||||
use axum::extract::State;
|
use axum::extract::State;
|
||||||
|
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
|
||||||
use axum::response::Response;
|
use axum::response::Response;
|
||||||
use futures_util::{SinkExt, StreamExt};
|
use futures_util::{SinkExt, StreamExt};
|
||||||
|
use std::sync::Arc;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
use crate::protocol::serialize_outbound;
|
|
||||||
use crate::protocol::WsOutbound;
|
|
||||||
use super::GatewayState;
|
|
||||||
|
|
||||||
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
|
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
|
||||||
ws.on_upgrade(|socket| async move {
|
ws.on_upgrade(|socket| async move {
|
||||||
@ -25,9 +25,11 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
|||||||
let (session_id, client) = cli_chat_channel.register_client(sender.clone()).await;
|
let (session_id, client) = cli_chat_channel.register_client(sender.clone()).await;
|
||||||
|
|
||||||
// Send session established message
|
// Send session established message
|
||||||
let _ = sender.send(WsOutbound::SessionEstablished {
|
let _ = sender
|
||||||
|
.send(WsOutbound::SessionEstablished {
|
||||||
session_id: session_id.clone(),
|
session_id: session_id.clone(),
|
||||||
}).await;
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
tracing::info!(session_id = %session_id, "CLI session established");
|
tracing::info!(session_id = %session_id, "CLI session established");
|
||||||
|
|
||||||
@ -37,7 +39,8 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
|||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
while let Some(msg) = receiver.recv().await {
|
while let Some(msg) = receiver.recv().await {
|
||||||
if let Ok(text) = serialize_outbound(&msg)
|
if let Ok(text) = serialize_outbound(&msg)
|
||||||
&& ws_sender.send(WsMessage::Text(text.into())).await.is_err() {
|
&& ws_sender.send(WsMessage::Text(text.into())).await.is_err()
|
||||||
|
{
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
14
src/lib.rs
14
src/lib.rs
@ -1,17 +1,17 @@
|
|||||||
pub mod config;
|
|
||||||
pub mod providers;
|
|
||||||
pub mod bus;
|
|
||||||
pub mod agent;
|
pub mod agent;
|
||||||
pub mod gateway;
|
pub mod bus;
|
||||||
pub mod session;
|
|
||||||
pub mod client;
|
|
||||||
pub mod protocol;
|
|
||||||
pub mod channels;
|
pub mod channels;
|
||||||
|
pub mod client;
|
||||||
|
pub mod config;
|
||||||
|
pub mod gateway;
|
||||||
pub mod logging;
|
pub mod logging;
|
||||||
pub mod mcp;
|
pub mod mcp;
|
||||||
pub mod memory;
|
pub mod memory;
|
||||||
pub mod observability;
|
pub mod observability;
|
||||||
|
pub mod protocol;
|
||||||
|
pub mod providers;
|
||||||
pub mod scheduler;
|
pub mod scheduler;
|
||||||
|
pub mod session;
|
||||||
pub mod skills;
|
pub mod skills;
|
||||||
pub mod storage;
|
pub mod storage;
|
||||||
pub mod tools;
|
pub mod tools;
|
||||||
|
|||||||
@ -1,11 +1,7 @@
|
|||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use tracing_appender::rolling::{RollingFileAppender, Rotation};
|
use tracing_appender::rolling::{RollingFileAppender, Rotation};
|
||||||
use tracing_subscriber::{
|
use tracing_subscriber::{
|
||||||
fmt,
|
EnvFilter, fmt, fmt::time::LocalTime, layer::SubscriberExt, util::SubscriberInitExt,
|
||||||
layer::SubscriberExt,
|
|
||||||
util::SubscriberInitExt,
|
|
||||||
fmt::time::LocalTime,
|
|
||||||
EnvFilter,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Get the default log directory path: ~/.picobot/logs
|
/// Get the default log directory path: ~/.picobot/logs
|
||||||
@ -27,20 +23,20 @@ pub fn init_logging() {
|
|||||||
|
|
||||||
// Create log directory if it doesn't exist
|
// Create log directory if it doesn't exist
|
||||||
if !log_dir.exists()
|
if !log_dir.exists()
|
||||||
&& let Err(e) = std::fs::create_dir_all(&log_dir) {
|
&& let Err(e) = std::fs::create_dir_all(&log_dir)
|
||||||
eprintln!("Warning: Failed to create log directory {}: {}", log_dir.display(), e);
|
{
|
||||||
|
eprintln!(
|
||||||
|
"Warning: Failed to create log directory {}: {}",
|
||||||
|
log_dir.display(),
|
||||||
|
e
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create file appender with daily rotation
|
// Create file appender with daily rotation
|
||||||
let file_appender = RollingFileAppender::new(
|
let file_appender = RollingFileAppender::new(Rotation::DAILY, &log_dir, "picobot.log");
|
||||||
Rotation::DAILY,
|
|
||||||
&log_dir,
|
|
||||||
"picobot.log",
|
|
||||||
);
|
|
||||||
|
|
||||||
// Build subscriber with both console and file output
|
// Build subscriber with both console and file output
|
||||||
let env_filter = EnvFilter::try_from_default_env()
|
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
|
||||||
.unwrap_or_else(|_| EnvFilter::new("info"));
|
|
||||||
|
|
||||||
let file_layer = fmt::layer()
|
let file_layer = fmt::layer()
|
||||||
.with_writer(file_appender)
|
.with_writer(file_appender)
|
||||||
@ -66,8 +62,7 @@ pub fn init_logging() {
|
|||||||
|
|
||||||
/// Initialize logging without file output (console only)
|
/// Initialize logging without file output (console only)
|
||||||
pub fn init_logging_console_only() {
|
pub fn init_logging_console_only() {
|
||||||
let env_filter = EnvFilter::try_from_default_env()
|
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
|
||||||
.unwrap_or_else(|_| EnvFilter::new("info"));
|
|
||||||
|
|
||||||
let console_layer = fmt::layer()
|
let console_layer = fmt::layer()
|
||||||
.with_timer(LocalTime::rfc_3339())
|
.with_timer(LocalTime::rfc_3339())
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
use clap::{Parser, CommandFactory};
|
use clap::{CommandFactory, Parser};
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
#[command(name = "picobot")]
|
#[command(name = "picobot")]
|
||||||
|
|||||||
@ -92,13 +92,9 @@ fn extract_text(result: &rmcp::model::CallToolResult) -> String {
|
|||||||
parts.push(text.text.clone());
|
parts.push(text.text.clone());
|
||||||
}
|
}
|
||||||
RawContent::Image(image) => {
|
RawContent::Image(image) => {
|
||||||
parts.push(format!(
|
parts.push(format!("[image: {}]", image.mime_type,));
|
||||||
"[image: {}]",
|
|
||||||
image.mime_type,
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
RawContent::Resource(resource) => {
|
RawContent::Resource(resource) => match &resource.resource {
|
||||||
match &resource.resource {
|
|
||||||
rmcp::model::ResourceContents::TextResourceContents { text, .. } => {
|
rmcp::model::ResourceContents::TextResourceContents { text, .. } => {
|
||||||
parts.push(format!(
|
parts.push(format!(
|
||||||
"[resource text: {}]",
|
"[resource text: {}]",
|
||||||
@ -108,8 +104,7 @@ fn extract_text(result: &rmcp::model::CallToolResult) -> String {
|
|||||||
rmcp::model::ResourceContents::BlobResourceContents { uri, .. } => {
|
rmcp::model::ResourceContents::BlobResourceContents { uri, .. } => {
|
||||||
parts.push(format!("[resource blob: {}]", uri));
|
parts.push(format!("[resource blob: {}]", uri));
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
}
|
|
||||||
_ => {
|
_ => {
|
||||||
parts.push("[unsupported content]".to_string());
|
parts.push("[unsupported content]".to_string());
|
||||||
}
|
}
|
||||||
@ -225,8 +220,8 @@ async fn connect_server(config: &McpServerConfig) -> anyhow::Result<McpConnectio
|
|||||||
cmd.env(k, v);
|
cmd.env(k, v);
|
||||||
}
|
}
|
||||||
|
|
||||||
let service = ()
|
let service =
|
||||||
.serve(
|
().serve(
|
||||||
TokioChildProcess::new(cmd).context("failed to create stdio MCP transport")?,
|
TokioChildProcess::new(cmd).context("failed to create stdio MCP transport")?,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
@ -261,12 +256,12 @@ async fn connect_server(config: &McpServerConfig) -> anyhow::Result<McpConnectio
|
|||||||
} else {
|
} else {
|
||||||
StreamableHttpClientTransport::from_config(
|
StreamableHttpClientTransport::from_config(
|
||||||
StreamableHttpClientTransportConfig::with_uri(url.to_string())
|
StreamableHttpClientTransportConfig::with_uri(url.to_string())
|
||||||
.custom_headers(headers_map)
|
.custom_headers(headers_map),
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
let service = ()
|
let service =
|
||||||
.serve(transport)
|
().serve(transport)
|
||||||
.await
|
.await
|
||||||
.context("failed to connect to HTTP/SSE MCP server")?;
|
.context("failed to connect to HTTP/SSE MCP server")?;
|
||||||
|
|
||||||
|
|||||||
@ -102,7 +102,11 @@ mod tests {
|
|||||||
let dir = tempdir().unwrap();
|
let dir = tempdir().unwrap();
|
||||||
let db_path = dir.path().join("test.db");
|
let db_path = dir.path().join("test.db");
|
||||||
let storage = Arc::new(Storage::new(&db_path).await.unwrap());
|
let storage = Arc::new(Storage::new(&db_path).await.unwrap());
|
||||||
let mm = Arc::new(MemoryManager::new(storage, "default".into(), "default".into()));
|
let mm = Arc::new(MemoryManager::new(
|
||||||
|
storage,
|
||||||
|
"default".into(),
|
||||||
|
"default".into(),
|
||||||
|
));
|
||||||
(mm, dir)
|
(mm, dir)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -131,13 +135,7 @@ mod tests {
|
|||||||
async fn test_upsert_overwrites() {
|
async fn test_upsert_overwrites() {
|
||||||
let (mm, _dir) = setup_memory_manager().await;
|
let (mm, _dir) = setup_memory_manager().await;
|
||||||
|
|
||||||
mm.store(
|
mm.store("dup_key", "original", MemoryCategory::Knowledge, None, None)
|
||||||
"dup_key",
|
|
||||||
"original",
|
|
||||||
MemoryCategory::Knowledge,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mm.store(
|
mm.store(
|
||||||
@ -247,7 +245,12 @@ mod tests {
|
|||||||
|
|
||||||
// Recall scoped to session A — should get only tl_a
|
// Recall scoped to session A — should get only tl_a
|
||||||
let scoped = mm
|
let scoped = mm
|
||||||
.recall("summary", 10, Some(MemoryCategory::Timeline), Some("chan:chat:dialog_a"))
|
.recall(
|
||||||
|
"summary",
|
||||||
|
10,
|
||||||
|
Some(MemoryCategory::Timeline),
|
||||||
|
Some("chan:chat:dialog_a"),
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(scoped.len(), 1);
|
assert_eq!(scoped.len(), 1);
|
||||||
|
|||||||
@ -20,10 +20,7 @@ pub enum ObserverEvent {
|
|||||||
success: bool,
|
success: bool,
|
||||||
},
|
},
|
||||||
/// Emitted when the agent starts processing.
|
/// Emitted when the agent starts processing.
|
||||||
AgentStart {
|
AgentStart { provider: String, model: String },
|
||||||
provider: String,
|
|
||||||
model: String,
|
|
||||||
},
|
|
||||||
/// Emitted when the agent finishes processing.
|
/// Emitted when the agent finishes processing.
|
||||||
AgentEnd {
|
AgentEnd {
|
||||||
provider: String,
|
provider: String,
|
||||||
@ -94,7 +91,11 @@ impl ToolExecutionOutcome {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Create a failed outcome with duration.
|
/// Create a failed outcome with duration.
|
||||||
pub fn failure_with_duration(output: String, error_reason: Option<String>, duration: Duration) -> Self {
|
pub fn failure_with_duration(
|
||||||
|
output: String,
|
||||||
|
error_reason: Option<String>,
|
||||||
|
duration: Duration,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
output,
|
output,
|
||||||
success: false,
|
success: false,
|
||||||
|
|||||||
@ -4,23 +4,24 @@ use serde::{Deserialize, Serialize};
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use crate::bus::message::ContentBlock;
|
|
||||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
|
||||||
use super::traits::Usage;
|
use super::traits::Usage;
|
||||||
use std::sync::Arc;
|
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
||||||
|
use crate::bus::message::ContentBlock;
|
||||||
use crate::storage::Storage;
|
use crate::storage::Storage;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
const LLM_REQUEST_TIMEOUT_SECS: u64 = 300;
|
const LLM_REQUEST_TIMEOUT_SECS: u64 = 300;
|
||||||
|
|
||||||
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> {
|
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> {
|
||||||
blocks.iter().map(|b| match b {
|
blocks
|
||||||
|
.iter()
|
||||||
|
.map(|b| match b {
|
||||||
ContentBlock::Text { text } => {
|
ContentBlock::Text { text } => {
|
||||||
serde_json::json!({ "type": "text", "text": text })
|
serde_json::json!({ "type": "text", "text": text })
|
||||||
}
|
}
|
||||||
ContentBlock::ImageUrl { image_url } => {
|
ContentBlock::ImageUrl { image_url } => convert_image_url_to_anthropic(&image_url.url),
|
||||||
convert_image_url_to_anthropic(&image_url.url)
|
})
|
||||||
}
|
.collect()
|
||||||
}).collect()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn convert_image_url_to_anthropic(url: &str) -> serde_json::Value {
|
fn convert_image_url_to_anthropic(url: &str) -> serde_json::Value {
|
||||||
@ -197,8 +198,13 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
};
|
};
|
||||||
let content = if let Some(ref tc_id) = m.tool_call_id {
|
let content = if let Some(ref tc_id) = m.tool_call_id {
|
||||||
// Tool result: wrap as tool_result content block
|
// Tool result: wrap as tool_result content block
|
||||||
let output = m.content.iter()
|
let output = m
|
||||||
.filter_map(|b| match b { ContentBlock::Text { text } => Some(text.as_str()), _ => None })
|
.content
|
||||||
|
.iter()
|
||||||
|
.filter_map(|b| match b {
|
||||||
|
ContentBlock::Text { text } => Some(text.as_str()),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join("");
|
.join("");
|
||||||
vec![serde_json::json!({
|
vec![serde_json::json!({
|
||||||
@ -244,8 +250,7 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
|
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
|
||||||
tracing::debug!(req_body = %req_body_str, "LLM request");
|
tracing::debug!(req_body = %req_body_str, "LLM request");
|
||||||
|
|
||||||
let resp = req_builder.json(&body).send().await
|
let resp = req_builder.json(&body).send().await.inspect_err(|e| {
|
||||||
.inspect_err(|e| {
|
|
||||||
let is_timeout = e.is_timeout();
|
let is_timeout = e.is_timeout();
|
||||||
tracing::error!(
|
tracing::error!(
|
||||||
provider = %self.name,
|
provider = %self.name,
|
||||||
@ -281,17 +286,21 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
"LLM API returned error"
|
"LLM API returned error"
|
||||||
);
|
);
|
||||||
if let Some(ref storage) = self.storage {
|
if let Some(ref storage) = self.storage {
|
||||||
let _ = storage.append_llm_call(
|
let _ = storage
|
||||||
&self.name, &self.model_id, &req_body_str,
|
.append_llm_call(
|
||||||
Some(&body_text), Some(&error_msg),
|
&self.name,
|
||||||
|
&self.model_id,
|
||||||
|
&req_body_str,
|
||||||
|
Some(&body_text),
|
||||||
|
Some(&error_msg),
|
||||||
start.elapsed().as_millis() as u64,
|
start.elapsed().as_millis() as u64,
|
||||||
).await;
|
)
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
return Err(format!("API error ({}): {}", status.as_u16(), error_msg).into());
|
return Err(format!("API error ({}): {}", status.as_u16(), error_msg).into());
|
||||||
}
|
}
|
||||||
|
|
||||||
let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text)
|
let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text).map_err(|e| {
|
||||||
.map_err(|e| {
|
|
||||||
let err_msg = format!("decode error: {} | body: {}", e, &body_text);
|
let err_msg = format!("decode error: {} | body: {}", e, &body_text);
|
||||||
if let Some(ref storage) = self.storage {
|
if let Some(ref storage) = self.storage {
|
||||||
let name = self.name.clone();
|
let name = self.name.clone();
|
||||||
@ -302,7 +311,9 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
let err = err_msg.clone();
|
let err = err_msg.clone();
|
||||||
let s = storage.clone();
|
let s = storage.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let _ = s.append_llm_call(&name, &model, &req, Some(&resp_body), Some(&err), dur).await;
|
let _ = s
|
||||||
|
.append_llm_call(&name, &model, &req, Some(&resp_body), Some(&err), dur)
|
||||||
|
.await;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
err_msg
|
err_msg
|
||||||
@ -343,21 +354,35 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
reasoning_content: reasoning,
|
reasoning_content: reasoning,
|
||||||
tool_calls,
|
tool_calls,
|
||||||
usage: Usage {
|
usage: Usage {
|
||||||
prompt_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens).unwrap_or(0),
|
prompt_tokens: anthropic_resp
|
||||||
completion_tokens: anthropic_resp.usage.as_ref().map(|u| u.output_tokens).unwrap_or(0),
|
.usage
|
||||||
total_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens + u.output_tokens).unwrap_or(0),
|
.as_ref()
|
||||||
|
.map(|u| u.input_tokens)
|
||||||
|
.unwrap_or(0),
|
||||||
|
completion_tokens: anthropic_resp
|
||||||
|
.usage
|
||||||
|
.as_ref()
|
||||||
|
.map(|u| u.output_tokens)
|
||||||
|
.unwrap_or(0),
|
||||||
|
total_tokens: anthropic_resp
|
||||||
|
.usage
|
||||||
|
.as_ref()
|
||||||
|
.map(|u| u.input_tokens + u.output_tokens)
|
||||||
|
.unwrap_or(0),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(ref storage) = self.storage {
|
if let Some(ref storage) = self.storage {
|
||||||
let _ = storage.append_llm_call(
|
let _ = storage
|
||||||
|
.append_llm_call(
|
||||||
&self.name,
|
&self.name,
|
||||||
&self.model_id,
|
&self.model_id,
|
||||||
&req_body_str,
|
&req_body_str,
|
||||||
Some(&body_text),
|
Some(&body_text),
|
||||||
None,
|
None,
|
||||||
start.elapsed().as_millis() as u64,
|
start.elapsed().as_millis() as u64,
|
||||||
).await;
|
)
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(response)
|
Ok(response)
|
||||||
|
|||||||
@ -1,12 +1,15 @@
|
|||||||
pub mod traits;
|
|
||||||
pub mod openai;
|
|
||||||
pub mod anthropic;
|
pub mod anthropic;
|
||||||
|
pub mod openai;
|
||||||
|
pub mod traits;
|
||||||
|
|
||||||
pub use self::openai::OpenAIProvider;
|
|
||||||
pub use self::anthropic::AnthropicProvider;
|
pub use self::anthropic::AnthropicProvider;
|
||||||
|
pub use self::openai::OpenAIProvider;
|
||||||
|
|
||||||
use crate::config::LLMProviderConfig;
|
use crate::config::LLMProviderConfig;
|
||||||
pub use traits::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall, ToolFunction, Usage};
|
pub use traits::{
|
||||||
|
ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall,
|
||||||
|
ToolFunction, Usage,
|
||||||
|
};
|
||||||
|
|
||||||
pub fn create_provider(config: LLMProviderConfig) -> Result<Box<dyn LLMProvider>, ProviderError> {
|
pub fn create_provider(config: LLMProviderConfig) -> Result<Box<dyn LLMProvider>, ProviderError> {
|
||||||
match config.provider_type.as_str() {
|
match config.provider_type.as_str() {
|
||||||
|
|||||||
@ -1,29 +1,35 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde_json::{json, Value};
|
use serde_json::{Value, json};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use crate::bus::message::ContentBlock;
|
|
||||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
|
||||||
use super::traits::Usage;
|
use super::traits::Usage;
|
||||||
use std::sync::Arc;
|
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
||||||
|
use crate::bus::message::ContentBlock;
|
||||||
use crate::storage::Storage;
|
use crate::storage::Storage;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
const LLM_REQUEST_TIMEOUT_SECS: u64 = 300;
|
const LLM_REQUEST_TIMEOUT_SECS: u64 = 300;
|
||||||
|
|
||||||
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
|
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
|
||||||
if blocks.len() == 1
|
if blocks.len() == 1
|
||||||
&& let ContentBlock::Text { text } = &blocks[0] {
|
&& let ContentBlock::Text { text } = &blocks[0]
|
||||||
|
{
|
||||||
return Value::String(text.clone());
|
return Value::String(text.clone());
|
||||||
}
|
}
|
||||||
Value::Array(blocks.iter().map(|b| match b {
|
Value::Array(
|
||||||
|
blocks
|
||||||
|
.iter()
|
||||||
|
.map(|b| match b {
|
||||||
ContentBlock::Text { text } => json!({ "type": "text", "text": text }),
|
ContentBlock::Text { text } => json!({ "type": "text", "text": text }),
|
||||||
ContentBlock::ImageUrl { image_url } => {
|
ContentBlock::ImageUrl { image_url } => {
|
||||||
json!({ "type": "image_url", "image_url": { "url": image_url.url } })
|
json!({ "type": "image_url", "image_url": { "url": image_url.url } })
|
||||||
}
|
}
|
||||||
}).collect())
|
})
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct OpenAIProvider {
|
pub struct OpenAIProvider {
|
||||||
@ -201,7 +207,11 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
if let Some(content) = msg.get("content").and_then(|c| c.as_array()) {
|
if let Some(content) = msg.get("content").and_then(|c| c.as_array()) {
|
||||||
for (j, item) in content.iter().enumerate() {
|
for (j, item) in content.iter().enumerate() {
|
||||||
if item.get("type").and_then(|t| t.as_str()) == Some("image_url")
|
if item.get("type").and_then(|t| t.as_str()) == Some("image_url")
|
||||||
&& let Some(url_str) = item.get("image_url").and_then(|u| u.get("url")).and_then(|v| v.as_str()) {
|
&& let Some(url_str) = item
|
||||||
|
.get("image_url")
|
||||||
|
.and_then(|u| u.get("url"))
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
{
|
||||||
let prefix: String = url_str.chars().take(20).collect();
|
let prefix: String = url_str.chars().take(20).collect();
|
||||||
tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)");
|
tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)");
|
||||||
}
|
}
|
||||||
@ -224,8 +234,7 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
|
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
|
||||||
tracing::debug!(req_body = %req_body_str, "LLM request");
|
tracing::debug!(req_body = %req_body_str, "LLM request");
|
||||||
|
|
||||||
let resp = req_builder.json(&body).send().await
|
let resp = req_builder.json(&body).send().await.inspect_err(|e| {
|
||||||
.inspect_err(|e| {
|
|
||||||
let is_timeout = e.is_timeout();
|
let is_timeout = e.is_timeout();
|
||||||
tracing::error!(
|
tracing::error!(
|
||||||
provider = %self.name,
|
provider = %self.name,
|
||||||
@ -253,18 +262,23 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
"LLM API returned error"
|
"LLM API returned error"
|
||||||
);
|
);
|
||||||
if let Some(ref storage) = self.storage
|
if let Some(ref storage) = self.storage
|
||||||
&& let Err(e) = storage.append_llm_call(
|
&& let Err(e) = storage
|
||||||
&self.name, &self.model_id, &req_body_str,
|
.append_llm_call(
|
||||||
Some(&text), Some(&error),
|
&self.name,
|
||||||
|
&self.model_id,
|
||||||
|
&req_body_str,
|
||||||
|
Some(&text),
|
||||||
|
Some(&error),
|
||||||
start.elapsed().as_millis() as u64,
|
start.elapsed().as_millis() as u64,
|
||||||
).await {
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
tracing::warn!("failed to persist LLM call: {}", e);
|
tracing::warn!("failed to persist LLM call: {}", e);
|
||||||
}
|
}
|
||||||
return Err(error.into());
|
return Err(error.into());
|
||||||
}
|
}
|
||||||
|
|
||||||
let openai_resp: OpenAIResponse = serde_json::from_str(&text)
|
let openai_resp: OpenAIResponse = serde_json::from_str(&text).map_err(|e| {
|
||||||
.map_err(|e| {
|
|
||||||
let err_msg = format!("decode error: {} | body: {}", e, &text);
|
let err_msg = format!("decode error: {} | body: {}", e, &text);
|
||||||
if let Some(ref storage) = self.storage {
|
if let Some(ref storage) = self.storage {
|
||||||
let name = self.name.clone();
|
let name = self.name.clone();
|
||||||
@ -275,7 +289,10 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
let err = err_msg.clone();
|
let err = err_msg.clone();
|
||||||
let s = storage.clone();
|
let s = storage.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(e) = s.append_llm_call(&name, &model, &req, Some(&resp), Some(&err), dur).await {
|
if let Err(e) = s
|
||||||
|
.append_llm_call(&name, &model, &req, Some(&resp), Some(&err), dur)
|
||||||
|
.await
|
||||||
|
{
|
||||||
tracing::warn!("failed to persist LLM call (decode error): {}", e);
|
tracing::warn!("failed to persist LLM call (decode error): {}", e);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@ -283,7 +300,10 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
err_msg
|
err_msg
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let first_choice = openai_resp.choices.into_iter().next()
|
let first_choice = openai_resp
|
||||||
|
.choices
|
||||||
|
.into_iter()
|
||||||
|
.next()
|
||||||
.ok_or("no choices in response")?;
|
.ok_or("no choices in response")?;
|
||||||
|
|
||||||
let content = first_choice
|
let content = first_choice
|
||||||
@ -300,7 +320,8 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
.map(|tc| ToolCall {
|
.map(|tc| ToolCall {
|
||||||
id: tc.id.clone(),
|
id: tc.id.clone(),
|
||||||
name: tc.function.name.clone(),
|
name: tc.function.name.clone(),
|
||||||
arguments: serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null),
|
arguments: serde_json::from_str(&tc.function.arguments)
|
||||||
|
.unwrap_or(serde_json::Value::Null),
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
@ -318,11 +339,17 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
};
|
};
|
||||||
|
|
||||||
if let Some(ref storage) = self.storage
|
if let Some(ref storage) = self.storage
|
||||||
&& let Err(e) = storage.append_llm_call(
|
&& let Err(e) = storage
|
||||||
&self.name, &self.model_id, &req_body_str,
|
.append_llm_call(
|
||||||
Some(&text), None,
|
&self.name,
|
||||||
|
&self.model_id,
|
||||||
|
&req_body_str,
|
||||||
|
Some(&text),
|
||||||
|
None,
|
||||||
start.elapsed().as_millis() as u64,
|
start.elapsed().as_millis() as u64,
|
||||||
).await {
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
tracing::warn!("failed to persist LLM call: {}", e);
|
tracing::warn!("failed to persist LLM call: {}", e);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -386,6 +413,9 @@ mod tests {
|
|||||||
assert_eq!(tool_calls[0]["id"], "call_1");
|
assert_eq!(tool_calls[0]["id"], "call_1");
|
||||||
assert_eq!(tool_calls[0]["type"], "function");
|
assert_eq!(tool_calls[0]["type"], "function");
|
||||||
assert_eq!(tool_calls[0]["function"]["name"], "calculator");
|
assert_eq!(tool_calls[0]["function"]["name"], "calculator");
|
||||||
assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}");
|
assert_eq!(
|
||||||
|
tool_calls[0]["function"]["arguments"],
|
||||||
|
"{\"expression\":\"1+1\"}"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
|
use crate::bus::message::ContentBlock;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use crate::bus::message::ContentBlock;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct Message {
|
pub struct Message {
|
||||||
@ -61,7 +61,11 @@ impl Message {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self {
|
pub fn tool(
|
||||||
|
tool_call_id: impl Into<String>,
|
||||||
|
tool_name: impl Into<String>,
|
||||||
|
content: impl Into<String>,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
role: "tool".to_string(),
|
role: "tool".to_string(),
|
||||||
content: vec![ContentBlock::text(content)],
|
content: vec![ContentBlock::text(content)],
|
||||||
|
|||||||
@ -5,11 +5,11 @@ use std::time::Instant;
|
|||||||
use tokio::time;
|
use tokio::time;
|
||||||
|
|
||||||
use crate::config::SchedulerConfig;
|
use crate::config::SchedulerConfig;
|
||||||
use crate::session::session::HandleResult;
|
|
||||||
use crate::session::SessionManager;
|
use crate::session::SessionManager;
|
||||||
|
use crate::session::session::HandleResult;
|
||||||
|
use crate::storage::JobRun;
|
||||||
use crate::storage::ScheduledJob;
|
use crate::storage::ScheduledJob;
|
||||||
use crate::storage::Storage;
|
use crate::storage::Storage;
|
||||||
use crate::storage::JobRun;
|
|
||||||
|
|
||||||
pub use types::Schedule;
|
pub use types::Schedule;
|
||||||
|
|
||||||
@ -89,7 +89,11 @@ impl Scheduler {
|
|||||||
|
|
||||||
let now = now_ms();
|
let now = now_ms();
|
||||||
|
|
||||||
let due = match self.storage.due_scheduled_jobs(now, self.config.max_concurrent).await {
|
let due = match self
|
||||||
|
.storage
|
||||||
|
.due_scheduled_jobs(now, self.config.max_concurrent)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(jobs) => jobs,
|
Ok(jobs) => jobs,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!("scheduler: failed to query due jobs: {}", e);
|
tracing::error!("scheduler: failed to query due jobs: {}", e);
|
||||||
@ -107,7 +111,11 @@ impl Scheduler {
|
|||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
let started_at = now_ms();
|
let started_at = now_ms();
|
||||||
|
|
||||||
if let Err(e) = self.storage.touch_scheduled_job_last_run(&job.id, started_at).await {
|
if let Err(e) = self
|
||||||
|
.storage
|
||||||
|
.touch_scheduled_job_last_run(&job.id, started_at)
|
||||||
|
.await
|
||||||
|
{
|
||||||
tracing::error!(job_id = %job.id, "scheduler: failed to touch last_run_at: {}", e);
|
tracing::error!(job_id = %job.id, "scheduler: failed to touch last_run_at: {}", e);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -135,7 +143,10 @@ impl Scheduler {
|
|||||||
match result {
|
match result {
|
||||||
Ok(HandleResult::AgentResponse(output)) => {
|
Ok(HandleResult::AgentResponse(output)) => {
|
||||||
let output_truncated = if output.len() > 8000 {
|
let output_truncated = if output.len() > 8000 {
|
||||||
format!("{}...[truncated]", &output[..output.ceil_char_boundary(8000)])
|
format!(
|
||||||
|
"{}...[truncated]",
|
||||||
|
&output[..output.ceil_char_boundary(8000)]
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
output.clone()
|
output.clone()
|
||||||
};
|
};
|
||||||
@ -155,7 +166,11 @@ impl Scheduler {
|
|||||||
tracing::error!(job_id = %job.id, "scheduler: failed to record run: {}", e);
|
tracing::error!(job_id = %job.id, "scheduler: failed to record run: {}", e);
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Err(e) = self.storage.set_scheduled_job_last_status(&job.id, "ok", None).await {
|
if let Err(e) = self
|
||||||
|
.storage
|
||||||
|
.set_scheduled_job_last_status(&job.id, "ok", None)
|
||||||
|
.await
|
||||||
|
{
|
||||||
tracing::error!(job_id = %job.id, "scheduler: failed to set last_status: {}", e);
|
tracing::error!(job_id = %job.id, "scheduler: failed to set last_status: {}", e);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -199,9 +214,11 @@ impl Scheduler {
|
|||||||
tracing::error!(job_id = %job.id, "scheduler: failed to record error run: {}", e2);
|
tracing::error!(job_id = %job.id, "scheduler: failed to record error run: {}", e2);
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Err(e2) = self.storage.set_scheduled_job_last_status(
|
if let Err(e2) = self
|
||||||
&job.id, "error", Some(&error_str),
|
.storage
|
||||||
).await {
|
.set_scheduled_job_last_status(&job.id, "error", Some(&error_str))
|
||||||
|
.await
|
||||||
|
{
|
||||||
tracing::error!(job_id = %job.id, "scheduler: failed to set error status: {}", e2);
|
tracing::error!(job_id = %job.id, "scheduler: failed to set error status: {}", e2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -231,17 +248,23 @@ impl Scheduler {
|
|||||||
self.storage.remove_scheduled_job(&job.id).await?;
|
self.storage.remove_scheduled_job(&job.id).await?;
|
||||||
tracing::info!(job_id = %job.id, "scheduler: one-shot job deleted after run");
|
tracing::info!(job_id = %job.id, "scheduler: one-shot job deleted after run");
|
||||||
} else {
|
} else {
|
||||||
self.storage.set_scheduled_job_enabled(&job.id, false).await?;
|
self.storage
|
||||||
|
.set_scheduled_job_enabled(&job.id, false)
|
||||||
|
.await?;
|
||||||
tracing::info!(job_id = %job.id, "scheduler: one-shot job disabled after run");
|
tracing::info!(job_id = %job.id, "scheduler: one-shot job disabled after run");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Schedule::Every { .. } | Schedule::Cron { .. } => {
|
Schedule::Every { .. } | Schedule::Cron { .. } => {
|
||||||
if let Some(next) = next_run_for_schedule(&job.schedule, now) {
|
if let Some(next) = next_run_for_schedule(&job.schedule, now) {
|
||||||
self.storage.set_scheduled_job_next_run(&job.id, next).await?;
|
self.storage
|
||||||
|
.set_scheduled_job_next_run(&job.id, next)
|
||||||
|
.await?;
|
||||||
tracing::info!(job_id = %job.id, next_run_at = %next, "scheduler: job rescheduled");
|
tracing::info!(job_id = %job.id, next_run_at = %next, "scheduler: job rescheduled");
|
||||||
} else {
|
} else {
|
||||||
tracing::error!(job_id = %job.id, "scheduler: could not compute next run -- disabling job");
|
tracing::error!(job_id = %job.id, "scheduler: could not compute next run -- disabling job");
|
||||||
self.storage.set_scheduled_job_enabled(&job.id, false).await?;
|
self.storage
|
||||||
|
.set_scheduled_job_enabled(&job.id, false)
|
||||||
|
.await?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -22,32 +22,20 @@ pub enum SessionCommand {
|
|||||||
dialog_id: String,
|
dialog_id: String,
|
||||||
},
|
},
|
||||||
/// Get the current dialog for a chat
|
/// Get the current dialog for a chat
|
||||||
GetCurrentDialog {
|
GetCurrentDialog { channel: String, chat_id: String },
|
||||||
channel: String,
|
|
||||||
chat_id: String,
|
|
||||||
},
|
|
||||||
/// Rename a dialog
|
/// Rename a dialog
|
||||||
RenameDialog {
|
RenameDialog {
|
||||||
session_id: UnifiedSessionId,
|
session_id: UnifiedSessionId,
|
||||||
title: String,
|
title: String,
|
||||||
},
|
},
|
||||||
/// Archive a dialog
|
/// Archive a dialog
|
||||||
ArchiveDialog {
|
ArchiveDialog { session_id: UnifiedSessionId },
|
||||||
session_id: UnifiedSessionId,
|
|
||||||
},
|
|
||||||
/// Delete a dialog
|
/// Delete a dialog
|
||||||
DeleteDialog {
|
DeleteDialog { session_id: UnifiedSessionId },
|
||||||
session_id: UnifiedSessionId,
|
|
||||||
},
|
|
||||||
/// Clear dialog history
|
/// Clear dialog history
|
||||||
ClearHistory {
|
ClearHistory { session_id: UnifiedSessionId },
|
||||||
session_id: UnifiedSessionId,
|
|
||||||
},
|
|
||||||
/// Get list of available slash commands
|
/// Get list of available slash commands
|
||||||
GetSlashCommands {
|
GetSlashCommands { channel: String, chat_id: String },
|
||||||
channel: String,
|
|
||||||
chat_id: String,
|
|
||||||
},
|
|
||||||
/// Execute a slash command
|
/// Execute a slash command
|
||||||
ExecuteSlashCommand {
|
ExecuteSlashCommand {
|
||||||
command: String,
|
command: String,
|
||||||
@ -60,7 +48,11 @@ pub enum SessionCommand {
|
|||||||
|
|
||||||
impl SessionCommand {
|
impl SessionCommand {
|
||||||
/// Create a CreateDialog command
|
/// Create a CreateDialog command
|
||||||
pub fn create_dialog(channel: impl Into<String>, chat_id: impl Into<String>, title: Option<String>) -> Self {
|
pub fn create_dialog(
|
||||||
|
channel: impl Into<String>,
|
||||||
|
chat_id: impl Into<String>,
|
||||||
|
title: Option<String>,
|
||||||
|
) -> Self {
|
||||||
Self::CreateDialog {
|
Self::CreateDialog {
|
||||||
channel: channel.into(),
|
channel: channel.into(),
|
||||||
chat_id: chat_id.into(),
|
chat_id: chat_id.into(),
|
||||||
@ -69,7 +61,11 @@ impl SessionCommand {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Create a ListDialogs command
|
/// Create a ListDialogs command
|
||||||
pub fn list_dialogs(channel: impl Into<String>, chat_id: impl Into<String>, include_archived: bool) -> Self {
|
pub fn list_dialogs(
|
||||||
|
channel: impl Into<String>,
|
||||||
|
chat_id: impl Into<String>,
|
||||||
|
include_archived: bool,
|
||||||
|
) -> Self {
|
||||||
Self::ListDialogs {
|
Self::ListDialogs {
|
||||||
channel: channel.into(),
|
channel: channel.into(),
|
||||||
chat_id: chat_id.into(),
|
chat_id: chat_id.into(),
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
use super::session_id::UnifiedSessionId;
|
|
||||||
use super::session::SlashCommand;
|
use super::session::SlashCommand;
|
||||||
|
use super::session_id::UnifiedSessionId;
|
||||||
|
|
||||||
/// Dialog information returned by SessionManager
|
/// Dialog information returned by SessionManager
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -30,30 +30,20 @@ pub enum SessionEvent {
|
|||||||
session_id: Option<UnifiedSessionId>,
|
session_id: Option<UnifiedSessionId>,
|
||||||
},
|
},
|
||||||
/// Dialog switched successfully
|
/// Dialog switched successfully
|
||||||
DialogSwitched {
|
DialogSwitched { session_id: UnifiedSessionId },
|
||||||
session_id: UnifiedSessionId,
|
|
||||||
},
|
|
||||||
/// Dialog renamed
|
/// Dialog renamed
|
||||||
DialogRenamed {
|
DialogRenamed {
|
||||||
session_id: UnifiedSessionId,
|
session_id: UnifiedSessionId,
|
||||||
title: String,
|
title: String,
|
||||||
},
|
},
|
||||||
/// Dialog archived
|
/// Dialog archived
|
||||||
DialogArchived {
|
DialogArchived { session_id: UnifiedSessionId },
|
||||||
session_id: UnifiedSessionId,
|
|
||||||
},
|
|
||||||
/// Dialog deleted
|
/// Dialog deleted
|
||||||
DialogDeleted {
|
DialogDeleted { session_id: UnifiedSessionId },
|
||||||
session_id: UnifiedSessionId,
|
|
||||||
},
|
|
||||||
/// Dialog history cleared
|
/// Dialog history cleared
|
||||||
HistoryCleared {
|
HistoryCleared { session_id: UnifiedSessionId },
|
||||||
session_id: UnifiedSessionId,
|
|
||||||
},
|
|
||||||
/// List of available slash commands
|
/// List of available slash commands
|
||||||
SlashCommandsList {
|
SlashCommandsList { commands: Vec<SlashCommand> },
|
||||||
commands: Vec<SlashCommand>,
|
|
||||||
},
|
|
||||||
/// Slash command executed successfully
|
/// Slash command executed successfully
|
||||||
SlashCommandExecuted {
|
SlashCommandExecuted {
|
||||||
new_session_id: Option<UnifiedSessionId>,
|
new_session_id: Option<UnifiedSessionId>,
|
||||||
@ -70,8 +60,5 @@ pub enum SessionEvent {
|
|||||||
message_count: usize,
|
message_count: usize,
|
||||||
},
|
},
|
||||||
/// Error occurred
|
/// Error occurred
|
||||||
Error {
|
Error { code: String, message: String },
|
||||||
code: String,
|
|
||||||
message: String,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
pub mod error;
|
|
||||||
pub mod commands;
|
pub mod commands;
|
||||||
|
pub mod error;
|
||||||
pub mod events;
|
pub mod events;
|
||||||
pub mod session;
|
pub mod session;
|
||||||
pub mod session_id;
|
pub mod session_id;
|
||||||
|
|
||||||
pub use error::SessionError;
|
|
||||||
pub use commands::SessionCommand;
|
pub use commands::SessionCommand;
|
||||||
pub use events::{SessionEvent, DialogInfo};
|
pub use error::SessionError;
|
||||||
pub use session::{Session, SessionManager, SlashCommand, SLASH_COMMANDS};
|
pub use events::{DialogInfo, SessionEvent};
|
||||||
|
pub use session::{SLASH_COMMANDS, Session, SessionManager, SlashCommand};
|
||||||
pub use session_id::UnifiedSessionId;
|
pub use session_id::UnifiedSessionId;
|
||||||
|
|||||||
@ -8,7 +8,6 @@
|
|||||||
///
|
///
|
||||||
/// For simple cases where only one dialog exists per chat:
|
/// For simple cases where only one dialog exists per chat:
|
||||||
/// - `dialog_id` defaults to `"default"`
|
/// - `dialog_id` defaults to `"default"`
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
pub const DEFAULT_DIALOG_ID: &str = "default";
|
pub const DEFAULT_DIALOG_ID: &str = "default";
|
||||||
@ -22,7 +21,11 @@ pub struct UnifiedSessionId {
|
|||||||
|
|
||||||
impl UnifiedSessionId {
|
impl UnifiedSessionId {
|
||||||
/// Create a new UnifiedSessionId
|
/// Create a new UnifiedSessionId
|
||||||
pub fn new(channel: impl Into<String>, chat_id: impl Into<String>, dialog_id: impl Into<String>) -> Self {
|
pub fn new(
|
||||||
|
channel: impl Into<String>,
|
||||||
|
chat_id: impl Into<String>,
|
||||||
|
dialog_id: impl Into<String>,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
channel: channel.into(),
|
channel: channel.into(),
|
||||||
chat_id: chat_id.into(),
|
chat_id: chat_id.into(),
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
use super::embedded::{EmbeddedSkill, EMBEDDED_SKILLS};
|
use super::embedded::{EMBEDDED_SKILLS, EmbeddedSkill};
|
||||||
|
|
||||||
pub fn install_builtin_skills(target_dir: &Path) {
|
pub fn install_builtin_skills(target_dir: &Path) {
|
||||||
for skill in EMBEDDED_SKILLS {
|
for skill in EMBEDDED_SKILLS {
|
||||||
@ -22,8 +22,7 @@ pub fn install_builtin_skills(target_dir: &Path) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn install_one(skill: &EmbeddedSkill, target_dir: &Path) -> Result<(), String> {
|
fn install_one(skill: &EmbeddedSkill, target_dir: &Path) -> Result<(), String> {
|
||||||
let decompressed = zstd::decode_all(skill.data)
|
let decompressed = zstd::decode_all(skill.data).map_err(|e| format!("zstd decode: {}", e))?;
|
||||||
.map_err(|e| format!("zstd decode: {}", e))?;
|
|
||||||
|
|
||||||
let mut archive = tar::Archive::new(decompressed.as_slice());
|
let mut archive = tar::Archive::new(decompressed.as_slice());
|
||||||
archive
|
archive
|
||||||
|
|||||||
@ -120,7 +120,11 @@ impl SkillsLoader {
|
|||||||
let count = loaded.len();
|
let count = loaded.len();
|
||||||
let mut replaced = 0usize;
|
let mut replaced = 0usize;
|
||||||
for skill in loaded {
|
for skill in loaded {
|
||||||
if let Some(existing) = state.loaded_skills.iter_mut().find(|s| s.name == skill.name) {
|
if let Some(existing) = state
|
||||||
|
.loaded_skills
|
||||||
|
.iter_mut()
|
||||||
|
.find(|s| s.name == skill.name)
|
||||||
|
{
|
||||||
*existing = skill;
|
*existing = skill;
|
||||||
replaced += 1;
|
replaced += 1;
|
||||||
} else {
|
} else {
|
||||||
@ -138,12 +142,17 @@ impl SkillsLoader {
|
|||||||
|
|
||||||
// Load from workspace skills dir (highest priority) — replace same-name skills
|
// Load from workspace skills dir (highest priority) — replace same-name skills
|
||||||
if let Some(ref ws_dir) = self.workspace_skills_dir
|
if let Some(ref ws_dir) = self.workspace_skills_dir
|
||||||
&& ws_dir.exists() {
|
&& ws_dir.exists()
|
||||||
|
{
|
||||||
let loaded = self.load_skills_from_dir(ws_dir);
|
let loaded = self.load_skills_from_dir(ws_dir);
|
||||||
let count = loaded.len();
|
let count = loaded.len();
|
||||||
let mut replaced = 0usize;
|
let mut replaced = 0usize;
|
||||||
for skill in loaded {
|
for skill in loaded {
|
||||||
if let Some(existing) = state.loaded_skills.iter_mut().find(|s| s.name == skill.name) {
|
if let Some(existing) = state
|
||||||
|
.loaded_skills
|
||||||
|
.iter_mut()
|
||||||
|
.find(|s| s.name == skill.name)
|
||||||
|
{
|
||||||
*existing = skill;
|
*existing = skill;
|
||||||
replaced += 1;
|
replaced += 1;
|
||||||
} else {
|
} else {
|
||||||
@ -164,7 +173,11 @@ impl SkillsLoader {
|
|||||||
if state.loaded_skills.is_empty() {
|
if state.loaded_skills.is_empty() {
|
||||||
tracing::debug!("No skills found in any skills directory");
|
tracing::debug!("No skills found in any skills directory");
|
||||||
} else {
|
} else {
|
||||||
tracing::info!(count = state.loaded_skills.len(), "Loaded {} skills total", state.loaded_skills.len());
|
tracing::info!(
|
||||||
|
count = state.loaded_skills.len(),
|
||||||
|
"Loaded {} skills total",
|
||||||
|
state.loaded_skills.len()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -215,7 +228,8 @@ impl SkillsLoader {
|
|||||||
let mut max_mtime = None;
|
let mut max_mtime = None;
|
||||||
|
|
||||||
if let Ok(metadata) = std::fs::metadata(dir)
|
if let Ok(metadata) = std::fs::metadata(dir)
|
||||||
&& let Ok(mtime) = metadata.modified() {
|
&& let Ok(mtime) = metadata.modified()
|
||||||
|
{
|
||||||
max_mtime = Some(mtime);
|
max_mtime = Some(mtime);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -224,7 +238,8 @@ impl SkillsLoader {
|
|||||||
let path = entry.path();
|
let path = entry.path();
|
||||||
if let Ok(metadata) = std::fs::metadata(&path)
|
if let Ok(metadata) = std::fs::metadata(&path)
|
||||||
&& let Ok(mtime) = metadata.modified()
|
&& let Ok(mtime) = metadata.modified()
|
||||||
&& max_mtime.is_none_or(|current| mtime > current) {
|
&& max_mtime.is_none_or(|current| mtime > current)
|
||||||
|
{
|
||||||
max_mtime = Some(mtime);
|
max_mtime = Some(mtime);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -244,7 +259,12 @@ impl SkillsLoader {
|
|||||||
pub fn get_always_skills(&self) -> Vec<Skill> {
|
pub fn get_always_skills(&self) -> Vec<Skill> {
|
||||||
self.reload_if_changed();
|
self.reload_if_changed();
|
||||||
let state = self.state.lock().unwrap();
|
let state = self.state.lock().unwrap();
|
||||||
state.loaded_skills.iter().filter(|s| s.always).cloned().collect()
|
state
|
||||||
|
.loaded_skills
|
||||||
|
.iter()
|
||||||
|
.filter(|s| s.always)
|
||||||
|
.cloned()
|
||||||
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get a specific skill by name (checks for changes first)
|
/// Get a specific skill by name (checks for changes first)
|
||||||
@ -258,7 +278,8 @@ impl SkillsLoader {
|
|||||||
pub fn list_skills(&self) -> Vec<(String, String)> {
|
pub fn list_skills(&self) -> Vec<(String, String)> {
|
||||||
self.reload_if_changed();
|
self.reload_if_changed();
|
||||||
let state = self.state.lock().unwrap();
|
let state = self.state.lock().unwrap();
|
||||||
state.loaded_skills
|
state
|
||||||
|
.loaded_skills
|
||||||
.iter()
|
.iter()
|
||||||
.map(|s| (s.name.clone(), s.description.clone()))
|
.map(|s| (s.name.clone(), s.description.clone()))
|
||||||
.collect()
|
.collect()
|
||||||
@ -279,15 +300,21 @@ impl SkillsLoader {
|
|||||||
prompt.push_str("### 目录说明\n\n");
|
prompt.push_str("### 目录说明\n\n");
|
||||||
prompt.push_str("- `~/.agents/skills/` — 外部共享 skill 目录(第三方、系统级 skill)\n");
|
prompt.push_str("- `~/.agents/skills/` — 外部共享 skill 目录(第三方、系统级 skill)\n");
|
||||||
prompt.push_str("- `~/.picobot/skills/` — 安装 skill 的默认目录\n");
|
prompt.push_str("- `~/.picobot/skills/` — 安装 skill 的默认目录\n");
|
||||||
prompt.push_str("- `{workspace}/skills/` — 工作目录下的 skill,picobot 自行创建的 skill 存放于此\n\n");
|
prompt.push_str(
|
||||||
prompt.push_str("安装或创建 skill 时请按上述目录规范存放,创建skill时不要和已有skill同名。\n\n");
|
"- `{workspace}/skills/` — 工作目录下的 skill,picobot 自行创建的 skill 存放于此\n\n",
|
||||||
|
);
|
||||||
|
prompt.push_str(
|
||||||
|
"安装或创建 skill 时请按上述目录规范存放,创建skill时不要和已有skill同名。\n\n",
|
||||||
|
);
|
||||||
|
|
||||||
// Always skills summary
|
// Always skills summary
|
||||||
let always_skills: Vec<_> = state.loaded_skills.iter().filter(|s| s.always).collect();
|
let always_skills: Vec<_> = state.loaded_skills.iter().filter(|s| s.always).collect();
|
||||||
if !always_skills.is_empty() {
|
if !always_skills.is_empty() {
|
||||||
prompt.push_str("### 常用技能\n\n");
|
prompt.push_str("### 常用技能\n\n");
|
||||||
for skill in &always_skills {
|
for skill in &always_skills {
|
||||||
let path_str = skill.path.as_ref()
|
let path_str = skill
|
||||||
|
.path
|
||||||
|
.as_ref()
|
||||||
.map(|p| p.to_string_lossy().to_string())
|
.map(|p| p.to_string_lossy().to_string())
|
||||||
.unwrap_or_else(|| "—".to_string());
|
.unwrap_or_else(|| "—".to_string());
|
||||||
prompt.push_str(&format!(
|
prompt.push_str(&format!(
|
||||||
@ -300,8 +327,12 @@ impl SkillsLoader {
|
|||||||
|
|
||||||
// Usage instructions
|
// Usage instructions
|
||||||
prompt.push_str("### 使用方法\n\n");
|
prompt.push_str("### 使用方法\n\n");
|
||||||
prompt.push_str("- 使用 `get_skill` 工具 action=\"list\" 列出所有可用 skill 及其名称、简介、路径\n");
|
prompt.push_str(
|
||||||
prompt.push_str("- 使用 `get_skill` 工具 action=\"get\" 并提供 `skill_name` 获取指定 skill 完整内容\n");
|
"- 使用 `get_skill` 工具 action=\"list\" 列出所有可用 skill 及其名称、简介、路径\n",
|
||||||
|
);
|
||||||
|
prompt.push_str(
|
||||||
|
"- 使用 `get_skill` 工具 action=\"get\" 并提供 `skill_name` 获取指定 skill 完整内容\n",
|
||||||
|
);
|
||||||
prompt.push_str("- 当用户询问已安装的 skill 相关信息时,需重新调用 get_skill 工具查询最新内容,避免 skill 已变更导致信息过时\n");
|
prompt.push_str("- 当用户询问已安装的 skill 相关信息时,需重新调用 get_skill 工具查询最新内容,避免 skill 已变更导致信息过时\n");
|
||||||
|
|
||||||
// Always skills full content
|
// Always skills full content
|
||||||
@ -338,8 +369,7 @@ impl SkillsLoader {
|
|||||||
}
|
}
|
||||||
|
|
||||||
match std::fs::read_to_string(&skill_file) {
|
match std::fs::read_to_string(&skill_file) {
|
||||||
Ok(content) => {
|
Ok(content) => match self.parse_skill(&path, &content) {
|
||||||
match self.parse_skill(&path, &content) {
|
|
||||||
Some(skill) => {
|
Some(skill) => {
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
skill = %skill.name,
|
skill = %skill.name,
|
||||||
@ -355,8 +385,7 @@ impl SkillsLoader {
|
|||||||
"Failed to parse skill"
|
"Failed to parse skill"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
}
|
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::warn!(
|
tracing::warn!(
|
||||||
path = %skill_file.display(),
|
path = %skill_file.display(),
|
||||||
@ -447,7 +476,6 @@ impl Default for SkillsLoader {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// Extract first non-empty, non-heading line as description
|
/// Extract first non-empty, non-heading line as description
|
||||||
fn extract_description(content: &str) -> String {
|
fn extract_description(content: &str) -> String {
|
||||||
content
|
content
|
||||||
|
|||||||
@ -241,9 +241,8 @@ impl super::Storage {
|
|||||||
let cutoff = chrono::Utc::now() - chrono::Duration::days(retention_days as i64);
|
let cutoff = chrono::Utc::now() - chrono::Duration::days(retention_days as i64);
|
||||||
let cutoff_str = cutoff.to_rfc3339();
|
let cutoff_str = cutoff.to_rfc3339();
|
||||||
|
|
||||||
let result = sqlx::query(
|
let result =
|
||||||
"DELETE FROM memories WHERE category = 'timeline' AND created_at < ?",
|
sqlx::query("DELETE FROM memories WHERE category = 'timeline' AND created_at < ?")
|
||||||
)
|
|
||||||
.bind(&cutoff_str)
|
.bind(&cutoff_str)
|
||||||
.execute(self.pool())
|
.execute(self.pool())
|
||||||
.await?;
|
.await?;
|
||||||
@ -276,9 +275,7 @@ impl super::Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_memory_rows(
|
fn parse_memory_rows(rows: &[sqlx::sqlite::SqliteRow]) -> Result<Vec<MemoryEntry>, StorageError> {
|
||||||
rows: &[sqlx::sqlite::SqliteRow],
|
|
||||||
) -> Result<Vec<MemoryEntry>, StorageError> {
|
|
||||||
rows.iter()
|
rows.iter()
|
||||||
.map(|row| {
|
.map(|row| {
|
||||||
Ok(MemoryEntry {
|
Ok(MemoryEntry {
|
||||||
|
|||||||
@ -165,7 +165,11 @@ impl crate::storage::Storage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Update next_run_at and last_run_at for a job.
|
/// Update next_run_at and last_run_at for a job.
|
||||||
pub async fn set_scheduled_job_next_run(&self, id: &str, next_run_at: i64) -> anyhow::Result<()> {
|
pub async fn set_scheduled_job_next_run(
|
||||||
|
&self,
|
||||||
|
id: &str,
|
||||||
|
next_run_at: i64,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
let now = now_ms();
|
let now = now_ms();
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
"UPDATE scheduled_jobs SET next_run_at = ?, last_run_at = ?, updated_at = ? WHERE id = ?",
|
"UPDATE scheduled_jobs SET next_run_at = ?, last_run_at = ?, updated_at = ? WHERE id = ?",
|
||||||
@ -331,7 +335,9 @@ mod tests {
|
|||||||
async fn setup_storage() -> Storage {
|
async fn setup_storage() -> Storage {
|
||||||
let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
|
let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
|
||||||
let storage = Storage { pool };
|
let storage = Storage { pool };
|
||||||
Storage::init_scheduler_schema(storage.pool()).await.unwrap();
|
Storage::init_scheduler_schema(storage.pool())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
storage
|
storage
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -450,7 +456,10 @@ mod tests {
|
|||||||
updated_at: t,
|
updated_at: t,
|
||||||
};
|
};
|
||||||
storage.add_scheduled_job(&job).await.unwrap();
|
storage.add_scheduled_job(&job).await.unwrap();
|
||||||
storage.set_scheduled_job_enabled("job-toggle", false).await.unwrap();
|
storage
|
||||||
|
.set_scheduled_job_enabled("job-toggle", false)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
let got = storage.get_scheduled_job("job-toggle").await.unwrap();
|
let got = storage.get_scheduled_job("job-toggle").await.unwrap();
|
||||||
assert!(!got.enabled);
|
assert!(!got.enabled);
|
||||||
}
|
}
|
||||||
@ -461,31 +470,55 @@ mod tests {
|
|||||||
let t = now();
|
let t = now();
|
||||||
let jobs = vec![
|
let jobs = vec![
|
||||||
ScheduledJob {
|
ScheduledJob {
|
||||||
id: "due".into(), name: "due".into(),
|
id: "due".into(),
|
||||||
schedule: Schedule::At { at: t }, prompt: "1".into(),
|
name: "due".into(),
|
||||||
channel: "cli_chat".into(), chat_id: "c".into(),
|
schedule: Schedule::At { at: t },
|
||||||
model: None, enabled: true, delete_after_run: false,
|
prompt: "1".into(),
|
||||||
next_run_at: t - 1000, last_run_at: None,
|
channel: "cli_chat".into(),
|
||||||
last_status: None, last_error: None,
|
chat_id: "c".into(),
|
||||||
created_at: t, updated_at: t,
|
model: None,
|
||||||
|
enabled: true,
|
||||||
|
delete_after_run: false,
|
||||||
|
next_run_at: t - 1000,
|
||||||
|
last_run_at: None,
|
||||||
|
last_status: None,
|
||||||
|
last_error: None,
|
||||||
|
created_at: t,
|
||||||
|
updated_at: t,
|
||||||
},
|
},
|
||||||
ScheduledJob {
|
ScheduledJob {
|
||||||
id: "future".into(), name: "future".into(),
|
id: "future".into(),
|
||||||
schedule: Schedule::At { at: t + 99999999 }, prompt: "2".into(),
|
name: "future".into(),
|
||||||
channel: "cli_chat".into(), chat_id: "c".into(),
|
schedule: Schedule::At { at: t + 99999999 },
|
||||||
model: None, enabled: true, delete_after_run: false,
|
prompt: "2".into(),
|
||||||
next_run_at: t + 99999999, last_run_at: None,
|
channel: "cli_chat".into(),
|
||||||
last_status: None, last_error: None,
|
chat_id: "c".into(),
|
||||||
created_at: t, updated_at: t,
|
model: None,
|
||||||
|
enabled: true,
|
||||||
|
delete_after_run: false,
|
||||||
|
next_run_at: t + 99999999,
|
||||||
|
last_run_at: None,
|
||||||
|
last_status: None,
|
||||||
|
last_error: None,
|
||||||
|
created_at: t,
|
||||||
|
updated_at: t,
|
||||||
},
|
},
|
||||||
ScheduledJob {
|
ScheduledJob {
|
||||||
id: "disabled-due".into(), name: "disabled due".into(),
|
id: "disabled-due".into(),
|
||||||
schedule: Schedule::At { at: t }, prompt: "3".into(),
|
name: "disabled due".into(),
|
||||||
channel: "cli_chat".into(), chat_id: "c".into(),
|
schedule: Schedule::At { at: t },
|
||||||
model: None, enabled: false, delete_after_run: false,
|
prompt: "3".into(),
|
||||||
next_run_at: t - 1000, last_run_at: None,
|
channel: "cli_chat".into(),
|
||||||
last_status: None, last_error: None,
|
chat_id: "c".into(),
|
||||||
created_at: t, updated_at: t,
|
model: None,
|
||||||
|
enabled: false,
|
||||||
|
delete_after_run: false,
|
||||||
|
next_run_at: t - 1000,
|
||||||
|
last_run_at: None,
|
||||||
|
last_status: None,
|
||||||
|
last_error: None,
|
||||||
|
created_at: t,
|
||||||
|
updated_at: t,
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
for j in &jobs {
|
for j in &jobs {
|
||||||
@ -501,24 +534,39 @@ mod tests {
|
|||||||
let storage = setup_storage().await;
|
let storage = setup_storage().await;
|
||||||
let t = now();
|
let t = now();
|
||||||
let job = ScheduledJob {
|
let job = ScheduledJob {
|
||||||
id: "job-run".into(), name: "run test".into(),
|
id: "job-run".into(),
|
||||||
|
name: "run test".into(),
|
||||||
schedule: Schedule::Every { every_ms: 1000 },
|
schedule: Schedule::Every { every_ms: 1000 },
|
||||||
prompt: "hi".into(), channel: "cli_chat".into(), chat_id: "c".into(),
|
prompt: "hi".into(),
|
||||||
model: None, enabled: true, delete_after_run: false,
|
channel: "cli_chat".into(),
|
||||||
next_run_at: t, last_run_at: None,
|
chat_id: "c".into(),
|
||||||
last_status: None, last_error: None,
|
model: None,
|
||||||
created_at: t, updated_at: t,
|
enabled: true,
|
||||||
|
delete_after_run: false,
|
||||||
|
next_run_at: t,
|
||||||
|
last_run_at: None,
|
||||||
|
last_status: None,
|
||||||
|
last_error: None,
|
||||||
|
created_at: t,
|
||||||
|
updated_at: t,
|
||||||
};
|
};
|
||||||
storage.add_scheduled_job(&job).await.unwrap();
|
storage.add_scheduled_job(&job).await.unwrap();
|
||||||
|
|
||||||
let run = super::JobRun {
|
let run = super::JobRun {
|
||||||
id: 0, job_id: "job-run".into(),
|
id: 0,
|
||||||
started_at: t, finished_at: t + 500,
|
job_id: "job-run".into(),
|
||||||
status: "ok".into(), output: Some("hello".into()),
|
started_at: t,
|
||||||
error: None, duration_ms: 500,
|
finished_at: t + 500,
|
||||||
|
status: "ok".into(),
|
||||||
|
output: Some("hello".into()),
|
||||||
|
error: None,
|
||||||
|
duration_ms: 500,
|
||||||
};
|
};
|
||||||
storage.record_scheduled_job_run(&run).await.unwrap();
|
storage.record_scheduled_job_run(&run).await.unwrap();
|
||||||
let runs = storage.list_scheduled_job_runs("job-run", 10).await.unwrap();
|
let runs = storage
|
||||||
|
.list_scheduled_job_runs("job-run", 10)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
assert_eq!(runs.len(), 1);
|
assert_eq!(runs.len(), 1);
|
||||||
assert_eq!(runs[0].status, "ok");
|
assert_eq!(runs[0].status, "ok");
|
||||||
assert_eq!(runs[0].output.as_deref(), Some("hello"));
|
assert_eq!(runs[0].output.as_deref(), Some("hello"));
|
||||||
@ -529,22 +577,34 @@ mod tests {
|
|||||||
let storage = setup_storage().await;
|
let storage = setup_storage().await;
|
||||||
let t = now();
|
let t = now();
|
||||||
let job = ScheduledJob {
|
let job = ScheduledJob {
|
||||||
id: "job-update".into(), name: "old name".into(),
|
id: "job-update".into(),
|
||||||
|
name: "old name".into(),
|
||||||
schedule: Schedule::Every { every_ms: 1000 },
|
schedule: Schedule::Every { every_ms: 1000 },
|
||||||
prompt: "old prompt".into(), channel: "feishu".into(),
|
prompt: "old prompt".into(),
|
||||||
chat_id: "oc_1".into(), model: None,
|
channel: "feishu".into(),
|
||||||
enabled: true, delete_after_run: false,
|
chat_id: "oc_1".into(),
|
||||||
next_run_at: t, last_run_at: None,
|
model: None,
|
||||||
last_status: None, last_error: None,
|
enabled: true,
|
||||||
created_at: t, updated_at: t,
|
delete_after_run: false,
|
||||||
|
next_run_at: t,
|
||||||
|
last_run_at: None,
|
||||||
|
last_status: None,
|
||||||
|
last_error: None,
|
||||||
|
created_at: t,
|
||||||
|
updated_at: t,
|
||||||
};
|
};
|
||||||
storage.add_scheduled_job(&job).await.unwrap();
|
storage.add_scheduled_job(&job).await.unwrap();
|
||||||
storage.update_scheduled_job(
|
storage
|
||||||
|
.update_scheduled_job(
|
||||||
"job-update",
|
"job-update",
|
||||||
Some("new prompt".into()),
|
Some("new prompt".into()),
|
||||||
Some(Schedule::Every { every_ms: 60000 }),
|
Some(Schedule::Every { every_ms: 60000 }),
|
||||||
None, None, None,
|
None,
|
||||||
).await.unwrap();
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
let got = storage.get_scheduled_job("job-update").await.unwrap();
|
let got = storage.get_scheduled_job("job-update").await.unwrap();
|
||||||
assert_eq!(got.prompt, "new prompt");
|
assert_eq!(got.prompt, "new prompt");
|
||||||
}
|
}
|
||||||
|
|||||||
@ -167,10 +167,7 @@ impl Tool for BashTool {
|
|||||||
Err(_) => Ok(ToolResult {
|
Err(_) => Ok(ToolResult {
|
||||||
success: false,
|
success: false,
|
||||||
output: String::new(),
|
output: String::new(),
|
||||||
error: Some(format!(
|
error: Some(format!("Command timed out after {} seconds", timeout_secs)),
|
||||||
"Command timed out after {} seconds",
|
|
||||||
timeout_secs
|
|
||||||
)),
|
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -249,10 +246,7 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_pwd_command() {
|
async fn test_pwd_command() {
|
||||||
let tool = BashTool::new();
|
let tool = BashTool::new();
|
||||||
let result = tool
|
let result = tool.execute(json!({ "command": "pwd" })).await.unwrap();
|
||||||
.execute(json!({ "command": "pwd" }))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert!(result.success);
|
assert!(result.success);
|
||||||
}
|
}
|
||||||
@ -260,7 +254,10 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_ls_command() {
|
async fn test_ls_command() {
|
||||||
let tool = BashTool::new();
|
let tool = BashTool::new();
|
||||||
let result = tool.execute(json!({ "command": "ls -la /tmp" })).await.unwrap();
|
let result = tool
|
||||||
|
.execute(json!({ "command": "ls -la /tmp" }))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert!(result.success);
|
assert!(result.success);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,7 +5,7 @@ use std::time::Duration;
|
|||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use base64::Engine;
|
use base64::Engine;
|
||||||
use fantoccini::actions::{InputSource, MouseActions, PointerAction, MOUSE_BUTTON_LEFT};
|
use fantoccini::actions::{InputSource, MOUSE_BUTTON_LEFT, MouseActions, PointerAction};
|
||||||
use fantoccini::key::Key;
|
use fantoccini::key::Key;
|
||||||
use fantoccini::{Client, ClientBuilder, Locator};
|
use fantoccini::{Client, ClientBuilder, Locator};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
@ -63,7 +63,9 @@ impl BrowserTool {
|
|||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub enum BrowserAction {
|
pub enum BrowserAction {
|
||||||
Open { url: String },
|
Open {
|
||||||
|
url: String,
|
||||||
|
},
|
||||||
Snapshot {
|
Snapshot {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
interactive_only: bool,
|
interactive_only: bool,
|
||||||
@ -72,10 +74,20 @@ pub enum BrowserAction {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
depth: Option<i64>,
|
depth: Option<i64>,
|
||||||
},
|
},
|
||||||
Click { selector: String },
|
Click {
|
||||||
Fill { selector: String, value: String },
|
selector: String,
|
||||||
Type { selector: Option<String>, text: String },
|
},
|
||||||
GetText { selector: String },
|
Fill {
|
||||||
|
selector: String,
|
||||||
|
value: String,
|
||||||
|
},
|
||||||
|
Type {
|
||||||
|
selector: Option<String>,
|
||||||
|
text: String,
|
||||||
|
},
|
||||||
|
GetText {
|
||||||
|
selector: String,
|
||||||
|
},
|
||||||
GetTitle,
|
GetTitle,
|
||||||
GetUrl,
|
GetUrl,
|
||||||
Screenshot {
|
Screenshot {
|
||||||
@ -84,7 +96,9 @@ pub enum BrowserAction {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
return_base64: bool,
|
return_base64: bool,
|
||||||
},
|
},
|
||||||
Focus { selector: String },
|
Focus {
|
||||||
|
selector: String,
|
||||||
|
},
|
||||||
Wait {
|
Wait {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
selector: Option<String>,
|
selector: Option<String>,
|
||||||
@ -93,9 +107,16 @@ pub enum BrowserAction {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
text: Option<String>,
|
text: Option<String>,
|
||||||
},
|
},
|
||||||
Press { key: String },
|
Press {
|
||||||
Hover { selector: String },
|
key: String,
|
||||||
ClickAt { x: u32, y: u32 },
|
},
|
||||||
|
Hover {
|
||||||
|
selector: String,
|
||||||
|
},
|
||||||
|
ClickAt {
|
||||||
|
x: u32,
|
||||||
|
y: u32,
|
||||||
|
},
|
||||||
Scroll {
|
Scroll {
|
||||||
direction: String,
|
direction: String,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
@ -120,13 +141,8 @@ fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result<Browse
|
|||||||
.get("interactive_only")
|
.get("interactive_only")
|
||||||
.and_then(Value::as_bool)
|
.and_then(Value::as_bool)
|
||||||
.unwrap_or(true),
|
.unwrap_or(true),
|
||||||
compact: args
|
compact: args.get("compact").and_then(Value::as_bool).unwrap_or(true),
|
||||||
.get("compact")
|
depth: args.get("depth").and_then(|v| v.as_i64()),
|
||||||
.and_then(Value::as_bool)
|
|
||||||
.unwrap_or(true),
|
|
||||||
depth: args
|
|
||||||
.get("depth")
|
|
||||||
.and_then(|v| v.as_i64()),
|
|
||||||
}),
|
}),
|
||||||
"click" => {
|
"click" => {
|
||||||
let selector = args
|
let selector = args
|
||||||
@ -198,10 +214,7 @@ fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result<Browse
|
|||||||
.and_then(|v| v.as_str())
|
.and_then(|v| v.as_str())
|
||||||
.map(String::from),
|
.map(String::from),
|
||||||
ms: args.get("ms").and_then(|v| v.as_u64()),
|
ms: args.get("ms").and_then(|v| v.as_u64()),
|
||||||
text: args
|
text: args.get("text").and_then(|v| v.as_str()).map(String::from),
|
||||||
.get("text")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.map(String::from),
|
|
||||||
}),
|
}),
|
||||||
"press" => {
|
"press" => {
|
||||||
let key = args
|
let key = args
|
||||||
@ -239,11 +252,13 @@ fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result<Browse
|
|||||||
let x = args
|
let x = args
|
||||||
.get("x")
|
.get("x")
|
||||||
.and_then(|v| v.as_u64())
|
.and_then(|v| v.as_u64())
|
||||||
.ok_or_else(|| anyhow::anyhow!("Missing 'x' for click_at"))? as u32;
|
.ok_or_else(|| anyhow::anyhow!("Missing 'x' for click_at"))?
|
||||||
|
as u32;
|
||||||
let y = args
|
let y = args
|
||||||
.get("y")
|
.get("y")
|
||||||
.and_then(|v| v.as_u64())
|
.and_then(|v| v.as_u64())
|
||||||
.ok_or_else(|| anyhow::anyhow!("Missing 'y' for click_at"))? as u32;
|
.ok_or_else(|| anyhow::anyhow!("Missing 'y' for click_at"))?
|
||||||
|
as u32;
|
||||||
Ok(BrowserAction::ClickAt { x, y })
|
Ok(BrowserAction::ClickAt { x, y })
|
||||||
}
|
}
|
||||||
other => anyhow::bail!("Unsupported browser action: {}", other),
|
other => anyhow::bail!("Unsupported browser action: {}", other),
|
||||||
@ -488,7 +503,11 @@ impl BrowserState {
|
|||||||
}
|
}
|
||||||
Err(e) => return Err(e.into()),
|
Err(e) => return Err(e.into()),
|
||||||
}
|
}
|
||||||
tracing::debug!(action = "fill", output_len = value.len(), "Browser action completed");
|
tracing::debug!(
|
||||||
|
action = "fill",
|
||||||
|
output_len = value.len(),
|
||||||
|
"Browser action completed"
|
||||||
|
);
|
||||||
Ok(ToolResult {
|
Ok(ToolResult {
|
||||||
success: true,
|
success: true,
|
||||||
output: format!("Filled {} with {}", selector, value),
|
output: format!("Filled {} with {}", selector, value),
|
||||||
@ -573,7 +592,10 @@ impl BrowserState {
|
|||||||
error: None,
|
error: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
BrowserAction::Screenshot { path, return_base64 } => {
|
BrowserAction::Screenshot {
|
||||||
|
path,
|
||||||
|
return_base64,
|
||||||
|
} => {
|
||||||
let client = self.active_client()?;
|
let client = self.active_client()?;
|
||||||
let png = client.screenshot().await?;
|
let png = client.screenshot().await?;
|
||||||
let save_path = path.unwrap_or_else(|| {
|
let save_path = path.unwrap_or_else(|| {
|
||||||
@ -588,14 +610,25 @@ impl BrowserState {
|
|||||||
tokio::fs::write(&save_path, &png).await?;
|
tokio::fs::write(&save_path, &png).await?;
|
||||||
if return_base64 {
|
if return_base64 {
|
||||||
let b64 = base64::engine::general_purpose::STANDARD.encode(&png);
|
let b64 = base64::engine::general_purpose::STANDARD.encode(&png);
|
||||||
tracing::debug!(action = "screenshot", output_len = b64.len(), "Browser action completed");
|
tracing::debug!(
|
||||||
|
action = "screenshot",
|
||||||
|
output_len = b64.len(),
|
||||||
|
"Browser action completed"
|
||||||
|
);
|
||||||
return Ok(ToolResult {
|
return Ok(ToolResult {
|
||||||
success: true,
|
success: true,
|
||||||
output: format!("Screenshot saved to {}. Base64: data:image/png;base64,{}", save_path, b64),
|
output: format!(
|
||||||
|
"Screenshot saved to {}. Base64: data:image/png;base64,{}",
|
||||||
|
save_path, b64
|
||||||
|
),
|
||||||
error: None,
|
error: None,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
tracing::debug!(action = "screenshot", output_len = save_path.len(), "Browser action completed");
|
tracing::debug!(
|
||||||
|
action = "screenshot",
|
||||||
|
output_len = save_path.len(),
|
||||||
|
"Browser action completed"
|
||||||
|
);
|
||||||
Ok(ToolResult {
|
Ok(ToolResult {
|
||||||
success: true,
|
success: true,
|
||||||
output: format!("Screenshot saved to {}", save_path),
|
output: format!("Screenshot saved to {}", save_path),
|
||||||
@ -611,18 +644,18 @@ impl BrowserState {
|
|||||||
vec![serde_json::to_value(el)?],
|
vec![serde_json::to_value(el)?],
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
tracing::debug!(action = "focus", output_len = selector.len(), "Browser action completed");
|
tracing::debug!(
|
||||||
|
action = "focus",
|
||||||
|
output_len = selector.len(),
|
||||||
|
"Browser action completed"
|
||||||
|
);
|
||||||
Ok(ToolResult {
|
Ok(ToolResult {
|
||||||
success: true,
|
success: true,
|
||||||
output: format!("Focused {}", selector),
|
output: format!("Focused {}", selector),
|
||||||
error: None,
|
error: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
BrowserAction::Wait {
|
BrowserAction::Wait { selector, ms, text } => {
|
||||||
selector,
|
|
||||||
ms,
|
|
||||||
text,
|
|
||||||
} => {
|
|
||||||
if let Some(sel) = selector {
|
if let Some(sel) = selector {
|
||||||
let client = self.active_client()?;
|
let client = self.active_client()?;
|
||||||
wait_for_selector(client, &sel).await?;
|
wait_for_selector(client, &sel).await?;
|
||||||
@ -719,9 +752,21 @@ impl BrowserState {
|
|||||||
let id = info.get("id").and_then(|v| v.as_str()).unwrap_or("");
|
let id = info.get("id").and_then(|v| v.as_str()).unwrap_or("");
|
||||||
let el_type = info.get("type").and_then(|v| v.as_str()).unwrap_or("");
|
let el_type = info.get("type").and_then(|v| v.as_str()).unwrap_or("");
|
||||||
let text = info.get("text").and_then(|v| v.as_str()).unwrap_or("");
|
let text = info.get("text").and_then(|v| v.as_str()).unwrap_or("");
|
||||||
let id_str = if id.is_empty() { String::new() } else { format!("#{id}") };
|
let id_str = if id.is_empty() {
|
||||||
let type_str = if el_type.is_empty() { String::new() } else { format!("[type={el_type}]") };
|
String::new()
|
||||||
let text_str = if text.is_empty() { String::new() } else { format!(" ({text})") };
|
} else {
|
||||||
|
format!("#{id}")
|
||||||
|
};
|
||||||
|
let type_str = if el_type.is_empty() {
|
||||||
|
String::new()
|
||||||
|
} else {
|
||||||
|
format!("[type={el_type}]")
|
||||||
|
};
|
||||||
|
let text_str = if text.is_empty() {
|
||||||
|
String::new()
|
||||||
|
} else {
|
||||||
|
format!(" ({text})")
|
||||||
|
};
|
||||||
format!("Clicked at ({x},{y}) on <{tag}{id_str}{type_str}>{text_str}")
|
format!("Clicked at ({x},{y}) on <{tag}{id_str}{type_str}>{text_str}")
|
||||||
}
|
}
|
||||||
None => format!("Clicked at ({}, {})", x, y),
|
None => format!("Clicked at ({}, {})", x, y),
|
||||||
@ -1090,10 +1135,7 @@ fn css_attr_escape(input: &str) -> String {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn xpath_contains_text(text: &str) -> String {
|
fn xpath_contains_text(text: &str) -> String {
|
||||||
format!(
|
format!("//*[contains(normalize-space(.), {})]", xpath_literal(text))
|
||||||
"//*[contains(normalize-space(.), {})]",
|
|
||||||
xpath_literal(text)
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn xpath_literal(input: &str) -> String {
|
fn xpath_literal(input: &str) -> String {
|
||||||
@ -1140,7 +1182,10 @@ fn webdriver_key(key: &str) -> String {
|
|||||||
"pagedown" => Key::PageDown.to_string(),
|
"pagedown" => Key::PageDown.to_string(),
|
||||||
"space" => " ".to_string(),
|
"space" => " ".to_string(),
|
||||||
other => {
|
other => {
|
||||||
tracing::warn!("Unrecognized key '{}', this will have no effect (press only supports single named keys)", other);
|
tracing::warn!(
|
||||||
|
"Unrecognized key '{}', this will have no effect (press only supports single named keys)",
|
||||||
|
other
|
||||||
|
);
|
||||||
other.to_string()
|
other.to_string()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -659,10 +659,7 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_evaluate_missing_expression() {
|
async fn test_evaluate_missing_expression() {
|
||||||
let tool = CalculatorTool::new();
|
let tool = CalculatorTool::new();
|
||||||
let result = tool
|
let result = tool.execute(json!({"function": "evaluate"})).await.unwrap();
|
||||||
.execute(json!({"function": "evaluate"}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(!result.success);
|
assert!(!result.success);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -31,10 +31,7 @@ impl ContentSearchTool {
|
|||||||
for (i, line) in lines.iter().enumerate() {
|
for (i, line) in lines.iter().enumerate() {
|
||||||
if output.len() + line.len() + 1 > MAX_OUTPUT_CHARS {
|
if output.len() + line.len() + 1 > MAX_OUTPUT_CHARS {
|
||||||
let omitted = lines.len() - i;
|
let omitted = lines.len() - i;
|
||||||
output.push_str(&format!(
|
output.push_str(&format!("\n... ({} matches omitted) ...", omitted));
|
||||||
"\n... ({} matches omitted) ...",
|
|
||||||
omitted
|
|
||||||
));
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if !output.is_empty() {
|
if !output.is_empty() {
|
||||||
@ -113,18 +110,40 @@ impl Tool for ContentSearchTool {
|
|||||||
|
|
||||||
let dir = self.resolve_dir(args.get("dir").and_then(|v| v.as_str()));
|
let dir = self.resolve_dir(args.get("dir").and_then(|v| v.as_str()));
|
||||||
let file_pattern = args.get("file_pattern").and_then(|v| v.as_str());
|
let file_pattern = args.get("file_pattern").and_then(|v| v.as_str());
|
||||||
let case_sensitive = args.get("case_sensitive").and_then(|v| v.as_bool()).unwrap_or(false);
|
let case_sensitive = args
|
||||||
let context_lines = args.get("context_lines").and_then(|v| v.as_u64()).unwrap_or(0) as usize;
|
.get("case_sensitive")
|
||||||
let max_results = args.get("max_results").and_then(|v| v.as_u64()).unwrap_or(MAX_RESULTS as u64) as usize;
|
.and_then(|v| v.as_bool())
|
||||||
|
.unwrap_or(false);
|
||||||
|
let context_lines = args
|
||||||
|
.get("context_lines")
|
||||||
|
.and_then(|v| v.as_u64())
|
||||||
|
.unwrap_or(0) as usize;
|
||||||
|
let max_results = args
|
||||||
|
.get("max_results")
|
||||||
|
.and_then(|v| v.as_u64())
|
||||||
|
.unwrap_or(MAX_RESULTS as u64) as usize;
|
||||||
|
|
||||||
let result = self.run_search(pattern, &dir, file_pattern, case_sensitive, context_lines, max_results).await;
|
let result = self
|
||||||
|
.run_search(
|
||||||
|
pattern,
|
||||||
|
&dir,
|
||||||
|
file_pattern,
|
||||||
|
case_sensitive,
|
||||||
|
context_lines,
|
||||||
|
max_results,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
Ok(lines) => {
|
Ok(lines) => {
|
||||||
let count = lines.len();
|
let count = lines.len();
|
||||||
let mut output = self.truncate_output(&lines);
|
let mut output = self.truncate_output(&lines);
|
||||||
output.push_str(&format!("\n\n---\n共 {} 条匹配", count));
|
output.push_str(&format!("\n\n---\n共 {} 条匹配", count));
|
||||||
Ok(ToolResult { success: true, output, error: None })
|
Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output,
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
Err(e) => Ok(ToolResult {
|
Err(e) => Ok(ToolResult {
|
||||||
success: false,
|
success: false,
|
||||||
@ -146,22 +165,52 @@ impl ContentSearchTool {
|
|||||||
max_results: usize,
|
max_results: usize,
|
||||||
) -> anyhow::Result<Vec<String>> {
|
) -> anyhow::Result<Vec<String>> {
|
||||||
if which::which("rg").is_ok() {
|
if which::which("rg").is_ok() {
|
||||||
match self.search_with_rg(pattern, dir, file_pattern, case_sensitive, context_lines, max_results).await {
|
match self
|
||||||
|
.search_with_rg(
|
||||||
|
pattern,
|
||||||
|
dir,
|
||||||
|
file_pattern,
|
||||||
|
case_sensitive,
|
||||||
|
context_lines,
|
||||||
|
max_results,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(lines) => return Ok(lines),
|
Ok(lines) => return Ok(lines),
|
||||||
Err(e) => tracing::warn!("rg failed: {}, falling back", e),
|
Err(e) => tracing::warn!("rg failed: {}, falling back", e),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if which::which("grep").is_ok() {
|
if which::which("grep").is_ok() {
|
||||||
match self.search_with_grep(pattern, dir, file_pattern, case_sensitive, context_lines, max_results).await {
|
match self
|
||||||
|
.search_with_grep(
|
||||||
|
pattern,
|
||||||
|
dir,
|
||||||
|
file_pattern,
|
||||||
|
case_sensitive,
|
||||||
|
context_lines,
|
||||||
|
max_results,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(lines) if !lines.is_empty() => return Ok(lines),
|
Ok(lines) if !lines.is_empty() => return Ok(lines),
|
||||||
Ok(_) => {},
|
Ok(_) => {}
|
||||||
Err(e) => tracing::warn!("grep failed: {}, falling back", e),
|
Err(e) => tracing::warn!("grep failed: {}, falling back", e),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tracing::warn!("No rg/grep available, using built-in content search (much slower). Install ripgrep for better performance.");
|
tracing::warn!(
|
||||||
self.search_with_rust(pattern, dir, file_pattern, case_sensitive, context_lines, max_results).await
|
"No rg/grep available, using built-in content search (much slower). Install ripgrep for better performance."
|
||||||
|
);
|
||||||
|
self.search_with_rust(
|
||||||
|
pattern,
|
||||||
|
dir,
|
||||||
|
file_pattern,
|
||||||
|
case_sensitive,
|
||||||
|
context_lines,
|
||||||
|
max_results,
|
||||||
|
)
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn search_with_rg(
|
async fn search_with_rg(
|
||||||
@ -176,8 +225,10 @@ impl ContentSearchTool {
|
|||||||
let mut cmd = Command::new("rg");
|
let mut cmd = Command::new("rg");
|
||||||
cmd.arg("-n")
|
cmd.arg("-n")
|
||||||
.arg("--no-heading")
|
.arg("--no-heading")
|
||||||
.arg("--color").arg("never")
|
.arg("--color")
|
||||||
.arg("--max-count").arg(max_results.to_string())
|
.arg("never")
|
||||||
|
.arg("--max-count")
|
||||||
|
.arg(max_results.to_string())
|
||||||
.arg(pattern)
|
.arg(pattern)
|
||||||
.arg(dir)
|
.arg(dir)
|
||||||
.stdout(Stdio::piped())
|
.stdout(Stdio::piped())
|
||||||
@ -193,10 +244,7 @@ impl ContentSearchTool {
|
|||||||
cmd.arg("--glob").arg(fp);
|
cmd.arg("--glob").arg(fp);
|
||||||
}
|
}
|
||||||
|
|
||||||
let output = timeout(
|
let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
|
||||||
std::time::Duration::from_secs(TIMEOUT_SECS),
|
|
||||||
cmd.output(),
|
|
||||||
)
|
|
||||||
.await
|
.await
|
||||||
.map_err(|_| anyhow::anyhow!("rg timed out after {}s", TIMEOUT_SECS))??;
|
.map_err(|_| anyhow::anyhow!("rg timed out after {}s", TIMEOUT_SECS))??;
|
||||||
|
|
||||||
@ -206,7 +254,8 @@ impl ContentSearchTool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let text = String::from_utf8_lossy(&output.stdout);
|
let text = String::from_utf8_lossy(&output.stdout);
|
||||||
let lines: Vec<String> = text.lines()
|
let lines: Vec<String> = text
|
||||||
|
.lines()
|
||||||
.take(max_results)
|
.take(max_results)
|
||||||
.map(|l| l.to_string())
|
.map(|l| l.to_string())
|
||||||
.collect();
|
.collect();
|
||||||
@ -242,15 +291,13 @@ impl ContentSearchTool {
|
|||||||
cmd.arg("--include").arg(fp);
|
cmd.arg("--include").arg(fp);
|
||||||
}
|
}
|
||||||
|
|
||||||
let output = timeout(
|
let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
|
||||||
std::time::Duration::from_secs(TIMEOUT_SECS),
|
|
||||||
cmd.output(),
|
|
||||||
)
|
|
||||||
.await
|
.await
|
||||||
.map_err(|_| anyhow::anyhow!("grep timed out after {}s", TIMEOUT_SECS))??;
|
.map_err(|_| anyhow::anyhow!("grep timed out after {}s", TIMEOUT_SECS))??;
|
||||||
|
|
||||||
let text = String::from_utf8_lossy(&output.stdout);
|
let text = String::from_utf8_lossy(&output.stdout);
|
||||||
let lines: Vec<String> = text.lines()
|
let lines: Vec<String> = text
|
||||||
|
.lines()
|
||||||
.take(max_results)
|
.take(max_results)
|
||||||
.map(|l| l.to_string())
|
.map(|l| l.to_string())
|
||||||
.collect();
|
.collect();
|
||||||
@ -280,7 +327,9 @@ impl ContentSearchTool {
|
|||||||
if case_sensitive {
|
if case_sensitive {
|
||||||
regex::Regex::new(&re_str)
|
regex::Regex::new(&re_str)
|
||||||
} else {
|
} else {
|
||||||
regex::RegexBuilder::new(&re_str).case_insensitive(true).build()
|
regex::RegexBuilder::new(&re_str)
|
||||||
|
.case_insensitive(true)
|
||||||
|
.build()
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -291,7 +340,14 @@ impl ContentSearchTool {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let mut results = Vec::new();
|
let mut results = Vec::new();
|
||||||
grep_dir(Path::new(dir), Path::new(dir), &re, file_re.as_ref(), &mut results, max_results)?;
|
grep_dir(
|
||||||
|
Path::new(dir),
|
||||||
|
Path::new(dir),
|
||||||
|
&re,
|
||||||
|
file_re.as_ref(),
|
||||||
|
&mut results,
|
||||||
|
max_results,
|
||||||
|
)?;
|
||||||
|
|
||||||
Ok(results)
|
Ok(results)
|
||||||
}
|
}
|
||||||
@ -350,14 +406,17 @@ fn grep_dir(
|
|||||||
|
|
||||||
if path.is_dir() {
|
if path.is_dir() {
|
||||||
if let Some(name) = rel.file_name().and_then(|n| n.to_str())
|
if let Some(name) = rel.file_name().and_then(|n| n.to_str())
|
||||||
&& name.starts_with('.') && name.len() > 1 {
|
&& name.starts_with('.')
|
||||||
|
&& name.len() > 1
|
||||||
|
{
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
grep_dir(base, &path, re, file_re, results, max)?;
|
grep_dir(base, &path, re, file_re, results, max)?;
|
||||||
} else if path.is_file() {
|
} else if path.is_file() {
|
||||||
if let Some(file_re) = file_re
|
if let Some(file_re) = file_re
|
||||||
&& let Some(name) = rel.file_name().and_then(|n| n.to_str())
|
&& let Some(name) = rel.file_name().and_then(|n| n.to_str())
|
||||||
&& !file_re.is_match(name) {
|
&& !file_re.is_match(name)
|
||||||
|
{
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -391,8 +450,16 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_content_search_rust_fallback() {
|
async fn test_content_search_rust_fallback() {
|
||||||
let dir = TempDir::new().unwrap();
|
let dir = TempDir::new().unwrap();
|
||||||
fs::write(dir.path().join("main.rs"), "fn main() {\n let x = 42;\n println!(\"hello\");\n}").unwrap();
|
fs::write(
|
||||||
fs::write(dir.path().join("lib.rs"), "pub fn foo() -> u32 {\n let y = 42;\n y\n}").unwrap();
|
dir.path().join("main.rs"),
|
||||||
|
"fn main() {\n let x = 42;\n println!(\"hello\");\n}",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
fs::write(
|
||||||
|
dir.path().join("lib.rs"),
|
||||||
|
"pub fn foo() -> u32 {\n let y = 42;\n y\n}",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
fs::write(dir.path().join("README.md"), "# Project\nHello world").unwrap();
|
fs::write(dir.path().join("README.md"), "# Project\nHello world").unwrap();
|
||||||
|
|
||||||
let tool = ContentSearchTool::new();
|
let tool = ContentSearchTool::new();
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde_json::{json, Value};
|
use serde_json::{Value, json};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::scheduler::{next_run_for_schedule, Schedule};
|
use crate::scheduler::{Schedule, next_run_for_schedule};
|
||||||
use crate::storage::{ScheduledJob, Storage};
|
use crate::storage::{ScheduledJob, Storage};
|
||||||
use crate::tools::traits::{Tool, ToolResult};
|
use crate::tools::traits::{Tool, ToolResult};
|
||||||
|
|
||||||
@ -229,10 +229,7 @@ impl Tool for CronListTool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn execute(&self, args: Value) -> anyhow::Result<ToolResult> {
|
async fn execute(&self, args: Value) -> anyhow::Result<ToolResult> {
|
||||||
let filter = args
|
let filter = args.get("status").and_then(|v| v.as_str()).unwrap_or("all");
|
||||||
.get("status")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.unwrap_or("all");
|
|
||||||
let jobs = self.storage.list_scheduled_jobs().await?;
|
let jobs = self.storage.list_scheduled_jobs().await?;
|
||||||
|
|
||||||
let filtered: Vec<&ScheduledJob> = match filter {
|
let filtered: Vec<&ScheduledJob> = match filter {
|
||||||
@ -397,7 +394,9 @@ impl Tool for CronEnableTool {
|
|||||||
.map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?;
|
.map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?;
|
||||||
|
|
||||||
let next = next_run_for_schedule(&job.schedule, now_ms());
|
let next = next_run_for_schedule(&job.schedule, now_ms());
|
||||||
self.storage.set_scheduled_job_enabled(&job_id, true).await?;
|
self.storage
|
||||||
|
.set_scheduled_job_enabled(&job_id, true)
|
||||||
|
.await?;
|
||||||
if let Some(n) = next {
|
if let Some(n) = next {
|
||||||
self.storage.set_scheduled_job_next_run(&job_id, n).await?;
|
self.storage.set_scheduled_job_next_run(&job_id, n).await?;
|
||||||
}
|
}
|
||||||
@ -464,7 +463,9 @@ impl Tool for CronDisableTool {
|
|||||||
.get_scheduled_job(&job_id)
|
.get_scheduled_job(&job_id)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?;
|
.map_err(|e| anyhow::anyhow!("Job {} not found: {}", job_id, e))?;
|
||||||
self.storage.set_scheduled_job_enabled(&job_id, false).await?;
|
self.storage
|
||||||
|
.set_scheduled_job_enabled(&job_id, false)
|
||||||
|
.await?;
|
||||||
|
|
||||||
Ok(ToolResult {
|
Ok(ToolResult {
|
||||||
success: true,
|
success: true,
|
||||||
@ -580,7 +581,9 @@ impl Tool for CronUpdateTool {
|
|||||||
if args.get("schedule").is_some() {
|
if args.get("schedule").is_some() {
|
||||||
let job = self.storage.get_scheduled_job(&job_id).await?;
|
let job = self.storage.get_scheduled_job(&job_id).await?;
|
||||||
if let Some(next) = next_run_for_schedule(&job.schedule, now_ms()) {
|
if let Some(next) = next_run_for_schedule(&job.schedule, now_ms()) {
|
||||||
self.storage.set_scheduled_job_next_run(&job_id, next).await?;
|
self.storage
|
||||||
|
.set_scheduled_job_next_run(&job_id, next)
|
||||||
|
.await?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -765,9 +768,7 @@ mod tests {
|
|||||||
let job = ScheduledJob {
|
let job = ScheduledJob {
|
||||||
id: "job-update-tool".into(),
|
id: "job-update-tool".into(),
|
||||||
name: "old".into(),
|
name: "old".into(),
|
||||||
schedule: Schedule::Every {
|
schedule: Schedule::Every { every_ms: 3600000 },
|
||||||
every_ms: 3600000,
|
|
||||||
},
|
|
||||||
prompt: "old prompt".into(),
|
prompt: "old prompt".into(),
|
||||||
channel: "feishu".into(),
|
channel: "feishu".into(),
|
||||||
chat_id: "oc_1".into(),
|
chat_id: "oc_1".into(),
|
||||||
|
|||||||
@ -102,7 +102,10 @@ impl Tool for DelegateTool {
|
|||||||
_ => Ok(ToolResult {
|
_ => Ok(ToolResult {
|
||||||
success: false,
|
success: false,
|
||||||
output: String::new(),
|
output: String::new(),
|
||||||
error: Some(format!("Unknown action: {}. Supported: run, check_task, cancel_task, list_tasks", action)),
|
error: Some(format!(
|
||||||
|
"Unknown action: {}. Supported: run, check_task, cancel_task, list_tasks",
|
||||||
|
action
|
||||||
|
)),
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -115,9 +118,11 @@ impl DelegateTool {
|
|||||||
.ok_or_else(|| anyhow::anyhow!("missing required parameter: prompt"))?
|
.ok_or_else(|| anyhow::anyhow!("missing required parameter: prompt"))?
|
||||||
.to_string();
|
.to_string();
|
||||||
|
|
||||||
let allowed_tools: Option<Vec<String>> = args["allowed_tools"]
|
let allowed_tools: Option<Vec<String>> = args["allowed_tools"].as_array().map(|arr| {
|
||||||
.as_array()
|
arr.iter()
|
||||||
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect());
|
.filter_map(|v| v.as_str().map(|s| s.to_string()))
|
||||||
|
.collect()
|
||||||
|
});
|
||||||
|
|
||||||
let max_iterations = args["max_iterations"].as_u64().map(|v| v as usize);
|
let max_iterations = args["max_iterations"].as_u64().map(|v| v as usize);
|
||||||
let timeout_secs = args["timeout_secs"].as_u64();
|
let timeout_secs = args["timeout_secs"].as_u64();
|
||||||
@ -141,15 +146,21 @@ impl DelegateTool {
|
|||||||
return Ok(ToolResult {
|
return Ok(ToolResult {
|
||||||
success: false,
|
success: false,
|
||||||
output: String::new(),
|
output: String::new(),
|
||||||
error: Some(format!("unknown mode: {}. Supported: inline, background, parallel", mode_str)),
|
error: Some(format!(
|
||||||
})
|
"unknown mode: {}. Supported: inline, background, parallel",
|
||||||
|
mode_str
|
||||||
|
)),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
match mode {
|
match mode {
|
||||||
ExecutionMode::Inline => {
|
ExecutionMode::Inline => {
|
||||||
let config = self.parse_config_from_args(args)?;
|
let config = self.parse_config_from_args(args)?;
|
||||||
let result = self.sub_agent_manager.run_inline(config).await
|
let result = self
|
||||||
|
.sub_agent_manager
|
||||||
|
.run_inline(config)
|
||||||
|
.await
|
||||||
.map_err(|e| anyhow::anyhow!("{}", e))?;
|
.map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||||
|
|
||||||
match result.status {
|
match result.status {
|
||||||
@ -177,10 +188,14 @@ impl DelegateTool {
|
|||||||
}
|
}
|
||||||
ExecutionMode::Background => {
|
ExecutionMode::Background => {
|
||||||
let config = self.parse_config_from_args(args)?;
|
let config = self.parse_config_from_args(args)?;
|
||||||
let ctx = crate::agent::sub_agent::get_delegate_context()
|
let ctx = crate::agent::sub_agent::get_delegate_context().map_err(|_| {
|
||||||
.map_err(|_| anyhow::anyhow!("delegate context not available: not in an agent worker"))?;
|
anyhow::anyhow!("delegate context not available: not in an agent worker")
|
||||||
|
})?;
|
||||||
|
|
||||||
let task_id = self.sub_agent_manager.run_background(config, ctx).await
|
let task_id = self
|
||||||
|
.sub_agent_manager
|
||||||
|
.run_background(config, ctx)
|
||||||
|
.await
|
||||||
.map_err(|e| anyhow::anyhow!("{}", e))?;
|
.map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||||
|
|
||||||
Ok(ToolResult {
|
Ok(ToolResult {
|
||||||
@ -200,9 +215,12 @@ impl DelegateTool {
|
|||||||
.as_str()
|
.as_str()
|
||||||
.ok_or_else(|| anyhow::anyhow!("each parallel task requires 'prompt'"))?
|
.ok_or_else(|| anyhow::anyhow!("each parallel task requires 'prompt'"))?
|
||||||
.to_string();
|
.to_string();
|
||||||
let allowed_tools: Option<Vec<String>> = task["allowed_tools"]
|
let allowed_tools: Option<Vec<String>> =
|
||||||
.as_array()
|
task["allowed_tools"].as_array().map(|arr| {
|
||||||
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect());
|
arr.iter()
|
||||||
|
.filter_map(|v| v.as_str().map(|s| s.to_string()))
|
||||||
|
.collect()
|
||||||
|
});
|
||||||
|
|
||||||
configs.push(SubAgentConfig {
|
configs.push(SubAgentConfig {
|
||||||
prompt,
|
prompt,
|
||||||
@ -216,13 +234,18 @@ impl DelegateTool {
|
|||||||
let has_args_allowed = args["allowed_tools"].as_array().is_some();
|
let has_args_allowed = args["allowed_tools"].as_array().is_some();
|
||||||
for c in &mut configs {
|
for c in &mut configs {
|
||||||
if c.allowed_tools.is_none() && has_args_allowed {
|
if c.allowed_tools.is_none() && has_args_allowed {
|
||||||
c.allowed_tools = args["allowed_tools"]
|
c.allowed_tools = args["allowed_tools"].as_array().map(|arr| {
|
||||||
.as_array()
|
arr.iter()
|
||||||
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect());
|
.filter_map(|v| v.as_str().map(|s| s.to_string()))
|
||||||
|
.collect()
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let results = self.sub_agent_manager.run_parallel(configs).await
|
let results = self
|
||||||
|
.sub_agent_manager
|
||||||
|
.run_parallel(configs)
|
||||||
|
.await
|
||||||
.map_err(|e| anyhow::anyhow!("{}", e))?;
|
.map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||||
|
|
||||||
let mut output = String::new();
|
let mut output = String::new();
|
||||||
@ -243,7 +266,9 @@ impl DelegateTool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let all_success = results.iter().all(|r| matches!(r.status, TaskStatus::Completed));
|
let all_success = results
|
||||||
|
.iter()
|
||||||
|
.all(|r| matches!(r.status, TaskStatus::Completed));
|
||||||
Ok(ToolResult {
|
Ok(ToolResult {
|
||||||
success: all_success,
|
success: all_success,
|
||||||
output: output.trim().to_string(),
|
output: output.trim().to_string(),
|
||||||
|
|||||||
@ -243,8 +243,8 @@ impl Tool for FileEditTool {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use tempfile::NamedTempFile;
|
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
use tempfile::NamedTempFile;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_edit_simple() {
|
async fn test_edit_simple() {
|
||||||
|
|||||||
@ -181,10 +181,7 @@ impl Tool for FileReadTool {
|
|||||||
}
|
}
|
||||||
result = lines[..end_idx].join("\n");
|
result = lines[..end_idx].join("\n");
|
||||||
let truncated = original_len - result.len();
|
let truncated = original_len - result.len();
|
||||||
result.push_str(&format!(
|
result.push_str(&format!("\n\n... ({} chars truncated) ...", truncated));
|
||||||
"\n\n... ({} chars truncated) ...",
|
|
||||||
truncated
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if end < total {
|
if end < total {
|
||||||
@ -196,10 +193,7 @@ impl Tool for FileReadTool {
|
|||||||
end + 1
|
end + 1
|
||||||
));
|
));
|
||||||
} else {
|
} else {
|
||||||
result.push_str(&format!(
|
result.push_str(&format!("\n\n(End of file — {} lines total)", total));
|
||||||
"\n\n(End of file — {} lines total)",
|
|
||||||
total
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(label) = encoding_label {
|
if let Some(label) = encoding_label {
|
||||||
@ -214,7 +208,7 @@ impl Tool for FileReadTool {
|
|||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
// Truly binary file — base64 encode
|
// Truly binary file — base64 encode
|
||||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
use base64::{Engine, engine::general_purpose::STANDARD};
|
||||||
let encoded = STANDARD.encode(&bytes);
|
let encoded = STANDARD.encode(&bytes);
|
||||||
let mime = mime_guess::from_path(&resolved)
|
let mime = mime_guess::from_path(&resolved)
|
||||||
.first_or_octet_stream()
|
.first_or_octet_stream()
|
||||||
@ -278,8 +272,8 @@ fn decode_text(bytes: &[u8]) -> (Option<String>, Option<&'static str>) {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use tempfile::NamedTempFile;
|
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
use tempfile::NamedTempFile;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_read_simple_file() {
|
async fn test_read_simple_file() {
|
||||||
@ -338,10 +332,7 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_is_directory() {
|
async fn test_is_directory() {
|
||||||
let tool = FileReadTool::new();
|
let tool = FileReadTool::new();
|
||||||
let result = tool
|
let result = tool.execute(json!({ "path": "." })).await.unwrap();
|
||||||
.execute(json!({ "path": "." }))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert!(!result.success);
|
assert!(!result.success);
|
||||||
assert!(result.error.unwrap().contains("Not a file"));
|
assert!(result.error.unwrap().contains("Not a file"));
|
||||||
|
|||||||
@ -101,17 +101,29 @@ impl Tool for FileSearchTool {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let dir = self.resolve_dir(args.get("dir").and_then(|v| v.as_str()));
|
let dir = self.resolve_dir(args.get("dir").and_then(|v| v.as_str()));
|
||||||
let case_sensitive = args.get("case_sensitive").and_then(|v| v.as_bool()).unwrap_or(true);
|
let case_sensitive = args
|
||||||
let max_results = args.get("max_results").and_then(|v| v.as_u64()).unwrap_or(MAX_RESULTS as u64) as usize;
|
.get("case_sensitive")
|
||||||
|
.and_then(|v| v.as_bool())
|
||||||
|
.unwrap_or(true);
|
||||||
|
let max_results = args
|
||||||
|
.get("max_results")
|
||||||
|
.and_then(|v| v.as_u64())
|
||||||
|
.unwrap_or(MAX_RESULTS as u64) as usize;
|
||||||
|
|
||||||
let result = self.run_search(pattern, &dir, case_sensitive, max_results).await;
|
let result = self
|
||||||
|
.run_search(pattern, &dir, case_sensitive, max_results)
|
||||||
|
.await;
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
Ok(lines) => {
|
Ok(lines) => {
|
||||||
let count = lines.len();
|
let count = lines.len();
|
||||||
let mut output = self.truncate_output(&lines);
|
let mut output = self.truncate_output(&lines);
|
||||||
output.push_str(&format!("\n\n---\n共 {} 个文件", count));
|
output.push_str(&format!("\n\n---\n共 {} 个文件", count));
|
||||||
Ok(ToolResult { success: true, output, error: None })
|
Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output,
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
Err(e) => Ok(ToolResult {
|
Err(e) => Ok(ToolResult {
|
||||||
success: false,
|
success: false,
|
||||||
@ -139,9 +151,12 @@ impl FileSearchTool {
|
|||||||
};
|
};
|
||||||
|
|
||||||
if !fd_cmd.is_empty() {
|
if !fd_cmd.is_empty() {
|
||||||
match self.search_with_fd(pattern, dir, case_sensitive, max_results, fd_cmd).await {
|
match self
|
||||||
|
.search_with_fd(pattern, dir, case_sensitive, max_results, fd_cmd)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(lines) if !lines.is_empty() => return Ok(lines),
|
Ok(lines) if !lines.is_empty() => return Ok(lines),
|
||||||
Ok(_) => {},
|
Ok(_) => {}
|
||||||
Err(e) => tracing::warn!("{} failed: {}, falling back", fd_cmd, e),
|
Err(e) => tracing::warn!("{} failed: {}, falling back", fd_cmd, e),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -149,13 +164,14 @@ impl FileSearchTool {
|
|||||||
if which::which("find").is_ok() {
|
if which::which("find").is_ok() {
|
||||||
match self.search_with_find(pattern, dir, max_results).await {
|
match self.search_with_find(pattern, dir, max_results).await {
|
||||||
Ok(lines) if !lines.is_empty() => return Ok(lines),
|
Ok(lines) if !lines.is_empty() => return Ok(lines),
|
||||||
Ok(_) => {},
|
Ok(_) => {}
|
||||||
Err(e) => tracing::warn!("find failed: {}, falling back", e),
|
Err(e) => tracing::warn!("find failed: {}, falling back", e),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tracing::warn!("No fd/find available, using built-in file search (slower)");
|
tracing::warn!("No fd/find available, using built-in file search (slower)");
|
||||||
self.search_with_rust(pattern, dir, case_sensitive, max_results).await
|
self.search_with_rust(pattern, dir, case_sensitive, max_results)
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn search_with_fd(
|
async fn search_with_fd(
|
||||||
@ -167,11 +183,15 @@ impl FileSearchTool {
|
|||||||
fd_cmd: &str,
|
fd_cmd: &str,
|
||||||
) -> anyhow::Result<Vec<String>> {
|
) -> anyhow::Result<Vec<String>> {
|
||||||
let mut cmd = Command::new(fd_cmd);
|
let mut cmd = Command::new(fd_cmd);
|
||||||
cmd.arg("--search-path").arg(dir)
|
cmd.arg("--search-path")
|
||||||
.arg("--glob").arg(pattern)
|
.arg(dir)
|
||||||
.arg("--color").arg("never")
|
.arg("--glob")
|
||||||
|
.arg(pattern)
|
||||||
|
.arg("--color")
|
||||||
|
.arg("never")
|
||||||
.arg("--strip-cwd-prefix")
|
.arg("--strip-cwd-prefix")
|
||||||
.arg("--max-results").arg(max_results.to_string())
|
.arg("--max-results")
|
||||||
|
.arg(max_results.to_string())
|
||||||
.stdout(Stdio::piped())
|
.stdout(Stdio::piped())
|
||||||
.stderr(Stdio::piped());
|
.stderr(Stdio::piped());
|
||||||
|
|
||||||
@ -179,10 +199,7 @@ impl FileSearchTool {
|
|||||||
cmd.arg("--ignore-case");
|
cmd.arg("--ignore-case");
|
||||||
}
|
}
|
||||||
|
|
||||||
let output = timeout(
|
let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
|
||||||
std::time::Duration::from_secs(TIMEOUT_SECS),
|
|
||||||
cmd.output(),
|
|
||||||
)
|
|
||||||
.await
|
.await
|
||||||
.map_err(|_| anyhow::anyhow!("fd timed out after {}s", TIMEOUT_SECS))??;
|
.map_err(|_| anyhow::anyhow!("fd timed out after {}s", TIMEOUT_SECS))??;
|
||||||
|
|
||||||
@ -192,7 +209,8 @@ impl FileSearchTool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let text = String::from_utf8_lossy(&output.stdout);
|
let text = String::from_utf8_lossy(&output.stdout);
|
||||||
let lines: Vec<String> = text.lines()
|
let lines: Vec<String> = text
|
||||||
|
.lines()
|
||||||
.filter(|l| !l.is_empty())
|
.filter(|l| !l.is_empty())
|
||||||
.map(|l| l.to_string())
|
.map(|l| l.to_string())
|
||||||
.collect();
|
.collect();
|
||||||
@ -215,15 +233,13 @@ impl FileSearchTool {
|
|||||||
.stdout(Stdio::piped())
|
.stdout(Stdio::piped())
|
||||||
.stderr(Stdio::null());
|
.stderr(Stdio::null());
|
||||||
|
|
||||||
let output = timeout(
|
let output = timeout(std::time::Duration::from_secs(TIMEOUT_SECS), cmd.output())
|
||||||
std::time::Duration::from_secs(TIMEOUT_SECS),
|
|
||||||
cmd.output(),
|
|
||||||
)
|
|
||||||
.await
|
.await
|
||||||
.map_err(|_| anyhow::anyhow!("find timed out after {}s", TIMEOUT_SECS))??;
|
.map_err(|_| anyhow::anyhow!("find timed out after {}s", TIMEOUT_SECS))??;
|
||||||
|
|
||||||
let text = String::from_utf8_lossy(&output.stdout);
|
let text = String::from_utf8_lossy(&output.stdout);
|
||||||
let mut lines: Vec<String> = text.lines()
|
let mut lines: Vec<String> = text
|
||||||
|
.lines()
|
||||||
.filter(|l| !l.is_empty())
|
.filter(|l| !l.is_empty())
|
||||||
.map(|l| {
|
.map(|l| {
|
||||||
let p = Path::new(l);
|
let p = Path::new(l);
|
||||||
@ -254,7 +270,13 @@ impl FileSearchTool {
|
|||||||
.map_err(|e| anyhow::anyhow!("Invalid glob pattern '{}': {}", pattern, e))?;
|
.map_err(|e| anyhow::anyhow!("Invalid glob pattern '{}': {}", pattern, e))?;
|
||||||
|
|
||||||
let mut results = Vec::new();
|
let mut results = Vec::new();
|
||||||
walk_dir(Path::new(dir), Path::new(dir), &re, &mut results, max_results)?;
|
walk_dir(
|
||||||
|
Path::new(dir),
|
||||||
|
Path::new(dir),
|
||||||
|
&re,
|
||||||
|
&mut results,
|
||||||
|
max_results,
|
||||||
|
)?;
|
||||||
Ok(results)
|
Ok(results)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -311,13 +333,16 @@ fn walk_dir(
|
|||||||
|
|
||||||
if path.is_dir() {
|
if path.is_dir() {
|
||||||
if let Some(name) = rel.file_name().and_then(|n| n.to_str())
|
if let Some(name) = rel.file_name().and_then(|n| n.to_str())
|
||||||
&& name.starts_with('.') && name.len() > 1 {
|
&& name.starts_with('.')
|
||||||
|
&& name.len() > 1
|
||||||
|
{
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
walk_dir(base, &path, re, results, max)?;
|
walk_dir(base, &path, re, results, max)?;
|
||||||
} else if path.is_file() {
|
} else if path.is_file() {
|
||||||
if let Some(name) = rel.file_name().and_then(|n| n.to_str())
|
if let Some(name) = rel.file_name().and_then(|n| n.to_str())
|
||||||
&& re.is_match(name) {
|
&& re.is_match(name)
|
||||||
|
{
|
||||||
results.push(rel.to_string_lossy().to_string());
|
results.push(rel.to_string_lossy().to_string());
|
||||||
}
|
}
|
||||||
if results.len() >= max {
|
if results.len() >= max {
|
||||||
|
|||||||
@ -90,7 +90,8 @@ impl Tool for FileWriteTool {
|
|||||||
// Create parent directories if needed
|
// Create parent directories if needed
|
||||||
if let Some(parent) = resolved.parent()
|
if let Some(parent) = resolved.parent()
|
||||||
&& !parent.exists()
|
&& !parent.exists()
|
||||||
&& let Err(e) = std::fs::create_dir_all(parent) {
|
&& let Err(e) = std::fs::create_dir_all(parent)
|
||||||
|
{
|
||||||
return Ok(ToolResult {
|
return Ok(ToolResult {
|
||||||
success: false,
|
success: false,
|
||||||
output: String::new(),
|
output: String::new(),
|
||||||
@ -168,10 +169,7 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_write_missing_path() {
|
async fn test_write_missing_path() {
|
||||||
let tool = FileWriteTool::new();
|
let tool = FileWriteTool::new();
|
||||||
let result = tool
|
let result = tool.execute(json!({ "content": "Hello" })).await.unwrap();
|
||||||
.execute(json!({ "content": "Hello" }))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert!(!result.success);
|
assert!(!result.success);
|
||||||
assert!(result.error.unwrap().contains("path"));
|
assert!(result.error.unwrap().contains("path"));
|
||||||
|
|||||||
@ -129,7 +129,9 @@ impl GetSkillTool {
|
|||||||
let mut output = format!("可用 skill (共 {} 个):\n", skills.len());
|
let mut output = format!("可用 skill (共 {} 个):\n", skills.len());
|
||||||
for s in &skills {
|
for s in &skills {
|
||||||
let always_mark = if s.always { " [常驻]" } else { "" };
|
let always_mark = if s.always { " [常驻]" } else { "" };
|
||||||
let path_str = s.path.as_ref()
|
let path_str = s
|
||||||
|
.path
|
||||||
|
.as_ref()
|
||||||
.map(|p| p.to_string_lossy().to_string())
|
.map(|p| p.to_string_lossy().to_string())
|
||||||
.unwrap_or_else(|| "—".to_string());
|
.unwrap_or_else(|| "—".to_string());
|
||||||
output.push_str(&format!(
|
output.push_str(&format!(
|
||||||
@ -148,10 +150,10 @@ impl GetSkillTool {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use tempfile::tempdir;
|
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
use tempfile::tempdir;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_existing_skill() {
|
async fn test_get_existing_skill() {
|
||||||
|
|||||||
@ -50,10 +50,7 @@ impl HttpRequestTool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !host_matches_allowlist(&host, &self.allowed_domains) {
|
if !host_matches_allowlist(&host, &self.allowed_domains) {
|
||||||
return Err(format!(
|
return Err(format!("Host '{}' is not in allowed_domains", host));
|
||||||
"Host '{}' is not in allowed_domains",
|
|
||||||
host
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(url.to_string())
|
Ok(url.to_string())
|
||||||
@ -80,8 +77,7 @@ impl HttpRequestTool {
|
|||||||
for (key, value) in obj {
|
for (key, value) in obj {
|
||||||
if let Some(str_val) = value.as_str()
|
if let Some(str_val) = value.as_str()
|
||||||
&& let Ok(name) = reqwest::header::HeaderName::from_bytes(key.as_bytes())
|
&& let Ok(name) = reqwest::header::HeaderName::from_bytes(key.as_bytes())
|
||||||
&& let Ok(val) =
|
&& let Ok(val) = reqwest::header::HeaderValue::from_str(str_val)
|
||||||
reqwest::header::HeaderValue::from_str(str_val)
|
|
||||||
{
|
{
|
||||||
header_map.insert(name, val);
|
header_map.insert(name, val);
|
||||||
}
|
}
|
||||||
@ -191,7 +187,9 @@ fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool {
|
|||||||
|
|
||||||
allowed_domains.iter().any(|domain| {
|
allowed_domains.iter().any(|domain| {
|
||||||
host == domain
|
host == domain
|
||||||
|| host.strip_suffix(domain).is_some_and(|prefix| prefix.ends_with('.'))
|
|| host
|
||||||
|
.strip_suffix(domain)
|
||||||
|
.is_some_and(|prefix| prefix.ends_with('.'))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -202,7 +200,11 @@ fn is_private_host(host: &str) -> bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check .local TLD
|
// Check .local TLD
|
||||||
if host.rsplit('.').next().is_some_and(|label| label == "local") {
|
if host
|
||||||
|
.rsplit('.')
|
||||||
|
.next()
|
||||||
|
.is_some_and(|label| label == "local")
|
||||||
|
{
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -224,9 +226,7 @@ fn is_private_ip(ip: &std::net::IpAddr) -> bool {
|
|||||||
|| v4.is_broadcast()
|
|| v4.is_broadcast()
|
||||||
|| v4.is_multicast()
|
|| v4.is_multicast()
|
||||||
}
|
}
|
||||||
std::net::IpAddr::V6(v6) => {
|
std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(),
|
||||||
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -278,10 +278,7 @@ impl Tool for HttpRequestTool {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let method_str = args
|
let method_str = args.get("method").and_then(|v| v.as_str()).unwrap_or("GET");
|
||||||
.get("method")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.unwrap_or("GET");
|
|
||||||
|
|
||||||
let headers_val = args.get("headers").cloned().unwrap_or(json!({}));
|
let headers_val = args.get("headers").cloned().unwrap_or(json!({}));
|
||||||
let body = args.get("body").and_then(|v| v.as_str());
|
let body = args.get("body").and_then(|v| v.as_str());
|
||||||
|
|||||||
@ -151,10 +151,19 @@ impl Tool for MemoryRecallTool {
|
|||||||
.and_then(|v| v.as_i64())
|
.and_then(|v| v.as_i64())
|
||||||
.unwrap_or(chrono::Utc::now().timestamp_millis());
|
.unwrap_or(chrono::Utc::now().timestamp_millis());
|
||||||
self.memory
|
self.memory
|
||||||
.recall_by_time(since, until, Some(query), limit, Some(MemoryCategory::Knowledge), None)
|
.recall_by_time(
|
||||||
|
since,
|
||||||
|
until,
|
||||||
|
Some(query),
|
||||||
|
limit,
|
||||||
|
Some(MemoryCategory::Knowledge),
|
||||||
|
None,
|
||||||
|
)
|
||||||
.await?
|
.await?
|
||||||
} else {
|
} else {
|
||||||
self.memory.recall(query, limit, Some(MemoryCategory::Knowledge), None).await?
|
self.memory
|
||||||
|
.recall(query, limit, Some(MemoryCategory::Knowledge), None)
|
||||||
|
.await?
|
||||||
};
|
};
|
||||||
|
|
||||||
if entries.is_empty() {
|
if entries.is_empty() {
|
||||||
@ -168,7 +177,11 @@ impl Tool for MemoryRecallTool {
|
|||||||
let formatted = entries
|
let formatted = entries
|
||||||
.iter()
|
.iter()
|
||||||
.map(|e| {
|
.map(|e| {
|
||||||
let session = e.session_id.as_deref().map(|s| format!(" [session: {}]", s)).unwrap_or_default();
|
let session = e
|
||||||
|
.session_id
|
||||||
|
.as_deref()
|
||||||
|
.map(|s| format!(" [session: {}]", s))
|
||||||
|
.unwrap_or_default();
|
||||||
format!(
|
format!(
|
||||||
"- {} [{}]{} [importance: {:.1}]: {}",
|
"- {} [{}]{} [importance: {:.1}]: {}",
|
||||||
e.key,
|
e.key,
|
||||||
@ -264,10 +277,19 @@ impl Tool for TimelineRecallTool {
|
|||||||
.and_then(|v| v.as_i64())
|
.and_then(|v| v.as_i64())
|
||||||
.unwrap_or(chrono::Utc::now().timestamp_millis());
|
.unwrap_or(chrono::Utc::now().timestamp_millis());
|
||||||
self.memory
|
self.memory
|
||||||
.recall_by_time(since, until, Some(query), limit, Some(MemoryCategory::Timeline), session_id)
|
.recall_by_time(
|
||||||
|
since,
|
||||||
|
until,
|
||||||
|
Some(query),
|
||||||
|
limit,
|
||||||
|
Some(MemoryCategory::Timeline),
|
||||||
|
session_id,
|
||||||
|
)
|
||||||
.await?
|
.await?
|
||||||
} else {
|
} else {
|
||||||
self.memory.recall(query, limit, Some(MemoryCategory::Timeline), session_id).await?
|
self.memory
|
||||||
|
.recall(query, limit, Some(MemoryCategory::Timeline), session_id)
|
||||||
|
.await?
|
||||||
};
|
};
|
||||||
|
|
||||||
if entries.is_empty() {
|
if entries.is_empty() {
|
||||||
@ -281,7 +303,11 @@ impl Tool for TimelineRecallTool {
|
|||||||
let formatted = entries
|
let formatted = entries
|
||||||
.iter()
|
.iter()
|
||||||
.map(|e| {
|
.map(|e| {
|
||||||
let session = e.session_id.as_deref().map(|s| format!(" [session: {}]", s)).unwrap_or_default();
|
let session = e
|
||||||
|
.session_id
|
||||||
|
.as_deref()
|
||||||
|
.map(|s| format!(" [session: {}]", s))
|
||||||
|
.unwrap_or_default();
|
||||||
format!(
|
format!(
|
||||||
"- {} [{}]{} [importance: {:.1}]: {}",
|
"- {} [{}]{} [importance: {:.1}]: {}",
|
||||||
e.key,
|
e.key,
|
||||||
|
|||||||
@ -37,11 +37,11 @@ pub use send_message::SendMessageTool;
|
|||||||
pub use traits::{OutboundMessenger, Tool, ToolResult};
|
pub use traits::{OutboundMessenger, Tool, ToolResult};
|
||||||
pub use web_fetch::WebFetchTool;
|
pub use web_fetch::WebFetchTool;
|
||||||
|
|
||||||
use std::sync::Arc;
|
|
||||||
use crate::agent::SubAgentManager;
|
use crate::agent::SubAgentManager;
|
||||||
use crate::config::BrowserConfig;
|
use crate::config::BrowserConfig;
|
||||||
use crate::memory::MemoryManager;
|
use crate::memory::MemoryManager;
|
||||||
use crate::skills::SkillsLoader;
|
use crate::skills::SkillsLoader;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
/// Create the base tool registry (without send_message).
|
/// Create the base tool registry (without send_message).
|
||||||
/// `send_message` tool is registered later via `SessionManager::register_outbound_tool()`
|
/// `send_message` tool is registered later via `SessionManager::register_outbound_tool()`
|
||||||
|
|||||||
@ -17,7 +17,10 @@ impl ToolRegistry {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn register<T: ToolTrait + 'static>(&self, tool: T) {
|
pub fn register<T: ToolTrait + 'static>(&self, tool: T) {
|
||||||
self.tools.lock().unwrap().insert(tool.name().to_string(), Arc::new(tool));
|
self.tools
|
||||||
|
.lock()
|
||||||
|
.unwrap()
|
||||||
|
.insert(tool.name().to_string(), Arc::new(tool));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Register an existing Arc-wrapped tool by name
|
/// Register an existing Arc-wrapped tool by name
|
||||||
|
|||||||
@ -115,7 +115,9 @@ impl SchemaCleanr {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if let Some(Value::String(t)) = obj.get("type")
|
if let Some(Value::String(t)) = obj.get("type")
|
||||||
&& t == "object" && !obj.contains_key("properties") {
|
&& t == "object"
|
||||||
|
&& !obj.contains_key("properties")
|
||||||
|
{
|
||||||
tracing::warn!("Object schema without 'properties' field may cause issues");
|
tracing::warn!("Object schema without 'properties' field may cause issues");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -173,7 +175,8 @@ impl SchemaCleanr {
|
|||||||
|
|
||||||
// Handle anyOf/oneOf simplification
|
// Handle anyOf/oneOf simplification
|
||||||
if (obj.contains_key("anyOf") || obj.contains_key("oneOf"))
|
if (obj.contains_key("anyOf") || obj.contains_key("oneOf"))
|
||||||
&& let Some(simplified) = Self::try_simplify_union(&obj, defs, strategy, ref_stack) {
|
&& let Some(simplified) = Self::try_simplify_union(&obj, defs, strategy, ref_stack)
|
||||||
|
{
|
||||||
return simplified;
|
return simplified;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -243,7 +246,8 @@ impl SchemaCleanr {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if let Some(def_name) = Self::parse_local_ref(ref_value)
|
if let Some(def_name) = Self::parse_local_ref(ref_value)
|
||||||
&& let Some(definition) = defs.get(def_name.as_str()) {
|
&& let Some(definition) = defs.get(def_name.as_str())
|
||||||
|
{
|
||||||
ref_stack.insert(ref_value.to_string());
|
ref_stack.insert(ref_value.to_string());
|
||||||
let cleaned = Self::clean_with_defs(definition.clone(), defs, strategy, ref_stack);
|
let cleaned = Self::clean_with_defs(definition.clone(), defs, strategy, ref_stack);
|
||||||
ref_stack.remove(ref_value);
|
ref_stack.remove(ref_value);
|
||||||
@ -340,11 +344,14 @@ impl SchemaCleanr {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if let Some(Value::Array(arr)) = obj.get("enum")
|
if let Some(Value::Array(arr)) = obj.get("enum")
|
||||||
&& arr.len() == 1 && matches!(arr[0], Value::Null) {
|
&& arr.len() == 1
|
||||||
|
&& matches!(arr[0], Value::Null)
|
||||||
|
{
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if let Some(Value::String(t)) = obj.get("type")
|
if let Some(Value::String(t)) = obj.get("type")
|
||||||
&& t == "null" {
|
&& t == "null"
|
||||||
|
{
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -403,7 +410,10 @@ impl SchemaCleanr {
|
|||||||
|
|
||||||
match non_null.len() {
|
match non_null.len() {
|
||||||
0 => Value::String("null".to_string()),
|
0 => Value::String("null".to_string()),
|
||||||
1 => non_null.into_iter().next().unwrap_or(Value::String("null".to_string())),
|
1 => non_null
|
||||||
|
.into_iter()
|
||||||
|
.next()
|
||||||
|
.unwrap_or(Value::String("null".to_string())),
|
||||||
_ => Value::Array(non_null),
|
_ => Value::Array(non_null),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
use std::sync::Arc;
|
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use mime_guess::mime;
|
use mime_guess::mime;
|
||||||
@ -31,14 +31,20 @@ fn parse_target_chat_id(raw: &str) -> Result<(&str, &str, Option<&str>), String>
|
|||||||
match parts.len() {
|
match parts.len() {
|
||||||
2 => {
|
2 => {
|
||||||
if parts[0].is_empty() || parts[1].is_empty() {
|
if parts[0].is_empty() || parts[1].is_empty() {
|
||||||
Err(format!("Invalid target_chat_id format '{}': channel and chat_id must not be empty", raw))
|
Err(format!(
|
||||||
|
"Invalid target_chat_id format '{}': channel and chat_id must not be empty",
|
||||||
|
raw
|
||||||
|
))
|
||||||
} else {
|
} else {
|
||||||
Ok((parts[0], parts[1], None))
|
Ok((parts[0], parts[1], None))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
3 => {
|
3 => {
|
||||||
if parts[0].is_empty() || parts[1].is_empty() || parts[2].is_empty() {
|
if parts[0].is_empty() || parts[1].is_empty() || parts[2].is_empty() {
|
||||||
Err(format!("Invalid target_chat_id format '{}': all three parts must not be empty", raw))
|
Err(format!(
|
||||||
|
"Invalid target_chat_id format '{}': all three parts must not be empty",
|
||||||
|
raw
|
||||||
|
))
|
||||||
} else {
|
} else {
|
||||||
Ok((parts[0], parts[1], Some(parts[2])))
|
Ok((parts[0], parts[1], Some(parts[2])))
|
||||||
}
|
}
|
||||||
@ -98,8 +104,8 @@ target_chat_id 支持两种格式:<channel>:<chat_id>(发送到该聊天下
|
|||||||
.ok_or_else(|| anyhow::anyhow!("missing content"))?;
|
.ok_or_else(|| anyhow::anyhow!("missing content"))?;
|
||||||
|
|
||||||
// 1. Parse target_chat_id
|
// 1. Parse target_chat_id
|
||||||
let (channel, chat_id, dialog_id) = parse_target_chat_id(raw_id)
|
let (channel, chat_id, dialog_id) =
|
||||||
.map_err(|e| anyhow::anyhow!(e))?;
|
parse_target_chat_id(raw_id).map_err(|e| anyhow::anyhow!(e))?;
|
||||||
|
|
||||||
// 2. Validate channel
|
// 2. Validate channel
|
||||||
if !self.available_channels.contains(channel) {
|
if !self.available_channels.contains(channel) {
|
||||||
@ -109,7 +115,11 @@ target_chat_id 支持两种格式:<channel>:<chat_id>(发送到该聊天下
|
|||||||
error: Some(format!(
|
error: Some(format!(
|
||||||
"Channel '{}' is not available. Available channels: {}",
|
"Channel '{}' is not available. Available channels: {}",
|
||||||
channel,
|
channel,
|
||||||
self.available_channels.iter().cloned().collect::<Vec<_>>().join(", ")
|
self.available_channels
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(", ")
|
||||||
)),
|
)),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -129,7 +139,8 @@ target_chat_id 支持两种格式:<channel>:<chat_id>(发送到该聊天下
|
|||||||
let media = parse_files_arg(&args);
|
let media = parse_files_arg(&args);
|
||||||
|
|
||||||
// 4. Send via messenger
|
// 4. Send via messenger
|
||||||
match self.messenger
|
match self
|
||||||
|
.messenger
|
||||||
.send_message(channel, chat_id, dialog_id, content, source, media)
|
.send_message(channel, chat_id, dialog_id, content, source, media)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
use async_trait::async_trait;
|
|
||||||
use crate::bus::{MediaItem, MessageSource};
|
use crate::bus::{MediaItem, MessageSource};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct ToolResult {
|
pub struct ToolResult {
|
||||||
|
|||||||
@ -239,7 +239,11 @@ fn is_private_host(host: &str) -> bool {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if host.rsplit('.').next().is_some_and(|label| label == "local") {
|
if host
|
||||||
|
.rsplit('.')
|
||||||
|
.next()
|
||||||
|
.is_some_and(|label| label == "local")
|
||||||
|
{
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -248,7 +252,9 @@ fn is_private_host(host: &str) -> bool {
|
|||||||
std::net::IpAddr::V4(v4) => {
|
std::net::IpAddr::V4(v4) => {
|
||||||
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified()
|
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified()
|
||||||
}
|
}
|
||||||
std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(),
|
std::net::IpAddr::V6(v6) => {
|
||||||
|
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
use picobot::providers::{create_provider, ChatCompletionRequest, Message};
|
|
||||||
use picobot::config::{Config, LLMProviderConfig};
|
use picobot::config::{Config, LLMProviderConfig};
|
||||||
|
use picobot::providers::{ChatCompletionRequest, Message, create_provider};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
fn load_config() -> Option<LLMProviderConfig> {
|
fn load_config() -> Option<LLMProviderConfig> {
|
||||||
dotenv::from_filename("tests/test.env").ok()?;
|
dotenv::from_filename("tests/test.env").ok()?;
|
||||||
@ -42,8 +42,7 @@ fn create_request(content: &str) -> ChatCompletionRequest {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
#[ignore]
|
#[ignore]
|
||||||
async fn test_openai_simple_completion() {
|
async fn test_openai_simple_completion() {
|
||||||
let config = load_config()
|
let config = load_config().expect("Please configure tests/test.env with valid API keys");
|
||||||
.expect("Please configure tests/test.env with valid API keys");
|
|
||||||
|
|
||||||
let provider = create_provider(config).expect("Failed to create provider");
|
let provider = create_provider(config).expect("Failed to create provider");
|
||||||
let response = provider.chat(create_request("Say 'ok'")).await.unwrap();
|
let response = provider.chat(create_request("Say 'ok'")).await.unwrap();
|
||||||
@ -57,8 +56,7 @@ async fn test_openai_simple_completion() {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
#[ignore]
|
#[ignore]
|
||||||
async fn test_openai_conversation() {
|
async fn test_openai_conversation() {
|
||||||
let config = load_config()
|
let config = load_config().expect("Please configure tests/test.env with valid API keys");
|
||||||
.expect("Please configure tests/test.env with valid API keys");
|
|
||||||
|
|
||||||
let provider = create_provider(config).expect("Failed to create provider");
|
let provider = create_provider(config).expect("Failed to create provider");
|
||||||
|
|
||||||
@ -82,7 +80,9 @@ async fn test_openai_conversation() {
|
|||||||
async fn test_config_load() {
|
async fn test_config_load() {
|
||||||
// Test that config.json can be loaded and provider config created
|
// Test that config.json can be loaded and provider config created
|
||||||
let config = Config::load("config.json").expect("Failed to load config.json");
|
let config = Config::load("config.json").expect("Failed to load config.json");
|
||||||
let provider_config = config.get_provider_config("default").expect("Failed to get provider config");
|
let provider_config = config
|
||||||
|
.get_provider_config("default")
|
||||||
|
.expect("Failed to get provider config");
|
||||||
|
|
||||||
assert_eq!(provider_config.provider_type, "openai");
|
assert_eq!(provider_config.provider_type, "openai");
|
||||||
assert_eq!(provider_config.name, "aliyun");
|
assert_eq!(provider_config.name, "aliyun");
|
||||||
|
|||||||
@ -41,7 +41,7 @@ async fn test_scheduler_types_roundtrip() {
|
|||||||
/// Verify that next_run_for_schedule produces valid future timestamps.
|
/// Verify that next_run_for_schedule produces valid future timestamps.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_next_run_always_future() {
|
fn test_next_run_always_future() {
|
||||||
use picobot::scheduler::{next_run_for_schedule, Schedule};
|
use picobot::scheduler::{Schedule, next_run_for_schedule};
|
||||||
|
|
||||||
let now = 1700000000000_i64;
|
let now = 1700000000000_i64;
|
||||||
|
|
||||||
@ -56,6 +56,10 @@ fn test_next_run_always_future() {
|
|||||||
for s in &schedules {
|
for s in &schedules {
|
||||||
let next = next_run_for_schedule(s, now);
|
let next = next_run_for_schedule(s, now);
|
||||||
assert!(next.is_some(), "expected next run for {:?}", s);
|
assert!(next.is_some(), "expected next run for {:?}", s);
|
||||||
assert!(next.unwrap() > now, "next run should be after now for {:?}", s);
|
assert!(
|
||||||
|
next.unwrap() > now,
|
||||||
|
"next run should be after now for {:?}",
|
||||||
|
s
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
use picobot::providers::{create_provider, ChatCompletionRequest, Message, Tool, ToolFunction};
|
|
||||||
use picobot::config::LLMProviderConfig;
|
use picobot::config::LLMProviderConfig;
|
||||||
|
use picobot::providers::{ChatCompletionRequest, Message, Tool, ToolFunction, create_provider};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
fn load_openai_config() -> Option<LLMProviderConfig> {
|
fn load_openai_config() -> Option<LLMProviderConfig> {
|
||||||
dotenv::from_filename("tests/test.env").ok()?;
|
dotenv::from_filename("tests/test.env").ok()?;
|
||||||
@ -53,8 +53,7 @@ fn make_weather_tool() -> Tool {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
#[ignore]
|
#[ignore]
|
||||||
async fn test_openai_tool_call() {
|
async fn test_openai_tool_call() {
|
||||||
let config = load_openai_config()
|
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
|
||||||
.expect("Please configure tests/test.env with valid API keys");
|
|
||||||
|
|
||||||
let provider = create_provider(config).expect("Failed to create provider");
|
let provider = create_provider(config).expect("Failed to create provider");
|
||||||
|
|
||||||
@ -68,7 +67,11 @@ async fn test_openai_tool_call() {
|
|||||||
let response = provider.chat(request).await.unwrap();
|
let response = provider.chat(request).await.unwrap();
|
||||||
|
|
||||||
// Should have tool calls
|
// Should have tool calls
|
||||||
assert!(!response.tool_calls.is_empty(), "Expected tool call, got: {}", response.content);
|
assert!(
|
||||||
|
!response.tool_calls.is_empty(),
|
||||||
|
"Expected tool call, got: {}",
|
||||||
|
response.content
|
||||||
|
);
|
||||||
|
|
||||||
let tool_call = &response.tool_calls[0];
|
let tool_call = &response.tool_calls[0];
|
||||||
assert_eq!(tool_call.name, "get_weather");
|
assert_eq!(tool_call.name, "get_weather");
|
||||||
@ -78,8 +81,7 @@ async fn test_openai_tool_call() {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
#[ignore]
|
#[ignore]
|
||||||
async fn test_openai_tool_call_with_manual_execution() {
|
async fn test_openai_tool_call_with_manual_execution() {
|
||||||
let config = load_openai_config()
|
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
|
||||||
.expect("Please configure tests/test.env with valid API keys");
|
|
||||||
|
|
||||||
let provider = create_provider(config).expect("Failed to create provider");
|
let provider = create_provider(config).expect("Failed to create provider");
|
||||||
|
|
||||||
@ -92,8 +94,7 @@ async fn test_openai_tool_call_with_manual_execution() {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let response1 = provider.chat(request1).await.unwrap();
|
let response1 = provider.chat(request1).await.unwrap();
|
||||||
let tool_call = response1.tool_calls.first()
|
let tool_call = response1.tool_calls.first().expect("Expected tool call");
|
||||||
.expect("Expected tool call");
|
|
||||||
assert_eq!(tool_call.name, "get_weather");
|
assert_eq!(tool_call.name, "get_weather");
|
||||||
|
|
||||||
// Second request with tool result
|
// Second request with tool result
|
||||||
@ -116,8 +117,7 @@ async fn test_openai_tool_call_with_manual_execution() {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
#[ignore]
|
#[ignore]
|
||||||
async fn test_openai_no_tool_when_not_provided() {
|
async fn test_openai_no_tool_when_not_provided() {
|
||||||
let config = load_openai_config()
|
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
|
||||||
.expect("Please configure tests/test.env with valid API keys");
|
|
||||||
|
|
||||||
let provider = create_provider(config).expect("Failed to create provider");
|
let provider = create_provider(config).expect("Failed to create provider");
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user