Compare commits
No commits in common. "fa3354db9ccb9844bbf155410911a4f9e00e653c" and "bca86abe6768b3f614cbdea67b7f7fe3e4475cab" have entirely different histories.
fa3354db9c
...
bca86abe67
@ -134,7 +134,7 @@ PicoBot 会在 ~/.picobot/agent/AGENT.md 维护一份持久化 Agent 画像文
|
||||
1. 系统先对当前活动历史做一个近似 token 估算。
|
||||
估算规则不是调用 tokenizer,而是按“约每 4 个字符约等于 1 token,并再乘以 1.2 安全系数”计算。
|
||||
2. 当估算结果超过模型上下文窗口的 50% 时,压缩器才认为“需要压缩”。
|
||||
这里的上下文窗口来自 agent 对应模型配置里的 context_window_tokens;未配置时按 128000 估算。
|
||||
这里的上下文窗口来自 agent 对应模型配置里的 token_limit。
|
||||
3. 即使超过阈值,如果当前历史里的 user turn 数量不超过保留阈值,也不会压缩。
|
||||
当前默认会完整保留最近 3 个 user turn。
|
||||
4. 一旦满足条件,压缩器会先按 user 消息切分 turn,再确定“旧历史”和“最近保留段”的分界点。
|
||||
|
||||
@ -1,16 +1,16 @@
|
||||
use crate::bus::ChatMessage;
|
||||
use async_trait::async_trait;
|
||||
use crate::bus::message::ContentBlock;
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::bus::message::ToolMessageState;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::observability::{
|
||||
Observer, ObserverEvent, ToolExecutionOutcome, ToolExecutionState, truncate_args,
|
||||
truncate_args, Observer, ObserverEvent, ToolExecutionOutcome, ToolExecutionState,
|
||||
};
|
||||
use crate::providers::{ChatCompletionRequest, LLMProvider, Message, ToolCall, create_provider};
|
||||
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall};
|
||||
use crate::skills::SkillRuntime;
|
||||
use crate::storage::SessionStore;
|
||||
use crate::text::{char_count, take_prefix_chars, take_suffix_chars};
|
||||
use crate::tools::{ToolContext, ToolRegistry};
|
||||
use async_trait::async_trait;
|
||||
use crate::text::{char_count, take_prefix_chars, take_suffix_chars};
|
||||
use std::collections::VecDeque;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::io::Read;
|
||||
@ -19,13 +19,18 @@ use std::time::Instant;
|
||||
|
||||
/// Minimum characters to keep when truncating
|
||||
const TRUNCATION_SUFFIX_LEN: usize = 200;
|
||||
const MEMORY_TOOL_USAGE_SYSTEM_PROMPT: &str = include_str!("memory_tool_usage_system_prompt.md");
|
||||
const MEMORY_TOOL_USAGE_SYSTEM_PROMPT: &str =
|
||||
include_str!("memory_tool_usage_system_prompt.md");
|
||||
const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
|
||||
const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str =
|
||||
"工具已经启动并进入等待用户操作的状态。请先完成外部操作,完成后直接告诉我继续。";
|
||||
const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str = "工具已经启动并进入等待用户操作的状态。请先完成外部操作,完成后直接告诉我继续。";
|
||||
const RECOVERABLE_LLM_ERROR_MESSAGE: &str = "模型服务暂时不可用或响应超时。请稍后重试。";
|
||||
|
||||
const SUPPORTED_IMAGE_MIME_TYPES: &[&str] = &["image/jpeg", "image/png", "image/gif", "image/webp"];
|
||||
const SUPPORTED_IMAGE_MIME_TYPES: &[&str] = &[
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/webp",
|
||||
];
|
||||
|
||||
/// Build content blocks from text and media paths
|
||||
fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock> {
|
||||
@ -110,15 +115,14 @@ fn truncate_tool_result(output: &str, max_tool_result_chars: usize) -> String {
|
||||
let tail = take_suffix_chars(output, total_chars.saturating_sub(truncated_start_len));
|
||||
format!(
|
||||
"...\n\n[Output truncated - {} characters removed]\n\n{}",
|
||||
truncated_start_len, tail
|
||||
truncated_start_len,
|
||||
tail
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_pending_tool_output(output: &str) -> Option<String> {
|
||||
output
|
||||
.strip_prefix(PENDING_USER_ACTION_MARKER)
|
||||
.map(|rest| rest.trim().to_string())
|
||||
output.strip_prefix(PENDING_USER_ACTION_MARKER).map(|rest| rest.trim().to_string())
|
||||
}
|
||||
|
||||
fn normalize_tool_arguments(arguments: &serde_json::Value) -> serde_json::Value {
|
||||
@ -337,10 +341,7 @@ impl AgentLoop {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn with_tools(
|
||||
provider_config: LLMProviderConfig,
|
||||
tools: Arc<ToolRegistry>,
|
||||
) -> Result<Self, AgentError> {
|
||||
pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>) -> Result<Self, AgentError> {
|
||||
let max_iterations = provider_config.max_tool_iterations;
|
||||
let provider = create_provider(provider_config.clone())
|
||||
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
||||
@ -415,16 +416,9 @@ impl AgentLoop {
|
||||
/// it loops back to the LLM with the tool results until either:
|
||||
/// - The LLM returns no more tool calls (final response)
|
||||
/// - Maximum iterations are reached
|
||||
pub async fn process(
|
||||
&self,
|
||||
mut messages: Vec<ChatMessage>,
|
||||
) -> Result<AgentProcessResult, AgentError> {
|
||||
pub async fn process(&self, mut messages: Vec<ChatMessage>) -> Result<AgentProcessResult, AgentError> {
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(
|
||||
history_len = messages.len(),
|
||||
max_iterations = self.max_iterations,
|
||||
"Starting agent process"
|
||||
);
|
||||
tracing::debug!(history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process");
|
||||
|
||||
// Track tool calls for loop detection
|
||||
let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default());
|
||||
@ -447,11 +441,7 @@ impl AgentLoop {
|
||||
if let Some(skill_tool) = self.skills.skill_tool_definition() {
|
||||
tool_defs.push(skill_tool);
|
||||
}
|
||||
let tools = if tool_defs.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(tool_defs)
|
||||
};
|
||||
let tools = if tool_defs.is_empty() { None } else { Some(tool_defs) };
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
messages: messages_for_llm,
|
||||
@ -471,8 +461,7 @@ impl AgentLoop {
|
||||
error_details = %format_error_chain(e.as_ref()),
|
||||
"LLM request failed"
|
||||
);
|
||||
let assistant_message =
|
||||
ChatMessage::assistant(recoverable_llm_message(&e.to_string()));
|
||||
let assistant_message = ChatMessage::assistant(recoverable_llm_message(&e.to_string()));
|
||||
emitted_messages.push(assistant_message.clone());
|
||||
return Ok(AgentProcessResult {
|
||||
final_response: assistant_message,
|
||||
@ -491,8 +480,7 @@ impl AgentLoop {
|
||||
|
||||
// If no tool calls, this is the final response
|
||||
if response.tool_calls.is_empty() {
|
||||
let assistant_message = if let Some(reasoning_content) = response.reasoning_content
|
||||
{
|
||||
let assistant_message = if let Some(reasoning_content) = response.reasoning_content {
|
||||
ChatMessage::assistant_with_reasoning(response.content, reasoning_content)
|
||||
} else {
|
||||
ChatMessage::assistant(response.content)
|
||||
@ -505,15 +493,10 @@ impl AgentLoop {
|
||||
}
|
||||
|
||||
// Execute tool calls
|
||||
tracing::info!(
|
||||
iteration,
|
||||
count = response.tool_calls.len(),
|
||||
"Tool calls detected, executing tools"
|
||||
);
|
||||
tracing::info!(iteration, count = response.tool_calls.len(), "Tool calls detected, executing tools");
|
||||
|
||||
// Add assistant message with tool calls
|
||||
let assistant_message =
|
||||
if let Some(reasoning_content) = response.reasoning_content.clone() {
|
||||
let assistant_message = if let Some(reasoning_content) = response.reasoning_content.clone() {
|
||||
ChatMessage::assistant_with_tool_calls_and_reasoning(
|
||||
response.content.clone(),
|
||||
response.tool_calls.clone(),
|
||||
@ -527,13 +510,7 @@ impl AgentLoop {
|
||||
};
|
||||
messages.push(assistant_message.clone());
|
||||
emitted_messages.push(assistant_message);
|
||||
self.emit_live_tool_call_message(
|
||||
emitted_messages
|
||||
.last()
|
||||
.expect("assistant message just pushed")
|
||||
.clone(),
|
||||
)
|
||||
.await;
|
||||
self.emit_live_tool_call_message(emitted_messages.last().expect("assistant message just pushed").clone()).await;
|
||||
|
||||
// Execute tools and add results to messages
|
||||
let tool_results = self.execute_tools(&response.tool_calls).await;
|
||||
@ -542,9 +519,7 @@ impl AgentLoop {
|
||||
// Log function call with name and arguments
|
||||
let args_str = match &tool_call.arguments {
|
||||
serde_json::Value::Object(obj) if obj.is_empty() => "{}".to_string(),
|
||||
other => {
|
||||
serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string())
|
||||
}
|
||||
other => serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()),
|
||||
};
|
||||
tracing::info!(tool = %tool_call.name, args = %args_str, "Calling tool");
|
||||
|
||||
@ -620,11 +595,7 @@ impl AgentLoop {
|
||||
|
||||
// Loop continues to next iteration with updated messages
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(
|
||||
iteration,
|
||||
message_count = messages.len(),
|
||||
"Tool execution complete, continuing to next iteration"
|
||||
);
|
||||
tracing::debug!(iteration, message_count = messages.len(), "Tool execution complete, continuing to next iteration");
|
||||
}
|
||||
|
||||
// Max iterations reached - ask LLM for a summary based on completed work
|
||||
@ -633,7 +604,7 @@ impl AgentLoop {
|
||||
// Add a message asking for summary
|
||||
let summary_request = ChatMessage::user(
|
||||
"You have reached the maximum number of tool call iterations. \
|
||||
Please provide your best answer based on the work completed so far.",
|
||||
Please provide your best answer based on the work completed so far."
|
||||
);
|
||||
messages.push(summary_request);
|
||||
|
||||
@ -653,8 +624,7 @@ impl AgentLoop {
|
||||
|
||||
match (*self.provider).chat(request).await {
|
||||
Ok(response) => {
|
||||
let assistant_message = if let Some(reasoning_content) = response.reasoning_content
|
||||
{
|
||||
let assistant_message = if let Some(reasoning_content) = response.reasoning_content {
|
||||
ChatMessage::assistant_with_reasoning(response.content, reasoning_content)
|
||||
} else {
|
||||
ChatMessage::assistant(response.content)
|
||||
@ -775,7 +745,10 @@ impl AgentLoop {
|
||||
}
|
||||
|
||||
// Apply duration
|
||||
ToolExecutionOutcome { duration, ..result }
|
||||
ToolExecutionOutcome {
|
||||
duration,
|
||||
..result
|
||||
}
|
||||
}
|
||||
|
||||
/// Internal tool execution without event tracking.
|
||||
@ -817,7 +790,10 @@ impl AgentLoop {
|
||||
"arguments": normalized_arguments,
|
||||
}),
|
||||
);
|
||||
ToolExecutionOutcome::failure(format!("Error: {}", err), Some(err))
|
||||
ToolExecutionOutcome::failure(
|
||||
format!("Error: {}", err),
|
||||
Some(err),
|
||||
)
|
||||
}
|
||||
};
|
||||
}
|
||||
@ -833,10 +809,7 @@ impl AgentLoop {
|
||||
}
|
||||
};
|
||||
|
||||
match tool
|
||||
.execute_with_context(&self.tool_context, normalized_arguments.clone())
|
||||
.await
|
||||
{
|
||||
match tool.execute_with_context(&self.tool_context, normalized_arguments.clone()).await {
|
||||
Ok(result) => {
|
||||
if result.success {
|
||||
if let Some(pending_output) = parse_pending_tool_output(&result.output) {
|
||||
@ -854,7 +827,10 @@ impl AgentLoop {
|
||||
output = %result.output,
|
||||
"Tool returned an error result"
|
||||
);
|
||||
ToolExecutionOutcome::failure(format!("Error: {}", error), Some(error))
|
||||
ToolExecutionOutcome::failure(
|
||||
format!("Error: {}", error),
|
||||
Some(error),
|
||||
)
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
@ -866,7 +842,10 @@ impl AgentLoop {
|
||||
error_details = %format!("{:#}", e),
|
||||
"Tool execution failed"
|
||||
);
|
||||
ToolExecutionOutcome::failure(format!("Error: {}", e), Some(e.to_string()))
|
||||
ToolExecutionOutcome::failure(
|
||||
format!("Error: {}", e),
|
||||
Some(e.to_string()),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -884,9 +863,7 @@ impl AgentLoop {
|
||||
return;
|
||||
};
|
||||
|
||||
if let Err(err) =
|
||||
store.append_skill_event(Some(session_id), event_type, skill_name, &payload)
|
||||
{
|
||||
if let Err(err) = store.append_skill_event(Some(session_id), event_type, skill_name, &payload) {
|
||||
tracing::warn!(error = %err, event_type = %event_type, "Failed to record skill event");
|
||||
}
|
||||
}
|
||||
@ -965,37 +942,28 @@ mod tests {
|
||||
|
||||
assert_eq!(provider_message.role, "assistant");
|
||||
assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1);
|
||||
assert_eq!(
|
||||
provider_message.tool_calls.as_ref().unwrap()[0].id,
|
||||
"call_1"
|
||||
);
|
||||
assert_eq!(
|
||||
provider_message.tool_calls.as_ref().unwrap()[0].name,
|
||||
"calculator"
|
||||
);
|
||||
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].id, "call_1");
|
||||
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].name, "calculator");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_message_to_llm_message_preserves_reasoning_content() {
|
||||
let chat_message =
|
||||
ChatMessage::assistant_with_reasoning("final answer", "hidden chain of thought");
|
||||
let chat_message = ChatMessage::assistant_with_reasoning(
|
||||
"final answer",
|
||||
"hidden chain of thought",
|
||||
);
|
||||
|
||||
let provider_message = chat_message_to_llm_message(&chat_message);
|
||||
|
||||
assert_eq!(provider_message.role, "assistant");
|
||||
assert_eq!(
|
||||
provider_message.reasoning_content.as_deref(),
|
||||
Some("hidden chain of thought")
|
||||
);
|
||||
assert_eq!(provider_message.reasoning_content.as_deref(), Some("hidden chain of thought"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_prompt_requires_proactive_memory_search() {
|
||||
assert!(MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("在绝大多数请求开始时"));
|
||||
assert!(MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("先使用长期记忆检索工具 memory_search"));
|
||||
assert!(
|
||||
MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("不要因为你自认为已经能直接回答就省略检索")
|
||||
);
|
||||
assert!(MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("不要因为你自认为已经能直接回答就省略检索"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -1033,13 +1001,9 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_normalize_tool_arguments_keeps_plain_string() {
|
||||
let normalized =
|
||||
normalize_tool_arguments(&serde_json::Value::String("plain text".to_string()));
|
||||
let normalized = normalize_tool_arguments(&serde_json::Value::String("plain text".to_string()));
|
||||
|
||||
assert_eq!(
|
||||
normalized,
|
||||
serde_json::Value::String("plain text".to_string())
|
||||
);
|
||||
assert_eq!(normalized, serde_json::Value::String("plain text".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -1064,9 +1028,7 @@ mod tests {
|
||||
|
||||
assert_eq!(blocks.len(), 2);
|
||||
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
|
||||
assert!(
|
||||
matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,"))
|
||||
);
|
||||
assert!(matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,")));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
use crate::bus::{
|
||||
ChatMessage, SYSTEM_CONTEXT_AGENT_PROMPT, SYSTEM_CONTEXT_HISTORY_COMPACTION,
|
||||
ChatMessage,
|
||||
SYSTEM_CONTEXT_AGENT_PROMPT,
|
||||
SYSTEM_CONTEXT_HISTORY_COMPACTION,
|
||||
SYSTEM_CONTEXT_SCHEDULED_PROMPT,
|
||||
};
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::providers::{ChatCompletionRequest, LLMProvider, Message, create_provider};
|
||||
use crate::providers::{create_provider, ChatCompletionRequest, LLMProvider, Message};
|
||||
use crate::text::{char_count, take_prefix_chars};
|
||||
|
||||
use crate::agent::AgentError;
|
||||
@ -60,7 +62,6 @@ pub struct ContextCompressor {
|
||||
}
|
||||
|
||||
impl ContextCompressor {
|
||||
#[cfg(test)]
|
||||
fn summary_char_budget_for_context_window(context_window: usize) -> usize {
|
||||
const SUMMARY_RATIO: f64 = 0.1;
|
||||
const CHARS_PER_TOKEN: f64 = 2.5;
|
||||
@ -220,9 +221,7 @@ Be concise, aim for {} characters or less.
|
||||
.await;
|
||||
}
|
||||
|
||||
let per_chunk_target = (target_chars / layer.len().max(1))
|
||||
.max(500)
|
||||
.min(target_chars);
|
||||
let per_chunk_target = (target_chars / layer.len().max(1)).max(500).min(target_chars);
|
||||
let mut summaries = Vec::with_capacity(layer.len());
|
||||
for chunk in &layer {
|
||||
summaries.push(
|
||||
@ -242,9 +241,7 @@ Be concise, aim for {} characters or less.
|
||||
|
||||
let merged = summaries.join("\n\n");
|
||||
if char_count(&merged) <= target_chars {
|
||||
return self
|
||||
.summarize_transcript(provider, &merged, target_chars)
|
||||
.await;
|
||||
return self.summarize_transcript(provider, &merged, target_chars).await;
|
||||
}
|
||||
|
||||
layer = Self::split_text_chunks(&merged, target_chars);
|
||||
@ -317,10 +314,7 @@ Be concise, aim for {} characters or less.
|
||||
|| message.has_system_context(SYSTEM_CONTEXT_SCHEDULED_PROMPT))
|
||||
}
|
||||
|
||||
fn split_prefix_messages(
|
||||
&self,
|
||||
history: &[ChatMessage],
|
||||
) -> (Vec<ChatMessage>, Vec<ChatMessage>) {
|
||||
fn split_prefix_messages(&self, history: &[ChatMessage]) -> (Vec<ChatMessage>, Vec<ChatMessage>) {
|
||||
let preserved_system_messages = history
|
||||
.iter()
|
||||
.filter(|message| self.should_preserve_system_message(message))
|
||||
@ -349,8 +343,7 @@ Be concise, aim for {} characters or less.
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let preserved_turn_start =
|
||||
turn_ranges[turn_ranges.len() - self.config.retain_last_user_turns].start;
|
||||
let preserved_turn_start = turn_ranges[turn_ranges.len() - self.config.retain_last_user_turns].start;
|
||||
if preserved_turn_start == 0 {
|
||||
return Ok(None);
|
||||
}
|
||||
@ -364,10 +357,10 @@ Be concise, aim for {} characters or less.
|
||||
|
||||
Ok(Some(HistoryCompactionPlan {
|
||||
preserved_system_messages,
|
||||
summary_message: ChatMessage::system_with_context(
|
||||
format!("[Compressed History]\n\n{}", summary),
|
||||
Some(SYSTEM_CONTEXT_HISTORY_COMPACTION.to_string()),
|
||||
),
|
||||
summary_message: ChatMessage::system_with_context(format!(
|
||||
"[Compressed History]\n\n{}",
|
||||
summary
|
||||
), Some(SYSTEM_CONTEXT_HISTORY_COMPACTION.to_string())),
|
||||
preserved_messages: history[preserved_turn_start..].to_vec(),
|
||||
compressed_turns: turn_ranges.len() - self.config.retain_last_user_turns,
|
||||
preserved_turns: self.config.retain_last_user_turns,
|
||||
@ -399,10 +392,7 @@ Be concise, aim for {} characters or less.
|
||||
"Starting context compression"
|
||||
);
|
||||
|
||||
let current_history = match self
|
||||
.build_compaction_plan(&history, provider_config)
|
||||
.await?
|
||||
{
|
||||
let current_history = match self.build_compaction_plan(&history, provider_config).await? {
|
||||
Some(plan) => {
|
||||
let mut compressed = Vec::with_capacity(
|
||||
plan.preserved_system_messages.len() + plan.preserved_messages.len() + 1,
|
||||
@ -439,11 +429,7 @@ Be concise, aim for {} characters or less.
|
||||
let transcript = Self::build_transcript(messages);
|
||||
|
||||
let result = if char_count(&transcript) <= self.config.summary_max_chars {
|
||||
self.summarize_transcript(
|
||||
provider.as_ref(),
|
||||
&transcript,
|
||||
self.config.summary_max_chars,
|
||||
)
|
||||
self.summarize_transcript(provider.as_ref(), &transcript, self.config.summary_max_chars)
|
||||
.await
|
||||
} else {
|
||||
self.summarize_chunked_transcript(provider.as_ref(), messages, &transcript)
|
||||
@ -454,10 +440,7 @@ Be concise, aim for {} characters or less.
|
||||
Ok(summary) => Ok(summary),
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "LLM summarization failed, using truncated transcript");
|
||||
Ok(take_prefix_chars(
|
||||
&transcript,
|
||||
self.config.summary_max_chars,
|
||||
))
|
||||
Ok(take_prefix_chars(&transcript, self.config.summary_max_chars))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -480,11 +463,7 @@ mod tests {
|
||||
// "Hi there!" (8) -> ceil(8/4)+4 = 2+4 = 6
|
||||
// "How are you?" (11) -> ceil(11/4)+4 = 3+4 = 7
|
||||
// raw = 19, with 1.2x = ~23
|
||||
assert!(
|
||||
tokens > 18 && tokens < 30,
|
||||
"Expected ~23 tokens, got {}",
|
||||
tokens
|
||||
);
|
||||
assert!(tokens > 18 && tokens < 30, "Expected ~23 tokens, got {}", tokens);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -508,39 +487,21 @@ mod tests {
|
||||
];
|
||||
|
||||
let turns = compressor.user_turn_ranges(&history);
|
||||
assert_eq!(
|
||||
turns,
|
||||
vec![
|
||||
UserTurnRange {
|
||||
start: 1,
|
||||
end_exclusive: 4
|
||||
},
|
||||
UserTurnRange {
|
||||
start: 4,
|
||||
end_exclusive: 6
|
||||
},
|
||||
UserTurnRange {
|
||||
start: 6,
|
||||
end_exclusive: 7
|
||||
},
|
||||
]
|
||||
);
|
||||
assert_eq!(turns, vec![
|
||||
UserTurnRange { start: 1, end_exclusive: 4 },
|
||||
UserTurnRange { start: 4, end_exclusive: 6 },
|
||||
UserTurnRange { start: 6, end_exclusive: 7 },
|
||||
]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_split_prefix_messages_preserves_key_system_messages() {
|
||||
let compressor = ContextCompressor::new(50);
|
||||
let prefix = vec![
|
||||
ChatMessage::system_with_context(
|
||||
"agent prompt",
|
||||
Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string()),
|
||||
),
|
||||
ChatMessage::system_with_context("agent prompt", Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string())),
|
||||
ChatMessage::user("u1"),
|
||||
ChatMessage::assistant("a1"),
|
||||
ChatMessage::system_with_context(
|
||||
"scheduled prompt",
|
||||
Some(SYSTEM_CONTEXT_SCHEDULED_PROMPT.to_string()),
|
||||
),
|
||||
ChatMessage::system_with_context("scheduled prompt", Some(SYSTEM_CONTEXT_SCHEDULED_PROMPT.to_string())),
|
||||
];
|
||||
|
||||
let (preserved_system_messages, summary_source) = compressor.split_prefix_messages(&prefix);
|
||||
@ -558,22 +519,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_summary_char_budget_for_context_window_scales_and_clamps() {
|
||||
assert_eq!(
|
||||
ContextCompressor::summary_char_budget_for_context_window(4_096),
|
||||
1_500
|
||||
);
|
||||
assert_eq!(
|
||||
ContextCompressor::summary_char_budget_for_context_window(65_536),
|
||||
16_384
|
||||
);
|
||||
assert_eq!(
|
||||
ContextCompressor::summary_char_budget_for_context_window(128_000),
|
||||
32_000
|
||||
);
|
||||
assert_eq!(
|
||||
ContextCompressor::summary_char_budget_for_context_window(400_000),
|
||||
50_000
|
||||
);
|
||||
assert_eq!(ContextCompressor::summary_char_budget_for_context_window(4_096), 1_500);
|
||||
assert_eq!(ContextCompressor::summary_char_budget_for_context_window(65_536), 16_384);
|
||||
assert_eq!(ContextCompressor::summary_char_budget_for_context_window(128_000), 32_000);
|
||||
assert_eq!(ContextCompressor::summary_char_budget_for_context_window(400_000), 50_000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
pub mod agent_loop;
|
||||
pub mod context_compressor;
|
||||
|
||||
pub use agent_loop::{AgentError, AgentLoop, AgentProcessResult, EmittedMessageHandler};
|
||||
pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult, EmittedMessageHandler};
|
||||
pub use context_compressor::ContextCompressor;
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::bus::{MessageBus, OutboundMessage};
|
||||
use crate::channels::base::{Channel, ChannelError};
|
||||
@ -22,10 +22,7 @@ impl OutboundDispatcher {
|
||||
|
||||
/// Register a channel with the dispatcher
|
||||
pub async fn register_channel(&self, name: &str, channel: Arc<dyn Channel + Send + Sync>) {
|
||||
self.channels
|
||||
.write()
|
||||
.await
|
||||
.insert(name.to_string(), channel);
|
||||
self.channels.write().await.insert(name.to_string(), channel);
|
||||
}
|
||||
|
||||
/// Run the dispatcher loop - consumes from bus and dispatches to channels
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::providers::ToolCall;
|
||||
|
||||
@ -34,9 +34,7 @@ pub struct ImageUrlBlock {
|
||||
|
||||
impl ContentBlock {
|
||||
pub fn text(content: impl Into<String>) -> Self {
|
||||
Self::Text {
|
||||
text: content.into(),
|
||||
}
|
||||
Self::Text { text: content.into() }
|
||||
}
|
||||
|
||||
pub fn image_url(url: impl Into<String>) -> Self {
|
||||
@ -152,10 +150,7 @@ impl ChatMessage {
|
||||
message
|
||||
}
|
||||
|
||||
pub fn assistant_with_tool_calls(
|
||||
content: impl Into<String>,
|
||||
tool_calls: Vec<ToolCall>,
|
||||
) -> Self {
|
||||
pub fn assistant_with_tool_calls(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
|
||||
Self {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
role: "assistant".to_string(),
|
||||
@ -204,17 +199,8 @@ impl ChatMessage {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tool(
|
||||
tool_call_id: impl Into<String>,
|
||||
tool_name: impl Into<String>,
|
||||
content: impl Into<String>,
|
||||
) -> Self {
|
||||
Self::tool_with_state(
|
||||
tool_call_id,
|
||||
tool_name,
|
||||
content,
|
||||
ToolMessageState::Completed,
|
||||
)
|
||||
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
Self::tool_with_state(tool_call_id, tool_name, content, ToolMessageState::Completed)
|
||||
}
|
||||
|
||||
pub fn tool_with_state(
|
||||
@ -301,8 +287,6 @@ pub enum OutboundEventKind {
|
||||
ToolCall,
|
||||
ToolResult,
|
||||
ToolPending,
|
||||
SchedulerNotification,
|
||||
ErrorNotification,
|
||||
}
|
||||
|
||||
impl OutboundMessage {
|
||||
@ -332,30 +316,6 @@ impl OutboundMessage {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn scheduler_notification(
|
||||
channel: impl Into<String>,
|
||||
chat_id: impl Into<String>,
|
||||
content: impl Into<String>,
|
||||
reply_to: Option<String>,
|
||||
metadata: HashMap<String, String>,
|
||||
) -> Self {
|
||||
let mut message = Self::assistant(channel, chat_id, content, reply_to, metadata);
|
||||
message.event_kind = OutboundEventKind::SchedulerNotification;
|
||||
message
|
||||
}
|
||||
|
||||
pub fn error_notification(
|
||||
channel: impl Into<String>,
|
||||
chat_id: impl Into<String>,
|
||||
content: impl Into<String>,
|
||||
reply_to: Option<String>,
|
||||
metadata: HashMap<String, String>,
|
||||
) -> Self {
|
||||
let mut message = Self::assistant(channel, chat_id, content, reply_to, metadata);
|
||||
message.event_kind = OutboundEventKind::ErrorNotification;
|
||||
message
|
||||
}
|
||||
|
||||
pub fn tool_call(
|
||||
channel: impl Into<String>,
|
||||
chat_id: impl Into<String>,
|
||||
@ -457,7 +417,9 @@ impl OutboundMessage {
|
||||
));
|
||||
}
|
||||
|
||||
outbound.extend(tool_calls.iter().map(|tool_call| {
|
||||
outbound.extend(tool_calls
|
||||
.iter()
|
||||
.map(|tool_call| {
|
||||
Self::tool_call(
|
||||
channel.to_string(),
|
||||
chat_id.to_string(),
|
||||
@ -467,7 +429,8 @@ impl OutboundMessage {
|
||||
reply_to.clone(),
|
||||
metadata.clone(),
|
||||
)
|
||||
}));
|
||||
})
|
||||
);
|
||||
outbound
|
||||
} else {
|
||||
vec![Self::assistant(
|
||||
@ -479,11 +442,7 @@ impl OutboundMessage {
|
||||
)]
|
||||
}
|
||||
}
|
||||
"tool" => match message
|
||||
.tool_state
|
||||
.as_ref()
|
||||
.unwrap_or(&ToolMessageState::Completed)
|
||||
{
|
||||
"tool" => match message.tool_state.as_ref().unwrap_or(&ToolMessageState::Completed) {
|
||||
ToolMessageState::Completed => vec![Self::tool_result(
|
||||
channel.to_string(),
|
||||
chat_id.to_string(),
|
||||
@ -508,10 +467,7 @@ impl OutboundMessage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn format_tool_call_content(
|
||||
tool_name: &str,
|
||||
tool_arguments: &serde_json::Value,
|
||||
) -> String {
|
||||
pub(crate) fn format_tool_call_content(tool_name: &str, tool_arguments: &serde_json::Value) -> String {
|
||||
match tool_arguments {
|
||||
serde_json::Value::Object(map) if map.is_empty() => tool_name.to_string(),
|
||||
other => format!("{}\nargs: {}", tool_name, format_tool_arguments_json(other)),
|
||||
@ -588,25 +544,21 @@ mod tests {
|
||||
],
|
||||
);
|
||||
|
||||
let outbound =
|
||||
OutboundMessage::from_chat_message("feishu", "chat-1", None, &HashMap::new(), &message);
|
||||
let outbound = OutboundMessage::from_chat_message(
|
||||
"feishu",
|
||||
"chat-1",
|
||||
None,
|
||||
&HashMap::new(),
|
||||
&message,
|
||||
);
|
||||
|
||||
assert_eq!(outbound.len(), 2);
|
||||
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolCall);
|
||||
assert_eq!(outbound[0].tool_name.as_deref(), Some("calculator"));
|
||||
assert_eq!(
|
||||
outbound[0].tool_arguments.as_ref().unwrap()["expression"],
|
||||
"1 + 1"
|
||||
);
|
||||
assert_eq!(
|
||||
outbound[0].content,
|
||||
"calculator\nargs: {\"expression\":\"1 + 1\"}"
|
||||
);
|
||||
assert_eq!(outbound[0].tool_arguments.as_ref().unwrap()["expression"], "1 + 1");
|
||||
assert_eq!(outbound[0].content, "calculator\nargs: {\"expression\":\"1 + 1\"}");
|
||||
assert_eq!(outbound[1].tool_name.as_deref(), Some("file_read"));
|
||||
assert_eq!(
|
||||
outbound[1].content,
|
||||
"file_read\nargs: {\"path\":\"README.md\"}"
|
||||
);
|
||||
assert_eq!(outbound[1].content, "file_read\nargs: {\"path\":\"README.md\"}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -620,8 +572,13 @@ mod tests {
|
||||
}],
|
||||
);
|
||||
|
||||
let outbound =
|
||||
OutboundMessage::from_chat_message("feishu", "chat-1", None, &HashMap::new(), &message);
|
||||
let outbound = OutboundMessage::from_chat_message(
|
||||
"feishu",
|
||||
"chat-1",
|
||||
None,
|
||||
&HashMap::new(),
|
||||
&message,
|
||||
);
|
||||
|
||||
assert_eq!(outbound.len(), 2);
|
||||
assert_eq!(outbound[0].event_kind, OutboundEventKind::AssistantResponse);
|
||||
@ -634,8 +591,13 @@ mod tests {
|
||||
fn test_from_chat_message_includes_tool_result() {
|
||||
let message = ChatMessage::tool("call-9", "calculator", "2");
|
||||
|
||||
let outbound =
|
||||
OutboundMessage::from_chat_message("feishu", "chat-1", None, &HashMap::new(), &message);
|
||||
let outbound = OutboundMessage::from_chat_message(
|
||||
"feishu",
|
||||
"chat-1",
|
||||
None,
|
||||
&HashMap::new(),
|
||||
&message,
|
||||
);
|
||||
|
||||
assert_eq!(outbound.len(), 1);
|
||||
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolResult);
|
||||
@ -650,8 +612,13 @@ mod tests {
|
||||
ToolMessageState::PendingUserAction,
|
||||
);
|
||||
|
||||
let outbound =
|
||||
OutboundMessage::from_chat_message("feishu", "chat-1", None, &HashMap::new(), &message);
|
||||
let outbound = OutboundMessage::from_chat_message(
|
||||
"feishu",
|
||||
"chat-1",
|
||||
None,
|
||||
&HashMap::new(),
|
||||
&message,
|
||||
);
|
||||
|
||||
assert_eq!(outbound.len(), 1);
|
||||
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolPending);
|
||||
|
||||
@ -3,13 +3,18 @@ pub mod message;
|
||||
|
||||
pub use dispatcher::OutboundDispatcher;
|
||||
pub use message::{
|
||||
ChatMessage, ContentBlock, InboundMessage, MediaItem, OutboundMessage,
|
||||
SYSTEM_CONTEXT_AGENT_PROMPT, SYSTEM_CONTEXT_HISTORY_COMPACTION,
|
||||
ChatMessage,
|
||||
ContentBlock,
|
||||
InboundMessage,
|
||||
MediaItem,
|
||||
OutboundMessage,
|
||||
SYSTEM_CONTEXT_AGENT_PROMPT,
|
||||
SYSTEM_CONTEXT_HISTORY_COMPACTION,
|
||||
SYSTEM_CONTEXT_SCHEDULED_PROMPT,
|
||||
};
|
||||
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
|
||||
// ============================================================================
|
||||
// MessageBus - Async message queue for Channel <-> Agent communication
|
||||
@ -47,8 +52,7 @@ impl MessageBus {
|
||||
|
||||
/// Consume an inbound message (Agent -> Bus)
|
||||
pub async fn consume_inbound(&self) -> InboundMessage {
|
||||
let msg = self
|
||||
.inbound_rx
|
||||
let msg = self.inbound_rx
|
||||
.lock()
|
||||
.await
|
||||
.recv()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -2,7 +2,7 @@ use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::bus::MessageBus;
|
||||
use crate::bus::{MessageBus, OutboundMessage};
|
||||
use crate::channels::base::{Channel, ChannelError};
|
||||
use crate::channels::feishu::FeishuChannel;
|
||||
use crate::config::Config;
|
||||
@ -28,18 +28,12 @@ impl ChannelManager {
|
||||
}
|
||||
|
||||
/// Initialize all Channel instances from config
|
||||
pub async fn init(
|
||||
&self,
|
||||
config: &Config,
|
||||
_provider_config: crate::config::LLMProviderConfig,
|
||||
) -> Result<(), ChannelError> {
|
||||
pub async fn init(&self, config: &Config, _provider_config: crate::config::LLMProviderConfig) -> Result<(), ChannelError> {
|
||||
// Initialize Feishu channel if enabled
|
||||
if let Some(feishu_config) = config.channels.get("feishu") {
|
||||
if feishu_config.enabled {
|
||||
let channel =
|
||||
FeishuChannel::new(feishu_config.clone(), _provider_config).map_err(|e| {
|
||||
ChannelError::Other(format!("Failed to create Feishu channel: {}", e))
|
||||
})?;
|
||||
let channel = FeishuChannel::new(feishu_config.clone(), _provider_config)
|
||||
.map_err(|e| ChannelError::Other(format!("Failed to create Feishu channel: {}", e)))?;
|
||||
|
||||
self.channels
|
||||
.write()
|
||||
@ -81,12 +75,13 @@ impl ChannelManager {
|
||||
self.channels.read().await.get(name).cloned()
|
||||
}
|
||||
|
||||
pub async fn channels(&self) -> Vec<(String, Arc<dyn Channel + Send + Sync>)> {
|
||||
self.channels
|
||||
.read()
|
||||
.await
|
||||
.iter()
|
||||
.map(|(name, channel)| (name.clone(), channel.clone()))
|
||||
.collect()
|
||||
/// Dispatch an outbound message to the appropriate channel
|
||||
pub async fn dispatch(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
|
||||
let channel_name = &msg.channel;
|
||||
if let Some(channel) = self.get_channel(channel_name).await {
|
||||
channel.send(msg).await
|
||||
} else {
|
||||
Err(ChannelError::Other(format!("Channel not found: {}", channel_name)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -3,5 +3,5 @@ pub mod feishu;
|
||||
pub mod manager;
|
||||
|
||||
pub use base::{Channel, ChannelError};
|
||||
pub use feishu::FeishuChannel;
|
||||
pub use manager::ChannelManager;
|
||||
pub use feishu::FeishuChannel;
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::io::{AsyncBufReadExt, BufReader, AsyncWriteExt};
|
||||
|
||||
pub struct CliChannel {
|
||||
read: BufReader<tokio::io::Stdin>,
|
||||
|
||||
@ -49,27 +49,18 @@ impl InputHandler {
|
||||
}
|
||||
|
||||
pub async fn write_output(&mut self, content: &str) -> Result<(), InputError> {
|
||||
self.channel
|
||||
.write_line(content)
|
||||
.await
|
||||
.map_err(InputError::IoError)
|
||||
self.channel.write_line(content).await.map_err(InputError::IoError)
|
||||
}
|
||||
|
||||
pub async fn write_response(&mut self, content: &str) -> Result<(), InputError> {
|
||||
self.channel
|
||||
.write_response(content)
|
||||
.await
|
||||
.map_err(InputError::IoError)
|
||||
self.channel.write_response(content).await.map_err(InputError::IoError)
|
||||
}
|
||||
|
||||
fn handle_special_commands(&self, line: &str) -> Option<InputCommand> {
|
||||
let trimmed = line.trim();
|
||||
let mut parts = trimmed.splitn(2, char::is_whitespace);
|
||||
let command = parts.next()?;
|
||||
let arg = parts
|
||||
.next()
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty());
|
||||
let arg = parts.next().map(str::trim).filter(|value| !value.is_empty());
|
||||
|
||||
match command {
|
||||
"/quit" | "/exit" | "/q" => Some(InputCommand::Exit),
|
||||
@ -114,26 +105,14 @@ mod tests {
|
||||
fn test_special_command_parsing() {
|
||||
let handler = InputHandler::new();
|
||||
|
||||
assert_eq!(
|
||||
handler.handle_special_commands("/quit"),
|
||||
Some(InputCommand::Exit)
|
||||
);
|
||||
assert_eq!(
|
||||
handler.handle_special_commands("/clear"),
|
||||
Some(InputCommand::Clear)
|
||||
);
|
||||
assert_eq!(
|
||||
handler.handle_special_commands("/new"),
|
||||
Some(InputCommand::New(None))
|
||||
);
|
||||
assert_eq!(handler.handle_special_commands("/quit"), Some(InputCommand::Exit));
|
||||
assert_eq!(handler.handle_special_commands("/clear"), Some(InputCommand::Clear));
|
||||
assert_eq!(handler.handle_special_commands("/new"), Some(InputCommand::New(None)));
|
||||
assert_eq!(
|
||||
handler.handle_special_commands("/new planning"),
|
||||
Some(InputCommand::New(Some("planning".to_string())))
|
||||
);
|
||||
assert_eq!(
|
||||
handler.handle_special_commands("/sessions"),
|
||||
Some(InputCommand::Sessions)
|
||||
);
|
||||
assert_eq!(handler.handle_special_commands("/sessions"), Some(InputCommand::Sessions));
|
||||
assert_eq!(
|
||||
handler.handle_special_commands("/use abc123"),
|
||||
Some(InputCommand::Use("abc123".to_string()))
|
||||
@ -142,14 +121,8 @@ mod tests {
|
||||
handler.handle_special_commands("/rename project alpha"),
|
||||
Some(InputCommand::Rename("project alpha".to_string()))
|
||||
);
|
||||
assert_eq!(
|
||||
handler.handle_special_commands("/archive"),
|
||||
Some(InputCommand::Archive)
|
||||
);
|
||||
assert_eq!(
|
||||
handler.handle_special_commands("/delete"),
|
||||
Some(InputCommand::Delete)
|
||||
);
|
||||
assert_eq!(handler.handle_special_commands("/archive"), Some(InputCommand::Archive));
|
||||
assert_eq!(handler.handle_special_commands("/delete"), Some(InputCommand::Delete));
|
||||
assert_eq!(handler.handle_special_commands("/unknown"), None);
|
||||
assert_eq!(handler.handle_special_commands("/use"), None);
|
||||
}
|
||||
|
||||
@ -5,10 +5,7 @@ use tokio_tungstenite::{connect_async, tungstenite::Message};
|
||||
|
||||
use crate::cli::{InputCommand, InputEvent, InputHandler};
|
||||
|
||||
fn format_session_list(
|
||||
sessions: &[crate::protocol::SessionSummary],
|
||||
current_session_id: Option<&str>,
|
||||
) -> String {
|
||||
fn format_session_list(sessions: &[crate::protocol::SessionSummary], current_session_id: Option<&str>) -> String {
|
||||
if sessions.is_empty() {
|
||||
return "No sessions found.".to_string();
|
||||
}
|
||||
@ -28,7 +25,11 @@ fn format_session_list(
|
||||
};
|
||||
lines.push(format!(
|
||||
"{} {} | {} | {} messages{}",
|
||||
marker, session.session_id, session.title, session.message_count, archived,
|
||||
marker,
|
||||
session.session_id,
|
||||
session.title,
|
||||
session.message_count,
|
||||
archived,
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
@ -123,9 +123,7 @@ fn default_allow_from() -> Vec<String> {
|
||||
|
||||
fn default_media_dir() -> String {
|
||||
let home = dirs::home_dir().unwrap_or_else(|| std::path::PathBuf::from("."));
|
||||
home.join(".picobot/media/feishu")
|
||||
.to_string_lossy()
|
||||
.to_string()
|
||||
home.join(".picobot/media/feishu").to_string_lossy().to_string()
|
||||
}
|
||||
|
||||
fn default_reaction_emoji() -> String {
|
||||
@ -159,8 +157,6 @@ pub struct ModelConfig {
|
||||
pub temperature: Option<f32>,
|
||||
#[serde(default)]
|
||||
pub max_tokens: Option<u32>,
|
||||
#[serde(default)]
|
||||
pub context_window_tokens: Option<u32>,
|
||||
#[serde(flatten)]
|
||||
pub extra: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
@ -203,10 +199,7 @@ pub struct GatewayConfig {
|
||||
pub show_tool_results: bool,
|
||||
#[serde(default, rename = "session_ttl_hours")]
|
||||
pub session_ttl_hours: Option<u64>,
|
||||
#[serde(
|
||||
default = "default_agent_prompt_reinject_every",
|
||||
rename = "agent_prompt_reinject_every"
|
||||
)]
|
||||
#[serde(default = "default_agent_prompt_reinject_every", rename = "agent_prompt_reinject_every")]
|
||||
pub agent_prompt_reinject_every: u64,
|
||||
}
|
||||
|
||||
@ -395,10 +388,7 @@ impl SchedulerSchedule {
|
||||
}
|
||||
|
||||
pub fn is_one_shot(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
SchedulerSchedule::Delay { .. } | SchedulerSchedule::At { .. }
|
||||
)
|
||||
matches!(self, SchedulerSchedule::Delay { .. } | SchedulerSchedule::At { .. })
|
||||
}
|
||||
|
||||
pub fn normalized_for_storage(&self) -> Self {
|
||||
@ -528,7 +518,6 @@ pub struct LLMProviderConfig {
|
||||
pub model_id: String,
|
||||
pub temperature: Option<f32>,
|
||||
pub max_tokens: Option<u32>,
|
||||
pub context_window_tokens: Option<u32>,
|
||||
pub model_extra: HashMap<String, serde_json::Value>,
|
||||
pub max_tool_iterations: usize,
|
||||
pub tool_result_max_chars: usize,
|
||||
@ -537,7 +526,7 @@ pub struct LLMProviderConfig {
|
||||
|
||||
impl LLMProviderConfig {
|
||||
pub fn context_window_tokens(&self) -> usize {
|
||||
self.context_window_tokens
|
||||
self.max_tokens
|
||||
.map(|value| value as usize)
|
||||
.unwrap_or(128_000)
|
||||
}
|
||||
@ -592,19 +581,13 @@ impl Config {
|
||||
}
|
||||
|
||||
pub fn get_provider_config(&self, agent_name: &str) -> Result<LLMProviderConfig, ConfigError> {
|
||||
let agent = self
|
||||
.agents
|
||||
.get(agent_name)
|
||||
let agent = self.agents.get(agent_name)
|
||||
.ok_or(ConfigError::AgentNotFound(agent_name.to_string()))?;
|
||||
|
||||
let provider = self
|
||||
.providers
|
||||
.get(&agent.provider)
|
||||
let provider = self.providers.get(&agent.provider)
|
||||
.ok_or(ConfigError::ProviderNotFound(agent.provider.clone()))?;
|
||||
|
||||
let model = self
|
||||
.models
|
||||
.get(&agent.model)
|
||||
let model = self.models.get(&agent.model)
|
||||
.ok_or(ConfigError::ModelNotFound(agent.model.clone()))?;
|
||||
|
||||
Ok(LLMProviderConfig {
|
||||
@ -617,7 +600,6 @@ impl Config {
|
||||
model_id: model.model_id.clone(),
|
||||
temperature: model.temperature,
|
||||
max_tokens: model.max_tokens,
|
||||
context_window_tokens: model.context_window_tokens,
|
||||
model_extra: model.extra.clone(),
|
||||
max_tool_iterations: agent.max_tool_iterations,
|
||||
tool_result_max_chars: agent.tool_result_max_chars,
|
||||
@ -639,17 +621,11 @@ pub enum ConfigError {
|
||||
impl std::fmt::Display for ConfigError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ConfigError::ConfigNotFound(path) => write!(
|
||||
f,
|
||||
"Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json",
|
||||
path
|
||||
),
|
||||
ConfigError::ConfigNotFound(path) => write!(f, "Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json", path),
|
||||
ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name),
|
||||
ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name),
|
||||
ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name),
|
||||
ConfigError::InvalidSchedulerJob(message) => {
|
||||
write!(f, "Invalid scheduler job: {}", message)
|
||||
}
|
||||
ConfigError::InvalidSchedulerJob(message) => write!(f, "Invalid scheduler job: {}", message),
|
||||
ConfigError::InvalidTimezone(message) => write!(f, "Invalid timezone: {}", message),
|
||||
}
|
||||
}
|
||||
@ -685,8 +661,7 @@ fn resolve_env_placeholders(content: &str) -> String {
|
||||
re.replace_all(content, |caps: ®ex::Captures| {
|
||||
let var_name = &caps[1];
|
||||
env::var(var_name).unwrap_or_else(|_| caps[0].to_string())
|
||||
})
|
||||
.to_string()
|
||||
}).to_string()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@ -901,10 +876,7 @@ mod tests {
|
||||
|
||||
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||
assert_eq!(config.time.timezone, "Asia/Shanghai");
|
||||
assert_eq!(
|
||||
config.time.parse_timezone().unwrap(),
|
||||
chrono_tz::Asia::Shanghai
|
||||
);
|
||||
assert_eq!(config.time.parse_timezone().unwrap(), chrono_tz::Asia::Shanghai);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -1011,10 +983,7 @@ mod tests {
|
||||
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||
assert_eq!(config.agents["default"].max_tool_iterations, 100);
|
||||
assert_eq!(config.agents["default"].tool_result_max_chars, 20_000);
|
||||
assert_eq!(
|
||||
config.agents["default"].context_tool_result_trim_chars,
|
||||
2_000
|
||||
);
|
||||
assert_eq!(config.agents["default"].context_tool_result_trim_chars, 2_000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -1060,44 +1029,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_config_summary_budget_scales_with_context_window_tokens() {
|
||||
let file = tempfile::NamedTempFile::new().unwrap();
|
||||
std::fs::write(
|
||||
file.path(),
|
||||
r#"{
|
||||
"providers": {
|
||||
"aliyun": {
|
||||
"type": "openai",
|
||||
"base_url": "https://example.invalid/v1",
|
||||
"api_key": "test-key",
|
||||
"extra_headers": {}
|
||||
}
|
||||
},
|
||||
"models": {
|
||||
"qwen-plus": {
|
||||
"model_id": "qwen-plus",
|
||||
"context_window_tokens": 4096
|
||||
}
|
||||
},
|
||||
"agents": {
|
||||
"default": {
|
||||
"provider": "aliyun",
|
||||
"model": "qwen-plus"
|
||||
}
|
||||
}
|
||||
}"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||
let provider_config = config.get_provider_config("default").unwrap();
|
||||
|
||||
assert_eq!(provider_config.context_window_tokens(), 4096);
|
||||
assert_eq!(provider_config.context_summary_char_budget(), 1_500);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_config_max_tokens_does_not_change_context_window() {
|
||||
fn test_provider_config_summary_budget_scales_with_model_max_tokens() {
|
||||
let file = tempfile::NamedTempFile::new().unwrap();
|
||||
std::fs::write(
|
||||
file.path(),
|
||||
@ -1129,9 +1061,8 @@ mod tests {
|
||||
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||
let provider_config = config.get_provider_config("default").unwrap();
|
||||
|
||||
assert_eq!(provider_config.max_tokens, Some(4096));
|
||||
assert_eq!(provider_config.context_window_tokens(), 128_000);
|
||||
assert_eq!(provider_config.context_summary_char_budget(), 32_000);
|
||||
assert_eq!(provider_config.context_window_tokens(), 4096);
|
||||
assert_eq!(provider_config.context_summary_char_budget(), 1_500);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -1228,10 +1159,7 @@ mod tests {
|
||||
assert!(config.scheduler.enabled);
|
||||
assert_eq!(config.scheduler.tick_resolution_ms, 1_000);
|
||||
assert_eq!(config.scheduler.worker_queue_capacity, 64);
|
||||
assert_eq!(
|
||||
config.scheduler.misfire_policy,
|
||||
SchedulerMisfirePolicy::Skip
|
||||
);
|
||||
assert_eq!(config.scheduler.misfire_policy, SchedulerMisfirePolicy::Skip);
|
||||
assert!(config.scheduler.jobs.is_empty());
|
||||
|
||||
let effective_jobs = config.scheduler.effective_jobs(&config.time);
|
||||
@ -1345,10 +1273,7 @@ mod tests {
|
||||
assert!(config.scheduler.enabled);
|
||||
assert_eq!(config.scheduler.tick_resolution_ms, 500);
|
||||
assert_eq!(config.scheduler.worker_queue_capacity, 8);
|
||||
assert_eq!(
|
||||
config.scheduler.misfire_policy,
|
||||
SchedulerMisfirePolicy::CatchUp
|
||||
);
|
||||
assert_eq!(config.scheduler.misfire_policy, SchedulerMisfirePolicy::CatchUp);
|
||||
assert_eq!(config.scheduler.jobs.len(), 1);
|
||||
|
||||
let job = &config.scheduler.jobs[0];
|
||||
@ -1359,17 +1284,11 @@ mod tests {
|
||||
assert_eq!(job.startup_delay_secs, 5);
|
||||
assert_eq!(job.target.channel.as_deref(), Some("feishu"));
|
||||
assert_eq!(job.target.chat_id.as_deref(), Some("oc_demo"));
|
||||
assert_eq!(
|
||||
job.payload.get("content").and_then(|value| value.as_str()),
|
||||
Some("heartbeat")
|
||||
);
|
||||
assert_eq!(
|
||||
job.resolved_schedule().unwrap(),
|
||||
SchedulerSchedule::Interval {
|
||||
assert_eq!(job.payload.get("content").and_then(|value| value.as_str()), Some("heartbeat"));
|
||||
assert_eq!(job.resolved_schedule().unwrap(), SchedulerSchedule::Interval {
|
||||
seconds: 60,
|
||||
startup_delay_secs: 5,
|
||||
}
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -1443,30 +1362,21 @@ mod tests {
|
||||
config.scheduler.jobs[0].resolved_schedule().unwrap(),
|
||||
SchedulerSchedule::Delay { seconds: 30 }
|
||||
);
|
||||
assert_eq!(
|
||||
config.scheduler.jobs[0].kind,
|
||||
SchedulerJobKind::InternalEvent
|
||||
);
|
||||
assert_eq!(config.scheduler.jobs[0].kind, SchedulerJobKind::InternalEvent);
|
||||
assert_eq!(
|
||||
config.scheduler.jobs[1].resolved_schedule().unwrap(),
|
||||
SchedulerSchedule::At {
|
||||
timestamp: "2026-04-23T09:00:00+00:00".to_string(),
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
config.scheduler.jobs[1].kind,
|
||||
SchedulerJobKind::OutboundMessage
|
||||
);
|
||||
assert_eq!(config.scheduler.jobs[1].kind, SchedulerJobKind::OutboundMessage);
|
||||
assert_eq!(
|
||||
config.scheduler.jobs[2].resolved_schedule().unwrap(),
|
||||
SchedulerSchedule::Cron {
|
||||
expression: "0 9 * * *".to_string(),
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
config.scheduler.jobs[2].kind,
|
||||
SchedulerJobKind::InternalEvent
|
||||
);
|
||||
assert_eq!(config.scheduler.jobs[2].kind, SchedulerJobKind::InternalEvent);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -1523,10 +1433,7 @@ mod tests {
|
||||
assert_eq!(job.kind, SchedulerJobKind::AgentTask);
|
||||
assert_eq!(job.target.channel.as_deref(), Some("feishu"));
|
||||
assert_eq!(job.target.chat_id.as_deref(), Some("oc_demo"));
|
||||
assert_eq!(
|
||||
job.payload.get("prompt").and_then(|value| value.as_str()),
|
||||
Some("请总结今天待办")
|
||||
);
|
||||
assert_eq!(job.payload.get("prompt").and_then(|value| value.as_str()), Some("请总结今天待办"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -1588,40 +1495,29 @@ mod tests {
|
||||
job.target.session_chat_id.as_deref(),
|
||||
Some("scheduler/agent.daily_summary.background")
|
||||
);
|
||||
assert_eq!(
|
||||
job.payload.get("prompt").and_then(|value| value.as_str()),
|
||||
Some("请后台总结今天待办")
|
||||
);
|
||||
assert_eq!(job.payload.get("prompt").and_then(|value| value.as_str()), Some("请后台总结今天待办"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scheduler_schedule_validation_rejects_invalid_values() {
|
||||
assert!(
|
||||
SchedulerSchedule::Delay { seconds: 0 }
|
||||
assert!(SchedulerSchedule::Delay { seconds: 0 }
|
||||
.validate("delay.job")
|
||||
.is_err()
|
||||
);
|
||||
assert!(
|
||||
SchedulerSchedule::Interval {
|
||||
.is_err());
|
||||
assert!(SchedulerSchedule::Interval {
|
||||
seconds: 0,
|
||||
startup_delay_secs: 0,
|
||||
}
|
||||
.validate("interval.job")
|
||||
.is_err()
|
||||
);
|
||||
assert!(
|
||||
SchedulerSchedule::At {
|
||||
.is_err());
|
||||
assert!(SchedulerSchedule::At {
|
||||
timestamp: "bad timestamp".to_string(),
|
||||
}
|
||||
.validate("at.job")
|
||||
.is_err()
|
||||
);
|
||||
assert!(
|
||||
SchedulerSchedule::Cron {
|
||||
.is_err());
|
||||
assert!(SchedulerSchedule::Cron {
|
||||
expression: "bad cron".to_string(),
|
||||
}
|
||||
.validate("cron.job")
|
||||
.is_err()
|
||||
);
|
||||
.is_err());
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,252 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::agent::{AgentError, AgentProcessResult};
|
||||
use crate::bus::message::ToolMessageState;
|
||||
use crate::bus::{ChatMessage, OutboundMessage};
|
||||
use crate::config::LLMProviderConfig;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use super::session::{Session, schedule_background_history_compaction};
|
||||
|
||||
const SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT: &str = "系统说明:当前输入来自一次已经触发的定时任务执行。你现在需要执行任务内容本身,而不是创建、修改、恢复、暂停或查询新的定时任务。除非当前任务内容明确要求管理调度器,否则不要调用任何定时任务管理工具;像“每小时”、“每天”、“cron”、“定时”等词,只应视为任务背景,不应再解释为新的建任务请求。";
|
||||
|
||||
pub(crate) fn compose_scheduled_task_system_prompt(system_prompt: Option<&str>) -> String {
|
||||
match system_prompt
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
{
|
||||
Some(system_prompt) => format!(
|
||||
"{}\n\n任务专属要求:{}",
|
||||
SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT, system_prompt
|
||||
),
|
||||
None => SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn select_provider_config(
|
||||
default_provider_config: &LLMProviderConfig,
|
||||
provider_configs: &HashMap<String, LLMProviderConfig>,
|
||||
agent_name: Option<&str>,
|
||||
) -> Result<LLMProviderConfig, AgentError> {
|
||||
match agent_name.map(str::trim).filter(|value| !value.is_empty()) {
|
||||
None | Some("default") => Ok(default_provider_config.clone()),
|
||||
Some(agent_name) => provider_configs.get(agent_name).cloned().ok_or_else(|| {
|
||||
AgentError::Other(format!("Scheduled agent '{}' not found", agent_name))
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct AgentExecutionService {
|
||||
show_tool_results: bool,
|
||||
}
|
||||
|
||||
pub(crate) struct FinalizeAgentResultRequest<'a> {
|
||||
pub(crate) channel_name: &'a str,
|
||||
pub(crate) chat_id: &'a str,
|
||||
pub(crate) user_message: &'a ChatMessage,
|
||||
pub(crate) result: AgentProcessResult,
|
||||
pub(crate) metadata: &'a HashMap<String, String>,
|
||||
pub(crate) suppress_live_tool_calls: bool,
|
||||
pub(crate) execution_kind: &'a str,
|
||||
}
|
||||
|
||||
pub(crate) struct FinalizedAgentResult {
|
||||
pub(crate) outbound_messages: Vec<OutboundMessage>,
|
||||
pub(crate) should_schedule_compaction: bool,
|
||||
}
|
||||
|
||||
impl AgentExecutionService {
|
||||
pub(crate) fn new(show_tool_results: bool) -> Self {
|
||||
Self { show_tool_results }
|
||||
}
|
||||
|
||||
pub(crate) fn finalize_result(
|
||||
&self,
|
||||
session: &mut Session,
|
||||
request: FinalizeAgentResultRequest<'_>,
|
||||
) -> Result<FinalizedAgentResult, AgentError> {
|
||||
if !session.matches_current_user_turn(request.chat_id, request.user_message) {
|
||||
let (latest_user_id, latest_user_preview, compression_in_flight, history_len) =
|
||||
session.stale_result_diagnostics(request.chat_id);
|
||||
tracing::warn!(
|
||||
channel = %request.channel_name,
|
||||
chat_id = %request.chat_id,
|
||||
user_message_id = %request.user_message.id,
|
||||
latest_user_id,
|
||||
latest_user_preview,
|
||||
compression_in_flight,
|
||||
history_len,
|
||||
execution_kind = %request.execution_kind,
|
||||
"Skipping stale agent result because a newer user message is already present"
|
||||
);
|
||||
|
||||
return Ok(FinalizedAgentResult {
|
||||
outbound_messages: Vec::new(),
|
||||
should_schedule_compaction: false,
|
||||
});
|
||||
}
|
||||
|
||||
session
|
||||
.append_persisted_messages(request.chat_id, request.result.emitted_messages.clone())?;
|
||||
|
||||
let outbound_messages = request
|
||||
.result
|
||||
.emitted_messages
|
||||
.iter()
|
||||
.filter(|message| {
|
||||
(!message.is_assistant_tool_call_message() || !request.suppress_live_tool_calls)
|
||||
&& should_display_message_to_user(self.show_tool_results, message)
|
||||
})
|
||||
.flat_map(|message| {
|
||||
OutboundMessage::from_chat_message(
|
||||
request.channel_name,
|
||||
request.chat_id,
|
||||
None,
|
||||
request.metadata,
|
||||
message,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(FinalizedAgentResult {
|
||||
outbound_messages,
|
||||
should_schedule_compaction: true,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn finalize_result_and_schedule_compaction(
|
||||
&self,
|
||||
session: Arc<Mutex<Session>>,
|
||||
request: FinalizeAgentResultRequest<'_>,
|
||||
) -> Result<Vec<OutboundMessage>, AgentError> {
|
||||
let channel_name = request.channel_name.to_string();
|
||||
let chat_id = request.chat_id.to_string();
|
||||
let execution_kind = request.execution_kind.to_string();
|
||||
|
||||
let finalized_result = {
|
||||
let mut session_guard = session.lock().await;
|
||||
self.finalize_result(&mut session_guard, request)?
|
||||
};
|
||||
|
||||
if finalized_result.should_schedule_compaction {
|
||||
if let Err(error) =
|
||||
schedule_background_history_compaction(session.clone(), chat_id.clone()).await
|
||||
{
|
||||
tracing::warn!(
|
||||
channel = %channel_name,
|
||||
chat_id = %chat_id,
|
||||
execution_kind = %execution_kind,
|
||||
error = %error,
|
||||
"Failed to schedule background history compaction"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(finalized_result.outbound_messages)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn should_display_message_to_user(
|
||||
show_tool_results: bool,
|
||||
message: &ChatMessage,
|
||||
) -> bool {
|
||||
if message.role != "tool" {
|
||||
return true;
|
||||
}
|
||||
|
||||
show_tool_results
|
||||
|| matches!(
|
||||
message
|
||||
.tool_state
|
||||
.as_ref()
|
||||
.unwrap_or(&ToolMessageState::Completed),
|
||||
ToolMessageState::PendingUserAction
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::bus::ChatMessage;
|
||||
|
||||
fn test_provider_config_named(name: &str, model_id: &str) -> LLMProviderConfig {
|
||||
LLMProviderConfig {
|
||||
provider_type: "openai".to_string(),
|
||||
name: name.to_string(),
|
||||
base_url: "http://localhost".to_string(),
|
||||
api_key: "test-key".to_string(),
|
||||
extra_headers: HashMap::new(),
|
||||
llm_timeout_secs: 120,
|
||||
model_id: model_id.to_string(),
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(32),
|
||||
context_window_tokens: None,
|
||||
model_extra: HashMap::new(),
|
||||
max_tool_iterations: 1,
|
||||
tool_result_max_chars: 20_000,
|
||||
context_tool_result_trim_chars: 20_000,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_provider_config_uses_named_agent_override() {
|
||||
let default_provider = test_provider_config_named("default-provider", "default-model");
|
||||
let provider_configs = HashMap::from([(
|
||||
"planner".to_string(),
|
||||
test_provider_config_named("planner-provider", "planner-model"),
|
||||
)]);
|
||||
|
||||
let selected =
|
||||
select_provider_config(&default_provider, &provider_configs, Some("planner")).unwrap();
|
||||
assert_eq!(selected.name, "planner-provider");
|
||||
assert_eq!(selected.model_id, "planner-model");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_provider_config_falls_back_to_default() {
|
||||
let default_provider = test_provider_config_named("default-provider", "default-model");
|
||||
let provider_configs = HashMap::new();
|
||||
|
||||
let selected =
|
||||
select_provider_config(&default_provider, &provider_configs, Some("default")).unwrap();
|
||||
assert_eq!(selected.name, "default-provider");
|
||||
assert_eq!(selected.model_id, "default-model");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compose_scheduled_task_system_prompt_appends_task_specific_prompt() {
|
||||
let prompt = compose_scheduled_task_system_prompt(Some(" 只汇报异常 "));
|
||||
|
||||
assert!(prompt.contains("当前输入来自一次已经触发的定时任务执行"));
|
||||
assert!(prompt.contains("任务专属要求:只汇报异常"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compose_scheduled_task_system_prompt_ignores_blank_override() {
|
||||
let prompt = compose_scheduled_task_system_prompt(Some(" "));
|
||||
|
||||
assert!(prompt.contains("当前输入来自一次已经触发的定时任务执行"));
|
||||
assert!(!prompt.contains("任务专属要求"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_should_display_message_to_user_keeps_pending_tool_action_visible() {
|
||||
let message = ChatMessage::tool_with_state(
|
||||
"call-1",
|
||||
"approval",
|
||||
"需要用户确认",
|
||||
ToolMessageState::PendingUserAction,
|
||||
);
|
||||
|
||||
assert!(should_display_message_to_user(false, &message));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_should_display_message_to_user_hides_completed_tool_when_disabled() {
|
||||
let message = ChatMessage::tool("call-1", "calculator", "2");
|
||||
|
||||
assert!(!should_display_message_to_user(false, &message));
|
||||
assert!(should_display_message_to_user(true, &message));
|
||||
}
|
||||
}
|
||||
@ -1,517 +0,0 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::agent::AgentError;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::providers::{ChatCompletionRequest, Message, create_provider};
|
||||
use crate::storage::{MemoryRecord, SessionStore};
|
||||
|
||||
use super::prompt::upsert_managed_agent_memory_summary;
|
||||
|
||||
const MEMORY_MAINTENANCE_SYSTEM_PROMPT: &str = include_str!("memory_maintenance_system_prompt.md");
|
||||
const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000];
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum MemoryMaintenanceCategory {
|
||||
UserFacts,
|
||||
Preferences,
|
||||
BehaviorPatterns,
|
||||
Other,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub(crate) struct MemoryMaintenanceCandidate {
|
||||
pub(crate) id: String,
|
||||
pub(crate) namespace: String,
|
||||
pub(crate) key: String,
|
||||
pub(crate) content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub(crate) struct MemoryMaintenancePlan {
|
||||
pub(crate) user_facts: Vec<MemoryMaintenanceCandidate>,
|
||||
pub(crate) preferences: Vec<MemoryMaintenanceCandidate>,
|
||||
pub(crate) behavior_patterns: Vec<MemoryMaintenanceCandidate>,
|
||||
pub(crate) others: Vec<MemoryMaintenanceCandidate>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub(crate) struct MemoryMaintenanceMerge {
|
||||
pub(crate) source_ids: Vec<String>,
|
||||
pub(crate) namespace: String,
|
||||
pub(crate) memory_key: String,
|
||||
pub(crate) content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub(crate) struct MemoryMaintenanceConflict {
|
||||
pub(crate) source_ids: Vec<String>,
|
||||
pub(crate) note: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub(crate) struct MemoryMaintenanceModelOutput {
|
||||
pub(crate) user_facts: Vec<String>,
|
||||
pub(crate) preferences: Vec<String>,
|
||||
pub(crate) behavior_patterns: Vec<String>,
|
||||
pub(crate) merges: Vec<MemoryMaintenanceMerge>,
|
||||
pub(crate) conflicts: Vec<MemoryMaintenanceConflict>,
|
||||
pub(crate) low_value_ids: Vec<String>,
|
||||
pub(crate) managed_markdown: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(crate) struct MemoryMaintenanceScopeResult {
|
||||
pub(crate) scope_key: String,
|
||||
pub(crate) output: MemoryMaintenanceModelOutput,
|
||||
}
|
||||
|
||||
pub(crate) struct MemoryMaintenanceService {
|
||||
store: Arc<SessionStore>,
|
||||
provider_config: LLMProviderConfig,
|
||||
}
|
||||
|
||||
impl MemoryMaintenanceService {
|
||||
pub(crate) fn new(store: Arc<SessionStore>, provider_config: LLMProviderConfig) -> Self {
|
||||
Self {
|
||||
store,
|
||||
provider_config,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn build_plan_for_scope(
|
||||
&self,
|
||||
scope_key: &str,
|
||||
) -> Result<Option<MemoryMaintenancePlan>, AgentError> {
|
||||
let memories = self
|
||||
.store
|
||||
.list_memories_for_scope("user", scope_key)
|
||||
.map_err(|err| AgentError::Other(format!("list memories for scope error: {}", err)))?;
|
||||
|
||||
if memories.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Ok(Some(build_memory_maintenance_plan(&memories)))
|
||||
}
|
||||
|
||||
pub(crate) async fn summarize_for_scope(
|
||||
&self,
|
||||
scope_key: &str,
|
||||
) -> Result<Option<MemoryMaintenanceModelOutput>, AgentError> {
|
||||
let Some(plan) = self.build_plan_for_scope(scope_key)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
self.summarize_plan(scope_key, &plan).await.map(Some)
|
||||
}
|
||||
|
||||
async fn summarize_plan(
|
||||
&self,
|
||||
scope_key: &str,
|
||||
plan: &MemoryMaintenancePlan,
|
||||
) -> Result<MemoryMaintenanceModelOutput, AgentError> {
|
||||
let provider = create_provider(self.provider_config.clone()).map_err(|err| {
|
||||
AgentError::Other(format!("create maintenance provider error: {}", err))
|
||||
})?;
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
messages: vec![
|
||||
Message::system(MEMORY_MAINTENANCE_SYSTEM_PROMPT),
|
||||
Message::user(
|
||||
serde_json::to_string_pretty(&serde_json::json!({
|
||||
"scope_key": scope_key,
|
||||
"candidates": plan,
|
||||
}))
|
||||
.unwrap_or_else(|_| "{}".to_string()),
|
||||
),
|
||||
],
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(1200),
|
||||
tools: None,
|
||||
};
|
||||
|
||||
let mut last_error = None;
|
||||
let mut response = None;
|
||||
|
||||
for (attempt, delay_ms) in MEMORY_MAINTENANCE_RETRY_DELAYS_MS
|
||||
.iter()
|
||||
.copied()
|
||||
.map(Some)
|
||||
.chain(std::iter::once(None))
|
||||
.enumerate()
|
||||
{
|
||||
match provider.chat(request.clone()).await {
|
||||
Ok(success) => {
|
||||
response = Some(success);
|
||||
break;
|
||||
}
|
||||
Err(err) => {
|
||||
let error_text = err.to_string();
|
||||
let should_retry =
|
||||
delay_ms.is_some() && is_recoverable_maintenance_llm_error(&error_text);
|
||||
last_error = Some(error_text.clone());
|
||||
|
||||
if should_retry {
|
||||
tracing::warn!(
|
||||
scope_key = %scope_key,
|
||||
attempt = attempt + 1,
|
||||
retry_in_ms = delay_ms.unwrap_or_default(),
|
||||
error = %error_text,
|
||||
"Memory maintenance model request failed, retrying"
|
||||
);
|
||||
tokio::time::sleep(Duration::from_millis(delay_ms.unwrap_or_default()))
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
|
||||
return Err(AgentError::Other(format!(
|
||||
"memory maintenance model error: {}",
|
||||
error_text
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let response = response.ok_or_else(|| {
|
||||
AgentError::Other(format!(
|
||||
"memory maintenance model error: {}",
|
||||
last_error.unwrap_or_else(|| "unknown provider error".to_string())
|
||||
))
|
||||
})?;
|
||||
|
||||
let raw_content = strip_json_code_fence(&response.content);
|
||||
let json_candidate = extract_json_object(raw_content).unwrap_or(raw_content);
|
||||
|
||||
let output: MemoryMaintenanceModelOutput =
|
||||
serde_json::from_str(json_candidate).map_err(|err| {
|
||||
tracing::error!(
|
||||
scope_key = %scope_key,
|
||||
error = %err,
|
||||
raw_len = raw_content.len(),
|
||||
raw_preview = %preview_text(raw_content, 400),
|
||||
json_candidate_len = json_candidate.len(),
|
||||
json_candidate_preview = %preview_text(json_candidate, 400),
|
||||
"Memory maintenance JSON decode failed"
|
||||
);
|
||||
AgentError::Other(format!("memory maintenance JSON decode error: {}", err))
|
||||
})?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub(crate) async fn run_for_scope(
|
||||
&self,
|
||||
scope_key: &str,
|
||||
) -> Result<Option<MemoryMaintenanceModelOutput>, AgentError> {
|
||||
let Some(plan) = self.build_plan_for_scope(scope_key)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let output = self.summarize_plan(scope_key, &plan).await?;
|
||||
apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &output)?;
|
||||
|
||||
Ok(Some(output))
|
||||
}
|
||||
|
||||
pub(crate) async fn run_for_all_scopes(
|
||||
&self,
|
||||
updated_since: Option<i64>,
|
||||
) -> Result<Vec<MemoryMaintenanceScopeResult>, AgentError> {
|
||||
let scope_keys = if let Some(cutoff) = updated_since {
|
||||
self.store
|
||||
.list_memory_scope_keys_updated_since("user", cutoff)
|
||||
.map_err(|err| {
|
||||
AgentError::Other(format!(
|
||||
"list memory scope keys updated since error: {}",
|
||||
err
|
||||
))
|
||||
})?
|
||||
} else {
|
||||
self.store.list_memory_scope_keys("user").map_err(|err| {
|
||||
AgentError::Other(format!("list memory scope keys error: {}", err))
|
||||
})?
|
||||
};
|
||||
let mut results = Vec::new();
|
||||
|
||||
for scope_key in scope_keys {
|
||||
let Some(output) = self.run_for_scope(&scope_key).await? else {
|
||||
continue;
|
||||
};
|
||||
|
||||
results.push(MemoryMaintenanceScopeResult { scope_key, output });
|
||||
}
|
||||
|
||||
let combined_markdown = combine_managed_memory_markdown(
|
||||
&results
|
||||
.iter()
|
||||
.map(|result| result.output.managed_markdown.clone())
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
if !combined_markdown.is_empty() {
|
||||
upsert_managed_agent_memory_summary(&combined_markdown)?;
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> MemoryMaintenancePlan {
|
||||
let mut plan = MemoryMaintenancePlan::default();
|
||||
let mut seen = HashSet::new();
|
||||
|
||||
for memory in memories {
|
||||
let normalized_content = memory.content.trim();
|
||||
if normalized_content.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let dedupe_key = format!(
|
||||
"{}\u{1f}{}\u{1f}{}",
|
||||
memory.namespace.trim().to_ascii_lowercase(),
|
||||
memory.memory_key.trim().to_ascii_lowercase(),
|
||||
normalized_content
|
||||
);
|
||||
if !seen.insert(dedupe_key) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let candidate = MemoryMaintenanceCandidate {
|
||||
id: memory.id.clone(),
|
||||
namespace: memory.namespace.clone(),
|
||||
key: memory.memory_key.clone(),
|
||||
content: normalized_content.to_string(),
|
||||
};
|
||||
|
||||
match memory_maintenance_category(&memory.namespace) {
|
||||
MemoryMaintenanceCategory::UserFacts => plan.user_facts.push(candidate),
|
||||
MemoryMaintenanceCategory::Preferences => plan.preferences.push(candidate),
|
||||
MemoryMaintenanceCategory::BehaviorPatterns => plan.behavior_patterns.push(candidate),
|
||||
MemoryMaintenanceCategory::Other => plan.others.push(candidate),
|
||||
}
|
||||
}
|
||||
|
||||
plan
|
||||
}
|
||||
|
||||
fn memory_maintenance_category(namespace: &str) -> MemoryMaintenanceCategory {
|
||||
match namespace.trim().to_ascii_lowercase().as_str() {
|
||||
"profile" | "facts" | "identity" => MemoryMaintenanceCategory::UserFacts,
|
||||
"preferences" | "style" | "likes" => MemoryMaintenanceCategory::Preferences,
|
||||
"patterns" | "behavior" | "habits" | "workflow" => {
|
||||
MemoryMaintenanceCategory::BehaviorPatterns
|
||||
}
|
||||
_ => MemoryMaintenanceCategory::Other,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_recoverable_maintenance_llm_error(error: &str) -> bool {
|
||||
let normalized = error.to_ascii_lowercase();
|
||||
normalized.contains("error sending request for url")
|
||||
|| normalized.contains("504")
|
||||
|| normalized.contains("gateway timeout")
|
||||
|| normalized.contains("stream timeout")
|
||||
|| normalized.contains("timed out")
|
||||
|| normalized.contains("timeout")
|
||||
}
|
||||
|
||||
pub(crate) fn strip_json_code_fence(content: &str) -> &str {
|
||||
let trimmed = content.trim();
|
||||
if let Some(rest) = trimmed.strip_prefix("```json") {
|
||||
return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed);
|
||||
}
|
||||
if let Some(rest) = trimmed.strip_prefix("```") {
|
||||
return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed);
|
||||
}
|
||||
trimmed
|
||||
}
|
||||
|
||||
pub(crate) fn extract_json_object(content: &str) -> Option<&str> {
|
||||
let mut start = None;
|
||||
let mut depth = 0usize;
|
||||
let mut in_string = false;
|
||||
let mut escaped = false;
|
||||
|
||||
for (index, ch) in content.char_indices() {
|
||||
if in_string {
|
||||
if escaped {
|
||||
escaped = false;
|
||||
continue;
|
||||
}
|
||||
match ch {
|
||||
'\\' => escaped = true,
|
||||
'"' => in_string = false,
|
||||
_ => {}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
match ch {
|
||||
'"' => in_string = true,
|
||||
'{' => {
|
||||
if start.is_none() {
|
||||
start = Some(index);
|
||||
}
|
||||
depth += 1;
|
||||
}
|
||||
'}' => {
|
||||
if depth == 0 {
|
||||
continue;
|
||||
}
|
||||
depth -= 1;
|
||||
if depth == 0 {
|
||||
let start = start?;
|
||||
let end = index + ch.len_utf8();
|
||||
return Some(content[start..end].trim());
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
pub(crate) fn combine_managed_memory_markdown(chunks: &[String]) -> String {
|
||||
let normalized_chunks = chunks
|
||||
.iter()
|
||||
.map(|chunk| chunk.trim())
|
||||
.filter(|chunk| !chunk.is_empty())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut combined = Vec::new();
|
||||
for (index, chunk) in normalized_chunks.iter().enumerate() {
|
||||
let chunk_lines = chunk
|
||||
.lines()
|
||||
.map(str::trim)
|
||||
.filter(|line| !line.is_empty())
|
||||
.collect::<HashSet<_>>();
|
||||
|
||||
let is_subset_of_other =
|
||||
normalized_chunks
|
||||
.iter()
|
||||
.enumerate()
|
||||
.any(|(other_index, other)| {
|
||||
if index == other_index {
|
||||
return false;
|
||||
}
|
||||
|
||||
let other_lines = other
|
||||
.lines()
|
||||
.map(str::trim)
|
||||
.filter(|line| !line.is_empty())
|
||||
.collect::<HashSet<_>>();
|
||||
|
||||
chunk_lines.len() < other_lines.len() && chunk_lines.is_subset(&other_lines)
|
||||
});
|
||||
|
||||
if !is_subset_of_other && !combined.iter().any(|existing: &String| existing == chunk) {
|
||||
combined.push((*chunk).to_string());
|
||||
}
|
||||
}
|
||||
|
||||
combined.join("\n\n")
|
||||
}
|
||||
|
||||
pub(crate) fn apply_memory_maintenance_output(
|
||||
store: &SessionStore,
|
||||
scope_key: &str,
|
||||
plan: &MemoryMaintenancePlan,
|
||||
output: &MemoryMaintenanceModelOutput,
|
||||
) -> Result<(), AgentError> {
|
||||
let all_candidates = plan
|
||||
.user_facts
|
||||
.iter()
|
||||
.chain(plan.preferences.iter())
|
||||
.chain(plan.behavior_patterns.iter())
|
||||
.chain(plan.others.iter())
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let candidates_by_id = all_candidates
|
||||
.iter()
|
||||
.map(|candidate| (candidate.id.as_str(), candidate))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let mut deleted_ids = HashSet::new();
|
||||
|
||||
for merge in &output.merges {
|
||||
if merge.source_ids.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let source_candidates = merge
|
||||
.source_ids
|
||||
.iter()
|
||||
.filter_map(|id| candidates_by_id.get(id.as_str()).copied())
|
||||
.collect::<Vec<_>>();
|
||||
if source_candidates.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let existing_target_id = source_candidates
|
||||
.iter()
|
||||
.find(|candidate| {
|
||||
candidate.namespace == merge.namespace && candidate.key == merge.memory_key
|
||||
})
|
||||
.map(|candidate| candidate.id.clone());
|
||||
|
||||
store
|
||||
.put_memory(&crate::storage::MemoryUpsert {
|
||||
scope_kind: "user".to_string(),
|
||||
scope_key: scope_key.to_string(),
|
||||
namespace: merge.namespace.trim().to_string(),
|
||||
memory_key: merge.memory_key.trim().to_string(),
|
||||
content: merge.content.trim().to_string(),
|
||||
source_type: "memory_maintenance".to_string(),
|
||||
source_session_id: None,
|
||||
source_message_id: None,
|
||||
source_message_seq: None,
|
||||
source_channel_name: None,
|
||||
source_chat_id: None,
|
||||
})
|
||||
.map_err(|err| AgentError::Other(format!("upsert merged memory error: {}", err)))?;
|
||||
|
||||
for candidate in source_candidates {
|
||||
if existing_target_id
|
||||
.as_ref()
|
||||
.is_some_and(|target_id| target_id == &candidate.id)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
if deleted_ids.insert(candidate.id.clone()) {
|
||||
store
|
||||
.delete_memory("user", scope_key, &candidate.namespace, &candidate.key)
|
||||
.map_err(|err| {
|
||||
AgentError::Other(format!("delete merged source memory error: {}", err))
|
||||
})?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for memory_id in &output.low_value_ids {
|
||||
if let Some(candidate) = candidates_by_id.get(memory_id.as_str()) {
|
||||
if deleted_ids.insert(candidate.id.clone()) {
|
||||
store
|
||||
.delete_memory("user", scope_key, &candidate.namespace, &candidate.key)
|
||||
.map_err(|err| {
|
||||
AgentError::Other(format!("delete low value memory error: {}", err))
|
||||
})?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn preview_text(content: &str, max_chars: usize) -> String {
|
||||
let mut preview = content.chars().take(max_chars).collect::<String>();
|
||||
if content.chars().count() > max_chars {
|
||||
preview.push_str("...");
|
||||
}
|
||||
preview.replace('\n', "\\n")
|
||||
}
|
||||
@ -1,14 +1,10 @@
|
||||
pub mod execution;
|
||||
pub mod http;
|
||||
pub mod memory_maintenance;
|
||||
pub mod processor;
|
||||
pub mod prompt;
|
||||
pub mod session;
|
||||
pub mod ws;
|
||||
|
||||
use axum::{Router, routing};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use axum::{routing, Router};
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
use crate::bus::{MessageBus, OutboundDispatcher};
|
||||
@ -18,8 +14,7 @@ use crate::config::LLMProviderConfig;
|
||||
use crate::logging;
|
||||
use crate::scheduler::Scheduler;
|
||||
use crate::skills::SkillRuntime;
|
||||
use processor::InboundProcessor;
|
||||
use session::SessionManager;
|
||||
use session::{BusToolCallEmitter, SessionManager};
|
||||
|
||||
pub struct GatewayState {
|
||||
pub config: Config,
|
||||
@ -66,17 +61,74 @@ impl GatewayState {
|
||||
|
||||
/// Start the message processing loops
|
||||
pub async fn start_message_processing(&self) {
|
||||
let bus_for_inbound = self.bus.clone();
|
||||
let bus_for_outbound = self.bus.clone();
|
||||
let inbound_processor =
|
||||
InboundProcessor::new(self.bus.clone(), self.session_manager.clone());
|
||||
tokio::spawn(inbound_processor.run());
|
||||
let session_manager = self.session_manager.clone();
|
||||
|
||||
// Spawn inbound message processor
|
||||
// This consumes from bus.inbound, processes via SessionManager, publishes to bus.outbound
|
||||
tokio::spawn(async move {
|
||||
tracing::info!("Inbound processor started");
|
||||
loop {
|
||||
let inbound = bus_for_inbound.consume_inbound().await;
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
tracing::debug!(
|
||||
channel = %inbound.channel,
|
||||
chat_id = %inbound.chat_id,
|
||||
sender = %inbound.sender_id,
|
||||
content = %inbound.content,
|
||||
media_count = %inbound.media.len(),
|
||||
"Processing inbound message"
|
||||
);
|
||||
if !inbound.media.is_empty() {
|
||||
for (i, m) in inbound.media.iter().enumerate() {
|
||||
tracing::debug!(media_index = i, media_type = %m.media_type, path = %m.path, "Media item");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process via session manager
|
||||
let live_emitter = Arc::new(BusToolCallEmitter::new(
|
||||
bus_for_inbound.clone(),
|
||||
inbound.channel.clone(),
|
||||
inbound.chat_id.clone(),
|
||||
inbound.forwarded_metadata.clone(),
|
||||
session_manager.show_tool_results(),
|
||||
));
|
||||
match session_manager.handle_message(
|
||||
&inbound.channel,
|
||||
&inbound.sender_id,
|
||||
&inbound.chat_id,
|
||||
&inbound.content,
|
||||
inbound.media,
|
||||
Some(live_emitter),
|
||||
).await {
|
||||
Ok(outbound_messages) => {
|
||||
// Forward channel-specific metadata from inbound to outbound.
|
||||
// This allows channels to propagate context (e.g. feishu message_id for reaction cleanup)
|
||||
// without gateway needing channel-specific code.
|
||||
for mut outbound in outbound_messages {
|
||||
outbound.metadata.extend(inbound.forwarded_metadata.clone());
|
||||
if let Err(e) = bus_for_inbound.publish_outbound(outbound).await {
|
||||
tracing::error!(error = %e, "Failed to publish outbound");
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to handle message");
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Spawn outbound dispatcher
|
||||
let dispatcher = OutboundDispatcher::new(bus_for_outbound);
|
||||
let channel_manager = self.channel_manager.clone();
|
||||
|
||||
for (name, channel) in channel_manager.channels().await {
|
||||
dispatcher.register_channel(&name, channel).await;
|
||||
// Register channels with dispatcher
|
||||
if let Some(channel) = channel_manager.get_channel("feishu").await {
|
||||
dispatcher.register_channel("feishu", channel).await;
|
||||
}
|
||||
|
||||
tokio::spawn(async move {
|
||||
@ -86,10 +138,7 @@ impl GatewayState {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(
|
||||
host: Option<String>,
|
||||
port: Option<u16>,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let config = Config::load_default()?;
|
||||
let timezone = config.time.parse_timezone()?;
|
||||
|
||||
@ -103,10 +152,7 @@ pub async fn run(
|
||||
let provider_config = state.config.get_provider_config("default")?;
|
||||
|
||||
// Initialize and start channels
|
||||
state
|
||||
.channel_manager
|
||||
.init(&state.config, provider_config.clone())
|
||||
.await?;
|
||||
state.channel_manager.init(&state.config, provider_config.clone()).await?;
|
||||
state.channel_manager.start_all().await?;
|
||||
|
||||
// Start message processing (inbound processor + outbound dispatcher)
|
||||
|
||||
@ -1,77 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::bus::MessageBus;
|
||||
|
||||
use super::session::{BusToolCallEmitter, SessionManager};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct InboundProcessor {
|
||||
bus: Arc<MessageBus>,
|
||||
session_manager: SessionManager,
|
||||
}
|
||||
|
||||
impl InboundProcessor {
|
||||
pub fn new(bus: Arc<MessageBus>, session_manager: SessionManager) -> Self {
|
||||
Self {
|
||||
bus,
|
||||
session_manager,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(self) {
|
||||
tracing::info!("Inbound processor started");
|
||||
|
||||
loop {
|
||||
let inbound = self.bus.consume_inbound().await;
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
tracing::debug!(
|
||||
channel = %inbound.channel,
|
||||
chat_id = %inbound.chat_id,
|
||||
sender = %inbound.sender_id,
|
||||
content = %inbound.content,
|
||||
media_count = %inbound.media.len(),
|
||||
"Processing inbound message"
|
||||
);
|
||||
if !inbound.media.is_empty() {
|
||||
for (i, media) in inbound.media.iter().enumerate() {
|
||||
tracing::debug!(media_index = i, media_type = %media.media_type, path = %media.path, "Media item");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let live_emitter = Arc::new(BusToolCallEmitter::new(
|
||||
self.bus.clone(),
|
||||
inbound.channel.clone(),
|
||||
inbound.chat_id.clone(),
|
||||
inbound.forwarded_metadata.clone(),
|
||||
self.session_manager.show_tool_results(),
|
||||
));
|
||||
|
||||
match self
|
||||
.session_manager
|
||||
.handle_message(
|
||||
&inbound.channel,
|
||||
&inbound.sender_id,
|
||||
&inbound.chat_id,
|
||||
&inbound.content,
|
||||
inbound.media,
|
||||
Some(live_emitter),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(outbound_messages) => {
|
||||
for mut outbound in outbound_messages {
|
||||
outbound.metadata.extend(inbound.forwarded_metadata.clone());
|
||||
if let Err(error) = self.bus.publish_outbound(outbound).await {
|
||||
tracing::error!(error = %error, "Failed to publish outbound");
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(error) => {
|
||||
tracing::error!(error = %error, "Failed to handle message");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,149 +0,0 @@
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use crate::agent::AgentError;
|
||||
|
||||
pub(crate) const DEFAULT_AGENT_PROMPT: &str = include_str!("default_agent_prompt.md");
|
||||
pub(crate) const MANAGED_AGENT_MEMORY_BLOCK_START: &str = "<!-- PICOBOT_MANAGED_MEMORY:START -->";
|
||||
pub(crate) const MANAGED_AGENT_MEMORY_BLOCK_END: &str = "<!-- PICOBOT_MANAGED_MEMORY:END -->";
|
||||
pub(crate) const MANAGED_AGENT_MEMORY_TITLE: &str = "## 用户记忆摘要";
|
||||
|
||||
pub(crate) fn load_agent_prompt() -> Result<Option<String>, AgentError> {
|
||||
let path = agent_prompt_path()?;
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.map_err(|err| AgentError::Other(format!("create agent prompt dir error: {}", err)))?;
|
||||
}
|
||||
|
||||
if !path.exists() {
|
||||
write_agent_prompt(&path, DEFAULT_AGENT_PROMPT)?;
|
||||
}
|
||||
|
||||
let content = fs::read_to_string(&path)
|
||||
.map_err(|err| AgentError::Other(format!("read agent prompt file error: {}", err)))?;
|
||||
let trimmed = content.trim();
|
||||
if trimmed.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Ok(Some(trimmed.to_string()))
|
||||
}
|
||||
|
||||
pub(crate) fn upsert_managed_agent_memory_summary(markdown_body: &str) -> Result<(), AgentError> {
|
||||
let path = agent_prompt_path()?;
|
||||
let existing = if path.exists() {
|
||||
fs::read_to_string(&path)
|
||||
.map_err(|err| AgentError::Other(format!("read agent prompt file error: {}", err)))?
|
||||
} else {
|
||||
DEFAULT_AGENT_PROMPT.to_string()
|
||||
};
|
||||
let updated = upsert_managed_agent_memory_block(&existing, markdown_body);
|
||||
write_agent_prompt(&path, &updated)
|
||||
}
|
||||
|
||||
pub(crate) fn upsert_managed_agent_memory_block(existing: &str, markdown_body: &str) -> String {
|
||||
let managed_block = render_managed_agent_memory_block(markdown_body);
|
||||
|
||||
if let (Some(start), Some(end)) = (
|
||||
existing.find(MANAGED_AGENT_MEMORY_BLOCK_START),
|
||||
existing.find(MANAGED_AGENT_MEMORY_BLOCK_END),
|
||||
) {
|
||||
let end = end + MANAGED_AGENT_MEMORY_BLOCK_END.len();
|
||||
let mut updated = String::new();
|
||||
updated.push_str(existing[..start].trim_end());
|
||||
updated.push_str("\n\n");
|
||||
updated.push_str(&managed_block);
|
||||
updated.push_str("\n\n");
|
||||
updated.push_str(existing[end..].trim_start());
|
||||
return updated.trim().to_string() + "\n";
|
||||
}
|
||||
|
||||
if let Some(reply_rules_index) = existing.find("## 回复规则") {
|
||||
let mut updated = String::new();
|
||||
updated.push_str(existing[..reply_rules_index].trim_end());
|
||||
updated.push_str("\n\n");
|
||||
updated.push_str(&managed_block);
|
||||
updated.push_str("\n\n");
|
||||
updated.push_str(existing[reply_rules_index..].trim_start());
|
||||
return updated.trim().to_string() + "\n";
|
||||
}
|
||||
|
||||
let mut updated = existing.trim_end().to_string();
|
||||
if !updated.is_empty() {
|
||||
updated.push_str("\n\n");
|
||||
}
|
||||
updated.push_str(&managed_block);
|
||||
updated.push('\n');
|
||||
updated
|
||||
}
|
||||
|
||||
fn render_managed_agent_memory_block(markdown_body: &str) -> String {
|
||||
format!(
|
||||
"{MANAGED_AGENT_MEMORY_BLOCK_START}\n{MANAGED_AGENT_MEMORY_TITLE}\n\n{}\n{MANAGED_AGENT_MEMORY_BLOCK_END}",
|
||||
markdown_body.trim()
|
||||
)
|
||||
}
|
||||
|
||||
fn write_agent_prompt(path: &Path, content: &str) -> Result<(), AgentError> {
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.map_err(|err| AgentError::Other(format!("create agent prompt dir error: {}", err)))?;
|
||||
}
|
||||
|
||||
let temp_path = path.with_extension("md.tmp");
|
||||
fs::write(&temp_path, content)
|
||||
.map_err(|err| AgentError::Other(format!("write agent prompt temp file error: {}", err)))?;
|
||||
fs::rename(&temp_path, path)
|
||||
.map_err(|err| AgentError::Other(format!("replace agent prompt file error: {}", err)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn agent_prompt_path() -> Result<PathBuf, AgentError> {
|
||||
let home = dirs::home_dir()
|
||||
.ok_or_else(|| AgentError::Other("home directory not found".to_string()))?;
|
||||
Ok(home.join(".picobot").join("agent").join("AGENT.md"))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_upsert_managed_agent_memory_block_inserts_before_reply_rules() {
|
||||
let original =
|
||||
"# PicoBot 代理配置\n\n## 身份\n- 你是 PicoBot。\n\n## 回复规则\n- 使用中文回复。\n";
|
||||
let updated = upsert_managed_agent_memory_block(
|
||||
original,
|
||||
"### 用户事实\n- 用户在做AI产品\n\n### 用户偏好\n- 偏好简洁表达",
|
||||
);
|
||||
|
||||
let managed_pos = updated.find(MANAGED_AGENT_MEMORY_BLOCK_START).unwrap();
|
||||
let reply_rules_pos = updated.find("## 回复规则").unwrap();
|
||||
assert!(managed_pos < reply_rules_pos);
|
||||
assert!(updated.contains(MANAGED_AGENT_MEMORY_TITLE));
|
||||
assert!(updated.contains("用户在做AI产品"));
|
||||
assert!(updated.contains("偏好简洁表达"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_upsert_managed_agent_memory_block_replaces_existing_block() {
|
||||
let original = format!(
|
||||
"# PicoBot\n\n{MANAGED_AGENT_MEMORY_BLOCK_START}\n{MANAGED_AGENT_MEMORY_TITLE}\n\nold\n{MANAGED_AGENT_MEMORY_BLOCK_END}\n\n## 回复规则\n- 简洁。\n"
|
||||
);
|
||||
|
||||
let updated = upsert_managed_agent_memory_block(&original, "new");
|
||||
|
||||
assert!(updated.contains("new"));
|
||||
assert!(!updated.contains("old"));
|
||||
assert_eq!(updated.matches(MANAGED_AGENT_MEMORY_BLOCK_START).count(), 1);
|
||||
assert_eq!(updated.matches(MANAGED_AGENT_MEMORY_BLOCK_END).count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_upsert_managed_agent_memory_block_trims_summary_body() {
|
||||
let updated = upsert_managed_agent_memory_block("# PicoBot\n", "\n\nsummary\n\n");
|
||||
|
||||
assert!(updated.contains("\n\nsummary\n"));
|
||||
assert!(!updated.contains("\n\nsummary\n\n\n"));
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,18 +1,15 @@
|
||||
use super::{
|
||||
GatewayState,
|
||||
session::{Session, handle_in_chat_command, schedule_background_history_compaction},
|
||||
};
|
||||
use crate::agent::EmittedMessageHandler;
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::bus::message::{ToolMessageState, format_tool_call_content};
|
||||
use crate::protocol::{SessionSummary, WsInbound, WsOutbound, parse_inbound, serialize_outbound};
|
||||
use std::sync::Arc;
|
||||
use async_trait::async_trait;
|
||||
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage};
|
||||
use axum::extract::State;
|
||||
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
|
||||
use axum::response::Response;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use crate::agent::EmittedMessageHandler;
|
||||
use crate::bus::message::{format_tool_call_content, ToolMessageState};
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound};
|
||||
use super::{GatewayState, session::{Session, handle_in_chat_command, schedule_background_history_compaction}};
|
||||
|
||||
struct WsToolCallEmitter {
|
||||
sender: mpsc::Sender<WsOutbound>,
|
||||
@ -123,9 +120,7 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
||||
&runtime_session_id,
|
||||
&mut current_session_id,
|
||||
inbound,
|
||||
)
|
||||
.await
|
||||
{
|
||||
).await {
|
||||
tracing::warn!(error = %e, session_id = %current_session_id, "Failed to handle inbound message");
|
||||
let _ = session
|
||||
.lock()
|
||||
@ -187,14 +182,17 @@ fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
|
||||
});
|
||||
}
|
||||
|
||||
outbound.extend(tool_calls.iter().map(|tool_call| WsOutbound::ToolCall {
|
||||
outbound.extend(tool_calls
|
||||
.iter()
|
||||
.map(|tool_call| WsOutbound::ToolCall {
|
||||
id: message.id.clone(),
|
||||
tool_call_id: tool_call.id.clone(),
|
||||
tool_name: tool_call.name.clone(),
|
||||
arguments: tool_call.arguments.clone(),
|
||||
content: format_tool_call_content(&tool_call.name, &tool_call.arguments),
|
||||
role: message.role.clone(),
|
||||
}));
|
||||
})
|
||||
);
|
||||
outbound
|
||||
} else {
|
||||
vec![WsOutbound::AssistantResponse {
|
||||
@ -204,11 +202,7 @@ fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
|
||||
}]
|
||||
}
|
||||
}
|
||||
"tool" => match message
|
||||
.tool_state
|
||||
.as_ref()
|
||||
.unwrap_or(&ToolMessageState::Completed)
|
||||
{
|
||||
"tool" => match message.tool_state.as_ref().unwrap_or(&ToolMessageState::Completed) {
|
||||
ToolMessageState::Completed => vec![WsOutbound::ToolResult {
|
||||
id: message.id.clone(),
|
||||
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
|
||||
@ -236,10 +230,7 @@ fn should_display_message_to_user(show_tool_results: bool, message: &ChatMessage
|
||||
|
||||
show_tool_results
|
||||
|| matches!(
|
||||
message
|
||||
.tool_state
|
||||
.as_ref()
|
||||
.unwrap_or(&ToolMessageState::Completed),
|
||||
message.tool_state.as_ref().unwrap_or(&ToolMessageState::Completed),
|
||||
ToolMessageState::PendingUserAction
|
||||
)
|
||||
}
|
||||
@ -252,12 +243,7 @@ async fn handle_inbound(
|
||||
inbound: WsInbound,
|
||||
) -> Result<(), crate::agent::AgentError> {
|
||||
match inbound {
|
||||
WsInbound::UserInput {
|
||||
content,
|
||||
chat_id,
|
||||
sender_id,
|
||||
..
|
||||
} => {
|
||||
WsInbound::UserInput { content, chat_id, sender_id, .. } => {
|
||||
let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone());
|
||||
let sender_id = resolve_ws_sender_id(sender_id.as_deref(), runtime_session_id);
|
||||
let (history, agent, user_tx) = {
|
||||
@ -266,9 +252,7 @@ async fn handle_inbound(
|
||||
session_guard.ensure_persistent_session(&chat_id)?;
|
||||
session_guard.ensure_chat_loaded(&chat_id)?;
|
||||
|
||||
if let Some(command_response) =
|
||||
handle_in_chat_command(&mut session_guard, &chat_id, &content)?
|
||||
{
|
||||
if let Some(command_response) = handle_in_chat_command(&mut session_guard, &chat_id, &content)? {
|
||||
let _ = session_guard
|
||||
.send(WsOutbound::AssistantResponse {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
@ -302,17 +286,13 @@ async fn handle_inbound(
|
||||
match agent.process(history).await {
|
||||
Ok(result) => {
|
||||
let mut session_guard = session.lock().await;
|
||||
session_guard
|
||||
.append_persisted_messages(&chat_id, result.emitted_messages.clone())?;
|
||||
session_guard.append_persisted_messages(&chat_id, result.emitted_messages.clone())?;
|
||||
for outbound in result
|
||||
.emitted_messages
|
||||
.iter()
|
||||
.filter(|message| {
|
||||
!message.is_assistant_tool_call_message()
|
||||
&& should_display_message_to_user(
|
||||
state.config.gateway.show_tool_results,
|
||||
message,
|
||||
)
|
||||
&& should_display_message_to_user(state.config.gateway.show_tool_results, message)
|
||||
})
|
||||
.flat_map(ws_outbound_from_chat_message)
|
||||
{
|
||||
@ -321,10 +301,7 @@ async fn handle_inbound(
|
||||
|
||||
drop(session_guard);
|
||||
|
||||
if let Err(error) =
|
||||
schedule_background_history_compaction(session.clone(), chat_id.clone())
|
||||
.await
|
||||
{
|
||||
if let Err(error) = schedule_background_history_compaction(session.clone(), chat_id.clone()).await {
|
||||
tracing::warn!(chat_id = %chat_id, error = %error, "Failed to schedule background history compaction for CLI session");
|
||||
}
|
||||
}
|
||||
@ -341,19 +318,16 @@ async fn handle_inbound(
|
||||
|
||||
Ok(())
|
||||
}
|
||||
WsInbound::ClearHistory {
|
||||
session_id,
|
||||
chat_id,
|
||||
} => {
|
||||
let target = session_id
|
||||
.or(chat_id)
|
||||
.unwrap_or_else(|| current_session_id.clone());
|
||||
WsInbound::ClearHistory { session_id, chat_id } => {
|
||||
let target = session_id.or(chat_id).unwrap_or_else(|| current_session_id.clone());
|
||||
state.session_manager.clear_session_messages(&target)?;
|
||||
|
||||
let mut session_guard = session.lock().await;
|
||||
session_guard.remove_history(&target);
|
||||
let _ = session_guard
|
||||
.send(WsOutbound::HistoryCleared { session_id: target })
|
||||
.send(WsOutbound::HistoryCleared {
|
||||
session_id: target,
|
||||
})
|
||||
.await;
|
||||
Ok(())
|
||||
}
|
||||
@ -478,15 +452,17 @@ fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> St
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::agent::EmittedMessageHandler;
|
||||
use super::{
|
||||
WsToolCallEmitter, resolve_ws_sender_id, should_display_message_to_user,
|
||||
WsToolCallEmitter,
|
||||
resolve_ws_sender_id,
|
||||
should_display_message_to_user,
|
||||
ws_outbound_from_chat_message,
|
||||
};
|
||||
use crate::agent::EmittedMessageHandler;
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::bus::message::ToolMessageState;
|
||||
use crate::protocol::WsOutbound;
|
||||
use crate::providers::ToolCall;
|
||||
use crate::protocol::WsOutbound;
|
||||
use serde_json::json;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
@ -505,17 +481,11 @@ mod tests {
|
||||
|
||||
assert_eq!(outbound.len(), 1);
|
||||
match &outbound[0] {
|
||||
WsOutbound::ToolCall {
|
||||
tool_call_id,
|
||||
tool_name,
|
||||
arguments,
|
||||
content,
|
||||
..
|
||||
} => {
|
||||
WsOutbound::ToolCall { tool_call_id, tool_name, arguments, content, .. } => {
|
||||
assert_eq!(tool_call_id, "call-1");
|
||||
assert_eq!(tool_name, "calculator");
|
||||
assert_eq!(arguments["expression"], "1 + 1");
|
||||
assert_eq!(content, "calculator\nargs: {\"expression\":\"1 + 1\"}");
|
||||
assert_eq!(content, "### calculator\n- expression: 1 + 1");
|
||||
}
|
||||
other => panic!("unexpected outbound variant: {:?}", other),
|
||||
}
|
||||
@ -581,14 +551,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_resolve_ws_sender_id_prefers_inbound_sender() {
|
||||
assert_eq!(
|
||||
resolve_ws_sender_id(Some("user-42"), "runtime-1"),
|
||||
"user-42"
|
||||
);
|
||||
assert_eq!(
|
||||
resolve_ws_sender_id(Some(" user-42 "), "runtime-1"),
|
||||
"user-42"
|
||||
);
|
||||
assert_eq!(resolve_ws_sender_id(Some("user-42"), "runtime-1"), "user-42");
|
||||
assert_eq!(resolve_ws_sender_id(Some(" user-42 "), "runtime-1"), "user-42");
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -609,10 +573,8 @@ mod tests {
|
||||
.handle(ChatMessage::tool("call-1", "calculator", "2"))
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
tokio::time::timeout(std::time::Duration::from_millis(50), receiver.recv())
|
||||
assert!(tokio::time::timeout(std::time::Duration::from_millis(50), receiver.recv())
|
||||
.await
|
||||
.is_err()
|
||||
);
|
||||
.is_err());
|
||||
}
|
||||
}
|
||||
|
||||
18
src/lib.rs
18
src/lib.rs
@ -1,16 +1,16 @@
|
||||
pub mod agent;
|
||||
pub mod bus;
|
||||
pub mod channels;
|
||||
pub mod cli;
|
||||
pub mod client;
|
||||
pub mod config;
|
||||
pub mod text;
|
||||
pub mod providers;
|
||||
pub mod bus;
|
||||
pub mod cli;
|
||||
pub mod agent;
|
||||
pub mod gateway;
|
||||
pub mod client;
|
||||
pub mod protocol;
|
||||
pub mod channels;
|
||||
pub mod logging;
|
||||
pub mod observability;
|
||||
pub mod protocol;
|
||||
pub mod providers;
|
||||
pub mod scheduler;
|
||||
pub mod skills;
|
||||
pub mod storage;
|
||||
pub mod text;
|
||||
pub mod tools;
|
||||
pub mod skills;
|
||||
|
||||
@ -1,9 +1,13 @@
|
||||
use std::path::PathBuf;
|
||||
use chrono::Utc;
|
||||
use chrono_tz::Tz;
|
||||
use std::path::PathBuf;
|
||||
use tracing_appender::rolling::{RollingFileAppender, Rotation};
|
||||
use tracing_subscriber::{
|
||||
EnvFilter, fmt, fmt::time::FormatTime, layer::SubscriberExt, util::SubscriberInitExt,
|
||||
fmt,
|
||||
fmt::time::FormatTime,
|
||||
layer::SubscriberExt,
|
||||
util::SubscriberInitExt,
|
||||
EnvFilter,
|
||||
};
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
@ -12,17 +16,8 @@ struct ConfiguredTimestamp {
|
||||
}
|
||||
|
||||
impl FormatTime for ConfiguredTimestamp {
|
||||
fn format_time(
|
||||
&self,
|
||||
writer: &mut tracing_subscriber::fmt::format::Writer<'_>,
|
||||
) -> std::fmt::Result {
|
||||
write!(
|
||||
writer,
|
||||
"{}",
|
||||
Utc::now()
|
||||
.with_timezone(&self.timezone)
|
||||
.to_rfc3339_opts(chrono::SecondsFormat::Millis, true)
|
||||
)
|
||||
fn format_time(&self, writer: &mut tracing_subscriber::fmt::format::Writer<'_>) -> std::fmt::Result {
|
||||
write!(writer, "{}", Utc::now().with_timezone(&self.timezone).to_rfc3339_opts(chrono::SecondsFormat::Millis, true))
|
||||
}
|
||||
}
|
||||
|
||||
@ -46,19 +41,20 @@ pub fn init_logging(timezone: Tz) {
|
||||
// Create log directory if it doesn't exist
|
||||
if !log_dir.exists() {
|
||||
if let Err(e) = std::fs::create_dir_all(&log_dir) {
|
||||
eprintln!(
|
||||
"Warning: Failed to create log directory {}: {}",
|
||||
log_dir.display(),
|
||||
e
|
||||
);
|
||||
eprintln!("Warning: Failed to create log directory {}: {}", log_dir.display(), e);
|
||||
}
|
||||
}
|
||||
|
||||
// Create file appender with daily rotation
|
||||
let file_appender = RollingFileAppender::new(Rotation::DAILY, &log_dir, "picobot.log");
|
||||
let file_appender = RollingFileAppender::new(
|
||||
Rotation::DAILY,
|
||||
&log_dir,
|
||||
"picobot.log",
|
||||
);
|
||||
|
||||
// Build subscriber with both console and file output
|
||||
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
|
||||
let env_filter = EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| EnvFilter::new("info"));
|
||||
|
||||
let file_layer = fmt::layer()
|
||||
.with_writer(file_appender)
|
||||
@ -84,7 +80,8 @@ pub fn init_logging(timezone: Tz) {
|
||||
|
||||
/// Initialize logging without file output (console only)
|
||||
pub fn init_logging_console_only(timezone: Tz) {
|
||||
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
|
||||
let env_filter = EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| EnvFilter::new("info"));
|
||||
|
||||
let console_layer = fmt::layer()
|
||||
.with_timer(ConfiguredTimestamp { timezone })
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
use clap::{CommandFactory, Parser};
|
||||
use clap::{Parser, CommandFactory};
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(name = "picobot")]
|
||||
|
||||
@ -26,7 +26,10 @@ pub enum ObserverEvent {
|
||||
success: bool,
|
||||
},
|
||||
/// Emitted when the agent starts processing.
|
||||
AgentStart { provider: String, model: String },
|
||||
AgentStart {
|
||||
provider: String,
|
||||
model: String,
|
||||
},
|
||||
/// Emitted when the agent finishes processing.
|
||||
AgentEnd {
|
||||
provider: String,
|
||||
@ -113,11 +116,7 @@ impl ToolExecutionOutcome {
|
||||
}
|
||||
|
||||
/// Create a failed outcome with duration.
|
||||
pub fn failure_with_duration(
|
||||
output: String,
|
||||
error_reason: Option<String>,
|
||||
duration: Duration,
|
||||
) -> Self {
|
||||
pub fn failure_with_duration(output: String, error_reason: Option<String>, duration: Duration) -> Self {
|
||||
Self {
|
||||
output,
|
||||
success: false,
|
||||
|
||||
@ -43,7 +43,9 @@ pub enum WsInbound {
|
||||
include_archived: bool,
|
||||
},
|
||||
#[serde(rename = "load_session")]
|
||||
LoadSession { session_id: String },
|
||||
LoadSession {
|
||||
session_id: String,
|
||||
},
|
||||
#[serde(rename = "rename_session")]
|
||||
RenameSession {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
@ -68,11 +70,7 @@ pub enum WsInbound {
|
||||
#[serde(tag = "type")]
|
||||
pub enum WsOutbound {
|
||||
#[serde(rename = "assistant_response")]
|
||||
AssistantResponse {
|
||||
id: String,
|
||||
content: String,
|
||||
role: String,
|
||||
},
|
||||
AssistantResponse { id: String, content: String, role: String },
|
||||
#[serde(rename = "tool_call")]
|
||||
ToolCall {
|
||||
id: String,
|
||||
|
||||
@ -4,9 +4,9 @@ use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use super::traits::Usage;
|
||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
||||
use crate::bus::message::ContentBlock;
|
||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
||||
use super::traits::Usage;
|
||||
|
||||
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
|
||||
let mut details = vec![error.to_string()];
|
||||
@ -20,10 +20,7 @@ fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
|
||||
details.join("\ncaused by: ")
|
||||
}
|
||||
|
||||
fn serialize_content_blocks<S>(
|
||||
blocks: &[serde_json::Value],
|
||||
serializer: S,
|
||||
) -> Result<S::Ok, S::Error>
|
||||
fn serialize_content_blocks<S>(blocks: &[serde_json::Value], serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
@ -31,15 +28,14 @@ where
|
||||
}
|
||||
|
||||
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> {
|
||||
blocks
|
||||
.iter()
|
||||
.map(|b| match b {
|
||||
blocks.iter().map(|b| match b {
|
||||
ContentBlock::Text { text } => {
|
||||
serde_json::json!({ "type": "text", "text": text })
|
||||
}
|
||||
ContentBlock::ImageUrl { image_url } => convert_image_url_to_anthropic(&image_url.url),
|
||||
})
|
||||
.collect()
|
||||
ContentBlock::ImageUrl { image_url } => {
|
||||
convert_image_url_to_anthropic(&image_url.url)
|
||||
}
|
||||
}).collect()
|
||||
}
|
||||
|
||||
fn convert_image_url_to_anthropic(url: &str) -> serde_json::Value {
|
||||
@ -151,13 +147,9 @@ struct AnthropicResponse {
|
||||
#[derive(Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
enum AnthropicContent {
|
||||
Text {
|
||||
text: String,
|
||||
},
|
||||
Text { text: String },
|
||||
#[allow(dead_code)]
|
||||
Thinking {
|
||||
thinking: String,
|
||||
},
|
||||
Thinking { thinking: String },
|
||||
#[serde(rename = "tool_use")]
|
||||
ToolUse {
|
||||
id: String,
|
||||
|
||||
@ -1,15 +1,12 @@
|
||||
pub mod anthropic;
|
||||
pub mod openai;
|
||||
pub mod traits;
|
||||
pub mod openai;
|
||||
pub mod anthropic;
|
||||
|
||||
pub use self::anthropic::AnthropicProvider;
|
||||
pub use self::openai::OpenAIProvider;
|
||||
pub use self::anthropic::AnthropicProvider;
|
||||
|
||||
use crate::config::LLMProviderConfig;
|
||||
pub use traits::{
|
||||
ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall,
|
||||
ToolFunction, Usage,
|
||||
};
|
||||
pub use traits::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall, ToolFunction, Usage};
|
||||
|
||||
pub fn create_provider(config: LLMProviderConfig) -> Result<Box<dyn LLMProvider>, ProviderError> {
|
||||
match config.provider_type.as_str() {
|
||||
|
||||
@ -1,15 +1,18 @@
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{Value, json};
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use super::traits::Usage;
|
||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
||||
use crate::bus::message::ContentBlock;
|
||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
||||
use super::traits::Usage;
|
||||
|
||||
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content"];
|
||||
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &[
|
||||
"tool_call_arguments_json",
|
||||
"mock_response_content",
|
||||
];
|
||||
|
||||
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
|
||||
let mut details = vec![error.to_string()];
|
||||
@ -29,17 +32,12 @@ fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
|
||||
return Value::String(text.clone());
|
||||
}
|
||||
}
|
||||
Value::Array(
|
||||
blocks
|
||||
.iter()
|
||||
.map(|b| match b {
|
||||
Value::Array(blocks.iter().map(|b| match b {
|
||||
ContentBlock::Text { text } => json!({ "type": "text", "text": text }),
|
||||
ContentBlock::ImageUrl { image_url } => {
|
||||
json!({ "type": "image_url", "image_url": { "url": image_url.url } })
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
}).collect())
|
||||
}
|
||||
|
||||
pub struct OpenAIProvider {
|
||||
@ -124,9 +122,7 @@ impl OpenAIProvider {
|
||||
|
||||
fn request_model_extra(&self) -> impl Iterator<Item = (&String, &Value)> {
|
||||
self.model_extra.iter().filter(|(key, _)| {
|
||||
!INTERNAL_MODEL_EXTRA_KEYS
|
||||
.iter()
|
||||
.any(|internal| internal == &key.as_str())
|
||||
!INTERNAL_MODEL_EXTRA_KEYS.iter().any(|internal| internal == &key.as_str())
|
||||
})
|
||||
}
|
||||
|
||||
@ -269,11 +265,7 @@ impl LLMProvider for OpenAIProvider {
|
||||
if let Some(content) = msg.get("content").and_then(|c| c.as_array()) {
|
||||
for (j, item) in content.iter().enumerate() {
|
||||
if item.get("type").and_then(|t| t.as_str()) == Some("image_url") {
|
||||
if let Some(url_str) = item
|
||||
.get("image_url")
|
||||
.and_then(|u| u.get("url"))
|
||||
.and_then(|v| v.as_str())
|
||||
{
|
||||
if let Some(url_str) = item.get("image_url").and_then(|u| u.get("url")).and_then(|v| v.as_str()) {
|
||||
let prefix: String = url_str.chars().take(20).collect();
|
||||
tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)");
|
||||
}
|
||||
@ -427,10 +419,7 @@ mod tests {
|
||||
assert_eq!(tool_calls[0]["id"], "call_1");
|
||||
assert_eq!(tool_calls[0]["type"], "function");
|
||||
assert_eq!(tool_calls[0]["function"]["name"], "calculator");
|
||||
assert_eq!(
|
||||
tool_calls[0]["function"]["arguments"],
|
||||
"{\"expression\":\"1+1\"}"
|
||||
);
|
||||
assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -444,7 +433,10 @@ mod tests {
|
||||
"gpt-test".to_string(),
|
||||
None,
|
||||
None,
|
||||
HashMap::from([("tool_call_arguments_json".to_string(), Value::Bool(true))]),
|
||||
HashMap::from([(
|
||||
"tool_call_arguments_json".to_string(),
|
||||
Value::Bool(true),
|
||||
)]),
|
||||
);
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
@ -469,10 +461,7 @@ mod tests {
|
||||
let messages = body["messages"].as_array().unwrap();
|
||||
let tool_calls = messages[0]["tool_calls"].as_array().unwrap();
|
||||
|
||||
assert_eq!(
|
||||
tool_calls[0]["function"]["arguments"],
|
||||
json!({"expression": "1+1"})
|
||||
);
|
||||
assert_eq!(tool_calls[0]["function"]["arguments"], json!({"expression": "1+1"}));
|
||||
assert!(body.get("tool_call_arguments_json").is_none());
|
||||
}
|
||||
|
||||
@ -512,10 +501,7 @@ mod tests {
|
||||
let messages = body["messages"].as_array().unwrap();
|
||||
let tool_calls = messages[0]["tool_calls"].as_array().unwrap();
|
||||
|
||||
assert_eq!(
|
||||
tool_calls[0]["function"]["arguments"],
|
||||
"{\"expression\":\"1+1\"}"
|
||||
);
|
||||
assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -531,10 +517,7 @@ mod tests {
|
||||
None,
|
||||
HashMap::from([
|
||||
("tool_call_arguments_json".to_string(), Value::Bool(true)),
|
||||
(
|
||||
"mock_response_content".to_string(),
|
||||
Value::String("stub".to_string()),
|
||||
),
|
||||
("mock_response_content".to_string(), Value::String("stub".to_string())),
|
||||
("parallel_tool_calls".to_string(), Value::Bool(true)),
|
||||
]),
|
||||
);
|
||||
@ -607,10 +590,7 @@ mod tests {
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
response.choices[0].message.reasoning_content.as_deref(),
|
||||
Some("hidden reasoning")
|
||||
);
|
||||
assert_eq!(response.choices[0].message.reasoning_content.as_deref(), Some("hidden reasoning"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
use crate::bus::message::ContentBlock;
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::bus::message::ContentBlock;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
@ -61,11 +61,7 @@ impl Message {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tool(
|
||||
tool_call_id: impl Into<String>,
|
||||
tool_name: impl Into<String>,
|
||||
content: impl Into<String>,
|
||||
) -> Self {
|
||||
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: "tool".to_string(),
|
||||
content: vec![ContentBlock::text(content)],
|
||||
|
||||
@ -8,11 +8,11 @@ use tokio::sync::watch;
|
||||
|
||||
use crate::bus::{MessageBus, OutboundMessage};
|
||||
use crate::config::{
|
||||
SchedulerConfig, SchedulerJobConfig, SchedulerJobKind, SchedulerJobTarget,
|
||||
SchedulerMisfirePolicy, SchedulerSchedule,
|
||||
SchedulerConfig, SchedulerJobConfig, SchedulerJobKind, SchedulerJobTarget, SchedulerMisfirePolicy,
|
||||
SchedulerSchedule,
|
||||
};
|
||||
use crate::gateway::session::ScheduledAgentTaskOptions;
|
||||
use crate::gateway::session::SessionManager;
|
||||
use crate::gateway::session::ScheduledAgentTaskOptions;
|
||||
use crate::storage::{
|
||||
SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus, SchedulerJobUpsert, SessionStore,
|
||||
};
|
||||
@ -76,11 +76,8 @@ impl Scheduler {
|
||||
|
||||
fn sync_config_jobs(&self) -> anyhow::Result<()> {
|
||||
let now = Utc::now();
|
||||
for job in self.config.effective_jobs(&crate::config::TimeConfig {
|
||||
timezone: self.timezone.name().to_string(),
|
||||
}) {
|
||||
let runtime =
|
||||
RuntimeJob::from_config(&job, now, self.config.misfire_policy, self.timezone)?;
|
||||
for job in self.config.effective_jobs(&crate::config::TimeConfig { timezone: self.timezone.name().to_string() }) {
|
||||
let runtime = RuntimeJob::from_config(&job, now, self.config.misfire_policy, self.timezone)?;
|
||||
self.store.upsert_scheduler_job(&runtime.to_upsert())?;
|
||||
}
|
||||
Ok(())
|
||||
@ -91,9 +88,7 @@ impl Scheduler {
|
||||
let jobs = self.store.list_scheduler_jobs(true)?;
|
||||
|
||||
for record in jobs {
|
||||
let Some(mut job) =
|
||||
RuntimeJob::from_record(&record, self.config.misfire_policy, self.timezone)?
|
||||
else {
|
||||
let Some(mut job) = RuntimeJob::from_record(&record, self.config.misfire_policy, self.timezone)? else {
|
||||
continue;
|
||||
};
|
||||
|
||||
@ -183,12 +178,8 @@ impl Scheduler {
|
||||
}
|
||||
SchedulerJobKind::SilentAgentTask => {
|
||||
let execution_chat_id = resolve_execution_chat_id(job)?;
|
||||
if let Err(error) =
|
||||
execute_agent_task(&self.session_manager, job, &execution_chat_id).await
|
||||
{
|
||||
if let Err(notify_error) =
|
||||
self.notify_silent_agent_task_failure(job, &error).await
|
||||
{
|
||||
if let Err(error) = execute_agent_task(&self.session_manager, job, &execution_chat_id).await {
|
||||
if let Err(notify_error) = self.notify_silent_agent_task_failure(job, &error).await {
|
||||
tracing::error!(
|
||||
job_id = %job.id,
|
||||
error = %notify_error,
|
||||
@ -217,13 +208,10 @@ impl Scheduler {
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("scheduler_job_id".to_string(), job.id.clone());
|
||||
metadata.insert(
|
||||
"scheduler_job_kind".to_string(),
|
||||
"silent_agent_task".to_string(),
|
||||
);
|
||||
metadata.insert("scheduler_job_kind".to_string(), "silent_agent_task".to_string());
|
||||
|
||||
self.bus
|
||||
.publish_outbound(OutboundMessage::error_notification(
|
||||
.publish_outbound(OutboundMessage::assistant(
|
||||
channel,
|
||||
chat_id,
|
||||
format!(
|
||||
@ -316,24 +304,14 @@ impl RuntimeJob {
|
||||
}
|
||||
};
|
||||
|
||||
let schedule = deserialize_schedule(
|
||||
&record.schedule,
|
||||
record.interval_secs,
|
||||
record.startup_delay_secs,
|
||||
)?;
|
||||
let schedule = deserialize_schedule(&record.schedule, record.interval_secs, record.startup_delay_secs)?;
|
||||
let now = Utc::now();
|
||||
let next_fire_at = match (record.enabled, record.state.clone(), record.next_fire_at) {
|
||||
(false, _, _) => None,
|
||||
(_, SchedulerJobState::Paused, _) => None,
|
||||
(_, SchedulerJobState::Completed, _) => None,
|
||||
(_, _, some_next) if some_next.is_some() => some_next,
|
||||
_ => compute_initial_next_fire_at(
|
||||
&schedule,
|
||||
now,
|
||||
record.last_fired_at,
|
||||
misfire_policy,
|
||||
timezone,
|
||||
)?,
|
||||
_ => compute_initial_next_fire_at(&schedule, now, record.last_fired_at, misfire_policy, timezone)?,
|
||||
};
|
||||
|
||||
Ok(Some(Self {
|
||||
@ -360,10 +338,7 @@ impl RuntimeJob {
|
||||
fn is_due(&self, now: DateTime<Utc>) -> bool {
|
||||
self.enabled
|
||||
&& self.state == SchedulerJobState::Scheduled
|
||||
&& self
|
||||
.next_fire_at
|
||||
.map(|value| value <= now.timestamp_millis())
|
||||
.unwrap_or(false)
|
||||
&& self.next_fire_at.map(|value| value <= now.timestamp_millis()).unwrap_or(false)
|
||||
}
|
||||
|
||||
fn after_execution(
|
||||
@ -396,8 +371,7 @@ impl RuntimeJob {
|
||||
let reference_ms = self.next_fire_at.or(self.last_fired_at);
|
||||
self.state = SchedulerJobState::Scheduled;
|
||||
self.completed_at = None;
|
||||
self.next_fire_at =
|
||||
compute_next_fire_at(&self.schedule, now, reference_ms, misfire_policy, timezone)?;
|
||||
self.next_fire_at = compute_next_fire_at(&self.schedule, now, reference_ms, misfire_policy, timezone)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -410,8 +384,7 @@ impl RuntimeJob {
|
||||
SchedulerJobKind::AgentTask => "agent_task".to_string(),
|
||||
SchedulerJobKind::SilentAgentTask => "silent_agent_task".to_string(),
|
||||
},
|
||||
schedule: serde_json::to_value(&self.schedule)
|
||||
.unwrap_or_else(|_| serde_json::json!({})),
|
||||
schedule: serde_json::to_value(&self.schedule).unwrap_or_else(|_| serde_json::json!({})),
|
||||
interval_secs: self.interval_secs,
|
||||
startup_delay_secs: self.startup_delay_secs,
|
||||
target: serde_json::to_value(&self.target).unwrap_or_else(|_| serde_json::json!({})),
|
||||
@ -457,36 +430,21 @@ fn compute_initial_next_fire_at(
|
||||
timezone: Tz,
|
||||
) -> anyhow::Result<Option<i64>> {
|
||||
match last_fired_at {
|
||||
Some(last_fired_at) => {
|
||||
compute_next_fire_at(schedule, now, Some(last_fired_at), misfire_policy, timezone)
|
||||
}
|
||||
Some(last_fired_at) => compute_next_fire_at(schedule, now, Some(last_fired_at), misfire_policy, timezone),
|
||||
None => match schedule {
|
||||
SchedulerSchedule::Delay { seconds } => Ok(Some(
|
||||
(now + ChronoDuration::seconds(*seconds as i64)).timestamp_millis(),
|
||||
)),
|
||||
SchedulerSchedule::Delay { seconds } => Ok(Some((now + ChronoDuration::seconds(*seconds as i64)).timestamp_millis())),
|
||||
SchedulerSchedule::Interval {
|
||||
seconds,
|
||||
startup_delay_secs,
|
||||
} => {
|
||||
let delay = if *startup_delay_secs > 0 {
|
||||
*startup_delay_secs
|
||||
} else {
|
||||
*seconds
|
||||
};
|
||||
Ok(Some(
|
||||
(now + ChronoDuration::seconds(delay as i64)).timestamp_millis(),
|
||||
))
|
||||
}
|
||||
SchedulerSchedule::At { timestamp } => {
|
||||
Ok(Some(parse_rfc3339_to_utc(timestamp)?.timestamp_millis()))
|
||||
let delay = if *startup_delay_secs > 0 { *startup_delay_secs } else { *seconds };
|
||||
Ok(Some((now + ChronoDuration::seconds(delay as i64)).timestamp_millis()))
|
||||
}
|
||||
SchedulerSchedule::At { timestamp } => Ok(Some(parse_rfc3339_to_utc(timestamp)?.timestamp_millis())),
|
||||
SchedulerSchedule::Cron { expression } => {
|
||||
let schedule = parse_scheduler_cron(expression)?;
|
||||
let local_now = now.with_timezone(&timezone);
|
||||
Ok(schedule
|
||||
.after(&local_now)
|
||||
.next()
|
||||
.map(|next| next.with_timezone(&Utc).timestamp_millis()))
|
||||
Ok(schedule.after(&local_now).next().map(|next| next.with_timezone(&Utc).timestamp_millis()))
|
||||
}
|
||||
},
|
||||
}
|
||||
@ -525,10 +483,7 @@ fn compute_next_fire_at(
|
||||
.map(|value| value.with_timezone(&timezone))
|
||||
.unwrap_or_else(|| now.with_timezone(&timezone)),
|
||||
};
|
||||
Ok(schedule
|
||||
.after(&anchor)
|
||||
.next()
|
||||
.map(|next| next.with_timezone(&Utc).timestamp_millis()))
|
||||
Ok(schedule.after(&anchor).next().map(|next| next.with_timezone(&Utc).timestamp_millis()))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -570,14 +525,12 @@ fn build_outbound_message(job: &RuntimeJob) -> anyhow::Result<OutboundMessage> {
|
||||
.payload
|
||||
.get("content")
|
||||
.and_then(|value| value.as_str())
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!("outbound scheduler job payload.content must be a string")
|
||||
})?;
|
||||
.ok_or_else(|| anyhow::anyhow!("outbound scheduler job payload.content must be a string"))?;
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("scheduler_job_id".to_string(), job.id.clone());
|
||||
|
||||
Ok(OutboundMessage::scheduler_notification(
|
||||
Ok(OutboundMessage::assistant(
|
||||
channel,
|
||||
chat_id,
|
||||
content.to_string(),
|
||||
@ -586,10 +539,7 @@ fn build_outbound_message(job: &RuntimeJob) -> anyhow::Result<OutboundMessage> {
|
||||
))
|
||||
}
|
||||
|
||||
async fn execute_internal_event(
|
||||
session_manager: &SessionManager,
|
||||
job: &RuntimeJob,
|
||||
) -> anyhow::Result<()> {
|
||||
async fn execute_internal_event(session_manager: &SessionManager, job: &RuntimeJob) -> anyhow::Result<()> {
|
||||
let event = job
|
||||
.payload
|
||||
.get("event")
|
||||
@ -649,10 +599,7 @@ async fn execute_agent_task(
|
||||
.map_err(|error| anyhow::anyhow!(error.to_string()))
|
||||
}
|
||||
|
||||
fn required_notification_chat_id<'a>(
|
||||
job: &'a RuntimeJob,
|
||||
kind_name: &str,
|
||||
) -> anyhow::Result<&'a str> {
|
||||
fn required_notification_chat_id<'a>(job: &'a RuntimeJob, kind_name: &str) -> anyhow::Result<&'a str> {
|
||||
job.target
|
||||
.chat_id
|
||||
.as_deref()
|
||||
@ -661,9 +608,7 @@ fn required_notification_chat_id<'a>(
|
||||
|
||||
fn resolve_execution_chat_id(job: &RuntimeJob) -> anyhow::Result<String> {
|
||||
match job.kind {
|
||||
SchedulerJobKind::AgentTask => {
|
||||
Ok(required_notification_chat_id(job, "agent_task")?.to_string())
|
||||
}
|
||||
SchedulerJobKind::AgentTask => Ok(required_notification_chat_id(job, "agent_task")?.to_string()),
|
||||
SchedulerJobKind::SilentAgentTask => Ok(job
|
||||
.target
|
||||
.session_chat_id
|
||||
@ -688,9 +633,7 @@ fn summarize_scheduler_error(error: &anyhow::Error) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_scheduled_agent_task_options(
|
||||
job: &RuntimeJob,
|
||||
) -> anyhow::Result<ScheduledAgentTaskOptions> {
|
||||
fn parse_scheduled_agent_task_options(job: &RuntimeJob) -> anyhow::Result<ScheduledAgentTaskOptions> {
|
||||
let sender_id = job
|
||||
.payload
|
||||
.get("sender_id")
|
||||
@ -722,9 +665,7 @@ fn parse_scheduled_agent_task_options(
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_metadata_map(
|
||||
value: Option<&serde_json::Value>,
|
||||
) -> anyhow::Result<HashMap<String, String>> {
|
||||
fn parse_metadata_map(value: Option<&serde_json::Value>) -> anyhow::Result<HashMap<String, String>> {
|
||||
let Some(value) = value else {
|
||||
return Ok(HashMap::new());
|
||||
};
|
||||
@ -744,7 +685,7 @@ fn parse_metadata_map(
|
||||
return Err(anyhow::anyhow!(
|
||||
"agent_task payload.metadata field '{}' must be a string, number, bool, or null",
|
||||
key
|
||||
));
|
||||
))
|
||||
}
|
||||
};
|
||||
metadata.insert(key.clone(), stringified);
|
||||
@ -789,19 +730,12 @@ mod agent_task_tests {
|
||||
updated_at: 1_700_000_000_000,
|
||||
};
|
||||
|
||||
let job = RuntimeJob::from_record(
|
||||
&record,
|
||||
SchedulerMisfirePolicy::Skip,
|
||||
chrono_tz::Asia::Shanghai,
|
||||
)
|
||||
let job = RuntimeJob::from_record(&record, SchedulerMisfirePolicy::Skip, chrono_tz::Asia::Shanghai)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(job.kind, SchedulerJobKind::AgentTask);
|
||||
assert_eq!(
|
||||
job.payload.get("prompt").and_then(|value| value.as_str()),
|
||||
Some("请总结今天待办")
|
||||
);
|
||||
assert_eq!(job.payload.get("prompt").and_then(|value| value.as_str()), Some("请总结今天待办"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -837,19 +771,12 @@ mod agent_task_tests {
|
||||
updated_at: 1_700_000_000_000,
|
||||
};
|
||||
|
||||
let job = RuntimeJob::from_record(
|
||||
&record,
|
||||
SchedulerMisfirePolicy::Skip,
|
||||
chrono_tz::Asia::Shanghai,
|
||||
)
|
||||
let job = RuntimeJob::from_record(&record, SchedulerMisfirePolicy::Skip, chrono_tz::Asia::Shanghai)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(job.kind, SchedulerJobKind::SilentAgentTask);
|
||||
assert_eq!(
|
||||
job.target.session_chat_id.as_deref(),
|
||||
Some("scheduler/agent.daily_summary.background")
|
||||
);
|
||||
assert_eq!(job.target.session_chat_id.as_deref(), Some("scheduler/agent.daily_summary.background"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -898,18 +825,9 @@ mod agent_task_tests {
|
||||
assert_eq!(options.sender_id.as_deref(), Some("scheduler-bot"));
|
||||
assert!(options.fresh_session);
|
||||
assert_eq!(options.system_prompt.as_deref(), Some("你是日报助手"));
|
||||
assert_eq!(
|
||||
options.metadata.get("job_type").map(String::as_str),
|
||||
Some("daily_summary")
|
||||
);
|
||||
assert_eq!(
|
||||
options.metadata.get("priority").map(String::as_str),
|
||||
Some("1")
|
||||
);
|
||||
assert_eq!(
|
||||
options.metadata.get("urgent").map(String::as_str),
|
||||
Some("false")
|
||||
);
|
||||
assert_eq!(options.metadata.get("job_type").map(String::as_str), Some("daily_summary"));
|
||||
assert_eq!(options.metadata.get("priority").map(String::as_str), Some("1"));
|
||||
assert_eq!(options.metadata.get("urgent").map(String::as_str), Some("false"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -962,12 +880,12 @@ impl TryFrom<serde_json::Value> for SchedulerJobTarget {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
use crate::bus::MessageBus;
|
||||
use crate::config::{BUILTIN_MEMORY_MAINTENANCE_JOB_ID, LLMProviderConfig};
|
||||
use crate::gateway::session::SessionManager;
|
||||
use crate::skills::SkillRuntime;
|
||||
use crate::storage::{SchedulerJobUpsert, SessionStore};
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn test_provider_config() -> LLMProviderConfig {
|
||||
LLMProviderConfig {
|
||||
@ -980,7 +898,6 @@ mod tests {
|
||||
model_id: "test-model".to_string(),
|
||||
temperature: Some(0.0),
|
||||
max_tokens: None,
|
||||
context_window_tokens: None,
|
||||
model_extra: HashMap::new(),
|
||||
max_tool_iterations: 4,
|
||||
tool_result_max_chars: 20_000,
|
||||
@ -1004,10 +921,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn runtime_job_skip_policy_advances_from_now() {
|
||||
let now = Utc
|
||||
.timestamp_millis_opt(1_700_000_000_000)
|
||||
.single()
|
||||
.unwrap();
|
||||
let now = Utc.timestamp_millis_opt(1_700_000_000_000).single().unwrap();
|
||||
let next = compute_next_fire_at(
|
||||
&SchedulerSchedule::Interval {
|
||||
seconds: 60,
|
||||
@ -1026,10 +940,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn runtime_job_catch_up_policy_moves_past_now() {
|
||||
let now = Utc
|
||||
.timestamp_millis_opt(1_700_000_000_000)
|
||||
.single()
|
||||
.unwrap();
|
||||
let now = Utc.timestamp_millis_opt(1_700_000_000_000).single().unwrap();
|
||||
let next = compute_next_fire_at(
|
||||
&SchedulerSchedule::Interval {
|
||||
seconds: 60,
|
||||
@ -1078,21 +989,14 @@ mod tests {
|
||||
updated_at: 1_700_000_000_000,
|
||||
};
|
||||
|
||||
let job = RuntimeJob::from_record(
|
||||
&record,
|
||||
SchedulerMisfirePolicy::Skip,
|
||||
chrono_tz::Asia::Shanghai,
|
||||
)
|
||||
let job = RuntimeJob::from_record(&record, SchedulerMisfirePolicy::Skip, chrono_tz::Asia::Shanghai)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
job.schedule,
|
||||
SchedulerSchedule::Interval {
|
||||
assert_eq!(job.schedule, SchedulerSchedule::Interval {
|
||||
seconds: 120,
|
||||
startup_delay_secs: 10,
|
||||
}
|
||||
);
|
||||
});
|
||||
assert_eq!(job.next_fire_at, Some(1_700_000_010_000));
|
||||
}
|
||||
|
||||
@ -1146,10 +1050,7 @@ mod tests {
|
||||
|
||||
scheduler.process_tick().await.unwrap();
|
||||
|
||||
let saved = store
|
||||
.get_scheduler_job("massage_reminder")
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let saved = store.get_scheduler_job("massage_reminder").unwrap().unwrap();
|
||||
assert!(saved.next_fire_at.is_some());
|
||||
assert_eq!(saved.run_count, 0);
|
||||
assert_eq!(saved.state, SchedulerJobState::Scheduled);
|
||||
@ -1179,10 +1080,7 @@ mod tests {
|
||||
assert_eq!(saved.kind, "internal_event");
|
||||
assert!(saved.enabled);
|
||||
assert_eq!(saved.state, SchedulerJobState::Scheduled);
|
||||
assert_eq!(
|
||||
saved.payload.get("event").and_then(|value| value.as_str()),
|
||||
Some("memory_maintenance")
|
||||
);
|
||||
assert_eq!(saved.payload.get("event").and_then(|value| value.as_str()), Some("memory_maintenance"));
|
||||
assert_eq!(
|
||||
saved.schedule,
|
||||
serde_json::json!({
|
||||
@ -1190,13 +1088,7 @@ mod tests {
|
||||
"expression": "0 */4 * * *"
|
||||
})
|
||||
);
|
||||
assert_eq!(
|
||||
saved
|
||||
.payload
|
||||
.get("local_time")
|
||||
.and_then(|value| value.as_str()),
|
||||
Some("every_4_hours")
|
||||
);
|
||||
assert_eq!(saved.payload.get("local_time").and_then(|value| value.as_str()), Some("every_4_hours"));
|
||||
assert!(saved.next_fire_at.is_some());
|
||||
}
|
||||
|
||||
@ -1263,10 +1155,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn cron_schedule_uses_configured_timezone() {
|
||||
let now = Utc
|
||||
.with_ymd_and_hms(2026, 4, 23, 18, 0, 0)
|
||||
.single()
|
||||
.unwrap();
|
||||
let now = Utc.with_ymd_and_hms(2026, 4, 23, 18, 0, 0).single().unwrap();
|
||||
let next = compute_next_fire_at(
|
||||
&SchedulerSchedule::Cron {
|
||||
expression: "0 3 * * *".to_string(),
|
||||
@ -1280,11 +1169,6 @@ mod tests {
|
||||
.unwrap();
|
||||
|
||||
let next_utc = ts_millis_to_utc(next).unwrap();
|
||||
assert_eq!(
|
||||
next_utc,
|
||||
Utc.with_ymd_and_hms(2026, 4, 23, 19, 0, 0)
|
||||
.single()
|
||||
.unwrap()
|
||||
);
|
||||
assert_eq!(next_utc, Utc.with_ymd_and_hms(2026, 4, 23, 19, 0, 0).single().unwrap());
|
||||
}
|
||||
}
|
||||
@ -89,10 +89,7 @@ impl SkillRuntime {
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.catalog
|
||||
.read()
|
||||
.expect("skills rwlock poisoned")
|
||||
.is_empty()
|
||||
self.catalog.read().expect("skills rwlock poisoned").is_empty()
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
@ -100,53 +97,31 @@ impl SkillRuntime {
|
||||
}
|
||||
|
||||
pub fn system_index_prompt(&self) -> Option<String> {
|
||||
self.catalog
|
||||
.read()
|
||||
.expect("skills rwlock poisoned")
|
||||
.system_index_prompt()
|
||||
self.catalog.read().expect("skills rwlock poisoned").system_index_prompt()
|
||||
}
|
||||
|
||||
pub fn discovery_event_payload(&self) -> serde_json::Value {
|
||||
self.catalog
|
||||
.read()
|
||||
.expect("skills rwlock poisoned")
|
||||
.discovery_event_payload()
|
||||
self.catalog.read().expect("skills rwlock poisoned").discovery_event_payload()
|
||||
}
|
||||
|
||||
pub fn offered_event_payload(&self) -> serde_json::Value {
|
||||
self.catalog
|
||||
.read()
|
||||
.expect("skills rwlock poisoned")
|
||||
.offered_event_payload()
|
||||
self.catalog.read().expect("skills rwlock poisoned").offered_event_payload()
|
||||
}
|
||||
|
||||
pub fn skill_tool_definition(&self) -> Option<Tool> {
|
||||
self.catalog
|
||||
.read()
|
||||
.expect("skills rwlock poisoned")
|
||||
.skill_tool_definition()
|
||||
self.catalog.read().expect("skills rwlock poisoned").skill_tool_definition()
|
||||
}
|
||||
|
||||
pub fn activation_payload(&self, name: &str) -> Result<String, String> {
|
||||
self.catalog
|
||||
.read()
|
||||
.expect("skills rwlock poisoned")
|
||||
.activation_payload(name)
|
||||
self.catalog.read().expect("skills rwlock poisoned").activation_payload(name)
|
||||
}
|
||||
|
||||
pub fn activation_event_payload(&self, name: &str) -> Result<serde_json::Value, String> {
|
||||
self.catalog
|
||||
.read()
|
||||
.expect("skills rwlock poisoned")
|
||||
.activation_event_payload(name)
|
||||
self.catalog.read().expect("skills rwlock poisoned").activation_event_payload(name)
|
||||
}
|
||||
|
||||
pub fn list_skills(&self) -> Vec<Skill> {
|
||||
self.catalog
|
||||
.read()
|
||||
.expect("skills rwlock poisoned")
|
||||
.skills
|
||||
.clone()
|
||||
self.catalog.read().expect("skills rwlock poisoned").skills.clone()
|
||||
}
|
||||
|
||||
pub fn get_skill(&self, name: &str) -> Option<Skill> {
|
||||
@ -168,11 +143,7 @@ impl SkillRuntime {
|
||||
validate_skill_name(name)?;
|
||||
let path = skill_file_path(scope, name)?;
|
||||
if path.exists() {
|
||||
return Err(format!(
|
||||
"skill '{}' already exists at {}",
|
||||
name,
|
||||
path.display()
|
||||
));
|
||||
return Err(format!("skill '{}' already exists at {}", name, path.display()));
|
||||
}
|
||||
|
||||
write_skill_file(&path, name, description, body)?;
|
||||
@ -209,20 +180,14 @@ impl SkillRuntime {
|
||||
Ok(skill)
|
||||
}
|
||||
|
||||
pub fn delete_skill(
|
||||
&self,
|
||||
scope: SkillScope,
|
||||
name: &str,
|
||||
reload: bool,
|
||||
) -> Result<PathBuf, String> {
|
||||
pub fn delete_skill(&self, scope: SkillScope, name: &str, reload: bool) -> Result<PathBuf, String> {
|
||||
validate_skill_name(name)?;
|
||||
let dir = skill_dir_path(scope, name)?;
|
||||
if !dir.exists() {
|
||||
return Err(format!("skill '{}' not found at {}", name, dir.display()));
|
||||
}
|
||||
|
||||
fs::remove_dir_all(&dir)
|
||||
.map_err(|err| format!("failed to delete skill directory: {}", err))?;
|
||||
fs::remove_dir_all(&dir).map_err(|err| format!("failed to delete skill directory: {}", err))?;
|
||||
if reload {
|
||||
let _ = self.reload()?;
|
||||
}
|
||||
@ -474,8 +439,7 @@ fn validate_skill_name(name: &str) -> Result<(), String> {
|
||||
}
|
||||
|
||||
pub fn project_skills_root() -> Result<PathBuf, String> {
|
||||
let cwd =
|
||||
std::env::current_dir().map_err(|err| format!("failed to get current dir: {}", err))?;
|
||||
let cwd = std::env::current_dir().map_err(|err| format!("failed to get current dir: {}", err))?;
|
||||
Ok(cwd.join(".picobot").join("skills"))
|
||||
}
|
||||
|
||||
@ -502,9 +466,7 @@ fn source_root(source: SkillSource, cwd: &Path) -> Option<PathBuf> {
|
||||
|
||||
fn root_for_scope(scope: SkillScope) -> Result<PathBuf, String> {
|
||||
match scope {
|
||||
SkillScope::User => {
|
||||
user_skills_root().ok_or_else(|| "failed to resolve home directory".to_string())
|
||||
}
|
||||
SkillScope::User => user_skills_root().ok_or_else(|| "failed to resolve home directory".to_string()),
|
||||
SkillScope::Project => project_skills_root(),
|
||||
}
|
||||
}
|
||||
@ -546,8 +508,7 @@ fn render_skill_file(name: &str, description: &str, body: &str) -> Result<String
|
||||
fn write_skill_file(path: &Path, name: &str, description: &str, body: &str) -> Result<(), String> {
|
||||
let content = render_skill_file(name, description, body)?;
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.map_err(|err| format!("failed to create skill directory: {}", err))?;
|
||||
fs::create_dir_all(parent).map_err(|err| format!("failed to create skill directory: {}", err))?;
|
||||
}
|
||||
fs::write(path, content).map_err(|err| format!("failed to write skill file: {}", err))
|
||||
}
|
||||
@ -595,10 +556,11 @@ struct SkillFrontmatter {
|
||||
}
|
||||
|
||||
fn parse_skill_file(path: &Path, source: SkillSource) -> Result<Skill, String> {
|
||||
let content = fs::read_to_string(path).map_err(|e| format!("failed to read file: {}", e))?;
|
||||
let content = fs::read_to_string(path)
|
||||
.map_err(|e| format!("failed to read file: {}", e))?;
|
||||
|
||||
let (frontmatter_raw, body) =
|
||||
split_frontmatter(&content).ok_or_else(|| "missing YAML frontmatter block".to_string())?;
|
||||
let (frontmatter_raw, body) = split_frontmatter(&content)
|
||||
.ok_or_else(|| "missing YAML frontmatter block".to_string())?;
|
||||
|
||||
let frontmatter: SkillFrontmatter = serde_yaml::from_str(frontmatter_raw)
|
||||
.map_err(|e| format!("invalid YAML frontmatter: {}", e))?;
|
||||
@ -614,7 +576,11 @@ fn parse_skill_file(path: &Path, source: SkillSource) -> Result<Skill, String> {
|
||||
.map(|s| s.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|| "unknown-skill".to_string());
|
||||
|
||||
let name = frontmatter.name.unwrap_or(dir_name).trim().to_string();
|
||||
let name = frontmatter
|
||||
.name
|
||||
.unwrap_or(dir_name)
|
||||
.trim()
|
||||
.to_string();
|
||||
|
||||
Ok(Skill {
|
||||
name,
|
||||
@ -690,10 +656,7 @@ mod tests {
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let skills = load_skills_from_root(
|
||||
&dir.path().join(".picobot").join("skills"),
|
||||
SkillSource::Project,
|
||||
);
|
||||
let skills = load_skills_from_root(&dir.path().join(".picobot").join("skills"), SkillSource::Project);
|
||||
let catalog = SkillCatalog {
|
||||
skills,
|
||||
max_index_chars: 4000,
|
||||
@ -744,13 +707,7 @@ mod tests {
|
||||
assert_eq!(runtime.len(), 0);
|
||||
|
||||
let created = runtime
|
||||
.create_skill(
|
||||
SkillScope::Project,
|
||||
"demo-skill",
|
||||
"demo desc",
|
||||
"line 1",
|
||||
true,
|
||||
)
|
||||
.create_skill(SkillScope::Project, "demo-skill", "demo desc", "line 1", true)
|
||||
.unwrap();
|
||||
assert_eq!(created.name, "demo-skill");
|
||||
assert_eq!(runtime.len(), 1);
|
||||
@ -765,12 +722,7 @@ mod tests {
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(updated.description, "updated desc");
|
||||
assert!(
|
||||
runtime
|
||||
.activation_payload("demo-skill")
|
||||
.unwrap()
|
||||
.contains("line 2")
|
||||
);
|
||||
assert!(runtime.activation_payload("demo-skill").unwrap().contains("line 2"));
|
||||
|
||||
let deleted_path = runtime
|
||||
.delete_skill(SkillScope::Project, "demo-skill", true)
|
||||
@ -807,11 +759,7 @@ mod tests {
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
let _guard = CurrentDirGuard::enter(temp_dir.path());
|
||||
|
||||
let agent_skill_dir = temp_dir
|
||||
.path()
|
||||
.join(".agents")
|
||||
.join("skills")
|
||||
.join("demo-agent");
|
||||
let agent_skill_dir = temp_dir.path().join(".agents").join("skills").join("demo-agent");
|
||||
fs::create_dir_all(&agent_skill_dir).unwrap();
|
||||
fs::write(
|
||||
agent_skill_dir.join("SKILL.md"),
|
||||
|
||||
@ -391,8 +391,7 @@ impl SessionStore {
|
||||
)?;
|
||||
|
||||
drop(conn);
|
||||
self.get_session(&id)?
|
||||
.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
|
||||
self.get_session(&id)?.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
|
||||
}
|
||||
|
||||
pub fn ensure_channel_session(
|
||||
@ -420,8 +419,7 @@ impl SessionStore {
|
||||
)?;
|
||||
drop(conn);
|
||||
|
||||
self.get_session(&session_id)?
|
||||
.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
|
||||
self.get_session(&session_id)?.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
|
||||
}
|
||||
|
||||
pub fn get_session(&self, session_id: &str) -> Result<Option<SessionRecord>, StorageError> {
|
||||
@ -497,10 +495,7 @@ impl SessionStore {
|
||||
|
||||
pub fn delete_session(&self, session_id: &str) -> Result<(), StorageError> {
|
||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||
conn.execute(
|
||||
"DELETE FROM messages WHERE session_id = ?1",
|
||||
params![session_id],
|
||||
)?;
|
||||
conn.execute("DELETE FROM messages WHERE session_id = ?1", params![session_id])?;
|
||||
conn.execute("DELETE FROM sessions WHERE id = ?1", params![session_id])?;
|
||||
Ok(())
|
||||
}
|
||||
@ -508,10 +503,7 @@ impl SessionStore {
|
||||
pub fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> {
|
||||
let now = current_timestamp();
|
||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||
conn.execute(
|
||||
"DELETE FROM messages WHERE session_id = ?1",
|
||||
params![session_id],
|
||||
)?;
|
||||
conn.execute("DELETE FROM messages WHERE session_id = ?1", params![session_id])?;
|
||||
conn.execute(
|
||||
"
|
||||
UPDATE sessions
|
||||
@ -557,11 +549,7 @@ impl SessionStore {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn append_message(
|
||||
&self,
|
||||
session_id: &str,
|
||||
message: &ChatMessage,
|
||||
) -> Result<(), StorageError> {
|
||||
pub fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError> {
|
||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||
let tx = conn.unchecked_transaction()?;
|
||||
|
||||
@ -572,11 +560,7 @@ impl SessionStore {
|
||||
)?;
|
||||
|
||||
let media_refs_json = serde_json::to_string(&message.media_refs)?;
|
||||
let tool_calls_json = message
|
||||
.tool_calls
|
||||
.as_ref()
|
||||
.map(serde_json::to_string)
|
||||
.transpose()?;
|
||||
let tool_calls_json = message.tool_calls.as_ref().map(serde_json::to_string).transpose()?;
|
||||
tx.execute(
|
||||
"
|
||||
INSERT INTO messages (
|
||||
@ -646,8 +630,7 @@ impl SessionStore {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let delta_messages =
|
||||
load_messages_between(&tx, session_id, snapshot_end_seq, current_max_seq)?;
|
||||
let delta_messages = load_messages_between(&tx, session_id, snapshot_end_seq, current_max_seq)?;
|
||||
let mut next_seq = current_max_seq + 1;
|
||||
let now = current_timestamp();
|
||||
let mut inserted_count = 0_i64;
|
||||
@ -799,7 +782,8 @@ impl SessionStore {
|
||||
)
|
||||
.optional()?;
|
||||
|
||||
let (id, created_at) = existing.unwrap_or_else(|| (uuid::Uuid::new_v4().to_string(), now));
|
||||
let (id, created_at) = existing
|
||||
.unwrap_or_else(|| (uuid::Uuid::new_v4().to_string(), now));
|
||||
|
||||
tx.execute(
|
||||
"
|
||||
@ -897,10 +881,7 @@ impl SessionStore {
|
||||
LIMIT ?4
|
||||
",
|
||||
)?;
|
||||
let rows = stmt.query_map(
|
||||
params![scope_kind, scope_key, namespace, limit],
|
||||
map_memory_record,
|
||||
)?;
|
||||
let rows = stmt.query_map(params![scope_kind, scope_key, namespace, limit], map_memory_record)?;
|
||||
for row in rows {
|
||||
memories.push(row?);
|
||||
}
|
||||
@ -959,9 +940,7 @@ impl SessionStore {
|
||||
",
|
||||
)?;
|
||||
|
||||
let rows = stmt.query_map(params![scope_kind, since_timestamp], |row| {
|
||||
row.get::<_, String>(0)
|
||||
})?;
|
||||
let rows = stmt.query_map(params![scope_kind, since_timestamp], |row| row.get::<_, String>(0))?;
|
||||
let mut scope_keys = Vec::new();
|
||||
for row in rows {
|
||||
scope_keys.push(row?);
|
||||
@ -1031,10 +1010,7 @@ impl SessionStore {
|
||||
Ok(changed > 0)
|
||||
}
|
||||
|
||||
pub fn upsert_scheduler_job(
|
||||
&self,
|
||||
input: &SchedulerJobUpsert,
|
||||
) -> Result<SchedulerJobRecord, StorageError> {
|
||||
pub fn upsert_scheduler_job(&self, input: &SchedulerJobUpsert) -> Result<SchedulerJobRecord, StorageError> {
|
||||
let now = current_timestamp();
|
||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||
conn.execute(
|
||||
@ -1091,10 +1067,7 @@ impl SessionStore {
|
||||
.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
|
||||
}
|
||||
|
||||
pub fn get_scheduler_job(
|
||||
&self,
|
||||
job_id: &str,
|
||||
) -> Result<Option<SchedulerJobRecord>, StorageError> {
|
||||
pub fn get_scheduler_job(&self, job_id: &str) -> Result<Option<SchedulerJobRecord>, StorageError> {
|
||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||
let mut stmt = conn.prepare(
|
||||
"
|
||||
@ -1112,10 +1085,7 @@ impl SessionStore {
|
||||
.map_err(StorageError::from)
|
||||
}
|
||||
|
||||
pub fn list_scheduler_jobs(
|
||||
&self,
|
||||
enabled_only: bool,
|
||||
) -> Result<Vec<SchedulerJobRecord>, StorageError> {
|
||||
pub fn list_scheduler_jobs(&self, enabled_only: bool) -> Result<Vec<SchedulerJobRecord>, StorageError> {
|
||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||
let sql = if enabled_only {
|
||||
"
|
||||
@ -1225,10 +1195,7 @@ impl SessionStore {
|
||||
LIMIT ?5
|
||||
",
|
||||
)?;
|
||||
let rows = stmt.query_map(
|
||||
params![query, scope_kind, scope_key, namespace, limit],
|
||||
map_memory_record,
|
||||
)?;
|
||||
let rows = stmt.query_map(params![query, scope_kind, scope_key, namespace, limit], map_memory_record)?;
|
||||
for row in rows {
|
||||
memories.push(row?);
|
||||
}
|
||||
@ -1247,10 +1214,7 @@ impl SessionStore {
|
||||
LIMIT ?4
|
||||
",
|
||||
)?;
|
||||
let rows = stmt.query_map(
|
||||
params![query, scope_kind, scope_key, limit],
|
||||
map_memory_record,
|
||||
)?;
|
||||
let rows = stmt.query_map(params![query, scope_kind, scope_key, limit], map_memory_record)?;
|
||||
for row in rows {
|
||||
memories.push(row?);
|
||||
}
|
||||
@ -1292,10 +1256,7 @@ impl SessionStore {
|
||||
LIMIT ?5
|
||||
",
|
||||
)?;
|
||||
let rows = stmt.query_map(
|
||||
params![query, scope_kind, scope_key, namespace, limit],
|
||||
map_memory_record,
|
||||
)?;
|
||||
let rows = stmt.query_map(params![query, scope_kind, scope_key, namespace, limit], map_memory_record)?;
|
||||
for row in rows {
|
||||
memories.push(row?);
|
||||
}
|
||||
@ -1314,10 +1275,7 @@ impl SessionStore {
|
||||
LIMIT ?4
|
||||
",
|
||||
)?;
|
||||
let rows = stmt.query_map(
|
||||
params![query, scope_kind, scope_key, limit],
|
||||
map_memory_record,
|
||||
)?;
|
||||
let rows = stmt.query_map(params![query, scope_kind, scope_key, limit], map_memory_record)?;
|
||||
for row in rows {
|
||||
memories.push(row?);
|
||||
}
|
||||
@ -1389,7 +1347,11 @@ fn map_session_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<SessionRecord
|
||||
fn map_skill_event_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<SkillEventRecord> {
|
||||
let payload_json: String = row.get(4)?;
|
||||
let payload = serde_json::from_str(&payload_json).map_err(|err| {
|
||||
rusqlite::Error::FromSqlConversionFailure(4, rusqlite::types::Type::Text, Box::new(err))
|
||||
rusqlite::Error::FromSqlConversionFailure(
|
||||
4,
|
||||
rusqlite::types::Type::Text,
|
||||
Box::new(err),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(SkillEventRecord {
|
||||
@ -1429,13 +1391,25 @@ fn map_scheduler_job_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<Schedul
|
||||
let last_status: Option<String> = row.get(9)?;
|
||||
|
||||
let schedule = serde_json::from_str(&schedule_json).map_err(|err| {
|
||||
rusqlite::Error::FromSqlConversionFailure(2, rusqlite::types::Type::Text, Box::new(err))
|
||||
rusqlite::Error::FromSqlConversionFailure(
|
||||
2,
|
||||
rusqlite::types::Type::Text,
|
||||
Box::new(err),
|
||||
)
|
||||
})?;
|
||||
let target = serde_json::from_str(&target_json).map_err(|err| {
|
||||
rusqlite::Error::FromSqlConversionFailure(5, rusqlite::types::Type::Text, Box::new(err))
|
||||
rusqlite::Error::FromSqlConversionFailure(
|
||||
5,
|
||||
rusqlite::types::Type::Text,
|
||||
Box::new(err),
|
||||
)
|
||||
})?;
|
||||
let payload = serde_json::from_str(&payload_json).map_err(|err| {
|
||||
rusqlite::Error::FromSqlConversionFailure(6, rusqlite::types::Type::Text, Box::new(err))
|
||||
rusqlite::Error::FromSqlConversionFailure(
|
||||
6,
|
||||
rusqlite::types::Type::Text,
|
||||
Box::new(err),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(SchedulerJobRecord {
|
||||
@ -1498,10 +1472,7 @@ fn ensure_messages_schema(conn: &Connection) -> Result<(), StorageError> {
|
||||
}
|
||||
|
||||
if !has_column(conn, "messages", "reasoning_content")? {
|
||||
add_column_if_missing(
|
||||
conn,
|
||||
"ALTER TABLE messages ADD COLUMN reasoning_content TEXT",
|
||||
)?;
|
||||
add_column_if_missing(conn, "ALTER TABLE messages ADD COLUMN reasoning_content TEXT")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@ -1523,11 +1494,17 @@ fn ensure_scheduler_schema(conn: &Connection) -> Result<(), StorageError> {
|
||||
}
|
||||
|
||||
if !has_column(conn, "scheduler_jobs", "last_status")? {
|
||||
conn.execute("ALTER TABLE scheduler_jobs ADD COLUMN last_status TEXT", [])?;
|
||||
conn.execute(
|
||||
"ALTER TABLE scheduler_jobs ADD COLUMN last_status TEXT",
|
||||
[],
|
||||
)?;
|
||||
}
|
||||
|
||||
if !has_column(conn, "scheduler_jobs", "last_error")? {
|
||||
conn.execute("ALTER TABLE scheduler_jobs ADD COLUMN last_error TEXT", [])?;
|
||||
conn.execute(
|
||||
"ALTER TABLE scheduler_jobs ADD COLUMN last_error TEXT",
|
||||
[],
|
||||
)?;
|
||||
}
|
||||
|
||||
if !has_column(conn, "scheduler_jobs", "run_count")? {
|
||||
@ -1538,7 +1515,10 @@ fn ensure_scheduler_schema(conn: &Connection) -> Result<(), StorageError> {
|
||||
}
|
||||
|
||||
if !has_column(conn, "scheduler_jobs", "max_runs")? {
|
||||
conn.execute("ALTER TABLE scheduler_jobs ADD COLUMN max_runs INTEGER", [])?;
|
||||
conn.execute(
|
||||
"ALTER TABLE scheduler_jobs ADD COLUMN max_runs INTEGER",
|
||||
[],
|
||||
)?;
|
||||
}
|
||||
|
||||
if !has_column(conn, "scheduler_jobs", "paused_at")? {
|
||||
@ -1558,11 +1538,7 @@ fn ensure_scheduler_schema(conn: &Connection) -> Result<(), StorageError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn has_column(
|
||||
conn: &Connection,
|
||||
table_name: &str,
|
||||
column_name: &str,
|
||||
) -> Result<bool, StorageError> {
|
||||
fn has_column(conn: &Connection, table_name: &str, column_name: &str) -> Result<bool, StorageError> {
|
||||
let pragma = format!("PRAGMA table_info({})", table_name);
|
||||
let mut stmt = conn.prepare(&pragma)?;
|
||||
let mut rows = stmt.query([])?;
|
||||
@ -1581,10 +1557,7 @@ fn add_column_if_missing(conn: &Connection, sql: &str) -> Result<(), StorageErro
|
||||
match conn.execute(sql, []) {
|
||||
Ok(_) => Ok(()),
|
||||
Err(rusqlite::Error::SqliteFailure(_, Some(message)))
|
||||
if message.contains("duplicate column name") =>
|
||||
{
|
||||
Ok(())
|
||||
}
|
||||
if message.contains("duplicate column name") => Ok(()),
|
||||
Err(error) => Err(StorageError::Database(error)),
|
||||
}
|
||||
}
|
||||
@ -1608,11 +1581,7 @@ fn insert_message_with_seq(
|
||||
message: &ChatMessage,
|
||||
) -> Result<(), StorageError> {
|
||||
let media_refs_json = serde_json::to_string(&message.media_refs)?;
|
||||
let tool_calls_json = message
|
||||
.tool_calls
|
||||
.as_ref()
|
||||
.map(serde_json::to_string)
|
||||
.transpose()?;
|
||||
let tool_calls_json = message.tool_calls.as_ref().map(serde_json::to_string).transpose()?;
|
||||
conn.execute(
|
||||
"
|
||||
INSERT INTO messages (
|
||||
@ -1669,12 +1638,9 @@ fn load_messages_between(
|
||||
",
|
||||
)?;
|
||||
|
||||
let rows = stmt.query_map(
|
||||
params![session_id, start_seq_exclusive, end_seq_inclusive],
|
||||
|row| {
|
||||
let rows = stmt.query_map(params![session_id, start_seq_exclusive, end_seq_inclusive], |row| {
|
||||
let media_refs_json: String = row.get(5)?;
|
||||
let media_refs: Vec<String> =
|
||||
serde_json::from_str(&media_refs_json).map_err(|err| {
|
||||
let media_refs: Vec<String> = serde_json::from_str(&media_refs_json).map_err(|err| {
|
||||
rusqlite::Error::FromSqlConversionFailure(
|
||||
media_refs_json.len(),
|
||||
rusqlite::types::Type::Text,
|
||||
@ -1708,8 +1674,7 @@ fn load_messages_between(
|
||||
tool_state: None,
|
||||
tool_calls,
|
||||
})
|
||||
},
|
||||
)?;
|
||||
})?;
|
||||
|
||||
let mut messages = Vec::new();
|
||||
for row in rows {
|
||||
@ -1901,10 +1866,7 @@ mod tests {
|
||||
assert_eq!(messages[0].role, "assistant");
|
||||
assert_eq!(messages[0].tool_calls.as_ref().unwrap().len(), 1);
|
||||
assert_eq!(messages[0].tool_calls.as_ref().unwrap()[0].id, "call_1");
|
||||
assert_eq!(
|
||||
messages[0].tool_calls.as_ref().unwrap()[0].name,
|
||||
"calculator"
|
||||
);
|
||||
assert_eq!(messages[0].tool_calls.as_ref().unwrap()[0].name, "calculator");
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -1912,17 +1874,17 @@ mod tests {
|
||||
let store = SessionStore::in_memory().unwrap();
|
||||
let session = store.create_cli_session(Some("reasoning")).unwrap();
|
||||
|
||||
let assistant = ChatMessage::assistant_with_reasoning("final answer", "hidden reasoning");
|
||||
let assistant = ChatMessage::assistant_with_reasoning(
|
||||
"final answer",
|
||||
"hidden reasoning",
|
||||
);
|
||||
|
||||
store.append_message(&session.id, &assistant).unwrap();
|
||||
|
||||
let messages = store.load_messages(&session.id).unwrap();
|
||||
assert_eq!(messages.len(), 1);
|
||||
assert_eq!(messages[0].content, "final answer");
|
||||
assert_eq!(
|
||||
messages[0].reasoning_content.as_deref(),
|
||||
Some("hidden reasoning")
|
||||
);
|
||||
assert_eq!(messages[0].reasoning_content.as_deref(), Some("hidden reasoning"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -1930,12 +1892,8 @@ mod tests {
|
||||
let store = SessionStore::in_memory().unwrap();
|
||||
let session = store.create_cli_session(Some("reset")).unwrap();
|
||||
|
||||
store
|
||||
.append_message(&session.id, &ChatMessage::user("before"))
|
||||
.unwrap();
|
||||
store
|
||||
.append_message(&session.id, &ChatMessage::assistant("context"))
|
||||
.unwrap();
|
||||
store.append_message(&session.id, &ChatMessage::user("before")).unwrap();
|
||||
store.append_message(&session.id, &ChatMessage::assistant("context")).unwrap();
|
||||
store.reset_session(&session.id).unwrap();
|
||||
|
||||
let stored = store.get_session(&session.id).unwrap().unwrap();
|
||||
@ -1951,9 +1909,7 @@ mod tests {
|
||||
assert_eq!(all_messages[0].content, "before");
|
||||
assert_eq!(all_messages[1].content, "context");
|
||||
|
||||
store
|
||||
.append_message(&session.id, &ChatMessage::user("after"))
|
||||
.unwrap();
|
||||
store.append_message(&session.id, &ChatMessage::user("after")).unwrap();
|
||||
let active_messages = store.load_messages(&session.id).unwrap();
|
||||
assert_eq!(active_messages.len(), 1);
|
||||
assert_eq!(active_messages[0].content, "after");
|
||||
@ -2054,33 +2010,19 @@ mod tests {
|
||||
let store = SessionStore::in_memory().unwrap();
|
||||
let session = store.create_cli_session(Some("count-users")).unwrap();
|
||||
|
||||
store
|
||||
.append_message(&session.id, &ChatMessage::system("agent"))
|
||||
.unwrap();
|
||||
store
|
||||
.append_message(&session.id, &ChatMessage::user("u1"))
|
||||
.unwrap();
|
||||
store
|
||||
.append_message(&session.id, &ChatMessage::assistant("a1"))
|
||||
.unwrap();
|
||||
store
|
||||
.append_message(&session.id, &ChatMessage::user("u2"))
|
||||
.unwrap();
|
||||
store.append_message(&session.id, &ChatMessage::system("agent")).unwrap();
|
||||
store.append_message(&session.id, &ChatMessage::user("u1")).unwrap();
|
||||
store.append_message(&session.id, &ChatMessage::assistant("a1")).unwrap();
|
||||
store.append_message(&session.id, &ChatMessage::user("u2")).unwrap();
|
||||
|
||||
assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 2);
|
||||
|
||||
store.reset_session(&session.id).unwrap();
|
||||
assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 0);
|
||||
|
||||
store
|
||||
.append_message(&session.id, &ChatMessage::system("agent-again"))
|
||||
.unwrap();
|
||||
store
|
||||
.append_message(&session.id, &ChatMessage::user("u3"))
|
||||
.unwrap();
|
||||
store
|
||||
.append_message(&session.id, &ChatMessage::user("u4"))
|
||||
.unwrap();
|
||||
store.append_message(&session.id, &ChatMessage::system("agent-again")).unwrap();
|
||||
store.append_message(&session.id, &ChatMessage::user("u3")).unwrap();
|
||||
store.append_message(&session.id, &ChatMessage::user("u4")).unwrap();
|
||||
|
||||
assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 2);
|
||||
}
|
||||
@ -2110,20 +2052,12 @@ mod tests {
|
||||
store.append_message(&session.id, message).unwrap();
|
||||
}
|
||||
|
||||
let snapshot_end_seq = store
|
||||
.get_session(&session.id)
|
||||
.unwrap()
|
||||
.unwrap()
|
||||
.message_count;
|
||||
let snapshot_end_seq = store.get_session(&session.id).unwrap().unwrap().message_count;
|
||||
let preserved_messages = store.load_messages(&session.id).unwrap()[3..].to_vec();
|
||||
let preserved_system_messages = vec![agent_prompt];
|
||||
|
||||
store
|
||||
.append_message(&session.id, &ChatMessage::user("u5"))
|
||||
.unwrap();
|
||||
store
|
||||
.append_message(&session.id, &ChatMessage::assistant("a5"))
|
||||
.unwrap();
|
||||
store.append_message(&session.id, &ChatMessage::user("u5")).unwrap();
|
||||
store.append_message(&session.id, &ChatMessage::assistant("a5")).unwrap();
|
||||
|
||||
let summary_message = ChatMessage::system("[Compressed History]\n\nsummary");
|
||||
let compacted = store
|
||||
@ -2143,15 +2077,9 @@ mod tests {
|
||||
assert_eq!(active_messages.len(), 10);
|
||||
assert_eq!(active_messages[0].role, "system");
|
||||
assert_eq!(active_messages[0].content, "agent");
|
||||
assert_eq!(
|
||||
active_messages[0].system_context.as_deref(),
|
||||
Some(SYSTEM_CONTEXT_AGENT_PROMPT)
|
||||
);
|
||||
assert_eq!(active_messages[0].system_context.as_deref(), Some(SYSTEM_CONTEXT_AGENT_PROMPT));
|
||||
assert_eq!(active_messages[1].role, "system");
|
||||
assert_eq!(
|
||||
active_messages[1].content,
|
||||
"[Compressed History]\n\nsummary"
|
||||
);
|
||||
assert_eq!(active_messages[1].content, "[Compressed History]\n\nsummary");
|
||||
assert_eq!(active_messages[2].content, "u2");
|
||||
assert_eq!(active_messages[3].content, "a2");
|
||||
assert_eq!(active_messages[8].content, "u5");
|
||||
@ -2200,7 +2128,12 @@ mod tests {
|
||||
let session = store.create_cli_session(Some("skill-events")).unwrap();
|
||||
|
||||
store
|
||||
.append_skill_event(None, "discovered", None, &serde_json::json!({"count": 2}))
|
||||
.append_skill_event(
|
||||
None,
|
||||
"discovered",
|
||||
None,
|
||||
&serde_json::json!({"count": 2}),
|
||||
)
|
||||
.unwrap();
|
||||
store
|
||||
.append_skill_event(
|
||||
@ -2450,26 +2383,13 @@ mod tests {
|
||||
.unwrap();
|
||||
|
||||
let scope_keys = store.list_memory_scope_keys("user").unwrap();
|
||||
assert_eq!(
|
||||
scope_keys,
|
||||
vec!["feishu:user-1".to_string(), "feishu:user-2".to_string()]
|
||||
);
|
||||
assert_eq!(scope_keys, vec!["feishu:user-1".to_string(), "feishu:user-2".to_string()]);
|
||||
|
||||
let full_scope = store
|
||||
.list_memories_for_scope("user", "feishu:user-1")
|
||||
.unwrap();
|
||||
let full_scope = store.list_memories_for_scope("user", "feishu:user-1").unwrap();
|
||||
assert_eq!(full_scope.len(), 2);
|
||||
assert!(
|
||||
full_scope
|
||||
.iter()
|
||||
.all(|memory| memory.scope_key == "feishu:user-1")
|
||||
);
|
||||
assert!(full_scope.iter().all(|memory| memory.scope_key == "feishu:user-1"));
|
||||
assert!(full_scope.iter().any(|memory| memory.memory_key == "work"));
|
||||
assert!(
|
||||
full_scope
|
||||
.iter()
|
||||
.any(|memory| memory.memory_key == "workflow")
|
||||
);
|
||||
assert!(full_scope.iter().any(|memory| memory.memory_key == "workflow"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@ -15,8 +15,7 @@ use crate::tools::traits::{Tool, ToolResult};
|
||||
const MAX_TIMEOUT_SECS: u64 = 600;
|
||||
const MAX_OUTPUT_CHARS: usize = 50_000;
|
||||
const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
|
||||
const USER_ACTION_HINT: &str =
|
||||
"该命令正在等待你完成外部操作。完成后请告诉我继续,或重新运行后续检查命令。";
|
||||
const USER_ACTION_HINT: &str = "该命令正在等待你完成外部操作。完成后请告诉我继续,或重新运行后续检查命令。";
|
||||
|
||||
pub struct BashTool {
|
||||
timeout_secs: u64,
|
||||
@ -209,10 +208,7 @@ impl Tool for BashTool {
|
||||
.map(|d| Path::new(d))
|
||||
.unwrap_or_else(|| Path::new("."));
|
||||
|
||||
match self
|
||||
.run_command(command, cwd, timeout_secs, interactive)
|
||||
.await
|
||||
{
|
||||
match self.run_command(command, cwd, timeout_secs, interactive).await {
|
||||
Ok(output) => Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
@ -370,7 +366,10 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_pwd_command() {
|
||||
let tool = BashTool::new();
|
||||
let result = tool.execute(json!({ "command": "pwd" })).await.unwrap();
|
||||
let result = tool
|
||||
.execute(json!({ "command": "pwd" }))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
}
|
||||
@ -378,10 +377,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_ls_command() {
|
||||
let tool = BashTool::new();
|
||||
let result = tool
|
||||
.execute(json!({ "command": "ls -la /tmp" }))
|
||||
.await
|
||||
.unwrap();
|
||||
let result = tool.execute(json!({ "command": "ls -la /tmp" })).await.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
}
|
||||
|
||||
@ -659,7 +659,10 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_evaluate_missing_expression() {
|
||||
let tool = CalculatorTool::new();
|
||||
let result = tool.execute(json!({"function": "evaluate"})).await.unwrap();
|
||||
let result = tool
|
||||
.execute(json!({"function": "evaluate"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
}
|
||||
|
||||
|
||||
@ -268,8 +268,8 @@ impl Tool for FileEditTool {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Write;
|
||||
use tempfile::NamedTempFile;
|
||||
use std::io::Write;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_edit_simple() {
|
||||
|
||||
@ -218,7 +218,7 @@ impl Tool for FileReadTool {
|
||||
// Try to read as binary and encode as base64
|
||||
match std::fs::read(&resolved) {
|
||||
Ok(bytes) => {
|
||||
use base64::{Engine, engine::general_purpose::STANDARD};
|
||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||
let encoded = STANDARD.encode(&bytes);
|
||||
let mime = mime_guess::from_path(&resolved)
|
||||
.first_or_octet_stream()
|
||||
@ -248,8 +248,8 @@ impl Tool for FileReadTool {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Write;
|
||||
use tempfile::NamedTempFile;
|
||||
use std::io::Write;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_read_simple_file() {
|
||||
@ -308,7 +308,10 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_is_directory() {
|
||||
let tool = FileReadTool::new();
|
||||
let result = tool.execute(json!({ "path": "." })).await.unwrap();
|
||||
let result = tool
|
||||
.execute(json!({ "path": "." }))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("Not a file"));
|
||||
|
||||
@ -195,7 +195,10 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_write_missing_path() {
|
||||
let tool = FileWriteTool::new();
|
||||
let result = tool.execute(json!({ "content": "Hello" })).await.unwrap();
|
||||
let result = tool
|
||||
.execute(json!({ "content": "Hello" }))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("path"));
|
||||
|
||||
@ -50,7 +50,10 @@ impl HttpRequestTool {
|
||||
}
|
||||
|
||||
if !host_matches_allowlist(&host, &self.allowed_domains) {
|
||||
return Err(format!("Host '{}' is not in allowed_domains", host));
|
||||
return Err(format!(
|
||||
"Host '{}' is not in allowed_domains",
|
||||
host
|
||||
));
|
||||
}
|
||||
|
||||
Ok(url.to_string())
|
||||
@ -77,7 +80,9 @@ impl HttpRequestTool {
|
||||
for (key, value) in obj {
|
||||
if let Some(str_val) = value.as_str() {
|
||||
if let Ok(name) = reqwest::header::HeaderName::from_bytes(key.as_bytes()) {
|
||||
if let Ok(val) = reqwest::header::HeaderValue::from_str(str_val) {
|
||||
if let Ok(val) =
|
||||
reqwest::header::HeaderValue::from_str(str_val)
|
||||
{
|
||||
header_map.insert(name, val);
|
||||
}
|
||||
}
|
||||
@ -188,9 +193,7 @@ fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool {
|
||||
|
||||
allowed_domains.iter().any(|domain| {
|
||||
host == domain
|
||||
|| host
|
||||
.strip_suffix(domain)
|
||||
.is_some_and(|prefix| prefix.ends_with('.'))
|
||||
|| host.strip_suffix(domain).is_some_and(|prefix| prefix.ends_with('.'))
|
||||
})
|
||||
}
|
||||
|
||||
@ -201,11 +204,7 @@ fn is_private_host(host: &str) -> bool {
|
||||
}
|
||||
|
||||
// Check .local TLD
|
||||
if host
|
||||
.rsplit('.')
|
||||
.next()
|
||||
.is_some_and(|label| label == "local")
|
||||
{
|
||||
if host.rsplit('.').next().is_some_and(|label| label == "local") {
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -227,7 +226,9 @@ fn is_private_ip(ip: &std::net::IpAddr) -> bool {
|
||||
|| v4.is_broadcast()
|
||||
|| v4.is_multicast()
|
||||
}
|
||||
std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(),
|
||||
std::net::IpAddr::V6(v6) => {
|
||||
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -279,7 +280,10 @@ impl Tool for HttpRequestTool {
|
||||
}
|
||||
};
|
||||
|
||||
let method_str = args.get("method").and_then(|v| v.as_str()).unwrap_or("GET");
|
||||
let method_str = args
|
||||
.get("method")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("GET");
|
||||
|
||||
let headers_val = args.get("headers").cloned().unwrap_or(json!({}));
|
||||
let body = args.get("body").and_then(|v| v.as_str());
|
||||
|
||||
@ -94,7 +94,7 @@ impl Tool for MemoryManageTool {
|
||||
return Ok(error_result(&format!(
|
||||
"memory '{}.{}' not found",
|
||||
input.namespace, input.memory_key
|
||||
)));
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -108,14 +108,9 @@ impl Tool for MemoryManageTool {
|
||||
None => return Ok(error_result("Missing required parameter: key")),
|
||||
};
|
||||
|
||||
let deleted = self
|
||||
.store
|
||||
.delete_memory("user", &scope_key, namespace, key)?;
|
||||
let deleted = self.store.delete_memory("user", &scope_key, namespace, key)?;
|
||||
if !deleted {
|
||||
return Ok(error_result(&format!(
|
||||
"memory '{}.{}' not found",
|
||||
namespace, key
|
||||
)));
|
||||
return Ok(error_result(&format!("memory '{}.{}' not found", namespace, key)));
|
||||
}
|
||||
|
||||
json!({
|
||||
|
||||
@ -90,9 +90,7 @@ impl Tool for MemorySearchTool {
|
||||
.get("limit")
|
||||
.and_then(|value| value.as_u64())
|
||||
.unwrap_or(10) as usize;
|
||||
let memories = self
|
||||
.store
|
||||
.list_memories("user", &scope_key, namespace, limit)?;
|
||||
let memories = self.store.list_memories("user", &scope_key, namespace, limit)?;
|
||||
json!({
|
||||
"count": memories.len(),
|
||||
"memories": memories.into_iter().map(memory_to_json).collect::<Vec<_>>()
|
||||
@ -137,12 +135,7 @@ impl Tool for MemorySearchTool {
|
||||
|
||||
match self.store.get_memory("user", &scope_key, namespace, key)? {
|
||||
Some(memory) => memory_to_json(memory),
|
||||
None => {
|
||||
return Ok(error_result(&format!(
|
||||
"memory '{}.{}' not found",
|
||||
namespace, key
|
||||
)));
|
||||
}
|
||||
None => return Ok(error_result(&format!("memory '{}.{}' not found", namespace, key))),
|
||||
}
|
||||
}
|
||||
_ => return Ok(error_result("Unsupported action")),
|
||||
|
||||
@ -5,7 +5,9 @@ use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::config::SchedulerSchedule;
|
||||
use crate::storage::{SchedulerJobRecord, SchedulerJobState, SchedulerJobUpsert, SessionStore};
|
||||
use crate::storage::{
|
||||
SchedulerJobRecord, SchedulerJobState, SchedulerJobUpsert, SessionStore,
|
||||
};
|
||||
use crate::tools::traits::{Tool, ToolResult};
|
||||
|
||||
pub struct SchedulerManageTool {
|
||||
@ -33,7 +35,11 @@ impl Tool for SchedulerManageTool {
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
let mut allowed_agents = self.known_agents.iter().cloned().collect::<Vec<_>>();
|
||||
let mut allowed_agents = self
|
||||
.known_agents
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
allowed_agents.sort();
|
||||
let agent_hint = if allowed_agents.is_empty() {
|
||||
"agent_task payload.agent may be omitted or set to 'default'.".to_string()
|
||||
@ -219,15 +225,8 @@ fn build_upsert(
|
||||
startup_delay_secs,
|
||||
target,
|
||||
payload,
|
||||
enabled: args
|
||||
.get("enabled")
|
||||
.and_then(|value| value.as_bool())
|
||||
.unwrap_or(true),
|
||||
state: if args
|
||||
.get("enabled")
|
||||
.and_then(|value| value.as_bool())
|
||||
.unwrap_or(true)
|
||||
{
|
||||
enabled: args.get("enabled").and_then(|value| value.as_bool()).unwrap_or(true),
|
||||
state: if args.get("enabled").and_then(|value| value.as_bool()).unwrap_or(true) {
|
||||
SchedulerJobState::Scheduled
|
||||
} else {
|
||||
SchedulerJobState::Paused
|
||||
@ -253,28 +252,14 @@ fn enrich_target_from_context(
|
||||
};
|
||||
|
||||
if !has_non_empty_string(&object, "channel") {
|
||||
if let Some(channel_name) = context
|
||||
.channel_name
|
||||
.as_ref()
|
||||
.filter(|value| !value.trim().is_empty())
|
||||
{
|
||||
object.insert(
|
||||
"channel".to_string(),
|
||||
serde_json::Value::String(channel_name.clone()),
|
||||
);
|
||||
if let Some(channel_name) = context.channel_name.as_ref().filter(|value| !value.trim().is_empty()) {
|
||||
object.insert("channel".to_string(), serde_json::Value::String(channel_name.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
if !has_non_empty_string(&object, "chat_id") {
|
||||
if let Some(chat_id) = context
|
||||
.chat_id
|
||||
.as_ref()
|
||||
.filter(|value| !value.trim().is_empty())
|
||||
{
|
||||
object.insert(
|
||||
"chat_id".to_string(),
|
||||
serde_json::Value::String(chat_id.clone()),
|
||||
);
|
||||
if let Some(chat_id) = context.chat_id.as_ref().filter(|value| !value.trim().is_empty()) {
|
||||
object.insert("chat_id".to_string(), serde_json::Value::String(chat_id.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
@ -289,10 +274,7 @@ fn has_non_empty_string(object: &serde_json::Map<String, serde_json::Value>, fie
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn validate_agent_task_payload(
|
||||
payload: &serde_json::Value,
|
||||
known_agents: &HashSet<String>,
|
||||
) -> anyhow::Result<()> {
|
||||
fn validate_agent_task_payload(payload: &serde_json::Value, known_agents: &HashSet<String>) -> anyhow::Result<()> {
|
||||
let Some(prompt) = payload.get("prompt").and_then(|value| value.as_str()) else {
|
||||
anyhow::bail!("agent_task payload.prompt is required and must be a string")
|
||||
};
|
||||
@ -317,8 +299,7 @@ fn unknown_agent_message(agent_name: &str, known_agents: &HashSet<String>) -> St
|
||||
configured_agents.sort();
|
||||
|
||||
let configured_hint = if configured_agents.is_empty() {
|
||||
"No named agents are configured; use payload.agent='default' or omit payload.agent."
|
||||
.to_string()
|
||||
"No named agents are configured; use payload.agent='default' or omit payload.agent.".to_string()
|
||||
} else {
|
||||
format!(
|
||||
"payload.agent must be omitted, set to 'default', or use one of configured agents: default, {}.",
|
||||
@ -328,7 +309,9 @@ fn unknown_agent_message(agent_name: &str, known_agents: &HashSet<String>) -> St
|
||||
|
||||
format!(
|
||||
"Unknown agent '{}' for agent_task payload.agent. {} '{}' is not an agent. If you mean a skill, do not put it in payload.agent.",
|
||||
agent_name, configured_hint, agent_name,
|
||||
agent_name,
|
||||
configured_hint,
|
||||
agent_name,
|
||||
)
|
||||
}
|
||||
|
||||
@ -534,10 +517,7 @@ mod tests {
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(saved.kind, "silent_agent_task");
|
||||
assert_eq!(
|
||||
saved.target["session_chat_id"],
|
||||
"scheduler/agent.daily_summary.background"
|
||||
);
|
||||
assert_eq!(saved.target["session_chat_id"], "scheduler/agent.daily_summary.background");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@ -674,9 +654,7 @@ mod tests {
|
||||
assert!(!result.success);
|
||||
assert_eq!(
|
||||
result.error.as_deref(),
|
||||
Some(
|
||||
"Missing required parameters: scheduler_manage expects a JSON object like {\"action\":\"list\"}"
|
||||
)
|
||||
Some("Missing required parameters: scheduler_manage expects a JSON object like {\"action\":\"list\"}")
|
||||
);
|
||||
}
|
||||
|
||||
@ -690,9 +668,7 @@ mod tests {
|
||||
.as_str()
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
payload_description.contains("avoid repeating schedule phrases or execution times")
|
||||
);
|
||||
assert!(payload_description.contains("avoid repeating schedule phrases or execution times"));
|
||||
assert!(payload_description.contains("每天9点"));
|
||||
assert!(payload_description.contains("每小时"));
|
||||
}
|
||||
|
||||
@ -408,10 +408,7 @@ impl SchemaCleanr {
|
||||
|
||||
match non_null.len() {
|
||||
0 => Value::String("null".to_string()),
|
||||
1 => non_null
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap_or(Value::String("null".to_string())),
|
||||
1 => non_null.into_iter().next().unwrap_or(Value::String("null".to_string())),
|
||||
_ => Value::Array(non_null),
|
||||
}
|
||||
} else {
|
||||
|
||||
@ -83,11 +83,7 @@ impl Tool for SkillManageTool {
|
||||
let scope = match args.get("scope").and_then(|v| v.as_str()) {
|
||||
Some(value) => match SkillScope::parse(value) {
|
||||
Some(scope) => scope,
|
||||
None => {
|
||||
return Ok(error_result(
|
||||
"scope must be 'project' or 'user'; .agents sources are discovery-only",
|
||||
));
|
||||
}
|
||||
None => return Ok(error_result("scope must be 'project' or 'user'; .agents sources are discovery-only")),
|
||||
},
|
||||
None => SkillScope::Project,
|
||||
};
|
||||
@ -95,7 +91,9 @@ impl Tool for SkillManageTool {
|
||||
let name = args.get("name").and_then(|v| v.as_str());
|
||||
|
||||
let result = match action {
|
||||
"list" => list_skills_payload(&self.skills),
|
||||
"list" => {
|
||||
list_skills_payload(&self.skills)
|
||||
}
|
||||
"get" => {
|
||||
let name = match name {
|
||||
Some(name) => name,
|
||||
@ -129,10 +127,7 @@ impl Tool for SkillManageTool {
|
||||
};
|
||||
let body = args.get("body").and_then(|v| v.as_str()).unwrap_or("");
|
||||
|
||||
match self
|
||||
.skills
|
||||
.create_skill(scope, name, description, body, reload)
|
||||
{
|
||||
match self.skills.create_skill(scope, name, description, body, reload) {
|
||||
Ok(skill) => json!({
|
||||
"status": "created",
|
||||
"name": skill.name,
|
||||
@ -154,10 +149,7 @@ impl Tool for SkillManageTool {
|
||||
return Ok(error_result("update requires description or body"));
|
||||
}
|
||||
|
||||
match self
|
||||
.skills
|
||||
.update_skill(scope, name, description, body, reload)
|
||||
{
|
||||
match self.skills.update_skill(scope, name, description, body, reload) {
|
||||
Ok(skill) => json!({
|
||||
"status": "updated",
|
||||
"name": skill.name,
|
||||
|
||||
@ -99,7 +99,9 @@ fn execute_time_request(
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or(default_timezone);
|
||||
let timezone = timezone_name.parse::<chrono_tz::Tz>().map_err(|_| {
|
||||
format!("Invalid timezone: {timezone_name}. Expected an IANA timezone like Asia/Shanghai")
|
||||
format!(
|
||||
"Invalid timezone: {timezone_name}. Expected an IANA timezone like Asia/Shanghai"
|
||||
)
|
||||
})?;
|
||||
|
||||
let now_local = now_utc.with_timezone(&timezone);
|
||||
@ -166,14 +168,13 @@ fn parse_offset_request(args: &Value) -> Result<Option<OffsetRequest>, String> {
|
||||
let direction = direction.ok_or_else(|| {
|
||||
"Missing required parameter: direction when requesting a relative time".to_string()
|
||||
})?;
|
||||
let amount = amount.and_then(Value::as_u64).ok_or_else(|| {
|
||||
"Missing required parameter: amount when requesting a relative time".to_string()
|
||||
})?;
|
||||
let amount = amount
|
||||
.and_then(Value::as_u64)
|
||||
.ok_or_else(|| "Missing required parameter: amount when requesting a relative time".to_string())?;
|
||||
let amount = u32::try_from(amount)
|
||||
.map_err(|_| "amount is too large; expected a 32-bit unsigned integer".to_string())?;
|
||||
let unit = unit.ok_or_else(|| {
|
||||
"Missing required parameter: unit when requesting a relative time".to_string()
|
||||
})?;
|
||||
let unit = unit
|
||||
.ok_or_else(|| "Missing required parameter: unit when requesting a relative time".to_string())?;
|
||||
|
||||
Ok(Some(OffsetRequest {
|
||||
direction: OffsetDirection::parse(direction)?,
|
||||
@ -187,18 +188,10 @@ fn apply_offset(
|
||||
offset: &OffsetRequest,
|
||||
) -> Result<DateTime<chrono_tz::Tz>, String> {
|
||||
match (offset.direction, offset.unit) {
|
||||
(OffsetDirection::Future, TimeUnit::Minute) => {
|
||||
Ok(now_local + Duration::minutes(i64::from(offset.amount)))
|
||||
}
|
||||
(OffsetDirection::Past, TimeUnit::Minute) => {
|
||||
Ok(now_local - Duration::minutes(i64::from(offset.amount)))
|
||||
}
|
||||
(OffsetDirection::Future, TimeUnit::Hour) => {
|
||||
Ok(now_local + Duration::hours(i64::from(offset.amount)))
|
||||
}
|
||||
(OffsetDirection::Past, TimeUnit::Hour) => {
|
||||
Ok(now_local - Duration::hours(i64::from(offset.amount)))
|
||||
}
|
||||
(OffsetDirection::Future, TimeUnit::Minute) => Ok(now_local + Duration::minutes(i64::from(offset.amount))),
|
||||
(OffsetDirection::Past, TimeUnit::Minute) => Ok(now_local - Duration::minutes(i64::from(offset.amount))),
|
||||
(OffsetDirection::Future, TimeUnit::Hour) => Ok(now_local + Duration::hours(i64::from(offset.amount))),
|
||||
(OffsetDirection::Past, TimeUnit::Hour) => Ok(now_local - Duration::hours(i64::from(offset.amount))),
|
||||
(OffsetDirection::Future, TimeUnit::Day) => now_local
|
||||
.checked_add_days(Days::new(u64::from(offset.amount)))
|
||||
.ok_or_else(|| "Failed to add days to the current time".to_string()),
|
||||
@ -446,8 +439,8 @@ mod tests {
|
||||
|
||||
assert!(result.success);
|
||||
let payload: Value = serde_json::from_str(&result.output).unwrap();
|
||||
let result_time =
|
||||
chrono::DateTime::parse_from_rfc3339(payload["result_time"].as_str().unwrap()).unwrap();
|
||||
let result_time = chrono::DateTime::parse_from_rfc3339(payload["result_time"].as_str().unwrap())
|
||||
.unwrap();
|
||||
assert_eq!(result_time.hour(), 12);
|
||||
assert_eq!(result_time.minute(), 30);
|
||||
}
|
||||
|
||||
@ -239,11 +239,7 @@ fn is_private_host(host: &str) -> bool {
|
||||
return true;
|
||||
}
|
||||
|
||||
if host
|
||||
.rsplit('.')
|
||||
.next()
|
||||
.is_some_and(|label| label == "local")
|
||||
{
|
||||
if host.rsplit('.').next().is_some_and(|label| label == "local") {
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -252,9 +248,7 @@ fn is_private_host(host: &str) -> bool {
|
||||
std::net::IpAddr::V4(v4) => {
|
||||
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified()
|
||||
}
|
||||
std::net::IpAddr::V6(v6) => {
|
||||
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
|
||||
}
|
||||
std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(),
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
use picobot::config::{Config, LLMProviderConfig};
|
||||
use picobot::providers::{ChatCompletionRequest, Message, create_provider};
|
||||
use std::collections::HashMap;
|
||||
use picobot::providers::{create_provider, ChatCompletionRequest, Message};
|
||||
use picobot::config::{Config, LLMProviderConfig};
|
||||
|
||||
fn load_config() -> Option<LLMProviderConfig> {
|
||||
dotenv::from_filename("tests/test.env").ok()?;
|
||||
@ -23,10 +23,11 @@ fn load_config() -> Option<LLMProviderConfig> {
|
||||
model_id: openai_model,
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(100),
|
||||
context_window_tokens: None,
|
||||
model_extra: HashMap::new(),
|
||||
max_tool_iterations: 20,
|
||||
token_limit: 128_000,
|
||||
tool_result_max_chars: 20_000,
|
||||
context_summary_max_chars: 20_000,
|
||||
context_tool_result_trim_chars: 20_000,
|
||||
})
|
||||
}
|
||||
@ -43,7 +44,8 @@ fn create_request(content: &str) -> ChatCompletionRequest {
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_openai_simple_completion() {
|
||||
let config = load_config().expect("Please configure tests/test.env with valid API keys");
|
||||
let config = load_config()
|
||||
.expect("Please configure tests/test.env with valid API keys");
|
||||
|
||||
let provider = create_provider(config).expect("Failed to create provider");
|
||||
let response = provider.chat(create_request("Say 'ok'")).await.unwrap();
|
||||
@ -57,7 +59,8 @@ async fn test_openai_simple_completion() {
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_openai_conversation() {
|
||||
let config = load_config().expect("Please configure tests/test.env with valid API keys");
|
||||
let config = load_config()
|
||||
.expect("Please configure tests/test.env with valid API keys");
|
||||
|
||||
let provider = create_provider(config).expect("Failed to create provider");
|
||||
|
||||
@ -81,9 +84,7 @@ async fn test_openai_conversation() {
|
||||
async fn test_config_load() {
|
||||
// Test that config.json can be loaded and provider config created
|
||||
let config = Config::load("config.json").expect("Failed to load config.json");
|
||||
let provider_config = config
|
||||
.get_provider_config("default")
|
||||
.expect("Failed to get provider config");
|
||||
let provider_config = config.get_provider_config("default").expect("Failed to get provider config");
|
||||
|
||||
assert_eq!(provider_config.provider_type, "openai");
|
||||
assert_eq!(provider_config.name, "aliyun");
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
use picobot::protocol::{SessionSummary, WsInbound, WsOutbound};
|
||||
use picobot::providers::{ChatCompletionRequest, Message};
|
||||
use picobot::protocol::{SessionSummary, WsInbound, WsOutbound};
|
||||
|
||||
/// Test that message with special characters is properly escaped
|
||||
#[test]
|
||||
@ -19,9 +19,7 @@ fn test_message_special_characters() {
|
||||
#[test]
|
||||
fn test_multiline_system_prompt() {
|
||||
let messages = vec![
|
||||
Message::system(
|
||||
"You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate",
|
||||
),
|
||||
Message::system("You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate"),
|
||||
Message::user("Hi"),
|
||||
];
|
||||
|
||||
@ -35,7 +33,10 @@ fn test_multiline_system_prompt() {
|
||||
#[test]
|
||||
fn test_chat_request_serialization() {
|
||||
let request = ChatCompletionRequest {
|
||||
messages: vec![Message::system("You are helpful"), Message::user("Hello")],
|
||||
messages: vec![
|
||||
Message::system("You are helpful"),
|
||||
Message::user("Hello"),
|
||||
],
|
||||
temperature: Some(0.7),
|
||||
max_tokens: Some(100),
|
||||
tools: None,
|
||||
@ -135,12 +136,7 @@ fn test_tool_call_outbound_serialization() {
|
||||
|
||||
let decoded: WsOutbound = serde_json::from_str(&json).unwrap();
|
||||
match decoded {
|
||||
WsOutbound::ToolCall {
|
||||
tool_call_id,
|
||||
tool_name,
|
||||
arguments,
|
||||
..
|
||||
} => {
|
||||
WsOutbound::ToolCall { tool_call_id, tool_name, arguments, .. } => {
|
||||
assert_eq!(tool_call_id, "call-1");
|
||||
assert_eq!(tool_name, "calculator");
|
||||
assert_eq!(arguments["expression"], "1 + 1");
|
||||
@ -165,12 +161,7 @@ fn test_tool_result_outbound_serialization() {
|
||||
|
||||
let decoded: WsOutbound = serde_json::from_str(&json).unwrap();
|
||||
match decoded {
|
||||
WsOutbound::ToolResult {
|
||||
tool_call_id,
|
||||
tool_name,
|
||||
content,
|
||||
..
|
||||
} => {
|
||||
WsOutbound::ToolResult { tool_call_id, tool_name, content, .. } => {
|
||||
assert_eq!(tool_call_id, "call-1");
|
||||
assert_eq!(tool_name, "calculator");
|
||||
assert!(content.contains('2'));
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
use picobot::config::LLMProviderConfig;
|
||||
use picobot::providers::{ChatCompletionRequest, Message, Tool, ToolFunction, create_provider};
|
||||
use std::collections::HashMap;
|
||||
use picobot::providers::{create_provider, ChatCompletionRequest, Message, Tool, ToolFunction};
|
||||
use picobot::config::LLMProviderConfig;
|
||||
|
||||
fn load_openai_config() -> Option<LLMProviderConfig> {
|
||||
dotenv::from_filename("tests/test.env").ok()?;
|
||||
@ -23,10 +23,11 @@ fn load_openai_config() -> Option<LLMProviderConfig> {
|
||||
model_id: openai_model,
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(100),
|
||||
context_window_tokens: None,
|
||||
model_extra: HashMap::new(),
|
||||
max_tool_iterations: 20,
|
||||
token_limit: 128_000,
|
||||
tool_result_max_chars: 20_000,
|
||||
context_summary_max_chars: 20_000,
|
||||
context_tool_result_trim_chars: 20_000,
|
||||
})
|
||||
}
|
||||
@ -54,7 +55,8 @@ fn make_weather_tool() -> Tool {
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_openai_tool_call() {
|
||||
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
|
||||
let config = load_openai_config()
|
||||
.expect("Please configure tests/test.env with valid API keys");
|
||||
|
||||
let provider = create_provider(config).expect("Failed to create provider");
|
||||
|
||||
@ -68,11 +70,7 @@ async fn test_openai_tool_call() {
|
||||
let response = provider.chat(request).await.unwrap();
|
||||
|
||||
// Should have tool calls
|
||||
assert!(
|
||||
!response.tool_calls.is_empty(),
|
||||
"Expected tool call, got: {}",
|
||||
response.content
|
||||
);
|
||||
assert!(!response.tool_calls.is_empty(), "Expected tool call, got: {}", response.content);
|
||||
|
||||
let tool_call = &response.tool_calls[0];
|
||||
assert_eq!(tool_call.name, "get_weather");
|
||||
@ -82,7 +80,8 @@ async fn test_openai_tool_call() {
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_openai_tool_call_with_manual_execution() {
|
||||
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
|
||||
let config = load_openai_config()
|
||||
.expect("Please configure tests/test.env with valid API keys");
|
||||
|
||||
let provider = create_provider(config).expect("Failed to create provider");
|
||||
|
||||
@ -95,7 +94,8 @@ async fn test_openai_tool_call_with_manual_execution() {
|
||||
};
|
||||
|
||||
let response1 = provider.chat(request1).await.unwrap();
|
||||
let tool_call = response1.tool_calls.first().expect("Expected tool call");
|
||||
let tool_call = response1.tool_calls.first()
|
||||
.expect("Expected tool call");
|
||||
assert_eq!(tool_call.name, "get_weather");
|
||||
|
||||
// Second request with tool result
|
||||
@ -118,7 +118,8 @@ async fn test_openai_tool_call_with_manual_execution() {
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_openai_no_tool_when_not_provided() {
|
||||
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
|
||||
let config = load_openai_config()
|
||||
.expect("Please configure tests/test.env with valid API keys");
|
||||
|
||||
let provider = create_provider(config).expect("Failed to create provider");
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user