Compare commits

...

4 Commits

Author SHA1 Message Date
fa3354db9c feat: add context_window_tokens to model configuration and update related logic
- Introduced context_window_tokens in ModelConfig and LLMProviderConfig structs.
- Updated context window estimation logic in ContextCompressor to use context_window_tokens.
- Modified tests to accommodate new context_window_tokens field.
- Refactored memory maintenance logic into a new memory_maintenance.rs file for better organization.
- Ensured backward compatibility by providing default values where necessary.

Co-authored-by: Copilot <copilot@github.com>
2026-04-28 11:29:06 +08:00
b2c8d76820 feat: 添加最终结果处理和调度压缩功能,重构会话管理逻辑以优化代码结构 2026-04-28 10:58:01 +08:00
33f5a4cbd2 feat: 添加执行服务和提示管理功能,重构相关模块以优化代码结构 2026-04-28 10:51:54 +08:00
73dab09bfe Refactor code for improved readability and consistency
- Adjusted formatting and indentation in various files for better clarity.
- Consolidated multi-line statements into single lines where appropriate.
- Enhanced error handling messages for better debugging.
- Added a new InboundProcessor struct to handle inbound messages more effectively.
- Updated test cases to ensure they align with the new code structure.
2026-04-28 10:33:31 +08:00
50 changed files with 3029 additions and 1811 deletions

View File

@ -134,7 +134,7 @@ PicoBot 会在 ~/.picobot/agent/AGENT.md 维护一份持久化 Agent 画像文
1. 系统先对当前活动历史做一个近似 token 估算。
估算规则不是调用 tokenizer而是按“约每 4 个字符约等于 1 token并再乘以 1.2 安全系数”计算。
2. 当估算结果超过模型上下文窗口的 50% 时,压缩器才认为“需要压缩”。
这里的上下文窗口来自 agent 对应模型配置里的 token_limit。
这里的上下文窗口来自 agent 对应模型配置里的 context_window_tokens未配置时按 128000 估算
3. 即使超过阈值,如果当前历史里的 user turn 数量不超过保留阈值,也不会压缩。
当前默认会完整保留最近 3 个 user turn。
4. 一旦满足条件,压缩器会先按 user 消息切分 turn再确定“旧历史”和“最近保留段”的分界点。

View File

@ -1,16 +1,16 @@
use async_trait::async_trait;
use crate::bus::message::ContentBlock;
use crate::bus::ChatMessage;
use crate::bus::message::ContentBlock;
use crate::bus::message::ToolMessageState;
use crate::config::LLMProviderConfig;
use crate::observability::{
truncate_args, Observer, ObserverEvent, ToolExecutionOutcome, ToolExecutionState,
Observer, ObserverEvent, ToolExecutionOutcome, ToolExecutionState, truncate_args,
};
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall};
use crate::providers::{ChatCompletionRequest, LLMProvider, Message, ToolCall, create_provider};
use crate::skills::SkillRuntime;
use crate::storage::SessionStore;
use crate::tools::{ToolContext, ToolRegistry};
use crate::text::{char_count, take_prefix_chars, take_suffix_chars};
use crate::tools::{ToolContext, ToolRegistry};
use async_trait::async_trait;
use std::collections::VecDeque;
use std::hash::{Hash, Hasher};
use std::io::Read;
@ -19,18 +19,13 @@ use std::time::Instant;
/// Minimum characters to keep when truncating
const TRUNCATION_SUFFIX_LEN: usize = 200;
const MEMORY_TOOL_USAGE_SYSTEM_PROMPT: &str =
include_str!("memory_tool_usage_system_prompt.md");
const MEMORY_TOOL_USAGE_SYSTEM_PROMPT: &str = include_str!("memory_tool_usage_system_prompt.md");
const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str = "工具已经启动并进入等待用户操作的状态。请先完成外部操作,完成后直接告诉我继续。";
const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str =
"工具已经启动并进入等待用户操作的状态。请先完成外部操作,完成后直接告诉我继续。";
const RECOVERABLE_LLM_ERROR_MESSAGE: &str = "模型服务暂时不可用或响应超时。请稍后重试。";
const SUPPORTED_IMAGE_MIME_TYPES: &[&str] = &[
"image/jpeg",
"image/png",
"image/gif",
"image/webp",
];
const SUPPORTED_IMAGE_MIME_TYPES: &[&str] = &["image/jpeg", "image/png", "image/gif", "image/webp"];
/// Build content blocks from text and media paths
fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock> {
@ -115,14 +110,15 @@ fn truncate_tool_result(output: &str, max_tool_result_chars: usize) -> String {
let tail = take_suffix_chars(output, total_chars.saturating_sub(truncated_start_len));
format!(
"...\n\n[Output truncated - {} characters removed]\n\n{}",
truncated_start_len,
tail
truncated_start_len, tail
)
}
}
fn parse_pending_tool_output(output: &str) -> Option<String> {
output.strip_prefix(PENDING_USER_ACTION_MARKER).map(|rest| rest.trim().to_string())
output
.strip_prefix(PENDING_USER_ACTION_MARKER)
.map(|rest| rest.trim().to_string())
}
fn normalize_tool_arguments(arguments: &serde_json::Value) -> serde_json::Value {
@ -341,7 +337,10 @@ impl AgentLoop {
})
}
pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>) -> Result<Self, AgentError> {
pub fn with_tools(
provider_config: LLMProviderConfig,
tools: Arc<ToolRegistry>,
) -> Result<Self, AgentError> {
let max_iterations = provider_config.max_tool_iterations;
let provider = create_provider(provider_config.clone())
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
@ -416,9 +415,16 @@ impl AgentLoop {
/// it loops back to the LLM with the tool results until either:
/// - The LLM returns no more tool calls (final response)
/// - Maximum iterations are reached
pub async fn process(&self, mut messages: Vec<ChatMessage>) -> Result<AgentProcessResult, AgentError> {
pub async fn process(
&self,
mut messages: Vec<ChatMessage>,
) -> Result<AgentProcessResult, AgentError> {
#[cfg(debug_assertions)]
tracing::debug!(history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process");
tracing::debug!(
history_len = messages.len(),
max_iterations = self.max_iterations,
"Starting agent process"
);
// Track tool calls for loop detection
let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default());
@ -441,7 +447,11 @@ impl AgentLoop {
if let Some(skill_tool) = self.skills.skill_tool_definition() {
tool_defs.push(skill_tool);
}
let tools = if tool_defs.is_empty() { None } else { Some(tool_defs) };
let tools = if tool_defs.is_empty() {
None
} else {
Some(tool_defs)
};
let request = ChatCompletionRequest {
messages: messages_for_llm,
@ -461,7 +471,8 @@ impl AgentLoop {
error_details = %format_error_chain(e.as_ref()),
"LLM request failed"
);
let assistant_message = ChatMessage::assistant(recoverable_llm_message(&e.to_string()));
let assistant_message =
ChatMessage::assistant(recoverable_llm_message(&e.to_string()));
emitted_messages.push(assistant_message.clone());
return Ok(AgentProcessResult {
final_response: assistant_message,
@ -480,7 +491,8 @@ impl AgentLoop {
// If no tool calls, this is the final response
if response.tool_calls.is_empty() {
let assistant_message = if let Some(reasoning_content) = response.reasoning_content {
let assistant_message = if let Some(reasoning_content) = response.reasoning_content
{
ChatMessage::assistant_with_reasoning(response.content, reasoning_content)
} else {
ChatMessage::assistant(response.content)
@ -493,24 +505,35 @@ impl AgentLoop {
}
// Execute tool calls
tracing::info!(iteration, count = response.tool_calls.len(), "Tool calls detected, executing tools");
tracing::info!(
iteration,
count = response.tool_calls.len(),
"Tool calls detected, executing tools"
);
// Add assistant message with tool calls
let assistant_message = if let Some(reasoning_content) = response.reasoning_content.clone() {
ChatMessage::assistant_with_tool_calls_and_reasoning(
response.content.clone(),
response.tool_calls.clone(),
reasoning_content,
)
} else {
ChatMessage::assistant_with_tool_calls(
response.content.clone(),
response.tool_calls.clone(),
)
};
let assistant_message =
if let Some(reasoning_content) = response.reasoning_content.clone() {
ChatMessage::assistant_with_tool_calls_and_reasoning(
response.content.clone(),
response.tool_calls.clone(),
reasoning_content,
)
} else {
ChatMessage::assistant_with_tool_calls(
response.content.clone(),
response.tool_calls.clone(),
)
};
messages.push(assistant_message.clone());
emitted_messages.push(assistant_message);
self.emit_live_tool_call_message(emitted_messages.last().expect("assistant message just pushed").clone()).await;
self.emit_live_tool_call_message(
emitted_messages
.last()
.expect("assistant message just pushed")
.clone(),
)
.await;
// Execute tools and add results to messages
let tool_results = self.execute_tools(&response.tool_calls).await;
@ -519,7 +542,9 @@ impl AgentLoop {
// Log function call with name and arguments
let args_str = match &tool_call.arguments {
serde_json::Value::Object(obj) if obj.is_empty() => "{}".to_string(),
other => serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()),
other => {
serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string())
}
};
tracing::info!(tool = %tool_call.name, args = %args_str, "Calling tool");
@ -595,7 +620,11 @@ impl AgentLoop {
// Loop continues to next iteration with updated messages
#[cfg(debug_assertions)]
tracing::debug!(iteration, message_count = messages.len(), "Tool execution complete, continuing to next iteration");
tracing::debug!(
iteration,
message_count = messages.len(),
"Tool execution complete, continuing to next iteration"
);
}
// Max iterations reached - ask LLM for a summary based on completed work
@ -604,7 +633,7 @@ impl AgentLoop {
// Add a message asking for summary
let summary_request = ChatMessage::user(
"You have reached the maximum number of tool call iterations. \
Please provide your best answer based on the work completed so far."
Please provide your best answer based on the work completed so far.",
);
messages.push(summary_request);
@ -624,7 +653,8 @@ impl AgentLoop {
match (*self.provider).chat(request).await {
Ok(response) => {
let assistant_message = if let Some(reasoning_content) = response.reasoning_content {
let assistant_message = if let Some(reasoning_content) = response.reasoning_content
{
ChatMessage::assistant_with_reasoning(response.content, reasoning_content)
} else {
ChatMessage::assistant(response.content)
@ -745,10 +775,7 @@ impl AgentLoop {
}
// Apply duration
ToolExecutionOutcome {
duration,
..result
}
ToolExecutionOutcome { duration, ..result }
}
/// Internal tool execution without event tracking.
@ -790,10 +817,7 @@ impl AgentLoop {
"arguments": normalized_arguments,
}),
);
ToolExecutionOutcome::failure(
format!("Error: {}", err),
Some(err),
)
ToolExecutionOutcome::failure(format!("Error: {}", err), Some(err))
}
};
}
@ -809,7 +833,10 @@ impl AgentLoop {
}
};
match tool.execute_with_context(&self.tool_context, normalized_arguments.clone()).await {
match tool
.execute_with_context(&self.tool_context, normalized_arguments.clone())
.await
{
Ok(result) => {
if result.success {
if let Some(pending_output) = parse_pending_tool_output(&result.output) {
@ -827,10 +854,7 @@ impl AgentLoop {
output = %result.output,
"Tool returned an error result"
);
ToolExecutionOutcome::failure(
format!("Error: {}", error),
Some(error),
)
ToolExecutionOutcome::failure(format!("Error: {}", error), Some(error))
}
}
Err(e) => {
@ -842,10 +866,7 @@ impl AgentLoop {
error_details = %format!("{:#}", e),
"Tool execution failed"
);
ToolExecutionOutcome::failure(
format!("Error: {}", e),
Some(e.to_string()),
)
ToolExecutionOutcome::failure(format!("Error: {}", e), Some(e.to_string()))
}
}
}
@ -863,7 +884,9 @@ impl AgentLoop {
return;
};
if let Err(err) = store.append_skill_event(Some(session_id), event_type, skill_name, &payload) {
if let Err(err) =
store.append_skill_event(Some(session_id), event_type, skill_name, &payload)
{
tracing::warn!(error = %err, event_type = %event_type, "Failed to record skill event");
}
}
@ -942,28 +965,37 @@ mod tests {
assert_eq!(provider_message.role, "assistant");
assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1);
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].id, "call_1");
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].name, "calculator");
assert_eq!(
provider_message.tool_calls.as_ref().unwrap()[0].id,
"call_1"
);
assert_eq!(
provider_message.tool_calls.as_ref().unwrap()[0].name,
"calculator"
);
}
#[test]
fn test_chat_message_to_llm_message_preserves_reasoning_content() {
let chat_message = ChatMessage::assistant_with_reasoning(
"final answer",
"hidden chain of thought",
);
let chat_message =
ChatMessage::assistant_with_reasoning("final answer", "hidden chain of thought");
let provider_message = chat_message_to_llm_message(&chat_message);
assert_eq!(provider_message.role, "assistant");
assert_eq!(provider_message.reasoning_content.as_deref(), Some("hidden chain of thought"));
assert_eq!(
provider_message.reasoning_content.as_deref(),
Some("hidden chain of thought")
);
}
#[test]
fn test_memory_prompt_requires_proactive_memory_search() {
assert!(MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("在绝大多数请求开始时"));
assert!(MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("先使用长期记忆检索工具 memory_search"));
assert!(MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("不要因为你自认为已经能直接回答就省略检索"));
assert!(
MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("不要因为你自认为已经能直接回答就省略检索")
);
}
#[test]
@ -1001,9 +1033,13 @@ mod tests {
#[test]
fn test_normalize_tool_arguments_keeps_plain_string() {
let normalized = normalize_tool_arguments(&serde_json::Value::String("plain text".to_string()));
let normalized =
normalize_tool_arguments(&serde_json::Value::String("plain text".to_string()));
assert_eq!(normalized, serde_json::Value::String("plain text".to_string()));
assert_eq!(
normalized,
serde_json::Value::String("plain text".to_string())
);
}
#[test]
@ -1028,7 +1064,9 @@ mod tests {
assert_eq!(blocks.len(), 2);
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
assert!(matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,")));
assert!(
matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,"))
);
}
}

View File

@ -1,11 +1,9 @@
use crate::bus::{
ChatMessage,
SYSTEM_CONTEXT_AGENT_PROMPT,
SYSTEM_CONTEXT_HISTORY_COMPACTION,
ChatMessage, SYSTEM_CONTEXT_AGENT_PROMPT, SYSTEM_CONTEXT_HISTORY_COMPACTION,
SYSTEM_CONTEXT_SCHEDULED_PROMPT,
};
use crate::config::LLMProviderConfig;
use crate::providers::{create_provider, ChatCompletionRequest, LLMProvider, Message};
use crate::providers::{ChatCompletionRequest, LLMProvider, Message, create_provider};
use crate::text::{char_count, take_prefix_chars};
use crate::agent::AgentError;
@ -62,6 +60,7 @@ pub struct ContextCompressor {
}
impl ContextCompressor {
#[cfg(test)]
fn summary_char_budget_for_context_window(context_window: usize) -> usize {
const SUMMARY_RATIO: f64 = 0.1;
const CHARS_PER_TOKEN: f64 = 2.5;
@ -221,7 +220,9 @@ Be concise, aim for {} characters or less.
.await;
}
let per_chunk_target = (target_chars / layer.len().max(1)).max(500).min(target_chars);
let per_chunk_target = (target_chars / layer.len().max(1))
.max(500)
.min(target_chars);
let mut summaries = Vec::with_capacity(layer.len());
for chunk in &layer {
summaries.push(
@ -241,7 +242,9 @@ Be concise, aim for {} characters or less.
let merged = summaries.join("\n\n");
if char_count(&merged) <= target_chars {
return self.summarize_transcript(provider, &merged, target_chars).await;
return self
.summarize_transcript(provider, &merged, target_chars)
.await;
}
layer = Self::split_text_chunks(&merged, target_chars);
@ -314,7 +317,10 @@ Be concise, aim for {} characters or less.
|| message.has_system_context(SYSTEM_CONTEXT_SCHEDULED_PROMPT))
}
fn split_prefix_messages(&self, history: &[ChatMessage]) -> (Vec<ChatMessage>, Vec<ChatMessage>) {
fn split_prefix_messages(
&self,
history: &[ChatMessage],
) -> (Vec<ChatMessage>, Vec<ChatMessage>) {
let preserved_system_messages = history
.iter()
.filter(|message| self.should_preserve_system_message(message))
@ -343,7 +349,8 @@ Be concise, aim for {} characters or less.
return Ok(None);
}
let preserved_turn_start = turn_ranges[turn_ranges.len() - self.config.retain_last_user_turns].start;
let preserved_turn_start =
turn_ranges[turn_ranges.len() - self.config.retain_last_user_turns].start;
if preserved_turn_start == 0 {
return Ok(None);
}
@ -357,10 +364,10 @@ Be concise, aim for {} characters or less.
Ok(Some(HistoryCompactionPlan {
preserved_system_messages,
summary_message: ChatMessage::system_with_context(format!(
"[Compressed History]\n\n{}",
summary
), Some(SYSTEM_CONTEXT_HISTORY_COMPACTION.to_string())),
summary_message: ChatMessage::system_with_context(
format!("[Compressed History]\n\n{}", summary),
Some(SYSTEM_CONTEXT_HISTORY_COMPACTION.to_string()),
),
preserved_messages: history[preserved_turn_start..].to_vec(),
compressed_turns: turn_ranges.len() - self.config.retain_last_user_turns,
preserved_turns: self.config.retain_last_user_turns,
@ -392,7 +399,10 @@ Be concise, aim for {} characters or less.
"Starting context compression"
);
let current_history = match self.build_compaction_plan(&history, provider_config).await? {
let current_history = match self
.build_compaction_plan(&history, provider_config)
.await?
{
Some(plan) => {
let mut compressed = Vec::with_capacity(
plan.preserved_system_messages.len() + plan.preserved_messages.len() + 1,
@ -429,8 +439,12 @@ Be concise, aim for {} characters or less.
let transcript = Self::build_transcript(messages);
let result = if char_count(&transcript) <= self.config.summary_max_chars {
self.summarize_transcript(provider.as_ref(), &transcript, self.config.summary_max_chars)
.await
self.summarize_transcript(
provider.as_ref(),
&transcript,
self.config.summary_max_chars,
)
.await
} else {
self.summarize_chunked_transcript(provider.as_ref(), messages, &transcript)
.await
@ -440,7 +454,10 @@ Be concise, aim for {} characters or less.
Ok(summary) => Ok(summary),
Err(e) => {
tracing::warn!(error = %e, "LLM summarization failed, using truncated transcript");
Ok(take_prefix_chars(&transcript, self.config.summary_max_chars))
Ok(take_prefix_chars(
&transcript,
self.config.summary_max_chars,
))
}
}
}
@ -463,7 +480,11 @@ mod tests {
// "Hi there!" (8) -> ceil(8/4)+4 = 2+4 = 6
// "How are you?" (11) -> ceil(11/4)+4 = 3+4 = 7
// raw = 19, with 1.2x = ~23
assert!(tokens > 18 && tokens < 30, "Expected ~23 tokens, got {}", tokens);
assert!(
tokens > 18 && tokens < 30,
"Expected ~23 tokens, got {}",
tokens
);
}
#[test]
@ -487,21 +508,39 @@ mod tests {
];
let turns = compressor.user_turn_ranges(&history);
assert_eq!(turns, vec![
UserTurnRange { start: 1, end_exclusive: 4 },
UserTurnRange { start: 4, end_exclusive: 6 },
UserTurnRange { start: 6, end_exclusive: 7 },
]);
assert_eq!(
turns,
vec![
UserTurnRange {
start: 1,
end_exclusive: 4
},
UserTurnRange {
start: 4,
end_exclusive: 6
},
UserTurnRange {
start: 6,
end_exclusive: 7
},
]
);
}
#[test]
fn test_split_prefix_messages_preserves_key_system_messages() {
let compressor = ContextCompressor::new(50);
let prefix = vec![
ChatMessage::system_with_context("agent prompt", Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string())),
ChatMessage::system_with_context(
"agent prompt",
Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string()),
),
ChatMessage::user("u1"),
ChatMessage::assistant("a1"),
ChatMessage::system_with_context("scheduled prompt", Some(SYSTEM_CONTEXT_SCHEDULED_PROMPT.to_string())),
ChatMessage::system_with_context(
"scheduled prompt",
Some(SYSTEM_CONTEXT_SCHEDULED_PROMPT.to_string()),
),
];
let (preserved_system_messages, summary_source) = compressor.split_prefix_messages(&prefix);
@ -519,10 +558,22 @@ mod tests {
#[test]
fn test_summary_char_budget_for_context_window_scales_and_clamps() {
assert_eq!(ContextCompressor::summary_char_budget_for_context_window(4_096), 1_500);
assert_eq!(ContextCompressor::summary_char_budget_for_context_window(65_536), 16_384);
assert_eq!(ContextCompressor::summary_char_budget_for_context_window(128_000), 32_000);
assert_eq!(ContextCompressor::summary_char_budget_for_context_window(400_000), 50_000);
assert_eq!(
ContextCompressor::summary_char_budget_for_context_window(4_096),
1_500
);
assert_eq!(
ContextCompressor::summary_char_budget_for_context_window(65_536),
16_384
);
assert_eq!(
ContextCompressor::summary_char_budget_for_context_window(128_000),
32_000
);
assert_eq!(
ContextCompressor::summary_char_budget_for_context_window(400_000),
50_000
);
}
#[test]

View File

@ -1,5 +1,5 @@
pub mod agent_loop;
pub mod context_compressor;
pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult, EmittedMessageHandler};
pub use agent_loop::{AgentError, AgentLoop, AgentProcessResult, EmittedMessageHandler};
pub use context_compressor::ContextCompressor;

View File

@ -1,6 +1,6 @@
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use std::collections::HashMap;
use crate::bus::{MessageBus, OutboundMessage};
use crate::channels::base::{Channel, ChannelError};
@ -22,7 +22,10 @@ impl OutboundDispatcher {
/// Register a channel with the dispatcher
pub async fn register_channel(&self, name: &str, channel: Arc<dyn Channel + Send + Sync>) {
self.channels.write().await.insert(name.to_string(), channel);
self.channels
.write()
.await
.insert(name.to_string(), channel);
}
/// Run the dispatcher loop - consumes from bus and dispatches to channels

View File

@ -1,5 +1,5 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::providers::ToolCall;
@ -34,7 +34,9 @@ pub struct ImageUrlBlock {
impl ContentBlock {
pub fn text(content: impl Into<String>) -> Self {
Self::Text { text: content.into() }
Self::Text {
text: content.into(),
}
}
pub fn image_url(url: impl Into<String>) -> Self {
@ -50,10 +52,10 @@ impl ContentBlock {
#[derive(Debug, Clone)]
pub struct MediaItem {
pub path: String, // Local file path
pub media_type: String, // "image", "audio", "file", "video"
pub path: String, // Local file path
pub media_type: String, // "image", "audio", "file", "video"
pub mime_type: Option<String>,
pub original_key: Option<String>, // Feishu file_key for download
pub original_key: Option<String>, // Feishu file_key for download
}
impl MediaItem {
@ -76,7 +78,7 @@ pub struct ChatMessage {
pub id: String,
pub role: String,
pub content: String,
pub media_refs: Vec<String>, // Paths to media files for context
pub media_refs: Vec<String>, // Paths to media files for context
pub timestamp: i64,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_context: Option<String>,
@ -150,7 +152,10 @@ impl ChatMessage {
message
}
pub fn assistant_with_tool_calls(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
pub fn assistant_with_tool_calls(
content: impl Into<String>,
tool_calls: Vec<ToolCall>,
) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
role: "assistant".to_string(),
@ -199,8 +204,17 @@ impl ChatMessage {
}
}
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self {
Self::tool_with_state(tool_call_id, tool_name, content, ToolMessageState::Completed)
pub fn tool(
tool_call_id: impl Into<String>,
tool_name: impl Into<String>,
content: impl Into<String>,
) -> Self {
Self::tool_with_state(
tool_call_id,
tool_name,
content,
ToolMessageState::Completed,
)
}
pub fn tool_with_state(
@ -287,6 +301,8 @@ pub enum OutboundEventKind {
ToolCall,
ToolResult,
ToolPending,
SchedulerNotification,
ErrorNotification,
}
impl OutboundMessage {
@ -316,6 +332,30 @@ impl OutboundMessage {
}
}
pub fn scheduler_notification(
channel: impl Into<String>,
chat_id: impl Into<String>,
content: impl Into<String>,
reply_to: Option<String>,
metadata: HashMap<String, String>,
) -> Self {
let mut message = Self::assistant(channel, chat_id, content, reply_to, metadata);
message.event_kind = OutboundEventKind::SchedulerNotification;
message
}
pub fn error_notification(
channel: impl Into<String>,
chat_id: impl Into<String>,
content: impl Into<String>,
reply_to: Option<String>,
metadata: HashMap<String, String>,
) -> Self {
let mut message = Self::assistant(channel, chat_id, content, reply_to, metadata);
message.event_kind = OutboundEventKind::ErrorNotification;
message
}
pub fn tool_call(
channel: impl Into<String>,
chat_id: impl Into<String>,
@ -417,20 +457,17 @@ impl OutboundMessage {
));
}
outbound.extend(tool_calls
.iter()
.map(|tool_call| {
Self::tool_call(
channel.to_string(),
chat_id.to_string(),
tool_call.id.clone(),
tool_call.name.clone(),
tool_call.arguments.clone(),
reply_to.clone(),
metadata.clone(),
)
})
);
outbound.extend(tool_calls.iter().map(|tool_call| {
Self::tool_call(
channel.to_string(),
chat_id.to_string(),
tool_call.id.clone(),
tool_call.name.clone(),
tool_call.arguments.clone(),
reply_to.clone(),
metadata.clone(),
)
}));
outbound
} else {
vec![Self::assistant(
@ -442,7 +479,11 @@ impl OutboundMessage {
)]
}
}
"tool" => match message.tool_state.as_ref().unwrap_or(&ToolMessageState::Completed) {
"tool" => match message
.tool_state
.as_ref()
.unwrap_or(&ToolMessageState::Completed)
{
ToolMessageState::Completed => vec![Self::tool_result(
channel.to_string(),
chat_id.to_string(),
@ -467,7 +508,10 @@ impl OutboundMessage {
}
}
pub(crate) fn format_tool_call_content(tool_name: &str, tool_arguments: &serde_json::Value) -> String {
pub(crate) fn format_tool_call_content(
tool_name: &str,
tool_arguments: &serde_json::Value,
) -> String {
match tool_arguments {
serde_json::Value::Object(map) if map.is_empty() => tool_name.to_string(),
other => format!("{}\nargs: {}", tool_name, format_tool_arguments_json(other)),
@ -544,21 +588,25 @@ mod tests {
],
);
let outbound = OutboundMessage::from_chat_message(
"feishu",
"chat-1",
None,
&HashMap::new(),
&message,
);
let outbound =
OutboundMessage::from_chat_message("feishu", "chat-1", None, &HashMap::new(), &message);
assert_eq!(outbound.len(), 2);
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolCall);
assert_eq!(outbound[0].tool_name.as_deref(), Some("calculator"));
assert_eq!(outbound[0].tool_arguments.as_ref().unwrap()["expression"], "1 + 1");
assert_eq!(outbound[0].content, "calculator\nargs: {\"expression\":\"1 + 1\"}");
assert_eq!(
outbound[0].tool_arguments.as_ref().unwrap()["expression"],
"1 + 1"
);
assert_eq!(
outbound[0].content,
"calculator\nargs: {\"expression\":\"1 + 1\"}"
);
assert_eq!(outbound[1].tool_name.as_deref(), Some("file_read"));
assert_eq!(outbound[1].content, "file_read\nargs: {\"path\":\"README.md\"}");
assert_eq!(
outbound[1].content,
"file_read\nargs: {\"path\":\"README.md\"}"
);
}
#[test]
@ -572,13 +620,8 @@ mod tests {
}],
);
let outbound = OutboundMessage::from_chat_message(
"feishu",
"chat-1",
None,
&HashMap::new(),
&message,
);
let outbound =
OutboundMessage::from_chat_message("feishu", "chat-1", None, &HashMap::new(), &message);
assert_eq!(outbound.len(), 2);
assert_eq!(outbound[0].event_kind, OutboundEventKind::AssistantResponse);
@ -591,13 +634,8 @@ mod tests {
fn test_from_chat_message_includes_tool_result() {
let message = ChatMessage::tool("call-9", "calculator", "2");
let outbound = OutboundMessage::from_chat_message(
"feishu",
"chat-1",
None,
&HashMap::new(),
&message,
);
let outbound =
OutboundMessage::from_chat_message("feishu", "chat-1", None, &HashMap::new(), &message);
assert_eq!(outbound.len(), 1);
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolResult);
@ -612,13 +650,8 @@ mod tests {
ToolMessageState::PendingUserAction,
);
let outbound = OutboundMessage::from_chat_message(
"feishu",
"chat-1",
None,
&HashMap::new(),
&message,
);
let outbound =
OutboundMessage::from_chat_message("feishu", "chat-1", None, &HashMap::new(), &message);
assert_eq!(outbound.len(), 1);
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolPending);

View File

@ -3,18 +3,13 @@ pub mod message;
pub use dispatcher::OutboundDispatcher;
pub use message::{
ChatMessage,
ContentBlock,
InboundMessage,
MediaItem,
OutboundMessage,
SYSTEM_CONTEXT_AGENT_PROMPT,
SYSTEM_CONTEXT_HISTORY_COMPACTION,
ChatMessage, ContentBlock, InboundMessage, MediaItem, OutboundMessage,
SYSTEM_CONTEXT_AGENT_PROMPT, SYSTEM_CONTEXT_HISTORY_COMPACTION,
SYSTEM_CONTEXT_SCHEDULED_PROMPT,
};
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
use tokio::sync::{Mutex, mpsc};
// ============================================================================
// MessageBus - Async message queue for Channel <-> Agent communication
@ -52,7 +47,8 @@ impl MessageBus {
/// Consume an inbound message (Agent -> Bus)
pub async fn consume_inbound(&self) -> InboundMessage {
let msg = self.inbound_rx
let msg = self
.inbound_rx
.lock()
.await
.recv()

File diff suppressed because it is too large Load Diff

View File

@ -2,7 +2,7 @@ use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::bus::{MessageBus, OutboundMessage};
use crate::bus::MessageBus;
use crate::channels::base::{Channel, ChannelError};
use crate::channels::feishu::FeishuChannel;
use crate::config::Config;
@ -28,12 +28,18 @@ impl ChannelManager {
}
/// Initialize all Channel instances from config
pub async fn init(&self, config: &Config, _provider_config: crate::config::LLMProviderConfig) -> Result<(), ChannelError> {
pub async fn init(
&self,
config: &Config,
_provider_config: crate::config::LLMProviderConfig,
) -> Result<(), ChannelError> {
// Initialize Feishu channel if enabled
if let Some(feishu_config) = config.channels.get("feishu") {
if feishu_config.enabled {
let channel = FeishuChannel::new(feishu_config.clone(), _provider_config)
.map_err(|e| ChannelError::Other(format!("Failed to create Feishu channel: {}", e)))?;
let channel =
FeishuChannel::new(feishu_config.clone(), _provider_config).map_err(|e| {
ChannelError::Other(format!("Failed to create Feishu channel: {}", e))
})?;
self.channels
.write()
@ -75,13 +81,12 @@ impl ChannelManager {
self.channels.read().await.get(name).cloned()
}
/// Dispatch an outbound message to the appropriate channel
pub async fn dispatch(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
let channel_name = &msg.channel;
if let Some(channel) = self.get_channel(channel_name).await {
channel.send(msg).await
} else {
Err(ChannelError::Other(format!("Channel not found: {}", channel_name)))
}
pub async fn channels(&self) -> Vec<(String, Arc<dyn Channel + Send + Sync>)> {
self.channels
.read()
.await
.iter()
.map(|(name, channel)| (name.clone(), channel.clone()))
.collect()
}
}

View File

@ -3,5 +3,5 @@ pub mod feishu;
pub mod manager;
pub use base::{Channel, ChannelError};
pub use manager::ChannelManager;
pub use feishu::FeishuChannel;
pub use manager::ChannelManager;

View File

@ -1,4 +1,4 @@
use tokio::io::{AsyncBufReadExt, BufReader, AsyncWriteExt};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
pub struct CliChannel {
read: BufReader<tokio::io::Stdin>,

View File

@ -49,18 +49,27 @@ impl InputHandler {
}
pub async fn write_output(&mut self, content: &str) -> Result<(), InputError> {
self.channel.write_line(content).await.map_err(InputError::IoError)
self.channel
.write_line(content)
.await
.map_err(InputError::IoError)
}
pub async fn write_response(&mut self, content: &str) -> Result<(), InputError> {
self.channel.write_response(content).await.map_err(InputError::IoError)
self.channel
.write_response(content)
.await
.map_err(InputError::IoError)
}
fn handle_special_commands(&self, line: &str) -> Option<InputCommand> {
let trimmed = line.trim();
let mut parts = trimmed.splitn(2, char::is_whitespace);
let command = parts.next()?;
let arg = parts.next().map(str::trim).filter(|value| !value.is_empty());
let arg = parts
.next()
.map(str::trim)
.filter(|value| !value.is_empty());
match command {
"/quit" | "/exit" | "/q" => Some(InputCommand::Exit),
@ -105,14 +114,26 @@ mod tests {
fn test_special_command_parsing() {
let handler = InputHandler::new();
assert_eq!(handler.handle_special_commands("/quit"), Some(InputCommand::Exit));
assert_eq!(handler.handle_special_commands("/clear"), Some(InputCommand::Clear));
assert_eq!(handler.handle_special_commands("/new"), Some(InputCommand::New(None)));
assert_eq!(
handler.handle_special_commands("/quit"),
Some(InputCommand::Exit)
);
assert_eq!(
handler.handle_special_commands("/clear"),
Some(InputCommand::Clear)
);
assert_eq!(
handler.handle_special_commands("/new"),
Some(InputCommand::New(None))
);
assert_eq!(
handler.handle_special_commands("/new planning"),
Some(InputCommand::New(Some("planning".to_string())))
);
assert_eq!(handler.handle_special_commands("/sessions"), Some(InputCommand::Sessions));
assert_eq!(
handler.handle_special_commands("/sessions"),
Some(InputCommand::Sessions)
);
assert_eq!(
handler.handle_special_commands("/use abc123"),
Some(InputCommand::Use("abc123".to_string()))
@ -121,8 +142,14 @@ mod tests {
handler.handle_special_commands("/rename project alpha"),
Some(InputCommand::Rename("project alpha".to_string()))
);
assert_eq!(handler.handle_special_commands("/archive"), Some(InputCommand::Archive));
assert_eq!(handler.handle_special_commands("/delete"), Some(InputCommand::Delete));
assert_eq!(
handler.handle_special_commands("/archive"),
Some(InputCommand::Archive)
);
assert_eq!(
handler.handle_special_commands("/delete"),
Some(InputCommand::Delete)
);
assert_eq!(handler.handle_special_commands("/unknown"), None);
assert_eq!(handler.handle_special_commands("/use"), None);
}

View File

@ -5,7 +5,10 @@ use tokio_tungstenite::{connect_async, tungstenite::Message};
use crate::cli::{InputCommand, InputEvent, InputHandler};
fn format_session_list(sessions: &[crate::protocol::SessionSummary], current_session_id: Option<&str>) -> String {
fn format_session_list(
sessions: &[crate::protocol::SessionSummary],
current_session_id: Option<&str>,
) -> String {
if sessions.is_empty() {
return "No sessions found.".to_string();
}
@ -25,11 +28,7 @@ fn format_session_list(sessions: &[crate::protocol::SessionSummary], current_ses
};
lines.push(format!(
"{} {} | {} | {} messages{}",
marker,
session.session_id,
session.title,
session.message_count,
archived,
marker, session.session_id, session.title, session.message_count, archived,
));
}

View File

@ -123,7 +123,9 @@ fn default_allow_from() -> Vec<String> {
fn default_media_dir() -> String {
let home = dirs::home_dir().unwrap_or_else(|| std::path::PathBuf::from("."));
home.join(".picobot/media/feishu").to_string_lossy().to_string()
home.join(".picobot/media/feishu")
.to_string_lossy()
.to_string()
}
fn default_reaction_emoji() -> String {
@ -157,6 +159,8 @@ pub struct ModelConfig {
pub temperature: Option<f32>,
#[serde(default)]
pub max_tokens: Option<u32>,
#[serde(default)]
pub context_window_tokens: Option<u32>,
#[serde(flatten)]
pub extra: HashMap<String, serde_json::Value>,
}
@ -199,7 +203,10 @@ pub struct GatewayConfig {
pub show_tool_results: bool,
#[serde(default, rename = "session_ttl_hours")]
pub session_ttl_hours: Option<u64>,
#[serde(default = "default_agent_prompt_reinject_every", rename = "agent_prompt_reinject_every")]
#[serde(
default = "default_agent_prompt_reinject_every",
rename = "agent_prompt_reinject_every"
)]
pub agent_prompt_reinject_every: u64,
}
@ -388,7 +395,10 @@ impl SchedulerSchedule {
}
pub fn is_one_shot(&self) -> bool {
matches!(self, SchedulerSchedule::Delay { .. } | SchedulerSchedule::At { .. })
matches!(
self,
SchedulerSchedule::Delay { .. } | SchedulerSchedule::At { .. }
)
}
pub fn normalized_for_storage(&self) -> Self {
@ -518,6 +528,7 @@ pub struct LLMProviderConfig {
pub model_id: String,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub context_window_tokens: Option<u32>,
pub model_extra: HashMap<String, serde_json::Value>,
pub max_tool_iterations: usize,
pub tool_result_max_chars: usize,
@ -526,7 +537,7 @@ pub struct LLMProviderConfig {
impl LLMProviderConfig {
pub fn context_window_tokens(&self) -> usize {
self.max_tokens
self.context_window_tokens
.map(|value| value as usize)
.unwrap_or(128_000)
}
@ -581,13 +592,19 @@ impl Config {
}
pub fn get_provider_config(&self, agent_name: &str) -> Result<LLMProviderConfig, ConfigError> {
let agent = self.agents.get(agent_name)
let agent = self
.agents
.get(agent_name)
.ok_or(ConfigError::AgentNotFound(agent_name.to_string()))?;
let provider = self.providers.get(&agent.provider)
let provider = self
.providers
.get(&agent.provider)
.ok_or(ConfigError::ProviderNotFound(agent.provider.clone()))?;
let model = self.models.get(&agent.model)
let model = self
.models
.get(&agent.model)
.ok_or(ConfigError::ModelNotFound(agent.model.clone()))?;
Ok(LLMProviderConfig {
@ -600,6 +617,7 @@ impl Config {
model_id: model.model_id.clone(),
temperature: model.temperature,
max_tokens: model.max_tokens,
context_window_tokens: model.context_window_tokens,
model_extra: model.extra.clone(),
max_tool_iterations: agent.max_tool_iterations,
tool_result_max_chars: agent.tool_result_max_chars,
@ -621,11 +639,17 @@ pub enum ConfigError {
impl std::fmt::Display for ConfigError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConfigError::ConfigNotFound(path) => write!(f, "Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json", path),
ConfigError::ConfigNotFound(path) => write!(
f,
"Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json",
path
),
ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name),
ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name),
ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name),
ConfigError::InvalidSchedulerJob(message) => write!(f, "Invalid scheduler job: {}", message),
ConfigError::InvalidSchedulerJob(message) => {
write!(f, "Invalid scheduler job: {}", message)
}
ConfigError::InvalidTimezone(message) => write!(f, "Invalid timezone: {}", message),
}
}
@ -661,18 +685,19 @@ fn resolve_env_placeholders(content: &str) -> String {
re.replace_all(content, |caps: &regex::Captures| {
let var_name = &caps[1];
env::var(var_name).unwrap_or_else(|_| caps[0].to_string())
}).to_string()
})
.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
fn write_test_config() -> tempfile::NamedTempFile {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
fn write_test_config() -> tempfile::NamedTempFile {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
@ -708,15 +733,15 @@ mod tests {
"agent_prompt_reinject_every": 120
}
}"#,
)
.unwrap();
file
}
)
.unwrap();
file
}
#[test]
fn test_config_load() {
let file = write_test_config();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
let file = write_test_config();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
// Check providers
assert!(config.providers.contains_key("volcengine"));
@ -876,7 +901,10 @@ mod tests {
let config = Config::load(file.path().to_str().unwrap()).unwrap();
assert_eq!(config.time.timezone, "Asia/Shanghai");
assert_eq!(config.time.parse_timezone().unwrap(), chrono_tz::Asia::Shanghai);
assert_eq!(
config.time.parse_timezone().unwrap(),
chrono_tz::Asia::Shanghai
);
}
#[test]
@ -983,7 +1011,10 @@ mod tests {
let config = Config::load(file.path().to_str().unwrap()).unwrap();
assert_eq!(config.agents["default"].max_tool_iterations, 100);
assert_eq!(config.agents["default"].tool_result_max_chars, 20_000);
assert_eq!(config.agents["default"].context_tool_result_trim_chars, 2_000);
assert_eq!(
config.agents["default"].context_tool_result_trim_chars,
2_000
);
}
#[test]
@ -1029,7 +1060,44 @@ mod tests {
}
#[test]
fn test_provider_config_summary_budget_scales_with_model_max_tokens() {
fn test_provider_config_summary_budget_scales_with_context_window_tokens() {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {}
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus",
"context_window_tokens": 4096
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
}
}"#,
)
.unwrap();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
let provider_config = config.get_provider_config("default").unwrap();
assert_eq!(provider_config.context_window_tokens(), 4096);
assert_eq!(provider_config.context_summary_char_budget(), 1_500);
}
#[test]
fn test_provider_config_max_tokens_does_not_change_context_window() {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
@ -1061,8 +1129,9 @@ mod tests {
let config = Config::load(file.path().to_str().unwrap()).unwrap();
let provider_config = config.get_provider_config("default").unwrap();
assert_eq!(provider_config.context_window_tokens(), 4096);
assert_eq!(provider_config.context_summary_char_budget(), 1_500);
assert_eq!(provider_config.max_tokens, Some(4096));
assert_eq!(provider_config.context_window_tokens(), 128_000);
assert_eq!(provider_config.context_summary_char_budget(), 32_000);
}
#[test]
@ -1159,7 +1228,10 @@ mod tests {
assert!(config.scheduler.enabled);
assert_eq!(config.scheduler.tick_resolution_ms, 1_000);
assert_eq!(config.scheduler.worker_queue_capacity, 64);
assert_eq!(config.scheduler.misfire_policy, SchedulerMisfirePolicy::Skip);
assert_eq!(
config.scheduler.misfire_policy,
SchedulerMisfirePolicy::Skip
);
assert!(config.scheduler.jobs.is_empty());
let effective_jobs = config.scheduler.effective_jobs(&config.time);
@ -1273,7 +1345,10 @@ mod tests {
assert!(config.scheduler.enabled);
assert_eq!(config.scheduler.tick_resolution_ms, 500);
assert_eq!(config.scheduler.worker_queue_capacity, 8);
assert_eq!(config.scheduler.misfire_policy, SchedulerMisfirePolicy::CatchUp);
assert_eq!(
config.scheduler.misfire_policy,
SchedulerMisfirePolicy::CatchUp
);
assert_eq!(config.scheduler.jobs.len(), 1);
let job = &config.scheduler.jobs[0];
@ -1284,11 +1359,17 @@ mod tests {
assert_eq!(job.startup_delay_secs, 5);
assert_eq!(job.target.channel.as_deref(), Some("feishu"));
assert_eq!(job.target.chat_id.as_deref(), Some("oc_demo"));
assert_eq!(job.payload.get("content").and_then(|value| value.as_str()), Some("heartbeat"));
assert_eq!(job.resolved_schedule().unwrap(), SchedulerSchedule::Interval {
seconds: 60,
startup_delay_secs: 5,
});
assert_eq!(
job.payload.get("content").and_then(|value| value.as_str()),
Some("heartbeat")
);
assert_eq!(
job.resolved_schedule().unwrap(),
SchedulerSchedule::Interval {
seconds: 60,
startup_delay_secs: 5,
}
);
}
#[test]
@ -1362,21 +1443,30 @@ mod tests {
config.scheduler.jobs[0].resolved_schedule().unwrap(),
SchedulerSchedule::Delay { seconds: 30 }
);
assert_eq!(config.scheduler.jobs[0].kind, SchedulerJobKind::InternalEvent);
assert_eq!(
config.scheduler.jobs[0].kind,
SchedulerJobKind::InternalEvent
);
assert_eq!(
config.scheduler.jobs[1].resolved_schedule().unwrap(),
SchedulerSchedule::At {
timestamp: "2026-04-23T09:00:00+00:00".to_string(),
}
);
assert_eq!(config.scheduler.jobs[1].kind, SchedulerJobKind::OutboundMessage);
assert_eq!(
config.scheduler.jobs[1].kind,
SchedulerJobKind::OutboundMessage
);
assert_eq!(
config.scheduler.jobs[2].resolved_schedule().unwrap(),
SchedulerSchedule::Cron {
expression: "0 9 * * *".to_string(),
}
);
assert_eq!(config.scheduler.jobs[2].kind, SchedulerJobKind::InternalEvent);
assert_eq!(
config.scheduler.jobs[2].kind,
SchedulerJobKind::InternalEvent
);
}
#[test]
@ -1433,7 +1523,10 @@ mod tests {
assert_eq!(job.kind, SchedulerJobKind::AgentTask);
assert_eq!(job.target.channel.as_deref(), Some("feishu"));
assert_eq!(job.target.chat_id.as_deref(), Some("oc_demo"));
assert_eq!(job.payload.get("prompt").and_then(|value| value.as_str()), Some("请总结今天待办"));
assert_eq!(
job.payload.get("prompt").and_then(|value| value.as_str()),
Some("请总结今天待办")
);
}
#[test]
@ -1495,29 +1588,40 @@ mod tests {
job.target.session_chat_id.as_deref(),
Some("scheduler/agent.daily_summary.background")
);
assert_eq!(job.payload.get("prompt").and_then(|value| value.as_str()), Some("请后台总结今天待办"));
assert_eq!(
job.payload.get("prompt").and_then(|value| value.as_str()),
Some("请后台总结今天待办")
);
}
#[test]
fn test_scheduler_schedule_validation_rejects_invalid_values() {
assert!(SchedulerSchedule::Delay { seconds: 0 }
.validate("delay.job")
.is_err());
assert!(SchedulerSchedule::Interval {
seconds: 0,
startup_delay_secs: 0,
}
.validate("interval.job")
.is_err());
assert!(SchedulerSchedule::At {
timestamp: "bad timestamp".to_string(),
}
.validate("at.job")
.is_err());
assert!(SchedulerSchedule::Cron {
expression: "bad cron".to_string(),
}
.validate("cron.job")
.is_err());
assert!(
SchedulerSchedule::Delay { seconds: 0 }
.validate("delay.job")
.is_err()
);
assert!(
SchedulerSchedule::Interval {
seconds: 0,
startup_delay_secs: 0,
}
.validate("interval.job")
.is_err()
);
assert!(
SchedulerSchedule::At {
timestamp: "bad timestamp".to_string(),
}
.validate("at.job")
.is_err()
);
assert!(
SchedulerSchedule::Cron {
expression: "bad cron".to_string(),
}
.validate("cron.job")
.is_err()
);
}
}

252
src/gateway/execution.rs Normal file
View File

@ -0,0 +1,252 @@
use std::collections::HashMap;
use std::sync::Arc;
use crate::agent::{AgentError, AgentProcessResult};
use crate::bus::message::ToolMessageState;
use crate::bus::{ChatMessage, OutboundMessage};
use crate::config::LLMProviderConfig;
use tokio::sync::Mutex;
use super::session::{Session, schedule_background_history_compaction};
const SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT: &str = "系统说明当前输入来自一次已经触发的定时任务执行。你现在需要执行任务内容本身而不是创建、修改、恢复、暂停或查询新的定时任务。除非当前任务内容明确要求管理调度器否则不要调用任何定时任务管理工具像“每小时”、“每天”、“cron”、“定时”等词只应视为任务背景不应再解释为新的建任务请求。";
pub(crate) fn compose_scheduled_task_system_prompt(system_prompt: Option<&str>) -> String {
match system_prompt
.map(str::trim)
.filter(|value| !value.is_empty())
{
Some(system_prompt) => format!(
"{}\n\n任务专属要求:{}",
SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT, system_prompt
),
None => SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT.to_string(),
}
}
pub(crate) fn select_provider_config(
default_provider_config: &LLMProviderConfig,
provider_configs: &HashMap<String, LLMProviderConfig>,
agent_name: Option<&str>,
) -> Result<LLMProviderConfig, AgentError> {
match agent_name.map(str::trim).filter(|value| !value.is_empty()) {
None | Some("default") => Ok(default_provider_config.clone()),
Some(agent_name) => provider_configs.get(agent_name).cloned().ok_or_else(|| {
AgentError::Other(format!("Scheduled agent '{}' not found", agent_name))
}),
}
}
pub(crate) struct AgentExecutionService {
show_tool_results: bool,
}
pub(crate) struct FinalizeAgentResultRequest<'a> {
pub(crate) channel_name: &'a str,
pub(crate) chat_id: &'a str,
pub(crate) user_message: &'a ChatMessage,
pub(crate) result: AgentProcessResult,
pub(crate) metadata: &'a HashMap<String, String>,
pub(crate) suppress_live_tool_calls: bool,
pub(crate) execution_kind: &'a str,
}
pub(crate) struct FinalizedAgentResult {
pub(crate) outbound_messages: Vec<OutboundMessage>,
pub(crate) should_schedule_compaction: bool,
}
impl AgentExecutionService {
pub(crate) fn new(show_tool_results: bool) -> Self {
Self { show_tool_results }
}
pub(crate) fn finalize_result(
&self,
session: &mut Session,
request: FinalizeAgentResultRequest<'_>,
) -> Result<FinalizedAgentResult, AgentError> {
if !session.matches_current_user_turn(request.chat_id, request.user_message) {
let (latest_user_id, latest_user_preview, compression_in_flight, history_len) =
session.stale_result_diagnostics(request.chat_id);
tracing::warn!(
channel = %request.channel_name,
chat_id = %request.chat_id,
user_message_id = %request.user_message.id,
latest_user_id,
latest_user_preview,
compression_in_flight,
history_len,
execution_kind = %request.execution_kind,
"Skipping stale agent result because a newer user message is already present"
);
return Ok(FinalizedAgentResult {
outbound_messages: Vec::new(),
should_schedule_compaction: false,
});
}
session
.append_persisted_messages(request.chat_id, request.result.emitted_messages.clone())?;
let outbound_messages = request
.result
.emitted_messages
.iter()
.filter(|message| {
(!message.is_assistant_tool_call_message() || !request.suppress_live_tool_calls)
&& should_display_message_to_user(self.show_tool_results, message)
})
.flat_map(|message| {
OutboundMessage::from_chat_message(
request.channel_name,
request.chat_id,
None,
request.metadata,
message,
)
})
.collect();
Ok(FinalizedAgentResult {
outbound_messages,
should_schedule_compaction: true,
})
}
pub(crate) async fn finalize_result_and_schedule_compaction(
&self,
session: Arc<Mutex<Session>>,
request: FinalizeAgentResultRequest<'_>,
) -> Result<Vec<OutboundMessage>, AgentError> {
let channel_name = request.channel_name.to_string();
let chat_id = request.chat_id.to_string();
let execution_kind = request.execution_kind.to_string();
let finalized_result = {
let mut session_guard = session.lock().await;
self.finalize_result(&mut session_guard, request)?
};
if finalized_result.should_schedule_compaction {
if let Err(error) =
schedule_background_history_compaction(session.clone(), chat_id.clone()).await
{
tracing::warn!(
channel = %channel_name,
chat_id = %chat_id,
execution_kind = %execution_kind,
error = %error,
"Failed to schedule background history compaction"
);
}
}
Ok(finalized_result.outbound_messages)
}
}
pub(crate) fn should_display_message_to_user(
show_tool_results: bool,
message: &ChatMessage,
) -> bool {
if message.role != "tool" {
return true;
}
show_tool_results
|| matches!(
message
.tool_state
.as_ref()
.unwrap_or(&ToolMessageState::Completed),
ToolMessageState::PendingUserAction
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bus::ChatMessage;
fn test_provider_config_named(name: &str, model_id: &str) -> LLMProviderConfig {
LLMProviderConfig {
provider_type: "openai".to_string(),
name: name.to_string(),
base_url: "http://localhost".to_string(),
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
llm_timeout_secs: 120,
model_id: model_id.to_string(),
temperature: Some(0.0),
max_tokens: Some(32),
context_window_tokens: None,
model_extra: HashMap::new(),
max_tool_iterations: 1,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
}
}
#[test]
fn test_select_provider_config_uses_named_agent_override() {
let default_provider = test_provider_config_named("default-provider", "default-model");
let provider_configs = HashMap::from([(
"planner".to_string(),
test_provider_config_named("planner-provider", "planner-model"),
)]);
let selected =
select_provider_config(&default_provider, &provider_configs, Some("planner")).unwrap();
assert_eq!(selected.name, "planner-provider");
assert_eq!(selected.model_id, "planner-model");
}
#[test]
fn test_select_provider_config_falls_back_to_default() {
let default_provider = test_provider_config_named("default-provider", "default-model");
let provider_configs = HashMap::new();
let selected =
select_provider_config(&default_provider, &provider_configs, Some("default")).unwrap();
assert_eq!(selected.name, "default-provider");
assert_eq!(selected.model_id, "default-model");
}
#[test]
fn test_compose_scheduled_task_system_prompt_appends_task_specific_prompt() {
let prompt = compose_scheduled_task_system_prompt(Some(" 只汇报异常 "));
assert!(prompt.contains("当前输入来自一次已经触发的定时任务执行"));
assert!(prompt.contains("任务专属要求:只汇报异常"));
}
#[test]
fn test_compose_scheduled_task_system_prompt_ignores_blank_override() {
let prompt = compose_scheduled_task_system_prompt(Some(" "));
assert!(prompt.contains("当前输入来自一次已经触发的定时任务执行"));
assert!(!prompt.contains("任务专属要求"));
}
#[test]
fn test_should_display_message_to_user_keeps_pending_tool_action_visible() {
let message = ChatMessage::tool_with_state(
"call-1",
"approval",
"需要用户确认",
ToolMessageState::PendingUserAction,
);
assert!(should_display_message_to_user(false, &message));
}
#[test]
fn test_should_display_message_to_user_hides_completed_tool_when_disabled() {
let message = ChatMessage::tool("call-1", "calculator", "2");
assert!(!should_display_message_to_user(false, &message));
assert!(should_display_message_to_user(true, &message));
}
}

View File

@ -0,0 +1,517 @@
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use crate::agent::AgentError;
use crate::config::LLMProviderConfig;
use crate::providers::{ChatCompletionRequest, Message, create_provider};
use crate::storage::{MemoryRecord, SessionStore};
use super::prompt::upsert_managed_agent_memory_summary;
const MEMORY_MAINTENANCE_SYSTEM_PROMPT: &str = include_str!("memory_maintenance_system_prompt.md");
const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000];
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum MemoryMaintenanceCategory {
UserFacts,
Preferences,
BehaviorPatterns,
Other,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct MemoryMaintenanceCandidate {
pub(crate) id: String,
pub(crate) namespace: String,
pub(crate) key: String,
pub(crate) content: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct MemoryMaintenancePlan {
pub(crate) user_facts: Vec<MemoryMaintenanceCandidate>,
pub(crate) preferences: Vec<MemoryMaintenanceCandidate>,
pub(crate) behavior_patterns: Vec<MemoryMaintenanceCandidate>,
pub(crate) others: Vec<MemoryMaintenanceCandidate>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct MemoryMaintenanceMerge {
pub(crate) source_ids: Vec<String>,
pub(crate) namespace: String,
pub(crate) memory_key: String,
pub(crate) content: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct MemoryMaintenanceConflict {
pub(crate) source_ids: Vec<String>,
pub(crate) note: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) struct MemoryMaintenanceModelOutput {
pub(crate) user_facts: Vec<String>,
pub(crate) preferences: Vec<String>,
pub(crate) behavior_patterns: Vec<String>,
pub(crate) merges: Vec<MemoryMaintenanceMerge>,
pub(crate) conflicts: Vec<MemoryMaintenanceConflict>,
pub(crate) low_value_ids: Vec<String>,
pub(crate) managed_markdown: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct MemoryMaintenanceScopeResult {
pub(crate) scope_key: String,
pub(crate) output: MemoryMaintenanceModelOutput,
}
pub(crate) struct MemoryMaintenanceService {
store: Arc<SessionStore>,
provider_config: LLMProviderConfig,
}
impl MemoryMaintenanceService {
pub(crate) fn new(store: Arc<SessionStore>, provider_config: LLMProviderConfig) -> Self {
Self {
store,
provider_config,
}
}
pub(crate) fn build_plan_for_scope(
&self,
scope_key: &str,
) -> Result<Option<MemoryMaintenancePlan>, AgentError> {
let memories = self
.store
.list_memories_for_scope("user", scope_key)
.map_err(|err| AgentError::Other(format!("list memories for scope error: {}", err)))?;
if memories.is_empty() {
return Ok(None);
}
Ok(Some(build_memory_maintenance_plan(&memories)))
}
pub(crate) async fn summarize_for_scope(
&self,
scope_key: &str,
) -> Result<Option<MemoryMaintenanceModelOutput>, AgentError> {
let Some(plan) = self.build_plan_for_scope(scope_key)? else {
return Ok(None);
};
self.summarize_plan(scope_key, &plan).await.map(Some)
}
async fn summarize_plan(
&self,
scope_key: &str,
plan: &MemoryMaintenancePlan,
) -> Result<MemoryMaintenanceModelOutput, AgentError> {
let provider = create_provider(self.provider_config.clone()).map_err(|err| {
AgentError::Other(format!("create maintenance provider error: {}", err))
})?;
let request = ChatCompletionRequest {
messages: vec![
Message::system(MEMORY_MAINTENANCE_SYSTEM_PROMPT),
Message::user(
serde_json::to_string_pretty(&serde_json::json!({
"scope_key": scope_key,
"candidates": plan,
}))
.unwrap_or_else(|_| "{}".to_string()),
),
],
temperature: Some(0.0),
max_tokens: Some(1200),
tools: None,
};
let mut last_error = None;
let mut response = None;
for (attempt, delay_ms) in MEMORY_MAINTENANCE_RETRY_DELAYS_MS
.iter()
.copied()
.map(Some)
.chain(std::iter::once(None))
.enumerate()
{
match provider.chat(request.clone()).await {
Ok(success) => {
response = Some(success);
break;
}
Err(err) => {
let error_text = err.to_string();
let should_retry =
delay_ms.is_some() && is_recoverable_maintenance_llm_error(&error_text);
last_error = Some(error_text.clone());
if should_retry {
tracing::warn!(
scope_key = %scope_key,
attempt = attempt + 1,
retry_in_ms = delay_ms.unwrap_or_default(),
error = %error_text,
"Memory maintenance model request failed, retrying"
);
tokio::time::sleep(Duration::from_millis(delay_ms.unwrap_or_default()))
.await;
continue;
}
return Err(AgentError::Other(format!(
"memory maintenance model error: {}",
error_text
)));
}
}
}
let response = response.ok_or_else(|| {
AgentError::Other(format!(
"memory maintenance model error: {}",
last_error.unwrap_or_else(|| "unknown provider error".to_string())
))
})?;
let raw_content = strip_json_code_fence(&response.content);
let json_candidate = extract_json_object(raw_content).unwrap_or(raw_content);
let output: MemoryMaintenanceModelOutput =
serde_json::from_str(json_candidate).map_err(|err| {
tracing::error!(
scope_key = %scope_key,
error = %err,
raw_len = raw_content.len(),
raw_preview = %preview_text(raw_content, 400),
json_candidate_len = json_candidate.len(),
json_candidate_preview = %preview_text(json_candidate, 400),
"Memory maintenance JSON decode failed"
);
AgentError::Other(format!("memory maintenance JSON decode error: {}", err))
})?;
Ok(output)
}
pub(crate) async fn run_for_scope(
&self,
scope_key: &str,
) -> Result<Option<MemoryMaintenanceModelOutput>, AgentError> {
let Some(plan) = self.build_plan_for_scope(scope_key)? else {
return Ok(None);
};
let output = self.summarize_plan(scope_key, &plan).await?;
apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &output)?;
Ok(Some(output))
}
pub(crate) async fn run_for_all_scopes(
&self,
updated_since: Option<i64>,
) -> Result<Vec<MemoryMaintenanceScopeResult>, AgentError> {
let scope_keys = if let Some(cutoff) = updated_since {
self.store
.list_memory_scope_keys_updated_since("user", cutoff)
.map_err(|err| {
AgentError::Other(format!(
"list memory scope keys updated since error: {}",
err
))
})?
} else {
self.store.list_memory_scope_keys("user").map_err(|err| {
AgentError::Other(format!("list memory scope keys error: {}", err))
})?
};
let mut results = Vec::new();
for scope_key in scope_keys {
let Some(output) = self.run_for_scope(&scope_key).await? else {
continue;
};
results.push(MemoryMaintenanceScopeResult { scope_key, output });
}
let combined_markdown = combine_managed_memory_markdown(
&results
.iter()
.map(|result| result.output.managed_markdown.clone())
.collect::<Vec<_>>(),
);
if !combined_markdown.is_empty() {
upsert_managed_agent_memory_summary(&combined_markdown)?;
}
Ok(results)
}
}
pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> MemoryMaintenancePlan {
let mut plan = MemoryMaintenancePlan::default();
let mut seen = HashSet::new();
for memory in memories {
let normalized_content = memory.content.trim();
if normalized_content.is_empty() {
continue;
}
let dedupe_key = format!(
"{}\u{1f}{}\u{1f}{}",
memory.namespace.trim().to_ascii_lowercase(),
memory.memory_key.trim().to_ascii_lowercase(),
normalized_content
);
if !seen.insert(dedupe_key) {
continue;
}
let candidate = MemoryMaintenanceCandidate {
id: memory.id.clone(),
namespace: memory.namespace.clone(),
key: memory.memory_key.clone(),
content: normalized_content.to_string(),
};
match memory_maintenance_category(&memory.namespace) {
MemoryMaintenanceCategory::UserFacts => plan.user_facts.push(candidate),
MemoryMaintenanceCategory::Preferences => plan.preferences.push(candidate),
MemoryMaintenanceCategory::BehaviorPatterns => plan.behavior_patterns.push(candidate),
MemoryMaintenanceCategory::Other => plan.others.push(candidate),
}
}
plan
}
fn memory_maintenance_category(namespace: &str) -> MemoryMaintenanceCategory {
match namespace.trim().to_ascii_lowercase().as_str() {
"profile" | "facts" | "identity" => MemoryMaintenanceCategory::UserFacts,
"preferences" | "style" | "likes" => MemoryMaintenanceCategory::Preferences,
"patterns" | "behavior" | "habits" | "workflow" => {
MemoryMaintenanceCategory::BehaviorPatterns
}
_ => MemoryMaintenanceCategory::Other,
}
}
pub(crate) fn is_recoverable_maintenance_llm_error(error: &str) -> bool {
let normalized = error.to_ascii_lowercase();
normalized.contains("error sending request for url")
|| normalized.contains("504")
|| normalized.contains("gateway timeout")
|| normalized.contains("stream timeout")
|| normalized.contains("timed out")
|| normalized.contains("timeout")
}
pub(crate) fn strip_json_code_fence(content: &str) -> &str {
let trimmed = content.trim();
if let Some(rest) = trimmed.strip_prefix("```json") {
return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed);
}
if let Some(rest) = trimmed.strip_prefix("```") {
return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed);
}
trimmed
}
pub(crate) fn extract_json_object(content: &str) -> Option<&str> {
let mut start = None;
let mut depth = 0usize;
let mut in_string = false;
let mut escaped = false;
for (index, ch) in content.char_indices() {
if in_string {
if escaped {
escaped = false;
continue;
}
match ch {
'\\' => escaped = true,
'"' => in_string = false,
_ => {}
}
continue;
}
match ch {
'"' => in_string = true,
'{' => {
if start.is_none() {
start = Some(index);
}
depth += 1;
}
'}' => {
if depth == 0 {
continue;
}
depth -= 1;
if depth == 0 {
let start = start?;
let end = index + ch.len_utf8();
return Some(content[start..end].trim());
}
}
_ => {}
}
}
None
}
pub(crate) fn combine_managed_memory_markdown(chunks: &[String]) -> String {
let normalized_chunks = chunks
.iter()
.map(|chunk| chunk.trim())
.filter(|chunk| !chunk.is_empty())
.collect::<Vec<_>>();
let mut combined = Vec::new();
for (index, chunk) in normalized_chunks.iter().enumerate() {
let chunk_lines = chunk
.lines()
.map(str::trim)
.filter(|line| !line.is_empty())
.collect::<HashSet<_>>();
let is_subset_of_other =
normalized_chunks
.iter()
.enumerate()
.any(|(other_index, other)| {
if index == other_index {
return false;
}
let other_lines = other
.lines()
.map(str::trim)
.filter(|line| !line.is_empty())
.collect::<HashSet<_>>();
chunk_lines.len() < other_lines.len() && chunk_lines.is_subset(&other_lines)
});
if !is_subset_of_other && !combined.iter().any(|existing: &String| existing == chunk) {
combined.push((*chunk).to_string());
}
}
combined.join("\n\n")
}
pub(crate) fn apply_memory_maintenance_output(
store: &SessionStore,
scope_key: &str,
plan: &MemoryMaintenancePlan,
output: &MemoryMaintenanceModelOutput,
) -> Result<(), AgentError> {
let all_candidates = plan
.user_facts
.iter()
.chain(plan.preferences.iter())
.chain(plan.behavior_patterns.iter())
.chain(plan.others.iter())
.cloned()
.collect::<Vec<_>>();
let candidates_by_id = all_candidates
.iter()
.map(|candidate| (candidate.id.as_str(), candidate))
.collect::<HashMap<_, _>>();
let mut deleted_ids = HashSet::new();
for merge in &output.merges {
if merge.source_ids.is_empty() {
continue;
}
let source_candidates = merge
.source_ids
.iter()
.filter_map(|id| candidates_by_id.get(id.as_str()).copied())
.collect::<Vec<_>>();
if source_candidates.is_empty() {
continue;
}
let existing_target_id = source_candidates
.iter()
.find(|candidate| {
candidate.namespace == merge.namespace && candidate.key == merge.memory_key
})
.map(|candidate| candidate.id.clone());
store
.put_memory(&crate::storage::MemoryUpsert {
scope_kind: "user".to_string(),
scope_key: scope_key.to_string(),
namespace: merge.namespace.trim().to_string(),
memory_key: merge.memory_key.trim().to_string(),
content: merge.content.trim().to_string(),
source_type: "memory_maintenance".to_string(),
source_session_id: None,
source_message_id: None,
source_message_seq: None,
source_channel_name: None,
source_chat_id: None,
})
.map_err(|err| AgentError::Other(format!("upsert merged memory error: {}", err)))?;
for candidate in source_candidates {
if existing_target_id
.as_ref()
.is_some_and(|target_id| target_id == &candidate.id)
{
continue;
}
if deleted_ids.insert(candidate.id.clone()) {
store
.delete_memory("user", scope_key, &candidate.namespace, &candidate.key)
.map_err(|err| {
AgentError::Other(format!("delete merged source memory error: {}", err))
})?;
}
}
}
for memory_id in &output.low_value_ids {
if let Some(candidate) = candidates_by_id.get(memory_id.as_str()) {
if deleted_ids.insert(candidate.id.clone()) {
store
.delete_memory("user", scope_key, &candidate.namespace, &candidate.key)
.map_err(|err| {
AgentError::Other(format!("delete low value memory error: {}", err))
})?;
}
}
}
Ok(())
}
fn preview_text(content: &str, max_chars: usize) -> String {
let mut preview = content.chars().take(max_chars).collect::<String>();
if content.chars().count() > max_chars {
preview.push_str("...");
}
preview.replace('\n', "\\n")
}

View File

@ -1,10 +1,14 @@
pub mod execution;
pub mod http;
pub mod memory_maintenance;
pub mod processor;
pub mod prompt;
pub mod session;
pub mod ws;
use axum::{Router, routing};
use std::collections::HashMap;
use std::sync::Arc;
use axum::{routing, Router};
use tokio::net::TcpListener;
use crate::bus::{MessageBus, OutboundDispatcher};
@ -14,7 +18,8 @@ use crate::config::LLMProviderConfig;
use crate::logging;
use crate::scheduler::Scheduler;
use crate::skills::SkillRuntime;
use session::{BusToolCallEmitter, SessionManager};
use processor::InboundProcessor;
use session::SessionManager;
pub struct GatewayState {
pub config: Config,
@ -61,74 +66,17 @@ impl GatewayState {
/// Start the message processing loops
pub async fn start_message_processing(&self) {
let bus_for_inbound = self.bus.clone();
let bus_for_outbound = self.bus.clone();
let session_manager = self.session_manager.clone();
// Spawn inbound message processor
// This consumes from bus.inbound, processes via SessionManager, publishes to bus.outbound
tokio::spawn(async move {
tracing::info!("Inbound processor started");
loop {
let inbound = bus_for_inbound.consume_inbound().await;
#[cfg(debug_assertions)]
{
tracing::debug!(
channel = %inbound.channel,
chat_id = %inbound.chat_id,
sender = %inbound.sender_id,
content = %inbound.content,
media_count = %inbound.media.len(),
"Processing inbound message"
);
if !inbound.media.is_empty() {
for (i, m) in inbound.media.iter().enumerate() {
tracing::debug!(media_index = i, media_type = %m.media_type, path = %m.path, "Media item");
}
}
}
// Process via session manager
let live_emitter = Arc::new(BusToolCallEmitter::new(
bus_for_inbound.clone(),
inbound.channel.clone(),
inbound.chat_id.clone(),
inbound.forwarded_metadata.clone(),
session_manager.show_tool_results(),
));
match session_manager.handle_message(
&inbound.channel,
&inbound.sender_id,
&inbound.chat_id,
&inbound.content,
inbound.media,
Some(live_emitter),
).await {
Ok(outbound_messages) => {
// Forward channel-specific metadata from inbound to outbound.
// This allows channels to propagate context (e.g. feishu message_id for reaction cleanup)
// without gateway needing channel-specific code.
for mut outbound in outbound_messages {
outbound.metadata.extend(inbound.forwarded_metadata.clone());
if let Err(e) = bus_for_inbound.publish_outbound(outbound).await {
tracing::error!(error = %e, "Failed to publish outbound");
}
}
}
Err(e) => {
tracing::error!(error = %e, "Failed to handle message");
}
}
}
});
let inbound_processor =
InboundProcessor::new(self.bus.clone(), self.session_manager.clone());
tokio::spawn(inbound_processor.run());
// Spawn outbound dispatcher
let dispatcher = OutboundDispatcher::new(bus_for_outbound);
let channel_manager = self.channel_manager.clone();
// Register channels with dispatcher
if let Some(channel) = channel_manager.get_channel("feishu").await {
dispatcher.register_channel("feishu", channel).await;
for (name, channel) in channel_manager.channels().await {
dispatcher.register_channel(&name, channel).await;
}
tokio::spawn(async move {
@ -138,7 +86,10 @@ impl GatewayState {
}
}
pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn std::error::Error>> {
pub async fn run(
host: Option<String>,
port: Option<u16>,
) -> Result<(), Box<dyn std::error::Error>> {
let config = Config::load_default()?;
let timezone = config.time.parse_timezone()?;
@ -152,7 +103,10 @@ pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn
let provider_config = state.config.get_provider_config("default")?;
// Initialize and start channels
state.channel_manager.init(&state.config, provider_config.clone()).await?;
state
.channel_manager
.init(&state.config, provider_config.clone())
.await?;
state.channel_manager.start_all().await?;
// Start message processing (inbound processor + outbound dispatcher)

77
src/gateway/processor.rs Normal file
View File

@ -0,0 +1,77 @@
use std::sync::Arc;
use crate::bus::MessageBus;
use super::session::{BusToolCallEmitter, SessionManager};
#[derive(Clone)]
pub struct InboundProcessor {
bus: Arc<MessageBus>,
session_manager: SessionManager,
}
impl InboundProcessor {
pub fn new(bus: Arc<MessageBus>, session_manager: SessionManager) -> Self {
Self {
bus,
session_manager,
}
}
pub async fn run(self) {
tracing::info!("Inbound processor started");
loop {
let inbound = self.bus.consume_inbound().await;
#[cfg(debug_assertions)]
{
tracing::debug!(
channel = %inbound.channel,
chat_id = %inbound.chat_id,
sender = %inbound.sender_id,
content = %inbound.content,
media_count = %inbound.media.len(),
"Processing inbound message"
);
if !inbound.media.is_empty() {
for (i, media) in inbound.media.iter().enumerate() {
tracing::debug!(media_index = i, media_type = %media.media_type, path = %media.path, "Media item");
}
}
}
let live_emitter = Arc::new(BusToolCallEmitter::new(
self.bus.clone(),
inbound.channel.clone(),
inbound.chat_id.clone(),
inbound.forwarded_metadata.clone(),
self.session_manager.show_tool_results(),
));
match self
.session_manager
.handle_message(
&inbound.channel,
&inbound.sender_id,
&inbound.chat_id,
&inbound.content,
inbound.media,
Some(live_emitter),
)
.await
{
Ok(outbound_messages) => {
for mut outbound in outbound_messages {
outbound.metadata.extend(inbound.forwarded_metadata.clone());
if let Err(error) = self.bus.publish_outbound(outbound).await {
tracing::error!(error = %error, "Failed to publish outbound");
}
}
}
Err(error) => {
tracing::error!(error = %error, "Failed to handle message");
}
}
}
}
}

149
src/gateway/prompt.rs Normal file
View File

@ -0,0 +1,149 @@
use std::fs;
use std::path::{Path, PathBuf};
use crate::agent::AgentError;
pub(crate) const DEFAULT_AGENT_PROMPT: &str = include_str!("default_agent_prompt.md");
pub(crate) const MANAGED_AGENT_MEMORY_BLOCK_START: &str = "<!-- PICOBOT_MANAGED_MEMORY:START -->";
pub(crate) const MANAGED_AGENT_MEMORY_BLOCK_END: &str = "<!-- PICOBOT_MANAGED_MEMORY:END -->";
pub(crate) const MANAGED_AGENT_MEMORY_TITLE: &str = "## 用户记忆摘要";
pub(crate) fn load_agent_prompt() -> Result<Option<String>, AgentError> {
let path = agent_prompt_path()?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.map_err(|err| AgentError::Other(format!("create agent prompt dir error: {}", err)))?;
}
if !path.exists() {
write_agent_prompt(&path, DEFAULT_AGENT_PROMPT)?;
}
let content = fs::read_to_string(&path)
.map_err(|err| AgentError::Other(format!("read agent prompt file error: {}", err)))?;
let trimmed = content.trim();
if trimmed.is_empty() {
return Ok(None);
}
Ok(Some(trimmed.to_string()))
}
pub(crate) fn upsert_managed_agent_memory_summary(markdown_body: &str) -> Result<(), AgentError> {
let path = agent_prompt_path()?;
let existing = if path.exists() {
fs::read_to_string(&path)
.map_err(|err| AgentError::Other(format!("read agent prompt file error: {}", err)))?
} else {
DEFAULT_AGENT_PROMPT.to_string()
};
let updated = upsert_managed_agent_memory_block(&existing, markdown_body);
write_agent_prompt(&path, &updated)
}
pub(crate) fn upsert_managed_agent_memory_block(existing: &str, markdown_body: &str) -> String {
let managed_block = render_managed_agent_memory_block(markdown_body);
if let (Some(start), Some(end)) = (
existing.find(MANAGED_AGENT_MEMORY_BLOCK_START),
existing.find(MANAGED_AGENT_MEMORY_BLOCK_END),
) {
let end = end + MANAGED_AGENT_MEMORY_BLOCK_END.len();
let mut updated = String::new();
updated.push_str(existing[..start].trim_end());
updated.push_str("\n\n");
updated.push_str(&managed_block);
updated.push_str("\n\n");
updated.push_str(existing[end..].trim_start());
return updated.trim().to_string() + "\n";
}
if let Some(reply_rules_index) = existing.find("## 回复规则") {
let mut updated = String::new();
updated.push_str(existing[..reply_rules_index].trim_end());
updated.push_str("\n\n");
updated.push_str(&managed_block);
updated.push_str("\n\n");
updated.push_str(existing[reply_rules_index..].trim_start());
return updated.trim().to_string() + "\n";
}
let mut updated = existing.trim_end().to_string();
if !updated.is_empty() {
updated.push_str("\n\n");
}
updated.push_str(&managed_block);
updated.push('\n');
updated
}
fn render_managed_agent_memory_block(markdown_body: &str) -> String {
format!(
"{MANAGED_AGENT_MEMORY_BLOCK_START}\n{MANAGED_AGENT_MEMORY_TITLE}\n\n{}\n{MANAGED_AGENT_MEMORY_BLOCK_END}",
markdown_body.trim()
)
}
fn write_agent_prompt(path: &Path, content: &str) -> Result<(), AgentError> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.map_err(|err| AgentError::Other(format!("create agent prompt dir error: {}", err)))?;
}
let temp_path = path.with_extension("md.tmp");
fs::write(&temp_path, content)
.map_err(|err| AgentError::Other(format!("write agent prompt temp file error: {}", err)))?;
fs::rename(&temp_path, path)
.map_err(|err| AgentError::Other(format!("replace agent prompt file error: {}", err)))?;
Ok(())
}
fn agent_prompt_path() -> Result<PathBuf, AgentError> {
let home = dirs::home_dir()
.ok_or_else(|| AgentError::Other("home directory not found".to_string()))?;
Ok(home.join(".picobot").join("agent").join("AGENT.md"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_upsert_managed_agent_memory_block_inserts_before_reply_rules() {
let original =
"# PicoBot 代理配置\n\n## 身份\n- 你是 PicoBot。\n\n## 回复规则\n- 使用中文回复。\n";
let updated = upsert_managed_agent_memory_block(
original,
"### 用户事实\n- 用户在做AI产品\n\n### 用户偏好\n- 偏好简洁表达",
);
let managed_pos = updated.find(MANAGED_AGENT_MEMORY_BLOCK_START).unwrap();
let reply_rules_pos = updated.find("## 回复规则").unwrap();
assert!(managed_pos < reply_rules_pos);
assert!(updated.contains(MANAGED_AGENT_MEMORY_TITLE));
assert!(updated.contains("用户在做AI产品"));
assert!(updated.contains("偏好简洁表达"));
}
#[test]
fn test_upsert_managed_agent_memory_block_replaces_existing_block() {
let original = format!(
"# PicoBot\n\n{MANAGED_AGENT_MEMORY_BLOCK_START}\n{MANAGED_AGENT_MEMORY_TITLE}\n\nold\n{MANAGED_AGENT_MEMORY_BLOCK_END}\n\n## 回复规则\n- 简洁。\n"
);
let updated = upsert_managed_agent_memory_block(&original, "new");
assert!(updated.contains("new"));
assert!(!updated.contains("old"));
assert_eq!(updated.matches(MANAGED_AGENT_MEMORY_BLOCK_START).count(), 1);
assert_eq!(updated.matches(MANAGED_AGENT_MEMORY_BLOCK_END).count(), 1);
}
#[test]
fn test_upsert_managed_agent_memory_block_trims_summary_body() {
let updated = upsert_managed_agent_memory_block("# PicoBot\n", "\n\nsummary\n\n");
assert!(updated.contains("\n\nsummary\n"));
assert!(!updated.contains("\n\nsummary\n\n\n"));
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,15 +1,18 @@
use std::sync::Arc;
use super::{
GatewayState,
session::{Session, handle_in_chat_command, schedule_background_history_compaction},
};
use crate::agent::EmittedMessageHandler;
use crate::bus::ChatMessage;
use crate::bus::message::{ToolMessageState, format_tool_call_content};
use crate::protocol::{SessionSummary, WsInbound, WsOutbound, parse_inbound, serialize_outbound};
use async_trait::async_trait;
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage};
use axum::extract::State;
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
use axum::response::Response;
use futures_util::{SinkExt, StreamExt};
use tokio::sync::{mpsc, Mutex};
use crate::agent::EmittedMessageHandler;
use crate::bus::message::{format_tool_call_content, ToolMessageState};
use crate::bus::ChatMessage;
use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound};
use super::{GatewayState, session::{Session, handle_in_chat_command, schedule_background_history_compaction}};
use std::sync::Arc;
use tokio::sync::{Mutex, mpsc};
struct WsToolCallEmitter {
sender: mpsc::Sender<WsOutbound>,
@ -120,7 +123,9 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
&runtime_session_id,
&mut current_session_id,
inbound,
).await {
)
.await
{
tracing::warn!(error = %e, session_id = %current_session_id, "Failed to handle inbound message");
let _ = session
.lock()
@ -182,17 +187,14 @@ fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
});
}
outbound.extend(tool_calls
.iter()
.map(|tool_call| WsOutbound::ToolCall {
id: message.id.clone(),
tool_call_id: tool_call.id.clone(),
tool_name: tool_call.name.clone(),
arguments: tool_call.arguments.clone(),
content: format_tool_call_content(&tool_call.name, &tool_call.arguments),
role: message.role.clone(),
})
);
outbound.extend(tool_calls.iter().map(|tool_call| WsOutbound::ToolCall {
id: message.id.clone(),
tool_call_id: tool_call.id.clone(),
tool_name: tool_call.name.clone(),
arguments: tool_call.arguments.clone(),
content: format_tool_call_content(&tool_call.name, &tool_call.arguments),
role: message.role.clone(),
}));
outbound
} else {
vec![WsOutbound::AssistantResponse {
@ -202,7 +204,11 @@ fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
}]
}
}
"tool" => match message.tool_state.as_ref().unwrap_or(&ToolMessageState::Completed) {
"tool" => match message
.tool_state
.as_ref()
.unwrap_or(&ToolMessageState::Completed)
{
ToolMessageState::Completed => vec![WsOutbound::ToolResult {
id: message.id.clone(),
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
@ -230,7 +236,10 @@ fn should_display_message_to_user(show_tool_results: bool, message: &ChatMessage
show_tool_results
|| matches!(
message.tool_state.as_ref().unwrap_or(&ToolMessageState::Completed),
message
.tool_state
.as_ref()
.unwrap_or(&ToolMessageState::Completed),
ToolMessageState::PendingUserAction
)
}
@ -243,7 +252,12 @@ async fn handle_inbound(
inbound: WsInbound,
) -> Result<(), crate::agent::AgentError> {
match inbound {
WsInbound::UserInput { content, chat_id, sender_id, .. } => {
WsInbound::UserInput {
content,
chat_id,
sender_id,
..
} => {
let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone());
let sender_id = resolve_ws_sender_id(sender_id.as_deref(), runtime_session_id);
let (history, agent, user_tx) = {
@ -252,7 +266,9 @@ async fn handle_inbound(
session_guard.ensure_persistent_session(&chat_id)?;
session_guard.ensure_chat_loaded(&chat_id)?;
if let Some(command_response) = handle_in_chat_command(&mut session_guard, &chat_id, &content)? {
if let Some(command_response) =
handle_in_chat_command(&mut session_guard, &chat_id, &content)?
{
let _ = session_guard
.send(WsOutbound::AssistantResponse {
id: uuid::Uuid::new_v4().to_string(),
@ -286,13 +302,17 @@ async fn handle_inbound(
match agent.process(history).await {
Ok(result) => {
let mut session_guard = session.lock().await;
session_guard.append_persisted_messages(&chat_id, result.emitted_messages.clone())?;
session_guard
.append_persisted_messages(&chat_id, result.emitted_messages.clone())?;
for outbound in result
.emitted_messages
.iter()
.filter(|message| {
!message.is_assistant_tool_call_message()
&& should_display_message_to_user(state.config.gateway.show_tool_results, message)
&& should_display_message_to_user(
state.config.gateway.show_tool_results,
message,
)
})
.flat_map(ws_outbound_from_chat_message)
{
@ -301,7 +321,10 @@ async fn handle_inbound(
drop(session_guard);
if let Err(error) = schedule_background_history_compaction(session.clone(), chat_id.clone()).await {
if let Err(error) =
schedule_background_history_compaction(session.clone(), chat_id.clone())
.await
{
tracing::warn!(chat_id = %chat_id, error = %error, "Failed to schedule background history compaction for CLI session");
}
}
@ -318,16 +341,19 @@ async fn handle_inbound(
Ok(())
}
WsInbound::ClearHistory { session_id, chat_id } => {
let target = session_id.or(chat_id).unwrap_or_else(|| current_session_id.clone());
WsInbound::ClearHistory {
session_id,
chat_id,
} => {
let target = session_id
.or(chat_id)
.unwrap_or_else(|| current_session_id.clone());
state.session_manager.clear_session_messages(&target)?;
let mut session_guard = session.lock().await;
session_guard.remove_history(&target);
let _ = session_guard
.send(WsOutbound::HistoryCleared {
session_id: target,
})
.send(WsOutbound::HistoryCleared { session_id: target })
.await;
Ok(())
}
@ -452,17 +478,15 @@ fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> St
#[cfg(test)]
mod tests {
use crate::agent::EmittedMessageHandler;
use super::{
WsToolCallEmitter,
resolve_ws_sender_id,
should_display_message_to_user,
WsToolCallEmitter, resolve_ws_sender_id, should_display_message_to_user,
ws_outbound_from_chat_message,
};
use crate::agent::EmittedMessageHandler;
use crate::bus::ChatMessage;
use crate::bus::message::ToolMessageState;
use crate::providers::ToolCall;
use crate::protocol::WsOutbound;
use crate::providers::ToolCall;
use serde_json::json;
use tokio::sync::mpsc;
@ -481,11 +505,17 @@ mod tests {
assert_eq!(outbound.len(), 1);
match &outbound[0] {
WsOutbound::ToolCall { tool_call_id, tool_name, arguments, content, .. } => {
WsOutbound::ToolCall {
tool_call_id,
tool_name,
arguments,
content,
..
} => {
assert_eq!(tool_call_id, "call-1");
assert_eq!(tool_name, "calculator");
assert_eq!(arguments["expression"], "1 + 1");
assert_eq!(content, "### calculator\n- expression: 1 + 1");
assert_eq!(content, "calculator\nargs: {\"expression\":\"1 + 1\"}");
}
other => panic!("unexpected outbound variant: {:?}", other),
}
@ -551,8 +581,14 @@ mod tests {
#[test]
fn test_resolve_ws_sender_id_prefers_inbound_sender() {
assert_eq!(resolve_ws_sender_id(Some("user-42"), "runtime-1"), "user-42");
assert_eq!(resolve_ws_sender_id(Some(" user-42 "), "runtime-1"), "user-42");
assert_eq!(
resolve_ws_sender_id(Some("user-42"), "runtime-1"),
"user-42"
);
assert_eq!(
resolve_ws_sender_id(Some(" user-42 "), "runtime-1"),
"user-42"
);
}
#[test]
@ -573,8 +609,10 @@ mod tests {
.handle(ChatMessage::tool("call-1", "calculator", "2"))
.await;
assert!(tokio::time::timeout(std::time::Duration::from_millis(50), receiver.recv())
.await
.is_err());
assert!(
tokio::time::timeout(std::time::Duration::from_millis(50), receiver.recv())
.await
.is_err()
);
}
}

View File

@ -1,16 +1,16 @@
pub mod config;
pub mod text;
pub mod providers;
pub mod bus;
pub mod cli;
pub mod agent;
pub mod gateway;
pub mod client;
pub mod protocol;
pub mod bus;
pub mod channels;
pub mod cli;
pub mod client;
pub mod config;
pub mod gateway;
pub mod logging;
pub mod observability;
pub mod protocol;
pub mod providers;
pub mod scheduler;
pub mod storage;
pub mod tools;
pub mod skills;
pub mod storage;
pub mod text;
pub mod tools;

View File

@ -1,13 +1,9 @@
use std::path::PathBuf;
use chrono::Utc;
use chrono_tz::Tz;
use std::path::PathBuf;
use tracing_appender::rolling::{RollingFileAppender, Rotation};
use tracing_subscriber::{
fmt,
fmt::time::FormatTime,
layer::SubscriberExt,
util::SubscriberInitExt,
EnvFilter,
EnvFilter, fmt, fmt::time::FormatTime, layer::SubscriberExt, util::SubscriberInitExt,
};
#[derive(Clone, Copy, Debug)]
@ -16,8 +12,17 @@ struct ConfiguredTimestamp {
}
impl FormatTime for ConfiguredTimestamp {
fn format_time(&self, writer: &mut tracing_subscriber::fmt::format::Writer<'_>) -> std::fmt::Result {
write!(writer, "{}", Utc::now().with_timezone(&self.timezone).to_rfc3339_opts(chrono::SecondsFormat::Millis, true))
fn format_time(
&self,
writer: &mut tracing_subscriber::fmt::format::Writer<'_>,
) -> std::fmt::Result {
write!(
writer,
"{}",
Utc::now()
.with_timezone(&self.timezone)
.to_rfc3339_opts(chrono::SecondsFormat::Millis, true)
)
}
}
@ -41,20 +46,19 @@ pub fn init_logging(timezone: Tz) {
// Create log directory if it doesn't exist
if !log_dir.exists() {
if let Err(e) = std::fs::create_dir_all(&log_dir) {
eprintln!("Warning: Failed to create log directory {}: {}", log_dir.display(), e);
eprintln!(
"Warning: Failed to create log directory {}: {}",
log_dir.display(),
e
);
}
}
// Create file appender with daily rotation
let file_appender = RollingFileAppender::new(
Rotation::DAILY,
&log_dir,
"picobot.log",
);
let file_appender = RollingFileAppender::new(Rotation::DAILY, &log_dir, "picobot.log");
// Build subscriber with both console and file output
let env_filter = EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new("info"));
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
let file_layer = fmt::layer()
.with_writer(file_appender)
@ -80,8 +84,7 @@ pub fn init_logging(timezone: Tz) {
/// Initialize logging without file output (console only)
pub fn init_logging_console_only(timezone: Tz) {
let env_filter = EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new("info"));
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
let console_layer = fmt::layer()
.with_timer(ConfiguredTimestamp { timezone })

View File

@ -1,4 +1,4 @@
use clap::{Parser, CommandFactory};
use clap::{CommandFactory, Parser};
#[derive(Parser)]
#[command(name = "picobot")]

View File

@ -26,10 +26,7 @@ pub enum ObserverEvent {
success: bool,
},
/// Emitted when the agent starts processing.
AgentStart {
provider: String,
model: String,
},
AgentStart { provider: String, model: String },
/// Emitted when the agent finishes processing.
AgentEnd {
provider: String,
@ -116,7 +113,11 @@ impl ToolExecutionOutcome {
}
/// Create a failed outcome with duration.
pub fn failure_with_duration(output: String, error_reason: Option<String>, duration: Duration) -> Self {
pub fn failure_with_duration(
output: String,
error_reason: Option<String>,
duration: Duration,
) -> Self {
Self {
output,
success: false,

View File

@ -43,9 +43,7 @@ pub enum WsInbound {
include_archived: bool,
},
#[serde(rename = "load_session")]
LoadSession {
session_id: String,
},
LoadSession { session_id: String },
#[serde(rename = "rename_session")]
RenameSession {
#[serde(default, skip_serializing_if = "Option::is_none")]
@ -70,7 +68,11 @@ pub enum WsInbound {
#[serde(tag = "type")]
pub enum WsOutbound {
#[serde(rename = "assistant_response")]
AssistantResponse { id: String, content: String, role: String },
AssistantResponse {
id: String,
content: String,
role: String,
},
#[serde(rename = "tool_call")]
ToolCall {
id: String,

View File

@ -4,9 +4,9 @@ use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
use crate::bus::message::ContentBlock;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
use super::traits::Usage;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
use crate::bus::message::ContentBlock;
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
let mut details = vec![error.to_string()];
@ -20,7 +20,10 @@ fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
details.join("\ncaused by: ")
}
fn serialize_content_blocks<S>(blocks: &[serde_json::Value], serializer: S) -> Result<S::Ok, S::Error>
fn serialize_content_blocks<S>(
blocks: &[serde_json::Value],
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
@ -28,14 +31,15 @@ where
}
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> {
blocks.iter().map(|b| match b {
ContentBlock::Text { text } => {
serde_json::json!({ "type": "text", "text": text })
}
ContentBlock::ImageUrl { image_url } => {
convert_image_url_to_anthropic(&image_url.url)
}
}).collect()
blocks
.iter()
.map(|b| match b {
ContentBlock::Text { text } => {
serde_json::json!({ "type": "text", "text": text })
}
ContentBlock::ImageUrl { image_url } => convert_image_url_to_anthropic(&image_url.url),
})
.collect()
}
fn convert_image_url_to_anthropic(url: &str) -> serde_json::Value {
@ -147,9 +151,13 @@ struct AnthropicResponse {
#[derive(Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum AnthropicContent {
Text { text: String },
Text {
text: String,
},
#[allow(dead_code)]
Thinking { thinking: String },
Thinking {
thinking: String,
},
#[serde(rename = "tool_use")]
ToolUse {
id: String,

View File

@ -1,12 +1,15 @@
pub mod traits;
pub mod openai;
pub mod anthropic;
pub mod openai;
pub mod traits;
pub use self::openai::OpenAIProvider;
pub use self::anthropic::AnthropicProvider;
pub use self::openai::OpenAIProvider;
use crate::config::LLMProviderConfig;
pub use traits::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall, ToolFunction, Usage};
pub use traits::{
ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall,
ToolFunction, Usage,
};
pub fn create_provider(config: LLMProviderConfig) -> Result<Box<dyn LLMProvider>, ProviderError> {
match config.provider_type.as_str() {

View File

@ -1,18 +1,15 @@
use async_trait::async_trait;
use reqwest::Client;
use serde::Deserialize;
use serde_json::{json, Value};
use serde_json::{Value, json};
use std::collections::HashMap;
use std::time::Duration;
use crate::bus::message::ContentBlock;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
use super::traits::Usage;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
use crate::bus::message::ContentBlock;
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &[
"tool_call_arguments_json",
"mock_response_content",
];
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content"];
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
let mut details = vec![error.to_string()];
@ -32,12 +29,17 @@ fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
return Value::String(text.clone());
}
}
Value::Array(blocks.iter().map(|b| match b {
ContentBlock::Text { text } => json!({ "type": "text", "text": text }),
ContentBlock::ImageUrl { image_url } => {
json!({ "type": "image_url", "image_url": { "url": image_url.url } })
}
}).collect())
Value::Array(
blocks
.iter()
.map(|b| match b {
ContentBlock::Text { text } => json!({ "type": "text", "text": text }),
ContentBlock::ImageUrl { image_url } => {
json!({ "type": "image_url", "image_url": { "url": image_url.url } })
}
})
.collect(),
)
}
pub struct OpenAIProvider {
@ -122,7 +124,9 @@ impl OpenAIProvider {
fn request_model_extra(&self) -> impl Iterator<Item = (&String, &Value)> {
self.model_extra.iter().filter(|(key, _)| {
!INTERNAL_MODEL_EXTRA_KEYS.iter().any(|internal| internal == &key.as_str())
!INTERNAL_MODEL_EXTRA_KEYS
.iter()
.any(|internal| internal == &key.as_str())
})
}
@ -265,7 +269,11 @@ impl LLMProvider for OpenAIProvider {
if let Some(content) = msg.get("content").and_then(|c| c.as_array()) {
for (j, item) in content.iter().enumerate() {
if item.get("type").and_then(|t| t.as_str()) == Some("image_url") {
if let Some(url_str) = item.get("image_url").and_then(|u| u.get("url")).and_then(|v| v.as_str()) {
if let Some(url_str) = item
.get("image_url")
.and_then(|u| u.get("url"))
.and_then(|v| v.as_str())
{
let prefix: String = url_str.chars().take(20).collect();
tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)");
}
@ -419,7 +427,10 @@ mod tests {
assert_eq!(tool_calls[0]["id"], "call_1");
assert_eq!(tool_calls[0]["type"], "function");
assert_eq!(tool_calls[0]["function"]["name"], "calculator");
assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}");
assert_eq!(
tool_calls[0]["function"]["arguments"],
"{\"expression\":\"1+1\"}"
);
}
#[test]
@ -433,10 +444,7 @@ mod tests {
"gpt-test".to_string(),
None,
None,
HashMap::from([(
"tool_call_arguments_json".to_string(),
Value::Bool(true),
)]),
HashMap::from([("tool_call_arguments_json".to_string(), Value::Bool(true))]),
);
let request = ChatCompletionRequest {
@ -461,7 +469,10 @@ mod tests {
let messages = body["messages"].as_array().unwrap();
let tool_calls = messages[0]["tool_calls"].as_array().unwrap();
assert_eq!(tool_calls[0]["function"]["arguments"], json!({"expression": "1+1"}));
assert_eq!(
tool_calls[0]["function"]["arguments"],
json!({"expression": "1+1"})
);
assert!(body.get("tool_call_arguments_json").is_none());
}
@ -501,7 +512,10 @@ mod tests {
let messages = body["messages"].as_array().unwrap();
let tool_calls = messages[0]["tool_calls"].as_array().unwrap();
assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}");
assert_eq!(
tool_calls[0]["function"]["arguments"],
"{\"expression\":\"1+1\"}"
);
}
#[test]
@ -517,7 +531,10 @@ mod tests {
None,
HashMap::from([
("tool_call_arguments_json".to_string(), Value::Bool(true)),
("mock_response_content".to_string(), Value::String("stub".to_string())),
(
"mock_response_content".to_string(),
Value::String("stub".to_string()),
),
("parallel_tool_calls".to_string(), Value::Bool(true)),
]),
);
@ -590,7 +607,10 @@ mod tests {
}))
.unwrap();
assert_eq!(response.choices[0].message.reasoning_content.as_deref(), Some("hidden reasoning"));
assert_eq!(
response.choices[0].message.reasoning_content.as_deref(),
Some("hidden reasoning")
);
}
#[test]

View File

@ -1,6 +1,6 @@
use crate::bus::message::ContentBlock;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::bus::message::ContentBlock;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
@ -61,7 +61,11 @@ impl Message {
}
}
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self {
pub fn tool(
tool_call_id: impl Into<String>,
tool_name: impl Into<String>,
content: impl Into<String>,
) -> Self {
Self {
role: "tool".to_string(),
content: vec![ContentBlock::text(content)],

View File

@ -8,11 +8,11 @@ use tokio::sync::watch;
use crate::bus::{MessageBus, OutboundMessage};
use crate::config::{
SchedulerConfig, SchedulerJobConfig, SchedulerJobKind, SchedulerJobTarget, SchedulerMisfirePolicy,
SchedulerSchedule,
SchedulerConfig, SchedulerJobConfig, SchedulerJobKind, SchedulerJobTarget,
SchedulerMisfirePolicy, SchedulerSchedule,
};
use crate::gateway::session::SessionManager;
use crate::gateway::session::ScheduledAgentTaskOptions;
use crate::gateway::session::SessionManager;
use crate::storage::{
SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus, SchedulerJobUpsert, SessionStore,
};
@ -76,8 +76,11 @@ impl Scheduler {
fn sync_config_jobs(&self) -> anyhow::Result<()> {
let now = Utc::now();
for job in self.config.effective_jobs(&crate::config::TimeConfig { timezone: self.timezone.name().to_string() }) {
let runtime = RuntimeJob::from_config(&job, now, self.config.misfire_policy, self.timezone)?;
for job in self.config.effective_jobs(&crate::config::TimeConfig {
timezone: self.timezone.name().to_string(),
}) {
let runtime =
RuntimeJob::from_config(&job, now, self.config.misfire_policy, self.timezone)?;
self.store.upsert_scheduler_job(&runtime.to_upsert())?;
}
Ok(())
@ -88,7 +91,9 @@ impl Scheduler {
let jobs = self.store.list_scheduler_jobs(true)?;
for record in jobs {
let Some(mut job) = RuntimeJob::from_record(&record, self.config.misfire_policy, self.timezone)? else {
let Some(mut job) =
RuntimeJob::from_record(&record, self.config.misfire_policy, self.timezone)?
else {
continue;
};
@ -178,8 +183,12 @@ impl Scheduler {
}
SchedulerJobKind::SilentAgentTask => {
let execution_chat_id = resolve_execution_chat_id(job)?;
if let Err(error) = execute_agent_task(&self.session_manager, job, &execution_chat_id).await {
if let Err(notify_error) = self.notify_silent_agent_task_failure(job, &error).await {
if let Err(error) =
execute_agent_task(&self.session_manager, job, &execution_chat_id).await
{
if let Err(notify_error) =
self.notify_silent_agent_task_failure(job, &error).await
{
tracing::error!(
job_id = %job.id,
error = %notify_error,
@ -208,10 +217,13 @@ impl Scheduler {
let mut metadata = HashMap::new();
metadata.insert("scheduler_job_id".to_string(), job.id.clone());
metadata.insert("scheduler_job_kind".to_string(), "silent_agent_task".to_string());
metadata.insert(
"scheduler_job_kind".to_string(),
"silent_agent_task".to_string(),
);
self.bus
.publish_outbound(OutboundMessage::assistant(
.publish_outbound(OutboundMessage::error_notification(
channel,
chat_id,
format!(
@ -304,14 +316,24 @@ impl RuntimeJob {
}
};
let schedule = deserialize_schedule(&record.schedule, record.interval_secs, record.startup_delay_secs)?;
let schedule = deserialize_schedule(
&record.schedule,
record.interval_secs,
record.startup_delay_secs,
)?;
let now = Utc::now();
let next_fire_at = match (record.enabled, record.state.clone(), record.next_fire_at) {
(false, _, _) => None,
(_, SchedulerJobState::Paused, _) => None,
(_, SchedulerJobState::Completed, _) => None,
(_, _, some_next) if some_next.is_some() => some_next,
_ => compute_initial_next_fire_at(&schedule, now, record.last_fired_at, misfire_policy, timezone)?,
_ => compute_initial_next_fire_at(
&schedule,
now,
record.last_fired_at,
misfire_policy,
timezone,
)?,
};
Ok(Some(Self {
@ -338,7 +360,10 @@ impl RuntimeJob {
fn is_due(&self, now: DateTime<Utc>) -> bool {
self.enabled
&& self.state == SchedulerJobState::Scheduled
&& self.next_fire_at.map(|value| value <= now.timestamp_millis()).unwrap_or(false)
&& self
.next_fire_at
.map(|value| value <= now.timestamp_millis())
.unwrap_or(false)
}
fn after_execution(
@ -371,7 +396,8 @@ impl RuntimeJob {
let reference_ms = self.next_fire_at.or(self.last_fired_at);
self.state = SchedulerJobState::Scheduled;
self.completed_at = None;
self.next_fire_at = compute_next_fire_at(&self.schedule, now, reference_ms, misfire_policy, timezone)?;
self.next_fire_at =
compute_next_fire_at(&self.schedule, now, reference_ms, misfire_policy, timezone)?;
Ok(())
}
@ -384,7 +410,8 @@ impl RuntimeJob {
SchedulerJobKind::AgentTask => "agent_task".to_string(),
SchedulerJobKind::SilentAgentTask => "silent_agent_task".to_string(),
},
schedule: serde_json::to_value(&self.schedule).unwrap_or_else(|_| serde_json::json!({})),
schedule: serde_json::to_value(&self.schedule)
.unwrap_or_else(|_| serde_json::json!({})),
interval_secs: self.interval_secs,
startup_delay_secs: self.startup_delay_secs,
target: serde_json::to_value(&self.target).unwrap_or_else(|_| serde_json::json!({})),
@ -430,21 +457,36 @@ fn compute_initial_next_fire_at(
timezone: Tz,
) -> anyhow::Result<Option<i64>> {
match last_fired_at {
Some(last_fired_at) => compute_next_fire_at(schedule, now, Some(last_fired_at), misfire_policy, timezone),
Some(last_fired_at) => {
compute_next_fire_at(schedule, now, Some(last_fired_at), misfire_policy, timezone)
}
None => match schedule {
SchedulerSchedule::Delay { seconds } => Ok(Some((now + ChronoDuration::seconds(*seconds as i64)).timestamp_millis())),
SchedulerSchedule::Delay { seconds } => Ok(Some(
(now + ChronoDuration::seconds(*seconds as i64)).timestamp_millis(),
)),
SchedulerSchedule::Interval {
seconds,
startup_delay_secs,
} => {
let delay = if *startup_delay_secs > 0 { *startup_delay_secs } else { *seconds };
Ok(Some((now + ChronoDuration::seconds(delay as i64)).timestamp_millis()))
let delay = if *startup_delay_secs > 0 {
*startup_delay_secs
} else {
*seconds
};
Ok(Some(
(now + ChronoDuration::seconds(delay as i64)).timestamp_millis(),
))
}
SchedulerSchedule::At { timestamp } => {
Ok(Some(parse_rfc3339_to_utc(timestamp)?.timestamp_millis()))
}
SchedulerSchedule::At { timestamp } => Ok(Some(parse_rfc3339_to_utc(timestamp)?.timestamp_millis())),
SchedulerSchedule::Cron { expression } => {
let schedule = parse_scheduler_cron(expression)?;
let local_now = now.with_timezone(&timezone);
Ok(schedule.after(&local_now).next().map(|next| next.with_timezone(&Utc).timestamp_millis()))
Ok(schedule
.after(&local_now)
.next()
.map(|next| next.with_timezone(&Utc).timestamp_millis()))
}
},
}
@ -483,7 +525,10 @@ fn compute_next_fire_at(
.map(|value| value.with_timezone(&timezone))
.unwrap_or_else(|| now.with_timezone(&timezone)),
};
Ok(schedule.after(&anchor).next().map(|next| next.with_timezone(&Utc).timestamp_millis()))
Ok(schedule
.after(&anchor)
.next()
.map(|next| next.with_timezone(&Utc).timestamp_millis()))
}
}
}
@ -525,12 +570,14 @@ fn build_outbound_message(job: &RuntimeJob) -> anyhow::Result<OutboundMessage> {
.payload
.get("content")
.and_then(|value| value.as_str())
.ok_or_else(|| anyhow::anyhow!("outbound scheduler job payload.content must be a string"))?;
.ok_or_else(|| {
anyhow::anyhow!("outbound scheduler job payload.content must be a string")
})?;
let mut metadata = HashMap::new();
metadata.insert("scheduler_job_id".to_string(), job.id.clone());
Ok(OutboundMessage::assistant(
Ok(OutboundMessage::scheduler_notification(
channel,
chat_id,
content.to_string(),
@ -539,7 +586,10 @@ fn build_outbound_message(job: &RuntimeJob) -> anyhow::Result<OutboundMessage> {
))
}
async fn execute_internal_event(session_manager: &SessionManager, job: &RuntimeJob) -> anyhow::Result<()> {
async fn execute_internal_event(
session_manager: &SessionManager,
job: &RuntimeJob,
) -> anyhow::Result<()> {
let event = job
.payload
.get("event")
@ -599,7 +649,10 @@ async fn execute_agent_task(
.map_err(|error| anyhow::anyhow!(error.to_string()))
}
fn required_notification_chat_id<'a>(job: &'a RuntimeJob, kind_name: &str) -> anyhow::Result<&'a str> {
fn required_notification_chat_id<'a>(
job: &'a RuntimeJob,
kind_name: &str,
) -> anyhow::Result<&'a str> {
job.target
.chat_id
.as_deref()
@ -608,7 +661,9 @@ fn required_notification_chat_id<'a>(job: &'a RuntimeJob, kind_name: &str) -> an
fn resolve_execution_chat_id(job: &RuntimeJob) -> anyhow::Result<String> {
match job.kind {
SchedulerJobKind::AgentTask => Ok(required_notification_chat_id(job, "agent_task")?.to_string()),
SchedulerJobKind::AgentTask => {
Ok(required_notification_chat_id(job, "agent_task")?.to_string())
}
SchedulerJobKind::SilentAgentTask => Ok(job
.target
.session_chat_id
@ -633,7 +688,9 @@ fn summarize_scheduler_error(error: &anyhow::Error) -> String {
}
}
fn parse_scheduled_agent_task_options(job: &RuntimeJob) -> anyhow::Result<ScheduledAgentTaskOptions> {
fn parse_scheduled_agent_task_options(
job: &RuntimeJob,
) -> anyhow::Result<ScheduledAgentTaskOptions> {
let sender_id = job
.payload
.get("sender_id")
@ -665,7 +722,9 @@ fn parse_scheduled_agent_task_options(job: &RuntimeJob) -> anyhow::Result<Schedu
})
}
fn parse_metadata_map(value: Option<&serde_json::Value>) -> anyhow::Result<HashMap<String, String>> {
fn parse_metadata_map(
value: Option<&serde_json::Value>,
) -> anyhow::Result<HashMap<String, String>> {
let Some(value) = value else {
return Ok(HashMap::new());
};
@ -685,7 +744,7 @@ fn parse_metadata_map(value: Option<&serde_json::Value>) -> anyhow::Result<HashM
return Err(anyhow::anyhow!(
"agent_task payload.metadata field '{}' must be a string, number, bool, or null",
key
))
));
}
};
metadata.insert(key.clone(), stringified);
@ -730,12 +789,19 @@ mod agent_task_tests {
updated_at: 1_700_000_000_000,
};
let job = RuntimeJob::from_record(&record, SchedulerMisfirePolicy::Skip, chrono_tz::Asia::Shanghai)
.unwrap()
.unwrap();
let job = RuntimeJob::from_record(
&record,
SchedulerMisfirePolicy::Skip,
chrono_tz::Asia::Shanghai,
)
.unwrap()
.unwrap();
assert_eq!(job.kind, SchedulerJobKind::AgentTask);
assert_eq!(job.payload.get("prompt").and_then(|value| value.as_str()), Some("请总结今天待办"));
assert_eq!(
job.payload.get("prompt").and_then(|value| value.as_str()),
Some("请总结今天待办")
);
}
#[test]
@ -771,12 +837,19 @@ mod agent_task_tests {
updated_at: 1_700_000_000_000,
};
let job = RuntimeJob::from_record(&record, SchedulerMisfirePolicy::Skip, chrono_tz::Asia::Shanghai)
.unwrap()
.unwrap();
let job = RuntimeJob::from_record(
&record,
SchedulerMisfirePolicy::Skip,
chrono_tz::Asia::Shanghai,
)
.unwrap()
.unwrap();
assert_eq!(job.kind, SchedulerJobKind::SilentAgentTask);
assert_eq!(job.target.session_chat_id.as_deref(), Some("scheduler/agent.daily_summary.background"));
assert_eq!(
job.target.session_chat_id.as_deref(),
Some("scheduler/agent.daily_summary.background")
);
}
#[test]
@ -825,9 +898,18 @@ mod agent_task_tests {
assert_eq!(options.sender_id.as_deref(), Some("scheduler-bot"));
assert!(options.fresh_session);
assert_eq!(options.system_prompt.as_deref(), Some("你是日报助手"));
assert_eq!(options.metadata.get("job_type").map(String::as_str), Some("daily_summary"));
assert_eq!(options.metadata.get("priority").map(String::as_str), Some("1"));
assert_eq!(options.metadata.get("urgent").map(String::as_str), Some("false"));
assert_eq!(
options.metadata.get("job_type").map(String::as_str),
Some("daily_summary")
);
assert_eq!(
options.metadata.get("priority").map(String::as_str),
Some("1")
);
assert_eq!(
options.metadata.get("urgent").map(String::as_str),
Some("false")
);
}
#[test]
@ -880,12 +962,12 @@ impl TryFrom<serde_json::Value> for SchedulerJobTarget {
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use crate::bus::MessageBus;
use crate::config::{BUILTIN_MEMORY_MAINTENANCE_JOB_ID, LLMProviderConfig};
use crate::gateway::session::SessionManager;
use crate::skills::SkillRuntime;
use crate::storage::{SchedulerJobUpsert, SessionStore};
use std::collections::HashMap;
fn test_provider_config() -> LLMProviderConfig {
LLMProviderConfig {
@ -898,6 +980,7 @@ mod tests {
model_id: "test-model".to_string(),
temperature: Some(0.0),
max_tokens: None,
context_window_tokens: None,
model_extra: HashMap::new(),
max_tool_iterations: 4,
tool_result_max_chars: 20_000,
@ -921,7 +1004,10 @@ mod tests {
#[test]
fn runtime_job_skip_policy_advances_from_now() {
let now = Utc.timestamp_millis_opt(1_700_000_000_000).single().unwrap();
let now = Utc
.timestamp_millis_opt(1_700_000_000_000)
.single()
.unwrap();
let next = compute_next_fire_at(
&SchedulerSchedule::Interval {
seconds: 60,
@ -940,7 +1026,10 @@ mod tests {
#[test]
fn runtime_job_catch_up_policy_moves_past_now() {
let now = Utc.timestamp_millis_opt(1_700_000_000_000).single().unwrap();
let now = Utc
.timestamp_millis_opt(1_700_000_000_000)
.single()
.unwrap();
let next = compute_next_fire_at(
&SchedulerSchedule::Interval {
seconds: 60,
@ -989,14 +1078,21 @@ mod tests {
updated_at: 1_700_000_000_000,
};
let job = RuntimeJob::from_record(&record, SchedulerMisfirePolicy::Skip, chrono_tz::Asia::Shanghai)
.unwrap()
.unwrap();
let job = RuntimeJob::from_record(
&record,
SchedulerMisfirePolicy::Skip,
chrono_tz::Asia::Shanghai,
)
.unwrap()
.unwrap();
assert_eq!(job.schedule, SchedulerSchedule::Interval {
seconds: 120,
startup_delay_secs: 10,
});
assert_eq!(
job.schedule,
SchedulerSchedule::Interval {
seconds: 120,
startup_delay_secs: 10,
}
);
assert_eq!(job.next_fire_at, Some(1_700_000_010_000));
}
@ -1050,7 +1146,10 @@ mod tests {
scheduler.process_tick().await.unwrap();
let saved = store.get_scheduler_job("massage_reminder").unwrap().unwrap();
let saved = store
.get_scheduler_job("massage_reminder")
.unwrap()
.unwrap();
assert!(saved.next_fire_at.is_some());
assert_eq!(saved.run_count, 0);
assert_eq!(saved.state, SchedulerJobState::Scheduled);
@ -1080,7 +1179,10 @@ mod tests {
assert_eq!(saved.kind, "internal_event");
assert!(saved.enabled);
assert_eq!(saved.state, SchedulerJobState::Scheduled);
assert_eq!(saved.payload.get("event").and_then(|value| value.as_str()), Some("memory_maintenance"));
assert_eq!(
saved.payload.get("event").and_then(|value| value.as_str()),
Some("memory_maintenance")
);
assert_eq!(
saved.schedule,
serde_json::json!({
@ -1088,7 +1190,13 @@ mod tests {
"expression": "0 */4 * * *"
})
);
assert_eq!(saved.payload.get("local_time").and_then(|value| value.as_str()), Some("every_4_hours"));
assert_eq!(
saved
.payload
.get("local_time")
.and_then(|value| value.as_str()),
Some("every_4_hours")
);
assert!(saved.next_fire_at.is_some());
}
@ -1155,7 +1263,10 @@ mod tests {
#[test]
fn cron_schedule_uses_configured_timezone() {
let now = Utc.with_ymd_and_hms(2026, 4, 23, 18, 0, 0).single().unwrap();
let now = Utc
.with_ymd_and_hms(2026, 4, 23, 18, 0, 0)
.single()
.unwrap();
let next = compute_next_fire_at(
&SchedulerSchedule::Cron {
expression: "0 3 * * *".to_string(),
@ -1169,6 +1280,11 @@ mod tests {
.unwrap();
let next_utc = ts_millis_to_utc(next).unwrap();
assert_eq!(next_utc, Utc.with_ymd_and_hms(2026, 4, 23, 19, 0, 0).single().unwrap());
assert_eq!(
next_utc,
Utc.with_ymd_and_hms(2026, 4, 23, 19, 0, 0)
.single()
.unwrap()
);
}
}

View File

@ -89,7 +89,10 @@ impl SkillRuntime {
}
pub fn is_empty(&self) -> bool {
self.catalog.read().expect("skills rwlock poisoned").is_empty()
self.catalog
.read()
.expect("skills rwlock poisoned")
.is_empty()
}
pub fn len(&self) -> usize {
@ -97,31 +100,53 @@ impl SkillRuntime {
}
pub fn system_index_prompt(&self) -> Option<String> {
self.catalog.read().expect("skills rwlock poisoned").system_index_prompt()
self.catalog
.read()
.expect("skills rwlock poisoned")
.system_index_prompt()
}
pub fn discovery_event_payload(&self) -> serde_json::Value {
self.catalog.read().expect("skills rwlock poisoned").discovery_event_payload()
self.catalog
.read()
.expect("skills rwlock poisoned")
.discovery_event_payload()
}
pub fn offered_event_payload(&self) -> serde_json::Value {
self.catalog.read().expect("skills rwlock poisoned").offered_event_payload()
self.catalog
.read()
.expect("skills rwlock poisoned")
.offered_event_payload()
}
pub fn skill_tool_definition(&self) -> Option<Tool> {
self.catalog.read().expect("skills rwlock poisoned").skill_tool_definition()
self.catalog
.read()
.expect("skills rwlock poisoned")
.skill_tool_definition()
}
pub fn activation_payload(&self, name: &str) -> Result<String, String> {
self.catalog.read().expect("skills rwlock poisoned").activation_payload(name)
self.catalog
.read()
.expect("skills rwlock poisoned")
.activation_payload(name)
}
pub fn activation_event_payload(&self, name: &str) -> Result<serde_json::Value, String> {
self.catalog.read().expect("skills rwlock poisoned").activation_event_payload(name)
self.catalog
.read()
.expect("skills rwlock poisoned")
.activation_event_payload(name)
}
pub fn list_skills(&self) -> Vec<Skill> {
self.catalog.read().expect("skills rwlock poisoned").skills.clone()
self.catalog
.read()
.expect("skills rwlock poisoned")
.skills
.clone()
}
pub fn get_skill(&self, name: &str) -> Option<Skill> {
@ -143,7 +168,11 @@ impl SkillRuntime {
validate_skill_name(name)?;
let path = skill_file_path(scope, name)?;
if path.exists() {
return Err(format!("skill '{}' already exists at {}", name, path.display()));
return Err(format!(
"skill '{}' already exists at {}",
name,
path.display()
));
}
write_skill_file(&path, name, description, body)?;
@ -180,14 +209,20 @@ impl SkillRuntime {
Ok(skill)
}
pub fn delete_skill(&self, scope: SkillScope, name: &str, reload: bool) -> Result<PathBuf, String> {
pub fn delete_skill(
&self,
scope: SkillScope,
name: &str,
reload: bool,
) -> Result<PathBuf, String> {
validate_skill_name(name)?;
let dir = skill_dir_path(scope, name)?;
if !dir.exists() {
return Err(format!("skill '{}' not found at {}", name, dir.display()));
}
fs::remove_dir_all(&dir).map_err(|err| format!("failed to delete skill directory: {}", err))?;
fs::remove_dir_all(&dir)
.map_err(|err| format!("failed to delete skill directory: {}", err))?;
if reload {
let _ = self.reload()?;
}
@ -439,7 +474,8 @@ fn validate_skill_name(name: &str) -> Result<(), String> {
}
pub fn project_skills_root() -> Result<PathBuf, String> {
let cwd = std::env::current_dir().map_err(|err| format!("failed to get current dir: {}", err))?;
let cwd =
std::env::current_dir().map_err(|err| format!("failed to get current dir: {}", err))?;
Ok(cwd.join(".picobot").join("skills"))
}
@ -466,7 +502,9 @@ fn source_root(source: SkillSource, cwd: &Path) -> Option<PathBuf> {
fn root_for_scope(scope: SkillScope) -> Result<PathBuf, String> {
match scope {
SkillScope::User => user_skills_root().ok_or_else(|| "failed to resolve home directory".to_string()),
SkillScope::User => {
user_skills_root().ok_or_else(|| "failed to resolve home directory".to_string())
}
SkillScope::Project => project_skills_root(),
}
}
@ -508,7 +546,8 @@ fn render_skill_file(name: &str, description: &str, body: &str) -> Result<String
fn write_skill_file(path: &Path, name: &str, description: &str, body: &str) -> Result<(), String> {
let content = render_skill_file(name, description, body)?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).map_err(|err| format!("failed to create skill directory: {}", err))?;
fs::create_dir_all(parent)
.map_err(|err| format!("failed to create skill directory: {}", err))?;
}
fs::write(path, content).map_err(|err| format!("failed to write skill file: {}", err))
}
@ -556,11 +595,10 @@ struct SkillFrontmatter {
}
fn parse_skill_file(path: &Path, source: SkillSource) -> Result<Skill, String> {
let content = fs::read_to_string(path)
.map_err(|e| format!("failed to read file: {}", e))?;
let content = fs::read_to_string(path).map_err(|e| format!("failed to read file: {}", e))?;
let (frontmatter_raw, body) = split_frontmatter(&content)
.ok_or_else(|| "missing YAML frontmatter block".to_string())?;
let (frontmatter_raw, body) =
split_frontmatter(&content).ok_or_else(|| "missing YAML frontmatter block".to_string())?;
let frontmatter: SkillFrontmatter = serde_yaml::from_str(frontmatter_raw)
.map_err(|e| format!("invalid YAML frontmatter: {}", e))?;
@ -576,11 +614,7 @@ fn parse_skill_file(path: &Path, source: SkillSource) -> Result<Skill, String> {
.map(|s| s.to_string_lossy().to_string())
.unwrap_or_else(|| "unknown-skill".to_string());
let name = frontmatter
.name
.unwrap_or(dir_name)
.trim()
.to_string();
let name = frontmatter.name.unwrap_or(dir_name).trim().to_string();
Ok(Skill {
name,
@ -656,7 +690,10 @@ mod tests {
)
.unwrap();
let skills = load_skills_from_root(&dir.path().join(".picobot").join("skills"), SkillSource::Project);
let skills = load_skills_from_root(
&dir.path().join(".picobot").join("skills"),
SkillSource::Project,
);
let catalog = SkillCatalog {
skills,
max_index_chars: 4000,
@ -707,7 +744,13 @@ mod tests {
assert_eq!(runtime.len(), 0);
let created = runtime
.create_skill(SkillScope::Project, "demo-skill", "demo desc", "line 1", true)
.create_skill(
SkillScope::Project,
"demo-skill",
"demo desc",
"line 1",
true,
)
.unwrap();
assert_eq!(created.name, "demo-skill");
assert_eq!(runtime.len(), 1);
@ -722,7 +765,12 @@ mod tests {
)
.unwrap();
assert_eq!(updated.description, "updated desc");
assert!(runtime.activation_payload("demo-skill").unwrap().contains("line 2"));
assert!(
runtime
.activation_payload("demo-skill")
.unwrap()
.contains("line 2")
);
let deleted_path = runtime
.delete_skill(SkillScope::Project, "demo-skill", true)
@ -759,7 +807,11 @@ mod tests {
let temp_dir = tempfile::tempdir().unwrap();
let _guard = CurrentDirGuard::enter(temp_dir.path());
let agent_skill_dir = temp_dir.path().join(".agents").join("skills").join("demo-agent");
let agent_skill_dir = temp_dir
.path()
.join(".agents")
.join("skills")
.join("demo-agent");
fs::create_dir_all(&agent_skill_dir).unwrap();
fs::write(
agent_skill_dir.join("SKILL.md"),

View File

@ -391,7 +391,8 @@ impl SessionStore {
)?;
drop(conn);
self.get_session(&id)?.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
self.get_session(&id)?
.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
}
pub fn ensure_channel_session(
@ -419,7 +420,8 @@ impl SessionStore {
)?;
drop(conn);
self.get_session(&session_id)?.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
self.get_session(&session_id)?
.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
}
pub fn get_session(&self, session_id: &str) -> Result<Option<SessionRecord>, StorageError> {
@ -495,7 +497,10 @@ impl SessionStore {
pub fn delete_session(&self, session_id: &str) -> Result<(), StorageError> {
let conn = self.conn.lock().expect("session db mutex poisoned");
conn.execute("DELETE FROM messages WHERE session_id = ?1", params![session_id])?;
conn.execute(
"DELETE FROM messages WHERE session_id = ?1",
params![session_id],
)?;
conn.execute("DELETE FROM sessions WHERE id = ?1", params![session_id])?;
Ok(())
}
@ -503,7 +508,10 @@ impl SessionStore {
pub fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> {
let now = current_timestamp();
let conn = self.conn.lock().expect("session db mutex poisoned");
conn.execute("DELETE FROM messages WHERE session_id = ?1", params![session_id])?;
conn.execute(
"DELETE FROM messages WHERE session_id = ?1",
params![session_id],
)?;
conn.execute(
"
UPDATE sessions
@ -549,7 +557,11 @@ impl SessionStore {
Ok(())
}
pub fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError> {
pub fn append_message(
&self,
session_id: &str,
message: &ChatMessage,
) -> Result<(), StorageError> {
let conn = self.conn.lock().expect("session db mutex poisoned");
let tx = conn.unchecked_transaction()?;
@ -560,7 +572,11 @@ impl SessionStore {
)?;
let media_refs_json = serde_json::to_string(&message.media_refs)?;
let tool_calls_json = message.tool_calls.as_ref().map(serde_json::to_string).transpose()?;
let tool_calls_json = message
.tool_calls
.as_ref()
.map(serde_json::to_string)
.transpose()?;
tx.execute(
"
INSERT INTO messages (
@ -630,7 +646,8 @@ impl SessionStore {
return Ok(false);
}
let delta_messages = load_messages_between(&tx, session_id, snapshot_end_seq, current_max_seq)?;
let delta_messages =
load_messages_between(&tx, session_id, snapshot_end_seq, current_max_seq)?;
let mut next_seq = current_max_seq + 1;
let now = current_timestamp();
let mut inserted_count = 0_i64;
@ -782,8 +799,7 @@ impl SessionStore {
)
.optional()?;
let (id, created_at) = existing
.unwrap_or_else(|| (uuid::Uuid::new_v4().to_string(), now));
let (id, created_at) = existing.unwrap_or_else(|| (uuid::Uuid::new_v4().to_string(), now));
tx.execute(
"
@ -881,7 +897,10 @@ impl SessionStore {
LIMIT ?4
",
)?;
let rows = stmt.query_map(params![scope_kind, scope_key, namespace, limit], map_memory_record)?;
let rows = stmt.query_map(
params![scope_kind, scope_key, namespace, limit],
map_memory_record,
)?;
for row in rows {
memories.push(row?);
}
@ -940,7 +959,9 @@ impl SessionStore {
",
)?;
let rows = stmt.query_map(params![scope_kind, since_timestamp], |row| row.get::<_, String>(0))?;
let rows = stmt.query_map(params![scope_kind, since_timestamp], |row| {
row.get::<_, String>(0)
})?;
let mut scope_keys = Vec::new();
for row in rows {
scope_keys.push(row?);
@ -1010,7 +1031,10 @@ impl SessionStore {
Ok(changed > 0)
}
pub fn upsert_scheduler_job(&self, input: &SchedulerJobUpsert) -> Result<SchedulerJobRecord, StorageError> {
pub fn upsert_scheduler_job(
&self,
input: &SchedulerJobUpsert,
) -> Result<SchedulerJobRecord, StorageError> {
let now = current_timestamp();
let conn = self.conn.lock().expect("session db mutex poisoned");
conn.execute(
@ -1067,7 +1091,10 @@ impl SessionStore {
.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
}
pub fn get_scheduler_job(&self, job_id: &str) -> Result<Option<SchedulerJobRecord>, StorageError> {
pub fn get_scheduler_job(
&self,
job_id: &str,
) -> Result<Option<SchedulerJobRecord>, StorageError> {
let conn = self.conn.lock().expect("session db mutex poisoned");
let mut stmt = conn.prepare(
"
@ -1085,7 +1112,10 @@ impl SessionStore {
.map_err(StorageError::from)
}
pub fn list_scheduler_jobs(&self, enabled_only: bool) -> Result<Vec<SchedulerJobRecord>, StorageError> {
pub fn list_scheduler_jobs(
&self,
enabled_only: bool,
) -> Result<Vec<SchedulerJobRecord>, StorageError> {
let conn = self.conn.lock().expect("session db mutex poisoned");
let sql = if enabled_only {
"
@ -1195,7 +1225,10 @@ impl SessionStore {
LIMIT ?5
",
)?;
let rows = stmt.query_map(params![query, scope_kind, scope_key, namespace, limit], map_memory_record)?;
let rows = stmt.query_map(
params![query, scope_kind, scope_key, namespace, limit],
map_memory_record,
)?;
for row in rows {
memories.push(row?);
}
@ -1214,7 +1247,10 @@ impl SessionStore {
LIMIT ?4
",
)?;
let rows = stmt.query_map(params![query, scope_kind, scope_key, limit], map_memory_record)?;
let rows = stmt.query_map(
params![query, scope_kind, scope_key, limit],
map_memory_record,
)?;
for row in rows {
memories.push(row?);
}
@ -1256,7 +1292,10 @@ impl SessionStore {
LIMIT ?5
",
)?;
let rows = stmt.query_map(params![query, scope_kind, scope_key, namespace, limit], map_memory_record)?;
let rows = stmt.query_map(
params![query, scope_kind, scope_key, namespace, limit],
map_memory_record,
)?;
for row in rows {
memories.push(row?);
}
@ -1275,7 +1314,10 @@ impl SessionStore {
LIMIT ?4
",
)?;
let rows = stmt.query_map(params![query, scope_kind, scope_key, limit], map_memory_record)?;
let rows = stmt.query_map(
params![query, scope_kind, scope_key, limit],
map_memory_record,
)?;
for row in rows {
memories.push(row?);
}
@ -1347,11 +1389,7 @@ fn map_session_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<SessionRecord
fn map_skill_event_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<SkillEventRecord> {
let payload_json: String = row.get(4)?;
let payload = serde_json::from_str(&payload_json).map_err(|err| {
rusqlite::Error::FromSqlConversionFailure(
4,
rusqlite::types::Type::Text,
Box::new(err),
)
rusqlite::Error::FromSqlConversionFailure(4, rusqlite::types::Type::Text, Box::new(err))
})?;
Ok(SkillEventRecord {
@ -1391,25 +1429,13 @@ fn map_scheduler_job_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<Schedul
let last_status: Option<String> = row.get(9)?;
let schedule = serde_json::from_str(&schedule_json).map_err(|err| {
rusqlite::Error::FromSqlConversionFailure(
2,
rusqlite::types::Type::Text,
Box::new(err),
)
rusqlite::Error::FromSqlConversionFailure(2, rusqlite::types::Type::Text, Box::new(err))
})?;
let target = serde_json::from_str(&target_json).map_err(|err| {
rusqlite::Error::FromSqlConversionFailure(
5,
rusqlite::types::Type::Text,
Box::new(err),
)
rusqlite::Error::FromSqlConversionFailure(5, rusqlite::types::Type::Text, Box::new(err))
})?;
let payload = serde_json::from_str(&payload_json).map_err(|err| {
rusqlite::Error::FromSqlConversionFailure(
6,
rusqlite::types::Type::Text,
Box::new(err),
)
rusqlite::Error::FromSqlConversionFailure(6, rusqlite::types::Type::Text, Box::new(err))
})?;
Ok(SchedulerJobRecord {
@ -1472,7 +1498,10 @@ fn ensure_messages_schema(conn: &Connection) -> Result<(), StorageError> {
}
if !has_column(conn, "messages", "reasoning_content")? {
add_column_if_missing(conn, "ALTER TABLE messages ADD COLUMN reasoning_content TEXT")?;
add_column_if_missing(
conn,
"ALTER TABLE messages ADD COLUMN reasoning_content TEXT",
)?;
}
Ok(())
@ -1494,17 +1523,11 @@ fn ensure_scheduler_schema(conn: &Connection) -> Result<(), StorageError> {
}
if !has_column(conn, "scheduler_jobs", "last_status")? {
conn.execute(
"ALTER TABLE scheduler_jobs ADD COLUMN last_status TEXT",
[],
)?;
conn.execute("ALTER TABLE scheduler_jobs ADD COLUMN last_status TEXT", [])?;
}
if !has_column(conn, "scheduler_jobs", "last_error")? {
conn.execute(
"ALTER TABLE scheduler_jobs ADD COLUMN last_error TEXT",
[],
)?;
conn.execute("ALTER TABLE scheduler_jobs ADD COLUMN last_error TEXT", [])?;
}
if !has_column(conn, "scheduler_jobs", "run_count")? {
@ -1515,10 +1538,7 @@ fn ensure_scheduler_schema(conn: &Connection) -> Result<(), StorageError> {
}
if !has_column(conn, "scheduler_jobs", "max_runs")? {
conn.execute(
"ALTER TABLE scheduler_jobs ADD COLUMN max_runs INTEGER",
[],
)?;
conn.execute("ALTER TABLE scheduler_jobs ADD COLUMN max_runs INTEGER", [])?;
}
if !has_column(conn, "scheduler_jobs", "paused_at")? {
@ -1538,7 +1558,11 @@ fn ensure_scheduler_schema(conn: &Connection) -> Result<(), StorageError> {
Ok(())
}
fn has_column(conn: &Connection, table_name: &str, column_name: &str) -> Result<bool, StorageError> {
fn has_column(
conn: &Connection,
table_name: &str,
column_name: &str,
) -> Result<bool, StorageError> {
let pragma = format!("PRAGMA table_info({})", table_name);
let mut stmt = conn.prepare(&pragma)?;
let mut rows = stmt.query([])?;
@ -1557,7 +1581,10 @@ fn add_column_if_missing(conn: &Connection, sql: &str) -> Result<(), StorageErro
match conn.execute(sql, []) {
Ok(_) => Ok(()),
Err(rusqlite::Error::SqliteFailure(_, Some(message)))
if message.contains("duplicate column name") => Ok(()),
if message.contains("duplicate column name") =>
{
Ok(())
}
Err(error) => Err(StorageError::Database(error)),
}
}
@ -1581,7 +1608,11 @@ fn insert_message_with_seq(
message: &ChatMessage,
) -> Result<(), StorageError> {
let media_refs_json = serde_json::to_string(&message.media_refs)?;
let tool_calls_json = message.tool_calls.as_ref().map(serde_json::to_string).transpose()?;
let tool_calls_json = message
.tool_calls
.as_ref()
.map(serde_json::to_string)
.transpose()?;
conn.execute(
"
INSERT INTO messages (
@ -1638,43 +1669,47 @@ fn load_messages_between(
",
)?;
let rows = stmt.query_map(params![session_id, start_seq_exclusive, end_seq_inclusive], |row| {
let media_refs_json: String = row.get(5)?;
let media_refs: Vec<String> = serde_json::from_str(&media_refs_json).map_err(|err| {
rusqlite::Error::FromSqlConversionFailure(
media_refs_json.len(),
rusqlite::types::Type::Text,
Box::new(err),
)
})?;
let rows = stmt.query_map(
params![session_id, start_seq_exclusive, end_seq_inclusive],
|row| {
let media_refs_json: String = row.get(5)?;
let media_refs: Vec<String> =
serde_json::from_str(&media_refs_json).map_err(|err| {
rusqlite::Error::FromSqlConversionFailure(
media_refs_json.len(),
rusqlite::types::Type::Text,
Box::new(err),
)
})?;
let tool_calls_json: Option<String> = row.get(9)?;
let tool_calls = tool_calls_json
.as_deref()
.map(serde_json::from_str)
.transpose()
.map_err(|err| {
rusqlite::Error::FromSqlConversionFailure(
9,
rusqlite::types::Type::Text,
Box::new(err),
)
})?;
let tool_calls_json: Option<String> = row.get(9)?;
let tool_calls = tool_calls_json
.as_deref()
.map(serde_json::from_str)
.transpose()
.map_err(|err| {
rusqlite::Error::FromSqlConversionFailure(
9,
rusqlite::types::Type::Text,
Box::new(err),
)
})?;
Ok(ChatMessage {
id: row.get(0)?,
role: row.get(1)?,
content: row.get(2)?,
system_context: row.get(3)?,
reasoning_content: row.get(4)?,
media_refs,
timestamp: row.get(6)?,
tool_call_id: row.get(7)?,
tool_name: row.get(8)?,
tool_state: None,
tool_calls,
})
})?;
Ok(ChatMessage {
id: row.get(0)?,
role: row.get(1)?,
content: row.get(2)?,
system_context: row.get(3)?,
reasoning_content: row.get(4)?,
media_refs,
timestamp: row.get(6)?,
tool_call_id: row.get(7)?,
tool_name: row.get(8)?,
tool_state: None,
tool_calls,
})
},
)?;
let mut messages = Vec::new();
for row in rows {
@ -1866,7 +1901,10 @@ mod tests {
assert_eq!(messages[0].role, "assistant");
assert_eq!(messages[0].tool_calls.as_ref().unwrap().len(), 1);
assert_eq!(messages[0].tool_calls.as_ref().unwrap()[0].id, "call_1");
assert_eq!(messages[0].tool_calls.as_ref().unwrap()[0].name, "calculator");
assert_eq!(
messages[0].tool_calls.as_ref().unwrap()[0].name,
"calculator"
);
}
#[test]
@ -1874,17 +1912,17 @@ mod tests {
let store = SessionStore::in_memory().unwrap();
let session = store.create_cli_session(Some("reasoning")).unwrap();
let assistant = ChatMessage::assistant_with_reasoning(
"final answer",
"hidden reasoning",
);
let assistant = ChatMessage::assistant_with_reasoning("final answer", "hidden reasoning");
store.append_message(&session.id, &assistant).unwrap();
let messages = store.load_messages(&session.id).unwrap();
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].content, "final answer");
assert_eq!(messages[0].reasoning_content.as_deref(), Some("hidden reasoning"));
assert_eq!(
messages[0].reasoning_content.as_deref(),
Some("hidden reasoning")
);
}
#[test]
@ -1892,8 +1930,12 @@ mod tests {
let store = SessionStore::in_memory().unwrap();
let session = store.create_cli_session(Some("reset")).unwrap();
store.append_message(&session.id, &ChatMessage::user("before")).unwrap();
store.append_message(&session.id, &ChatMessage::assistant("context")).unwrap();
store
.append_message(&session.id, &ChatMessage::user("before"))
.unwrap();
store
.append_message(&session.id, &ChatMessage::assistant("context"))
.unwrap();
store.reset_session(&session.id).unwrap();
let stored = store.get_session(&session.id).unwrap().unwrap();
@ -1909,7 +1951,9 @@ mod tests {
assert_eq!(all_messages[0].content, "before");
assert_eq!(all_messages[1].content, "context");
store.append_message(&session.id, &ChatMessage::user("after")).unwrap();
store
.append_message(&session.id, &ChatMessage::user("after"))
.unwrap();
let active_messages = store.load_messages(&session.id).unwrap();
assert_eq!(active_messages.len(), 1);
assert_eq!(active_messages[0].content, "after");
@ -2010,19 +2054,33 @@ mod tests {
let store = SessionStore::in_memory().unwrap();
let session = store.create_cli_session(Some("count-users")).unwrap();
store.append_message(&session.id, &ChatMessage::system("agent")).unwrap();
store.append_message(&session.id, &ChatMessage::user("u1")).unwrap();
store.append_message(&session.id, &ChatMessage::assistant("a1")).unwrap();
store.append_message(&session.id, &ChatMessage::user("u2")).unwrap();
store
.append_message(&session.id, &ChatMessage::system("agent"))
.unwrap();
store
.append_message(&session.id, &ChatMessage::user("u1"))
.unwrap();
store
.append_message(&session.id, &ChatMessage::assistant("a1"))
.unwrap();
store
.append_message(&session.id, &ChatMessage::user("u2"))
.unwrap();
assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 2);
store.reset_session(&session.id).unwrap();
assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 0);
store.append_message(&session.id, &ChatMessage::system("agent-again")).unwrap();
store.append_message(&session.id, &ChatMessage::user("u3")).unwrap();
store.append_message(&session.id, &ChatMessage::user("u4")).unwrap();
store
.append_message(&session.id, &ChatMessage::system("agent-again"))
.unwrap();
store
.append_message(&session.id, &ChatMessage::user("u3"))
.unwrap();
store
.append_message(&session.id, &ChatMessage::user("u4"))
.unwrap();
assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 2);
}
@ -2052,12 +2110,20 @@ mod tests {
store.append_message(&session.id, message).unwrap();
}
let snapshot_end_seq = store.get_session(&session.id).unwrap().unwrap().message_count;
let snapshot_end_seq = store
.get_session(&session.id)
.unwrap()
.unwrap()
.message_count;
let preserved_messages = store.load_messages(&session.id).unwrap()[3..].to_vec();
let preserved_system_messages = vec![agent_prompt];
store.append_message(&session.id, &ChatMessage::user("u5")).unwrap();
store.append_message(&session.id, &ChatMessage::assistant("a5")).unwrap();
store
.append_message(&session.id, &ChatMessage::user("u5"))
.unwrap();
store
.append_message(&session.id, &ChatMessage::assistant("a5"))
.unwrap();
let summary_message = ChatMessage::system("[Compressed History]\n\nsummary");
let compacted = store
@ -2074,16 +2140,22 @@ mod tests {
assert!(compacted);
let active_messages = store.load_messages(&session.id).unwrap();
assert_eq!(active_messages.len(), 10);
assert_eq!(active_messages[0].role, "system");
assert_eq!(active_messages[0].content, "agent");
assert_eq!(active_messages[0].system_context.as_deref(), Some(SYSTEM_CONTEXT_AGENT_PROMPT));
assert_eq!(active_messages[1].role, "system");
assert_eq!(active_messages[1].content, "[Compressed History]\n\nsummary");
assert_eq!(active_messages[2].content, "u2");
assert_eq!(active_messages[3].content, "a2");
assert_eq!(active_messages[8].content, "u5");
assert_eq!(active_messages[9].content, "a5");
assert_eq!(active_messages.len(), 10);
assert_eq!(active_messages[0].role, "system");
assert_eq!(active_messages[0].content, "agent");
assert_eq!(
active_messages[0].system_context.as_deref(),
Some(SYSTEM_CONTEXT_AGENT_PROMPT)
);
assert_eq!(active_messages[1].role, "system");
assert_eq!(
active_messages[1].content,
"[Compressed History]\n\nsummary"
);
assert_eq!(active_messages[2].content, "u2");
assert_eq!(active_messages[3].content, "a2");
assert_eq!(active_messages[8].content, "u5");
assert_eq!(active_messages[9].content, "a5");
let stored = store.get_session(&session.id).unwrap().unwrap();
assert_eq!(stored.reset_cutoff_seq, 11);
@ -2128,12 +2200,7 @@ mod tests {
let session = store.create_cli_session(Some("skill-events")).unwrap();
store
.append_skill_event(
None,
"discovered",
None,
&serde_json::json!({"count": 2}),
)
.append_skill_event(None, "discovered", None, &serde_json::json!({"count": 2}))
.unwrap();
store
.append_skill_event(
@ -2383,13 +2450,26 @@ mod tests {
.unwrap();
let scope_keys = store.list_memory_scope_keys("user").unwrap();
assert_eq!(scope_keys, vec!["feishu:user-1".to_string(), "feishu:user-2".to_string()]);
assert_eq!(
scope_keys,
vec!["feishu:user-1".to_string(), "feishu:user-2".to_string()]
);
let full_scope = store.list_memories_for_scope("user", "feishu:user-1").unwrap();
let full_scope = store
.list_memories_for_scope("user", "feishu:user-1")
.unwrap();
assert_eq!(full_scope.len(), 2);
assert!(full_scope.iter().all(|memory| memory.scope_key == "feishu:user-1"));
assert!(
full_scope
.iter()
.all(|memory| memory.scope_key == "feishu:user-1")
);
assert!(full_scope.iter().any(|memory| memory.memory_key == "work"));
assert!(full_scope.iter().any(|memory| memory.memory_key == "workflow"));
assert!(
full_scope
.iter()
.any(|memory| memory.memory_key == "workflow")
);
}
#[test]

View File

@ -15,7 +15,8 @@ use crate::tools::traits::{Tool, ToolResult};
const MAX_TIMEOUT_SECS: u64 = 600;
const MAX_OUTPUT_CHARS: usize = 50_000;
const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
const USER_ACTION_HINT: &str = "该命令正在等待你完成外部操作。完成后请告诉我继续,或重新运行后续检查命令。";
const USER_ACTION_HINT: &str =
"该命令正在等待你完成外部操作。完成后请告诉我继续,或重新运行后续检查命令。";
pub struct BashTool {
timeout_secs: u64,
@ -208,7 +209,10 @@ impl Tool for BashTool {
.map(|d| Path::new(d))
.unwrap_or_else(|| Path::new("."));
match self.run_command(command, cwd, timeout_secs, interactive).await {
match self
.run_command(command, cwd, timeout_secs, interactive)
.await
{
Ok(output) => Ok(ToolResult {
success: true,
output,
@ -366,10 +370,7 @@ mod tests {
#[tokio::test]
async fn test_pwd_command() {
let tool = BashTool::new();
let result = tool
.execute(json!({ "command": "pwd" }))
.await
.unwrap();
let result = tool.execute(json!({ "command": "pwd" })).await.unwrap();
assert!(result.success);
}
@ -377,7 +378,10 @@ mod tests {
#[tokio::test]
async fn test_ls_command() {
let tool = BashTool::new();
let result = tool.execute(json!({ "command": "ls -la /tmp" })).await.unwrap();
let result = tool
.execute(json!({ "command": "ls -la /tmp" }))
.await
.unwrap();
assert!(result.success);
}

View File

@ -659,10 +659,7 @@ mod tests {
#[tokio::test]
async fn test_evaluate_missing_expression() {
let tool = CalculatorTool::new();
let result = tool
.execute(json!({"function": "evaluate"}))
.await
.unwrap();
let result = tool.execute(json!({"function": "evaluate"})).await.unwrap();
assert!(!result.success);
}

View File

@ -268,8 +268,8 @@ impl Tool for FileEditTool {
#[cfg(test)]
mod tests {
use super::*;
use tempfile::NamedTempFile;
use std::io::Write;
use tempfile::NamedTempFile;
#[tokio::test]
async fn test_edit_simple() {

View File

@ -218,7 +218,7 @@ impl Tool for FileReadTool {
// Try to read as binary and encode as base64
match std::fs::read(&resolved) {
Ok(bytes) => {
use base64::{engine::general_purpose::STANDARD, Engine};
use base64::{Engine, engine::general_purpose::STANDARD};
let encoded = STANDARD.encode(&bytes);
let mime = mime_guess::from_path(&resolved)
.first_or_octet_stream()
@ -248,8 +248,8 @@ impl Tool for FileReadTool {
#[cfg(test)]
mod tests {
use super::*;
use tempfile::NamedTempFile;
use std::io::Write;
use tempfile::NamedTempFile;
#[tokio::test]
async fn test_read_simple_file() {
@ -308,10 +308,7 @@ mod tests {
#[tokio::test]
async fn test_is_directory() {
let tool = FileReadTool::new();
let result = tool
.execute(json!({ "path": "." }))
.await
.unwrap();
let result = tool.execute(json!({ "path": "." })).await.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("Not a file"));

View File

@ -195,10 +195,7 @@ mod tests {
#[tokio::test]
async fn test_write_missing_path() {
let tool = FileWriteTool::new();
let result = tool
.execute(json!({ "content": "Hello" }))
.await
.unwrap();
let result = tool.execute(json!({ "content": "Hello" })).await.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("path"));

View File

@ -50,10 +50,7 @@ impl HttpRequestTool {
}
if !host_matches_allowlist(&host, &self.allowed_domains) {
return Err(format!(
"Host '{}' is not in allowed_domains",
host
));
return Err(format!("Host '{}' is not in allowed_domains", host));
}
Ok(url.to_string())
@ -80,9 +77,7 @@ impl HttpRequestTool {
for (key, value) in obj {
if let Some(str_val) = value.as_str() {
if let Ok(name) = reqwest::header::HeaderName::from_bytes(key.as_bytes()) {
if let Ok(val) =
reqwest::header::HeaderValue::from_str(str_val)
{
if let Ok(val) = reqwest::header::HeaderValue::from_str(str_val) {
header_map.insert(name, val);
}
}
@ -193,7 +188,9 @@ fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool {
allowed_domains.iter().any(|domain| {
host == domain
|| host.strip_suffix(domain).is_some_and(|prefix| prefix.ends_with('.'))
|| host
.strip_suffix(domain)
.is_some_and(|prefix| prefix.ends_with('.'))
})
}
@ -204,7 +201,11 @@ fn is_private_host(host: &str) -> bool {
}
// Check .local TLD
if host.rsplit('.').next().is_some_and(|label| label == "local") {
if host
.rsplit('.')
.next()
.is_some_and(|label| label == "local")
{
return true;
}
@ -226,9 +227,7 @@ fn is_private_ip(ip: &std::net::IpAddr) -> bool {
|| v4.is_broadcast()
|| v4.is_multicast()
}
std::net::IpAddr::V6(v6) => {
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
}
std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(),
}
}
@ -280,10 +279,7 @@ impl Tool for HttpRequestTool {
}
};
let method_str = args
.get("method")
.and_then(|v| v.as_str())
.unwrap_or("GET");
let method_str = args.get("method").and_then(|v| v.as_str()).unwrap_or("GET");
let headers_val = args.get("headers").cloned().unwrap_or(json!({}));
let body = args.get("body").and_then(|v| v.as_str());

View File

@ -94,7 +94,7 @@ impl Tool for MemoryManageTool {
return Ok(error_result(&format!(
"memory '{}.{}' not found",
input.namespace, input.memory_key
)))
)));
}
}
}
@ -108,9 +108,14 @@ impl Tool for MemoryManageTool {
None => return Ok(error_result("Missing required parameter: key")),
};
let deleted = self.store.delete_memory("user", &scope_key, namespace, key)?;
let deleted = self
.store
.delete_memory("user", &scope_key, namespace, key)?;
if !deleted {
return Ok(error_result(&format!("memory '{}.{}' not found", namespace, key)));
return Ok(error_result(&format!(
"memory '{}.{}' not found",
namespace, key
)));
}
json!({

View File

@ -90,7 +90,9 @@ impl Tool for MemorySearchTool {
.get("limit")
.and_then(|value| value.as_u64())
.unwrap_or(10) as usize;
let memories = self.store.list_memories("user", &scope_key, namespace, limit)?;
let memories = self
.store
.list_memories("user", &scope_key, namespace, limit)?;
json!({
"count": memories.len(),
"memories": memories.into_iter().map(memory_to_json).collect::<Vec<_>>()
@ -135,7 +137,12 @@ impl Tool for MemorySearchTool {
match self.store.get_memory("user", &scope_key, namespace, key)? {
Some(memory) => memory_to_json(memory),
None => return Ok(error_result(&format!("memory '{}.{}' not found", namespace, key))),
None => {
return Ok(error_result(&format!(
"memory '{}.{}' not found",
namespace, key
)));
}
}
}
_ => return Ok(error_result("Unsupported action")),

View File

@ -5,9 +5,7 @@ use async_trait::async_trait;
use serde_json::json;
use crate::config::SchedulerSchedule;
use crate::storage::{
SchedulerJobRecord, SchedulerJobState, SchedulerJobUpsert, SessionStore,
};
use crate::storage::{SchedulerJobRecord, SchedulerJobState, SchedulerJobUpsert, SessionStore};
use crate::tools::traits::{Tool, ToolResult};
pub struct SchedulerManageTool {
@ -35,11 +33,7 @@ impl Tool for SchedulerManageTool {
}
fn parameters_schema(&self) -> serde_json::Value {
let mut allowed_agents = self
.known_agents
.iter()
.cloned()
.collect::<Vec<_>>();
let mut allowed_agents = self.known_agents.iter().cloned().collect::<Vec<_>>();
allowed_agents.sort();
let agent_hint = if allowed_agents.is_empty() {
"agent_task payload.agent may be omitted or set to 'default'.".to_string()
@ -225,8 +219,15 @@ fn build_upsert(
startup_delay_secs,
target,
payload,
enabled: args.get("enabled").and_then(|value| value.as_bool()).unwrap_or(true),
state: if args.get("enabled").and_then(|value| value.as_bool()).unwrap_or(true) {
enabled: args
.get("enabled")
.and_then(|value| value.as_bool())
.unwrap_or(true),
state: if args
.get("enabled")
.and_then(|value| value.as_bool())
.unwrap_or(true)
{
SchedulerJobState::Scheduled
} else {
SchedulerJobState::Paused
@ -252,14 +253,28 @@ fn enrich_target_from_context(
};
if !has_non_empty_string(&object, "channel") {
if let Some(channel_name) = context.channel_name.as_ref().filter(|value| !value.trim().is_empty()) {
object.insert("channel".to_string(), serde_json::Value::String(channel_name.clone()));
if let Some(channel_name) = context
.channel_name
.as_ref()
.filter(|value| !value.trim().is_empty())
{
object.insert(
"channel".to_string(),
serde_json::Value::String(channel_name.clone()),
);
}
}
if !has_non_empty_string(&object, "chat_id") {
if let Some(chat_id) = context.chat_id.as_ref().filter(|value| !value.trim().is_empty()) {
object.insert("chat_id".to_string(), serde_json::Value::String(chat_id.clone()));
if let Some(chat_id) = context
.chat_id
.as_ref()
.filter(|value| !value.trim().is_empty())
{
object.insert(
"chat_id".to_string(),
serde_json::Value::String(chat_id.clone()),
);
}
}
@ -274,7 +289,10 @@ fn has_non_empty_string(object: &serde_json::Map<String, serde_json::Value>, fie
.unwrap_or(false)
}
fn validate_agent_task_payload(payload: &serde_json::Value, known_agents: &HashSet<String>) -> anyhow::Result<()> {
fn validate_agent_task_payload(
payload: &serde_json::Value,
known_agents: &HashSet<String>,
) -> anyhow::Result<()> {
let Some(prompt) = payload.get("prompt").and_then(|value| value.as_str()) else {
anyhow::bail!("agent_task payload.prompt is required and must be a string")
};
@ -299,7 +317,8 @@ fn unknown_agent_message(agent_name: &str, known_agents: &HashSet<String>) -> St
configured_agents.sort();
let configured_hint = if configured_agents.is_empty() {
"No named agents are configured; use payload.agent='default' or omit payload.agent.".to_string()
"No named agents are configured; use payload.agent='default' or omit payload.agent."
.to_string()
} else {
format!(
"payload.agent must be omitted, set to 'default', or use one of configured agents: default, {}.",
@ -309,9 +328,7 @@ fn unknown_agent_message(agent_name: &str, known_agents: &HashSet<String>) -> St
format!(
"Unknown agent '{}' for agent_task payload.agent. {} '{}' is not an agent. If you mean a skill, do not put it in payload.agent.",
agent_name,
configured_hint,
agent_name,
agent_name, configured_hint, agent_name,
)
}
@ -517,7 +534,10 @@ mod tests {
.unwrap()
.unwrap();
assert_eq!(saved.kind, "silent_agent_task");
assert_eq!(saved.target["session_chat_id"], "scheduler/agent.daily_summary.background");
assert_eq!(
saved.target["session_chat_id"],
"scheduler/agent.daily_summary.background"
);
}
#[tokio::test]
@ -654,7 +674,9 @@ mod tests {
assert!(!result.success);
assert_eq!(
result.error.as_deref(),
Some("Missing required parameters: scheduler_manage expects a JSON object like {\"action\":\"list\"}")
Some(
"Missing required parameters: scheduler_manage expects a JSON object like {\"action\":\"list\"}"
)
);
}
@ -668,7 +690,9 @@ mod tests {
.as_str()
.unwrap();
assert!(payload_description.contains("avoid repeating schedule phrases or execution times"));
assert!(
payload_description.contains("avoid repeating schedule phrases or execution times")
);
assert!(payload_description.contains("每天9点"));
assert!(payload_description.contains("每小时"));
}

View File

@ -408,7 +408,10 @@ impl SchemaCleanr {
match non_null.len() {
0 => Value::String("null".to_string()),
1 => non_null.into_iter().next().unwrap_or(Value::String("null".to_string())),
1 => non_null
.into_iter()
.next()
.unwrap_or(Value::String("null".to_string())),
_ => Value::Array(non_null),
}
} else {

View File

@ -83,7 +83,11 @@ impl Tool for SkillManageTool {
let scope = match args.get("scope").and_then(|v| v.as_str()) {
Some(value) => match SkillScope::parse(value) {
Some(scope) => scope,
None => return Ok(error_result("scope must be 'project' or 'user'; .agents sources are discovery-only")),
None => {
return Ok(error_result(
"scope must be 'project' or 'user'; .agents sources are discovery-only",
));
}
},
None => SkillScope::Project,
};
@ -91,9 +95,7 @@ impl Tool for SkillManageTool {
let name = args.get("name").and_then(|v| v.as_str());
let result = match action {
"list" => {
list_skills_payload(&self.skills)
}
"list" => list_skills_payload(&self.skills),
"get" => {
let name = match name {
Some(name) => name,
@ -127,7 +129,10 @@ impl Tool for SkillManageTool {
};
let body = args.get("body").and_then(|v| v.as_str()).unwrap_or("");
match self.skills.create_skill(scope, name, description, body, reload) {
match self
.skills
.create_skill(scope, name, description, body, reload)
{
Ok(skill) => json!({
"status": "created",
"name": skill.name,
@ -149,7 +154,10 @@ impl Tool for SkillManageTool {
return Ok(error_result("update requires description or body"));
}
match self.skills.update_skill(scope, name, description, body, reload) {
match self
.skills
.update_skill(scope, name, description, body, reload)
{
Ok(skill) => json!({
"status": "updated",
"name": skill.name,

View File

@ -99,9 +99,7 @@ fn execute_time_request(
.and_then(Value::as_str)
.unwrap_or(default_timezone);
let timezone = timezone_name.parse::<chrono_tz::Tz>().map_err(|_| {
format!(
"Invalid timezone: {timezone_name}. Expected an IANA timezone like Asia/Shanghai"
)
format!("Invalid timezone: {timezone_name}. Expected an IANA timezone like Asia/Shanghai")
})?;
let now_local = now_utc.with_timezone(&timezone);
@ -168,13 +166,14 @@ fn parse_offset_request(args: &Value) -> Result<Option<OffsetRequest>, String> {
let direction = direction.ok_or_else(|| {
"Missing required parameter: direction when requesting a relative time".to_string()
})?;
let amount = amount
.and_then(Value::as_u64)
.ok_or_else(|| "Missing required parameter: amount when requesting a relative time".to_string())?;
let amount = amount.and_then(Value::as_u64).ok_or_else(|| {
"Missing required parameter: amount when requesting a relative time".to_string()
})?;
let amount = u32::try_from(amount)
.map_err(|_| "amount is too large; expected a 32-bit unsigned integer".to_string())?;
let unit = unit
.ok_or_else(|| "Missing required parameter: unit when requesting a relative time".to_string())?;
let unit = unit.ok_or_else(|| {
"Missing required parameter: unit when requesting a relative time".to_string()
})?;
Ok(Some(OffsetRequest {
direction: OffsetDirection::parse(direction)?,
@ -188,10 +187,18 @@ fn apply_offset(
offset: &OffsetRequest,
) -> Result<DateTime<chrono_tz::Tz>, String> {
match (offset.direction, offset.unit) {
(OffsetDirection::Future, TimeUnit::Minute) => Ok(now_local + Duration::minutes(i64::from(offset.amount))),
(OffsetDirection::Past, TimeUnit::Minute) => Ok(now_local - Duration::minutes(i64::from(offset.amount))),
(OffsetDirection::Future, TimeUnit::Hour) => Ok(now_local + Duration::hours(i64::from(offset.amount))),
(OffsetDirection::Past, TimeUnit::Hour) => Ok(now_local - Duration::hours(i64::from(offset.amount))),
(OffsetDirection::Future, TimeUnit::Minute) => {
Ok(now_local + Duration::minutes(i64::from(offset.amount)))
}
(OffsetDirection::Past, TimeUnit::Minute) => {
Ok(now_local - Duration::minutes(i64::from(offset.amount)))
}
(OffsetDirection::Future, TimeUnit::Hour) => {
Ok(now_local + Duration::hours(i64::from(offset.amount)))
}
(OffsetDirection::Past, TimeUnit::Hour) => {
Ok(now_local - Duration::hours(i64::from(offset.amount)))
}
(OffsetDirection::Future, TimeUnit::Day) => now_local
.checked_add_days(Days::new(u64::from(offset.amount)))
.ok_or_else(|| "Failed to add days to the current time".to_string()),
@ -439,8 +446,8 @@ mod tests {
assert!(result.success);
let payload: Value = serde_json::from_str(&result.output).unwrap();
let result_time = chrono::DateTime::parse_from_rfc3339(payload["result_time"].as_str().unwrap())
.unwrap();
let result_time =
chrono::DateTime::parse_from_rfc3339(payload["result_time"].as_str().unwrap()).unwrap();
assert_eq!(result_time.hour(), 12);
assert_eq!(result_time.minute(), 30);
}

View File

@ -239,7 +239,11 @@ fn is_private_host(host: &str) -> bool {
return true;
}
if host.rsplit('.').next().is_some_and(|label| label == "local") {
if host
.rsplit('.')
.next()
.is_some_and(|label| label == "local")
{
return true;
}
@ -248,7 +252,9 @@ fn is_private_host(host: &str) -> bool {
std::net::IpAddr::V4(v4) => {
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified()
}
std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(),
std::net::IpAddr::V6(v6) => {
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
}
};
}

View File

@ -1,6 +1,6 @@
use std::collections::HashMap;
use picobot::providers::{create_provider, ChatCompletionRequest, Message};
use picobot::config::{Config, LLMProviderConfig};
use picobot::providers::{ChatCompletionRequest, Message, create_provider};
use std::collections::HashMap;
fn load_config() -> Option<LLMProviderConfig> {
dotenv::from_filename("tests/test.env").ok()?;
@ -23,11 +23,10 @@ fn load_config() -> Option<LLMProviderConfig> {
model_id: openai_model,
temperature: Some(0.0),
max_tokens: Some(100),
context_window_tokens: None,
model_extra: HashMap::new(),
max_tool_iterations: 20,
token_limit: 128_000,
tool_result_max_chars: 20_000,
context_summary_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
})
}
@ -44,8 +43,7 @@ fn create_request(content: &str) -> ChatCompletionRequest {
#[tokio::test]
#[ignore]
async fn test_openai_simple_completion() {
let config = load_config()
.expect("Please configure tests/test.env with valid API keys");
let config = load_config().expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider");
let response = provider.chat(create_request("Say 'ok'")).await.unwrap();
@ -59,8 +57,7 @@ async fn test_openai_simple_completion() {
#[tokio::test]
#[ignore]
async fn test_openai_conversation() {
let config = load_config()
.expect("Please configure tests/test.env with valid API keys");
let config = load_config().expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider");
@ -84,7 +81,9 @@ async fn test_openai_conversation() {
async fn test_config_load() {
// Test that config.json can be loaded and provider config created
let config = Config::load("config.json").expect("Failed to load config.json");
let provider_config = config.get_provider_config("default").expect("Failed to get provider config");
let provider_config = config
.get_provider_config("default")
.expect("Failed to get provider config");
assert_eq!(provider_config.provider_type, "openai");
assert_eq!(provider_config.name, "aliyun");

View File

@ -1,5 +1,5 @@
use picobot::providers::{ChatCompletionRequest, Message};
use picobot::protocol::{SessionSummary, WsInbound, WsOutbound};
use picobot::providers::{ChatCompletionRequest, Message};
/// Test that message with special characters is properly escaped
#[test]
@ -19,7 +19,9 @@ fn test_message_special_characters() {
#[test]
fn test_multiline_system_prompt() {
let messages = vec![
Message::system("You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate"),
Message::system(
"You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate",
),
Message::user("Hi"),
];
@ -33,10 +35,7 @@ fn test_multiline_system_prompt() {
#[test]
fn test_chat_request_serialization() {
let request = ChatCompletionRequest {
messages: vec![
Message::system("You are helpful"),
Message::user("Hello"),
],
messages: vec![Message::system("You are helpful"), Message::user("Hello")],
temperature: Some(0.7),
max_tokens: Some(100),
tools: None,
@ -136,7 +135,12 @@ fn test_tool_call_outbound_serialization() {
let decoded: WsOutbound = serde_json::from_str(&json).unwrap();
match decoded {
WsOutbound::ToolCall { tool_call_id, tool_name, arguments, .. } => {
WsOutbound::ToolCall {
tool_call_id,
tool_name,
arguments,
..
} => {
assert_eq!(tool_call_id, "call-1");
assert_eq!(tool_name, "calculator");
assert_eq!(arguments["expression"], "1 + 1");
@ -161,7 +165,12 @@ fn test_tool_result_outbound_serialization() {
let decoded: WsOutbound = serde_json::from_str(&json).unwrap();
match decoded {
WsOutbound::ToolResult { tool_call_id, tool_name, content, .. } => {
WsOutbound::ToolResult {
tool_call_id,
tool_name,
content,
..
} => {
assert_eq!(tool_call_id, "call-1");
assert_eq!(tool_name, "calculator");
assert!(content.contains('2'));

View File

@ -1,6 +1,6 @@
use std::collections::HashMap;
use picobot::providers::{create_provider, ChatCompletionRequest, Message, Tool, ToolFunction};
use picobot::config::LLMProviderConfig;
use picobot::providers::{ChatCompletionRequest, Message, Tool, ToolFunction, create_provider};
use std::collections::HashMap;
fn load_openai_config() -> Option<LLMProviderConfig> {
dotenv::from_filename("tests/test.env").ok()?;
@ -23,11 +23,10 @@ fn load_openai_config() -> Option<LLMProviderConfig> {
model_id: openai_model,
temperature: Some(0.0),
max_tokens: Some(100),
context_window_tokens: None,
model_extra: HashMap::new(),
max_tool_iterations: 20,
token_limit: 128_000,
tool_result_max_chars: 20_000,
context_summary_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
})
}
@ -55,8 +54,7 @@ fn make_weather_tool() -> Tool {
#[tokio::test]
#[ignore]
async fn test_openai_tool_call() {
let config = load_openai_config()
.expect("Please configure tests/test.env with valid API keys");
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider");
@ -70,7 +68,11 @@ async fn test_openai_tool_call() {
let response = provider.chat(request).await.unwrap();
// Should have tool calls
assert!(!response.tool_calls.is_empty(), "Expected tool call, got: {}", response.content);
assert!(
!response.tool_calls.is_empty(),
"Expected tool call, got: {}",
response.content
);
let tool_call = &response.tool_calls[0];
assert_eq!(tool_call.name, "get_weather");
@ -80,8 +82,7 @@ async fn test_openai_tool_call() {
#[tokio::test]
#[ignore]
async fn test_openai_tool_call_with_manual_execution() {
let config = load_openai_config()
.expect("Please configure tests/test.env with valid API keys");
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider");
@ -94,8 +95,7 @@ async fn test_openai_tool_call_with_manual_execution() {
};
let response1 = provider.chat(request1).await.unwrap();
let tool_call = response1.tool_calls.first()
.expect("Expected tool call");
let tool_call = response1.tool_calls.first().expect("Expected tool call");
assert_eq!(tool_call.name, "get_weather");
// Second request with tool result
@ -118,8 +118,7 @@ async fn test_openai_tool_call_with_manual_execution() {
#[tokio::test]
#[ignore]
async fn test_openai_no_tool_when_not_provided() {
let config = load_openai_config()
.expect("Please configure tests/test.env with valid API keys");
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
let provider = create_provider(config).expect("Failed to create provider");