Compare commits
4 Commits
bca86abe67
...
fa3354db9c
| Author | SHA1 | Date | |
|---|---|---|---|
| fa3354db9c | |||
| b2c8d76820 | |||
| 33f5a4cbd2 | |||
| 73dab09bfe |
@ -134,7 +134,7 @@ PicoBot 会在 ~/.picobot/agent/AGENT.md 维护一份持久化 Agent 画像文
|
|||||||
1. 系统先对当前活动历史做一个近似 token 估算。
|
1. 系统先对当前活动历史做一个近似 token 估算。
|
||||||
估算规则不是调用 tokenizer,而是按“约每 4 个字符约等于 1 token,并再乘以 1.2 安全系数”计算。
|
估算规则不是调用 tokenizer,而是按“约每 4 个字符约等于 1 token,并再乘以 1.2 安全系数”计算。
|
||||||
2. 当估算结果超过模型上下文窗口的 50% 时,压缩器才认为“需要压缩”。
|
2. 当估算结果超过模型上下文窗口的 50% 时,压缩器才认为“需要压缩”。
|
||||||
这里的上下文窗口来自 agent 对应模型配置里的 token_limit。
|
这里的上下文窗口来自 agent 对应模型配置里的 context_window_tokens;未配置时按 128000 估算。
|
||||||
3. 即使超过阈值,如果当前历史里的 user turn 数量不超过保留阈值,也不会压缩。
|
3. 即使超过阈值,如果当前历史里的 user turn 数量不超过保留阈值,也不会压缩。
|
||||||
当前默认会完整保留最近 3 个 user turn。
|
当前默认会完整保留最近 3 个 user turn。
|
||||||
4. 一旦满足条件,压缩器会先按 user 消息切分 turn,再确定“旧历史”和“最近保留段”的分界点。
|
4. 一旦满足条件,压缩器会先按 user 消息切分 turn,再确定“旧历史”和“最近保留段”的分界点。
|
||||||
|
|||||||
@ -1,16 +1,16 @@
|
|||||||
use async_trait::async_trait;
|
|
||||||
use crate::bus::message::ContentBlock;
|
|
||||||
use crate::bus::ChatMessage;
|
use crate::bus::ChatMessage;
|
||||||
|
use crate::bus::message::ContentBlock;
|
||||||
use crate::bus::message::ToolMessageState;
|
use crate::bus::message::ToolMessageState;
|
||||||
use crate::config::LLMProviderConfig;
|
use crate::config::LLMProviderConfig;
|
||||||
use crate::observability::{
|
use crate::observability::{
|
||||||
truncate_args, Observer, ObserverEvent, ToolExecutionOutcome, ToolExecutionState,
|
Observer, ObserverEvent, ToolExecutionOutcome, ToolExecutionState, truncate_args,
|
||||||
};
|
};
|
||||||
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall};
|
use crate::providers::{ChatCompletionRequest, LLMProvider, Message, ToolCall, create_provider};
|
||||||
use crate::skills::SkillRuntime;
|
use crate::skills::SkillRuntime;
|
||||||
use crate::storage::SessionStore;
|
use crate::storage::SessionStore;
|
||||||
use crate::tools::{ToolContext, ToolRegistry};
|
|
||||||
use crate::text::{char_count, take_prefix_chars, take_suffix_chars};
|
use crate::text::{char_count, take_prefix_chars, take_suffix_chars};
|
||||||
|
use crate::tools::{ToolContext, ToolRegistry};
|
||||||
|
use async_trait::async_trait;
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
use std::hash::{Hash, Hasher};
|
use std::hash::{Hash, Hasher};
|
||||||
use std::io::Read;
|
use std::io::Read;
|
||||||
@ -19,18 +19,13 @@ use std::time::Instant;
|
|||||||
|
|
||||||
/// Minimum characters to keep when truncating
|
/// Minimum characters to keep when truncating
|
||||||
const TRUNCATION_SUFFIX_LEN: usize = 200;
|
const TRUNCATION_SUFFIX_LEN: usize = 200;
|
||||||
const MEMORY_TOOL_USAGE_SYSTEM_PROMPT: &str =
|
const MEMORY_TOOL_USAGE_SYSTEM_PROMPT: &str = include_str!("memory_tool_usage_system_prompt.md");
|
||||||
include_str!("memory_tool_usage_system_prompt.md");
|
|
||||||
const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
|
const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
|
||||||
const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str = "工具已经启动并进入等待用户操作的状态。请先完成外部操作,完成后直接告诉我继续。";
|
const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str =
|
||||||
|
"工具已经启动并进入等待用户操作的状态。请先完成外部操作,完成后直接告诉我继续。";
|
||||||
const RECOVERABLE_LLM_ERROR_MESSAGE: &str = "模型服务暂时不可用或响应超时。请稍后重试。";
|
const RECOVERABLE_LLM_ERROR_MESSAGE: &str = "模型服务暂时不可用或响应超时。请稍后重试。";
|
||||||
|
|
||||||
const SUPPORTED_IMAGE_MIME_TYPES: &[&str] = &[
|
const SUPPORTED_IMAGE_MIME_TYPES: &[&str] = &["image/jpeg", "image/png", "image/gif", "image/webp"];
|
||||||
"image/jpeg",
|
|
||||||
"image/png",
|
|
||||||
"image/gif",
|
|
||||||
"image/webp",
|
|
||||||
];
|
|
||||||
|
|
||||||
/// Build content blocks from text and media paths
|
/// Build content blocks from text and media paths
|
||||||
fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock> {
|
fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock> {
|
||||||
@ -115,14 +110,15 @@ fn truncate_tool_result(output: &str, max_tool_result_chars: usize) -> String {
|
|||||||
let tail = take_suffix_chars(output, total_chars.saturating_sub(truncated_start_len));
|
let tail = take_suffix_chars(output, total_chars.saturating_sub(truncated_start_len));
|
||||||
format!(
|
format!(
|
||||||
"...\n\n[Output truncated - {} characters removed]\n\n{}",
|
"...\n\n[Output truncated - {} characters removed]\n\n{}",
|
||||||
truncated_start_len,
|
truncated_start_len, tail
|
||||||
tail
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_pending_tool_output(output: &str) -> Option<String> {
|
fn parse_pending_tool_output(output: &str) -> Option<String> {
|
||||||
output.strip_prefix(PENDING_USER_ACTION_MARKER).map(|rest| rest.trim().to_string())
|
output
|
||||||
|
.strip_prefix(PENDING_USER_ACTION_MARKER)
|
||||||
|
.map(|rest| rest.trim().to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn normalize_tool_arguments(arguments: &serde_json::Value) -> serde_json::Value {
|
fn normalize_tool_arguments(arguments: &serde_json::Value) -> serde_json::Value {
|
||||||
@ -341,7 +337,10 @@ impl AgentLoop {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc<ToolRegistry>) -> Result<Self, AgentError> {
|
pub fn with_tools(
|
||||||
|
provider_config: LLMProviderConfig,
|
||||||
|
tools: Arc<ToolRegistry>,
|
||||||
|
) -> Result<Self, AgentError> {
|
||||||
let max_iterations = provider_config.max_tool_iterations;
|
let max_iterations = provider_config.max_tool_iterations;
|
||||||
let provider = create_provider(provider_config.clone())
|
let provider = create_provider(provider_config.clone())
|
||||||
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
.map_err(|e| AgentError::ProviderCreation(e.to_string()))?;
|
||||||
@ -416,9 +415,16 @@ impl AgentLoop {
|
|||||||
/// it loops back to the LLM with the tool results until either:
|
/// it loops back to the LLM with the tool results until either:
|
||||||
/// - The LLM returns no more tool calls (final response)
|
/// - The LLM returns no more tool calls (final response)
|
||||||
/// - Maximum iterations are reached
|
/// - Maximum iterations are reached
|
||||||
pub async fn process(&self, mut messages: Vec<ChatMessage>) -> Result<AgentProcessResult, AgentError> {
|
pub async fn process(
|
||||||
|
&self,
|
||||||
|
mut messages: Vec<ChatMessage>,
|
||||||
|
) -> Result<AgentProcessResult, AgentError> {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process");
|
tracing::debug!(
|
||||||
|
history_len = messages.len(),
|
||||||
|
max_iterations = self.max_iterations,
|
||||||
|
"Starting agent process"
|
||||||
|
);
|
||||||
|
|
||||||
// Track tool calls for loop detection
|
// Track tool calls for loop detection
|
||||||
let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default());
|
let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default());
|
||||||
@ -441,7 +447,11 @@ impl AgentLoop {
|
|||||||
if let Some(skill_tool) = self.skills.skill_tool_definition() {
|
if let Some(skill_tool) = self.skills.skill_tool_definition() {
|
||||||
tool_defs.push(skill_tool);
|
tool_defs.push(skill_tool);
|
||||||
}
|
}
|
||||||
let tools = if tool_defs.is_empty() { None } else { Some(tool_defs) };
|
let tools = if tool_defs.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(tool_defs)
|
||||||
|
};
|
||||||
|
|
||||||
let request = ChatCompletionRequest {
|
let request = ChatCompletionRequest {
|
||||||
messages: messages_for_llm,
|
messages: messages_for_llm,
|
||||||
@ -461,7 +471,8 @@ impl AgentLoop {
|
|||||||
error_details = %format_error_chain(e.as_ref()),
|
error_details = %format_error_chain(e.as_ref()),
|
||||||
"LLM request failed"
|
"LLM request failed"
|
||||||
);
|
);
|
||||||
let assistant_message = ChatMessage::assistant(recoverable_llm_message(&e.to_string()));
|
let assistant_message =
|
||||||
|
ChatMessage::assistant(recoverable_llm_message(&e.to_string()));
|
||||||
emitted_messages.push(assistant_message.clone());
|
emitted_messages.push(assistant_message.clone());
|
||||||
return Ok(AgentProcessResult {
|
return Ok(AgentProcessResult {
|
||||||
final_response: assistant_message,
|
final_response: assistant_message,
|
||||||
@ -480,7 +491,8 @@ impl AgentLoop {
|
|||||||
|
|
||||||
// If no tool calls, this is the final response
|
// If no tool calls, this is the final response
|
||||||
if response.tool_calls.is_empty() {
|
if response.tool_calls.is_empty() {
|
||||||
let assistant_message = if let Some(reasoning_content) = response.reasoning_content {
|
let assistant_message = if let Some(reasoning_content) = response.reasoning_content
|
||||||
|
{
|
||||||
ChatMessage::assistant_with_reasoning(response.content, reasoning_content)
|
ChatMessage::assistant_with_reasoning(response.content, reasoning_content)
|
||||||
} else {
|
} else {
|
||||||
ChatMessage::assistant(response.content)
|
ChatMessage::assistant(response.content)
|
||||||
@ -493,24 +505,35 @@ impl AgentLoop {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Execute tool calls
|
// Execute tool calls
|
||||||
tracing::info!(iteration, count = response.tool_calls.len(), "Tool calls detected, executing tools");
|
tracing::info!(
|
||||||
|
iteration,
|
||||||
|
count = response.tool_calls.len(),
|
||||||
|
"Tool calls detected, executing tools"
|
||||||
|
);
|
||||||
|
|
||||||
// Add assistant message with tool calls
|
// Add assistant message with tool calls
|
||||||
let assistant_message = if let Some(reasoning_content) = response.reasoning_content.clone() {
|
let assistant_message =
|
||||||
ChatMessage::assistant_with_tool_calls_and_reasoning(
|
if let Some(reasoning_content) = response.reasoning_content.clone() {
|
||||||
response.content.clone(),
|
ChatMessage::assistant_with_tool_calls_and_reasoning(
|
||||||
response.tool_calls.clone(),
|
response.content.clone(),
|
||||||
reasoning_content,
|
response.tool_calls.clone(),
|
||||||
)
|
reasoning_content,
|
||||||
} else {
|
)
|
||||||
ChatMessage::assistant_with_tool_calls(
|
} else {
|
||||||
response.content.clone(),
|
ChatMessage::assistant_with_tool_calls(
|
||||||
response.tool_calls.clone(),
|
response.content.clone(),
|
||||||
)
|
response.tool_calls.clone(),
|
||||||
};
|
)
|
||||||
|
};
|
||||||
messages.push(assistant_message.clone());
|
messages.push(assistant_message.clone());
|
||||||
emitted_messages.push(assistant_message);
|
emitted_messages.push(assistant_message);
|
||||||
self.emit_live_tool_call_message(emitted_messages.last().expect("assistant message just pushed").clone()).await;
|
self.emit_live_tool_call_message(
|
||||||
|
emitted_messages
|
||||||
|
.last()
|
||||||
|
.expect("assistant message just pushed")
|
||||||
|
.clone(),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
// Execute tools and add results to messages
|
// Execute tools and add results to messages
|
||||||
let tool_results = self.execute_tools(&response.tool_calls).await;
|
let tool_results = self.execute_tools(&response.tool_calls).await;
|
||||||
@ -519,7 +542,9 @@ impl AgentLoop {
|
|||||||
// Log function call with name and arguments
|
// Log function call with name and arguments
|
||||||
let args_str = match &tool_call.arguments {
|
let args_str = match &tool_call.arguments {
|
||||||
serde_json::Value::Object(obj) if obj.is_empty() => "{}".to_string(),
|
serde_json::Value::Object(obj) if obj.is_empty() => "{}".to_string(),
|
||||||
other => serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()),
|
other => {
|
||||||
|
serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string())
|
||||||
|
}
|
||||||
};
|
};
|
||||||
tracing::info!(tool = %tool_call.name, args = %args_str, "Calling tool");
|
tracing::info!(tool = %tool_call.name, args = %args_str, "Calling tool");
|
||||||
|
|
||||||
@ -595,7 +620,11 @@ impl AgentLoop {
|
|||||||
|
|
||||||
// Loop continues to next iteration with updated messages
|
// Loop continues to next iteration with updated messages
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(iteration, message_count = messages.len(), "Tool execution complete, continuing to next iteration");
|
tracing::debug!(
|
||||||
|
iteration,
|
||||||
|
message_count = messages.len(),
|
||||||
|
"Tool execution complete, continuing to next iteration"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Max iterations reached - ask LLM for a summary based on completed work
|
// Max iterations reached - ask LLM for a summary based on completed work
|
||||||
@ -604,7 +633,7 @@ impl AgentLoop {
|
|||||||
// Add a message asking for summary
|
// Add a message asking for summary
|
||||||
let summary_request = ChatMessage::user(
|
let summary_request = ChatMessage::user(
|
||||||
"You have reached the maximum number of tool call iterations. \
|
"You have reached the maximum number of tool call iterations. \
|
||||||
Please provide your best answer based on the work completed so far."
|
Please provide your best answer based on the work completed so far.",
|
||||||
);
|
);
|
||||||
messages.push(summary_request);
|
messages.push(summary_request);
|
||||||
|
|
||||||
@ -624,7 +653,8 @@ impl AgentLoop {
|
|||||||
|
|
||||||
match (*self.provider).chat(request).await {
|
match (*self.provider).chat(request).await {
|
||||||
Ok(response) => {
|
Ok(response) => {
|
||||||
let assistant_message = if let Some(reasoning_content) = response.reasoning_content {
|
let assistant_message = if let Some(reasoning_content) = response.reasoning_content
|
||||||
|
{
|
||||||
ChatMessage::assistant_with_reasoning(response.content, reasoning_content)
|
ChatMessage::assistant_with_reasoning(response.content, reasoning_content)
|
||||||
} else {
|
} else {
|
||||||
ChatMessage::assistant(response.content)
|
ChatMessage::assistant(response.content)
|
||||||
@ -745,10 +775,7 @@ impl AgentLoop {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Apply duration
|
// Apply duration
|
||||||
ToolExecutionOutcome {
|
ToolExecutionOutcome { duration, ..result }
|
||||||
duration,
|
|
||||||
..result
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Internal tool execution without event tracking.
|
/// Internal tool execution without event tracking.
|
||||||
@ -790,10 +817,7 @@ impl AgentLoop {
|
|||||||
"arguments": normalized_arguments,
|
"arguments": normalized_arguments,
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
ToolExecutionOutcome::failure(
|
ToolExecutionOutcome::failure(format!("Error: {}", err), Some(err))
|
||||||
format!("Error: {}", err),
|
|
||||||
Some(err),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -809,7 +833,10 @@ impl AgentLoop {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
match tool.execute_with_context(&self.tool_context, normalized_arguments.clone()).await {
|
match tool
|
||||||
|
.execute_with_context(&self.tool_context, normalized_arguments.clone())
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(result) => {
|
Ok(result) => {
|
||||||
if result.success {
|
if result.success {
|
||||||
if let Some(pending_output) = parse_pending_tool_output(&result.output) {
|
if let Some(pending_output) = parse_pending_tool_output(&result.output) {
|
||||||
@ -827,10 +854,7 @@ impl AgentLoop {
|
|||||||
output = %result.output,
|
output = %result.output,
|
||||||
"Tool returned an error result"
|
"Tool returned an error result"
|
||||||
);
|
);
|
||||||
ToolExecutionOutcome::failure(
|
ToolExecutionOutcome::failure(format!("Error: {}", error), Some(error))
|
||||||
format!("Error: {}", error),
|
|
||||||
Some(error),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@ -842,10 +866,7 @@ impl AgentLoop {
|
|||||||
error_details = %format!("{:#}", e),
|
error_details = %format!("{:#}", e),
|
||||||
"Tool execution failed"
|
"Tool execution failed"
|
||||||
);
|
);
|
||||||
ToolExecutionOutcome::failure(
|
ToolExecutionOutcome::failure(format!("Error: {}", e), Some(e.to_string()))
|
||||||
format!("Error: {}", e),
|
|
||||||
Some(e.to_string()),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -863,7 +884,9 @@ impl AgentLoop {
|
|||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Err(err) = store.append_skill_event(Some(session_id), event_type, skill_name, &payload) {
|
if let Err(err) =
|
||||||
|
store.append_skill_event(Some(session_id), event_type, skill_name, &payload)
|
||||||
|
{
|
||||||
tracing::warn!(error = %err, event_type = %event_type, "Failed to record skill event");
|
tracing::warn!(error = %err, event_type = %event_type, "Failed to record skill event");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -942,28 +965,37 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(provider_message.role, "assistant");
|
assert_eq!(provider_message.role, "assistant");
|
||||||
assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1);
|
assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1);
|
||||||
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].id, "call_1");
|
assert_eq!(
|
||||||
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].name, "calculator");
|
provider_message.tool_calls.as_ref().unwrap()[0].id,
|
||||||
|
"call_1"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
provider_message.tool_calls.as_ref().unwrap()[0].name,
|
||||||
|
"calculator"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_chat_message_to_llm_message_preserves_reasoning_content() {
|
fn test_chat_message_to_llm_message_preserves_reasoning_content() {
|
||||||
let chat_message = ChatMessage::assistant_with_reasoning(
|
let chat_message =
|
||||||
"final answer",
|
ChatMessage::assistant_with_reasoning("final answer", "hidden chain of thought");
|
||||||
"hidden chain of thought",
|
|
||||||
);
|
|
||||||
|
|
||||||
let provider_message = chat_message_to_llm_message(&chat_message);
|
let provider_message = chat_message_to_llm_message(&chat_message);
|
||||||
|
|
||||||
assert_eq!(provider_message.role, "assistant");
|
assert_eq!(provider_message.role, "assistant");
|
||||||
assert_eq!(provider_message.reasoning_content.as_deref(), Some("hidden chain of thought"));
|
assert_eq!(
|
||||||
|
provider_message.reasoning_content.as_deref(),
|
||||||
|
Some("hidden chain of thought")
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_memory_prompt_requires_proactive_memory_search() {
|
fn test_memory_prompt_requires_proactive_memory_search() {
|
||||||
assert!(MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("在绝大多数请求开始时"));
|
assert!(MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("在绝大多数请求开始时"));
|
||||||
assert!(MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("先使用长期记忆检索工具 memory_search"));
|
assert!(MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("先使用长期记忆检索工具 memory_search"));
|
||||||
assert!(MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("不要因为你自认为已经能直接回答就省略检索"));
|
assert!(
|
||||||
|
MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("不要因为你自认为已经能直接回答就省略检索")
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -1001,9 +1033,13 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_normalize_tool_arguments_keeps_plain_string() {
|
fn test_normalize_tool_arguments_keeps_plain_string() {
|
||||||
let normalized = normalize_tool_arguments(&serde_json::Value::String("plain text".to_string()));
|
let normalized =
|
||||||
|
normalize_tool_arguments(&serde_json::Value::String("plain text".to_string()));
|
||||||
|
|
||||||
assert_eq!(normalized, serde_json::Value::String("plain text".to_string()));
|
assert_eq!(
|
||||||
|
normalized,
|
||||||
|
serde_json::Value::String("plain text".to_string())
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -1028,7 +1064,9 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(blocks.len(), 2);
|
assert_eq!(blocks.len(), 2);
|
||||||
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
|
assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "hello"));
|
||||||
assert!(matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,")));
|
assert!(
|
||||||
|
matches!(&blocks[1], ContentBlock::ImageUrl { image_url } if image_url.url.starts_with("data:image/jpeg;base64,"))
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,11 +1,9 @@
|
|||||||
use crate::bus::{
|
use crate::bus::{
|
||||||
ChatMessage,
|
ChatMessage, SYSTEM_CONTEXT_AGENT_PROMPT, SYSTEM_CONTEXT_HISTORY_COMPACTION,
|
||||||
SYSTEM_CONTEXT_AGENT_PROMPT,
|
|
||||||
SYSTEM_CONTEXT_HISTORY_COMPACTION,
|
|
||||||
SYSTEM_CONTEXT_SCHEDULED_PROMPT,
|
SYSTEM_CONTEXT_SCHEDULED_PROMPT,
|
||||||
};
|
};
|
||||||
use crate::config::LLMProviderConfig;
|
use crate::config::LLMProviderConfig;
|
||||||
use crate::providers::{create_provider, ChatCompletionRequest, LLMProvider, Message};
|
use crate::providers::{ChatCompletionRequest, LLMProvider, Message, create_provider};
|
||||||
use crate::text::{char_count, take_prefix_chars};
|
use crate::text::{char_count, take_prefix_chars};
|
||||||
|
|
||||||
use crate::agent::AgentError;
|
use crate::agent::AgentError;
|
||||||
@ -62,6 +60,7 @@ pub struct ContextCompressor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl ContextCompressor {
|
impl ContextCompressor {
|
||||||
|
#[cfg(test)]
|
||||||
fn summary_char_budget_for_context_window(context_window: usize) -> usize {
|
fn summary_char_budget_for_context_window(context_window: usize) -> usize {
|
||||||
const SUMMARY_RATIO: f64 = 0.1;
|
const SUMMARY_RATIO: f64 = 0.1;
|
||||||
const CHARS_PER_TOKEN: f64 = 2.5;
|
const CHARS_PER_TOKEN: f64 = 2.5;
|
||||||
@ -221,7 +220,9 @@ Be concise, aim for {} characters or less.
|
|||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
let per_chunk_target = (target_chars / layer.len().max(1)).max(500).min(target_chars);
|
let per_chunk_target = (target_chars / layer.len().max(1))
|
||||||
|
.max(500)
|
||||||
|
.min(target_chars);
|
||||||
let mut summaries = Vec::with_capacity(layer.len());
|
let mut summaries = Vec::with_capacity(layer.len());
|
||||||
for chunk in &layer {
|
for chunk in &layer {
|
||||||
summaries.push(
|
summaries.push(
|
||||||
@ -241,7 +242,9 @@ Be concise, aim for {} characters or less.
|
|||||||
|
|
||||||
let merged = summaries.join("\n\n");
|
let merged = summaries.join("\n\n");
|
||||||
if char_count(&merged) <= target_chars {
|
if char_count(&merged) <= target_chars {
|
||||||
return self.summarize_transcript(provider, &merged, target_chars).await;
|
return self
|
||||||
|
.summarize_transcript(provider, &merged, target_chars)
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
layer = Self::split_text_chunks(&merged, target_chars);
|
layer = Self::split_text_chunks(&merged, target_chars);
|
||||||
@ -314,7 +317,10 @@ Be concise, aim for {} characters or less.
|
|||||||
|| message.has_system_context(SYSTEM_CONTEXT_SCHEDULED_PROMPT))
|
|| message.has_system_context(SYSTEM_CONTEXT_SCHEDULED_PROMPT))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn split_prefix_messages(&self, history: &[ChatMessage]) -> (Vec<ChatMessage>, Vec<ChatMessage>) {
|
fn split_prefix_messages(
|
||||||
|
&self,
|
||||||
|
history: &[ChatMessage],
|
||||||
|
) -> (Vec<ChatMessage>, Vec<ChatMessage>) {
|
||||||
let preserved_system_messages = history
|
let preserved_system_messages = history
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|message| self.should_preserve_system_message(message))
|
.filter(|message| self.should_preserve_system_message(message))
|
||||||
@ -343,7 +349,8 @@ Be concise, aim for {} characters or less.
|
|||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
|
|
||||||
let preserved_turn_start = turn_ranges[turn_ranges.len() - self.config.retain_last_user_turns].start;
|
let preserved_turn_start =
|
||||||
|
turn_ranges[turn_ranges.len() - self.config.retain_last_user_turns].start;
|
||||||
if preserved_turn_start == 0 {
|
if preserved_turn_start == 0 {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
@ -357,10 +364,10 @@ Be concise, aim for {} characters or less.
|
|||||||
|
|
||||||
Ok(Some(HistoryCompactionPlan {
|
Ok(Some(HistoryCompactionPlan {
|
||||||
preserved_system_messages,
|
preserved_system_messages,
|
||||||
summary_message: ChatMessage::system_with_context(format!(
|
summary_message: ChatMessage::system_with_context(
|
||||||
"[Compressed History]\n\n{}",
|
format!("[Compressed History]\n\n{}", summary),
|
||||||
summary
|
Some(SYSTEM_CONTEXT_HISTORY_COMPACTION.to_string()),
|
||||||
), Some(SYSTEM_CONTEXT_HISTORY_COMPACTION.to_string())),
|
),
|
||||||
preserved_messages: history[preserved_turn_start..].to_vec(),
|
preserved_messages: history[preserved_turn_start..].to_vec(),
|
||||||
compressed_turns: turn_ranges.len() - self.config.retain_last_user_turns,
|
compressed_turns: turn_ranges.len() - self.config.retain_last_user_turns,
|
||||||
preserved_turns: self.config.retain_last_user_turns,
|
preserved_turns: self.config.retain_last_user_turns,
|
||||||
@ -392,7 +399,10 @@ Be concise, aim for {} characters or less.
|
|||||||
"Starting context compression"
|
"Starting context compression"
|
||||||
);
|
);
|
||||||
|
|
||||||
let current_history = match self.build_compaction_plan(&history, provider_config).await? {
|
let current_history = match self
|
||||||
|
.build_compaction_plan(&history, provider_config)
|
||||||
|
.await?
|
||||||
|
{
|
||||||
Some(plan) => {
|
Some(plan) => {
|
||||||
let mut compressed = Vec::with_capacity(
|
let mut compressed = Vec::with_capacity(
|
||||||
plan.preserved_system_messages.len() + plan.preserved_messages.len() + 1,
|
plan.preserved_system_messages.len() + plan.preserved_messages.len() + 1,
|
||||||
@ -429,8 +439,12 @@ Be concise, aim for {} characters or less.
|
|||||||
let transcript = Self::build_transcript(messages);
|
let transcript = Self::build_transcript(messages);
|
||||||
|
|
||||||
let result = if char_count(&transcript) <= self.config.summary_max_chars {
|
let result = if char_count(&transcript) <= self.config.summary_max_chars {
|
||||||
self.summarize_transcript(provider.as_ref(), &transcript, self.config.summary_max_chars)
|
self.summarize_transcript(
|
||||||
.await
|
provider.as_ref(),
|
||||||
|
&transcript,
|
||||||
|
self.config.summary_max_chars,
|
||||||
|
)
|
||||||
|
.await
|
||||||
} else {
|
} else {
|
||||||
self.summarize_chunked_transcript(provider.as_ref(), messages, &transcript)
|
self.summarize_chunked_transcript(provider.as_ref(), messages, &transcript)
|
||||||
.await
|
.await
|
||||||
@ -440,7 +454,10 @@ Be concise, aim for {} characters or less.
|
|||||||
Ok(summary) => Ok(summary),
|
Ok(summary) => Ok(summary),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::warn!(error = %e, "LLM summarization failed, using truncated transcript");
|
tracing::warn!(error = %e, "LLM summarization failed, using truncated transcript");
|
||||||
Ok(take_prefix_chars(&transcript, self.config.summary_max_chars))
|
Ok(take_prefix_chars(
|
||||||
|
&transcript,
|
||||||
|
self.config.summary_max_chars,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -463,7 +480,11 @@ mod tests {
|
|||||||
// "Hi there!" (8) -> ceil(8/4)+4 = 2+4 = 6
|
// "Hi there!" (8) -> ceil(8/4)+4 = 2+4 = 6
|
||||||
// "How are you?" (11) -> ceil(11/4)+4 = 3+4 = 7
|
// "How are you?" (11) -> ceil(11/4)+4 = 3+4 = 7
|
||||||
// raw = 19, with 1.2x = ~23
|
// raw = 19, with 1.2x = ~23
|
||||||
assert!(tokens > 18 && tokens < 30, "Expected ~23 tokens, got {}", tokens);
|
assert!(
|
||||||
|
tokens > 18 && tokens < 30,
|
||||||
|
"Expected ~23 tokens, got {}",
|
||||||
|
tokens
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -487,21 +508,39 @@ mod tests {
|
|||||||
];
|
];
|
||||||
|
|
||||||
let turns = compressor.user_turn_ranges(&history);
|
let turns = compressor.user_turn_ranges(&history);
|
||||||
assert_eq!(turns, vec![
|
assert_eq!(
|
||||||
UserTurnRange { start: 1, end_exclusive: 4 },
|
turns,
|
||||||
UserTurnRange { start: 4, end_exclusive: 6 },
|
vec![
|
||||||
UserTurnRange { start: 6, end_exclusive: 7 },
|
UserTurnRange {
|
||||||
]);
|
start: 1,
|
||||||
|
end_exclusive: 4
|
||||||
|
},
|
||||||
|
UserTurnRange {
|
||||||
|
start: 4,
|
||||||
|
end_exclusive: 6
|
||||||
|
},
|
||||||
|
UserTurnRange {
|
||||||
|
start: 6,
|
||||||
|
end_exclusive: 7
|
||||||
|
},
|
||||||
|
]
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_split_prefix_messages_preserves_key_system_messages() {
|
fn test_split_prefix_messages_preserves_key_system_messages() {
|
||||||
let compressor = ContextCompressor::new(50);
|
let compressor = ContextCompressor::new(50);
|
||||||
let prefix = vec![
|
let prefix = vec![
|
||||||
ChatMessage::system_with_context("agent prompt", Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string())),
|
ChatMessage::system_with_context(
|
||||||
|
"agent prompt",
|
||||||
|
Some(SYSTEM_CONTEXT_AGENT_PROMPT.to_string()),
|
||||||
|
),
|
||||||
ChatMessage::user("u1"),
|
ChatMessage::user("u1"),
|
||||||
ChatMessage::assistant("a1"),
|
ChatMessage::assistant("a1"),
|
||||||
ChatMessage::system_with_context("scheduled prompt", Some(SYSTEM_CONTEXT_SCHEDULED_PROMPT.to_string())),
|
ChatMessage::system_with_context(
|
||||||
|
"scheduled prompt",
|
||||||
|
Some(SYSTEM_CONTEXT_SCHEDULED_PROMPT.to_string()),
|
||||||
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
let (preserved_system_messages, summary_source) = compressor.split_prefix_messages(&prefix);
|
let (preserved_system_messages, summary_source) = compressor.split_prefix_messages(&prefix);
|
||||||
@ -519,10 +558,22 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_summary_char_budget_for_context_window_scales_and_clamps() {
|
fn test_summary_char_budget_for_context_window_scales_and_clamps() {
|
||||||
assert_eq!(ContextCompressor::summary_char_budget_for_context_window(4_096), 1_500);
|
assert_eq!(
|
||||||
assert_eq!(ContextCompressor::summary_char_budget_for_context_window(65_536), 16_384);
|
ContextCompressor::summary_char_budget_for_context_window(4_096),
|
||||||
assert_eq!(ContextCompressor::summary_char_budget_for_context_window(128_000), 32_000);
|
1_500
|
||||||
assert_eq!(ContextCompressor::summary_char_budget_for_context_window(400_000), 50_000);
|
);
|
||||||
|
assert_eq!(
|
||||||
|
ContextCompressor::summary_char_budget_for_context_window(65_536),
|
||||||
|
16_384
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
ContextCompressor::summary_char_budget_for_context_window(128_000),
|
||||||
|
32_000
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
ContextCompressor::summary_char_budget_for_context_window(400_000),
|
||||||
|
50_000
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
pub mod agent_loop;
|
pub mod agent_loop;
|
||||||
pub mod context_compressor;
|
pub mod context_compressor;
|
||||||
|
|
||||||
pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult, EmittedMessageHandler};
|
pub use agent_loop::{AgentError, AgentLoop, AgentProcessResult, EmittedMessageHandler};
|
||||||
pub use context_compressor::ContextCompressor;
|
pub use context_compressor::ContextCompressor;
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
use crate::bus::{MessageBus, OutboundMessage};
|
use crate::bus::{MessageBus, OutboundMessage};
|
||||||
use crate::channels::base::{Channel, ChannelError};
|
use crate::channels::base::{Channel, ChannelError};
|
||||||
@ -22,7 +22,10 @@ impl OutboundDispatcher {
|
|||||||
|
|
||||||
/// Register a channel with the dispatcher
|
/// Register a channel with the dispatcher
|
||||||
pub async fn register_channel(&self, name: &str, channel: Arc<dyn Channel + Send + Sync>) {
|
pub async fn register_channel(&self, name: &str, channel: Arc<dyn Channel + Send + Sync>) {
|
||||||
self.channels.write().await.insert(name.to_string(), channel);
|
self.channels
|
||||||
|
.write()
|
||||||
|
.await
|
||||||
|
.insert(name.to_string(), channel);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Run the dispatcher loop - consumes from bus and dispatches to channels
|
/// Run the dispatcher loop - consumes from bus and dispatches to channels
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use crate::providers::ToolCall;
|
use crate::providers::ToolCall;
|
||||||
|
|
||||||
@ -34,7 +34,9 @@ pub struct ImageUrlBlock {
|
|||||||
|
|
||||||
impl ContentBlock {
|
impl ContentBlock {
|
||||||
pub fn text(content: impl Into<String>) -> Self {
|
pub fn text(content: impl Into<String>) -> Self {
|
||||||
Self::Text { text: content.into() }
|
Self::Text {
|
||||||
|
text: content.into(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn image_url(url: impl Into<String>) -> Self {
|
pub fn image_url(url: impl Into<String>) -> Self {
|
||||||
@ -50,10 +52,10 @@ impl ContentBlock {
|
|||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct MediaItem {
|
pub struct MediaItem {
|
||||||
pub path: String, // Local file path
|
pub path: String, // Local file path
|
||||||
pub media_type: String, // "image", "audio", "file", "video"
|
pub media_type: String, // "image", "audio", "file", "video"
|
||||||
pub mime_type: Option<String>,
|
pub mime_type: Option<String>,
|
||||||
pub original_key: Option<String>, // Feishu file_key for download
|
pub original_key: Option<String>, // Feishu file_key for download
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MediaItem {
|
impl MediaItem {
|
||||||
@ -76,7 +78,7 @@ pub struct ChatMessage {
|
|||||||
pub id: String,
|
pub id: String,
|
||||||
pub role: String,
|
pub role: String,
|
||||||
pub content: String,
|
pub content: String,
|
||||||
pub media_refs: Vec<String>, // Paths to media files for context
|
pub media_refs: Vec<String>, // Paths to media files for context
|
||||||
pub timestamp: i64,
|
pub timestamp: i64,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub system_context: Option<String>,
|
pub system_context: Option<String>,
|
||||||
@ -150,7 +152,10 @@ impl ChatMessage {
|
|||||||
message
|
message
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn assistant_with_tool_calls(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
|
pub fn assistant_with_tool_calls(
|
||||||
|
content: impl Into<String>,
|
||||||
|
tool_calls: Vec<ToolCall>,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
id: uuid::Uuid::new_v4().to_string(),
|
id: uuid::Uuid::new_v4().to_string(),
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
@ -199,8 +204,17 @@ impl ChatMessage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self {
|
pub fn tool(
|
||||||
Self::tool_with_state(tool_call_id, tool_name, content, ToolMessageState::Completed)
|
tool_call_id: impl Into<String>,
|
||||||
|
tool_name: impl Into<String>,
|
||||||
|
content: impl Into<String>,
|
||||||
|
) -> Self {
|
||||||
|
Self::tool_with_state(
|
||||||
|
tool_call_id,
|
||||||
|
tool_name,
|
||||||
|
content,
|
||||||
|
ToolMessageState::Completed,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tool_with_state(
|
pub fn tool_with_state(
|
||||||
@ -287,6 +301,8 @@ pub enum OutboundEventKind {
|
|||||||
ToolCall,
|
ToolCall,
|
||||||
ToolResult,
|
ToolResult,
|
||||||
ToolPending,
|
ToolPending,
|
||||||
|
SchedulerNotification,
|
||||||
|
ErrorNotification,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OutboundMessage {
|
impl OutboundMessage {
|
||||||
@ -316,6 +332,30 @@ impl OutboundMessage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn scheduler_notification(
|
||||||
|
channel: impl Into<String>,
|
||||||
|
chat_id: impl Into<String>,
|
||||||
|
content: impl Into<String>,
|
||||||
|
reply_to: Option<String>,
|
||||||
|
metadata: HashMap<String, String>,
|
||||||
|
) -> Self {
|
||||||
|
let mut message = Self::assistant(channel, chat_id, content, reply_to, metadata);
|
||||||
|
message.event_kind = OutboundEventKind::SchedulerNotification;
|
||||||
|
message
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn error_notification(
|
||||||
|
channel: impl Into<String>,
|
||||||
|
chat_id: impl Into<String>,
|
||||||
|
content: impl Into<String>,
|
||||||
|
reply_to: Option<String>,
|
||||||
|
metadata: HashMap<String, String>,
|
||||||
|
) -> Self {
|
||||||
|
let mut message = Self::assistant(channel, chat_id, content, reply_to, metadata);
|
||||||
|
message.event_kind = OutboundEventKind::ErrorNotification;
|
||||||
|
message
|
||||||
|
}
|
||||||
|
|
||||||
pub fn tool_call(
|
pub fn tool_call(
|
||||||
channel: impl Into<String>,
|
channel: impl Into<String>,
|
||||||
chat_id: impl Into<String>,
|
chat_id: impl Into<String>,
|
||||||
@ -417,20 +457,17 @@ impl OutboundMessage {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
outbound.extend(tool_calls
|
outbound.extend(tool_calls.iter().map(|tool_call| {
|
||||||
.iter()
|
Self::tool_call(
|
||||||
.map(|tool_call| {
|
channel.to_string(),
|
||||||
Self::tool_call(
|
chat_id.to_string(),
|
||||||
channel.to_string(),
|
tool_call.id.clone(),
|
||||||
chat_id.to_string(),
|
tool_call.name.clone(),
|
||||||
tool_call.id.clone(),
|
tool_call.arguments.clone(),
|
||||||
tool_call.name.clone(),
|
reply_to.clone(),
|
||||||
tool_call.arguments.clone(),
|
metadata.clone(),
|
||||||
reply_to.clone(),
|
)
|
||||||
metadata.clone(),
|
}));
|
||||||
)
|
|
||||||
})
|
|
||||||
);
|
|
||||||
outbound
|
outbound
|
||||||
} else {
|
} else {
|
||||||
vec![Self::assistant(
|
vec![Self::assistant(
|
||||||
@ -442,7 +479,11 @@ impl OutboundMessage {
|
|||||||
)]
|
)]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"tool" => match message.tool_state.as_ref().unwrap_or(&ToolMessageState::Completed) {
|
"tool" => match message
|
||||||
|
.tool_state
|
||||||
|
.as_ref()
|
||||||
|
.unwrap_or(&ToolMessageState::Completed)
|
||||||
|
{
|
||||||
ToolMessageState::Completed => vec![Self::tool_result(
|
ToolMessageState::Completed => vec![Self::tool_result(
|
||||||
channel.to_string(),
|
channel.to_string(),
|
||||||
chat_id.to_string(),
|
chat_id.to_string(),
|
||||||
@ -467,7 +508,10 @@ impl OutboundMessage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn format_tool_call_content(tool_name: &str, tool_arguments: &serde_json::Value) -> String {
|
pub(crate) fn format_tool_call_content(
|
||||||
|
tool_name: &str,
|
||||||
|
tool_arguments: &serde_json::Value,
|
||||||
|
) -> String {
|
||||||
match tool_arguments {
|
match tool_arguments {
|
||||||
serde_json::Value::Object(map) if map.is_empty() => tool_name.to_string(),
|
serde_json::Value::Object(map) if map.is_empty() => tool_name.to_string(),
|
||||||
other => format!("{}\nargs: {}", tool_name, format_tool_arguments_json(other)),
|
other => format!("{}\nargs: {}", tool_name, format_tool_arguments_json(other)),
|
||||||
@ -544,21 +588,25 @@ mod tests {
|
|||||||
],
|
],
|
||||||
);
|
);
|
||||||
|
|
||||||
let outbound = OutboundMessage::from_chat_message(
|
let outbound =
|
||||||
"feishu",
|
OutboundMessage::from_chat_message("feishu", "chat-1", None, &HashMap::new(), &message);
|
||||||
"chat-1",
|
|
||||||
None,
|
|
||||||
&HashMap::new(),
|
|
||||||
&message,
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(outbound.len(), 2);
|
assert_eq!(outbound.len(), 2);
|
||||||
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolCall);
|
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolCall);
|
||||||
assert_eq!(outbound[0].tool_name.as_deref(), Some("calculator"));
|
assert_eq!(outbound[0].tool_name.as_deref(), Some("calculator"));
|
||||||
assert_eq!(outbound[0].tool_arguments.as_ref().unwrap()["expression"], "1 + 1");
|
assert_eq!(
|
||||||
assert_eq!(outbound[0].content, "calculator\nargs: {\"expression\":\"1 + 1\"}");
|
outbound[0].tool_arguments.as_ref().unwrap()["expression"],
|
||||||
|
"1 + 1"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
outbound[0].content,
|
||||||
|
"calculator\nargs: {\"expression\":\"1 + 1\"}"
|
||||||
|
);
|
||||||
assert_eq!(outbound[1].tool_name.as_deref(), Some("file_read"));
|
assert_eq!(outbound[1].tool_name.as_deref(), Some("file_read"));
|
||||||
assert_eq!(outbound[1].content, "file_read\nargs: {\"path\":\"README.md\"}");
|
assert_eq!(
|
||||||
|
outbound[1].content,
|
||||||
|
"file_read\nargs: {\"path\":\"README.md\"}"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -572,13 +620,8 @@ mod tests {
|
|||||||
}],
|
}],
|
||||||
);
|
);
|
||||||
|
|
||||||
let outbound = OutboundMessage::from_chat_message(
|
let outbound =
|
||||||
"feishu",
|
OutboundMessage::from_chat_message("feishu", "chat-1", None, &HashMap::new(), &message);
|
||||||
"chat-1",
|
|
||||||
None,
|
|
||||||
&HashMap::new(),
|
|
||||||
&message,
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(outbound.len(), 2);
|
assert_eq!(outbound.len(), 2);
|
||||||
assert_eq!(outbound[0].event_kind, OutboundEventKind::AssistantResponse);
|
assert_eq!(outbound[0].event_kind, OutboundEventKind::AssistantResponse);
|
||||||
@ -591,13 +634,8 @@ mod tests {
|
|||||||
fn test_from_chat_message_includes_tool_result() {
|
fn test_from_chat_message_includes_tool_result() {
|
||||||
let message = ChatMessage::tool("call-9", "calculator", "2");
|
let message = ChatMessage::tool("call-9", "calculator", "2");
|
||||||
|
|
||||||
let outbound = OutboundMessage::from_chat_message(
|
let outbound =
|
||||||
"feishu",
|
OutboundMessage::from_chat_message("feishu", "chat-1", None, &HashMap::new(), &message);
|
||||||
"chat-1",
|
|
||||||
None,
|
|
||||||
&HashMap::new(),
|
|
||||||
&message,
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(outbound.len(), 1);
|
assert_eq!(outbound.len(), 1);
|
||||||
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolResult);
|
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolResult);
|
||||||
@ -612,13 +650,8 @@ mod tests {
|
|||||||
ToolMessageState::PendingUserAction,
|
ToolMessageState::PendingUserAction,
|
||||||
);
|
);
|
||||||
|
|
||||||
let outbound = OutboundMessage::from_chat_message(
|
let outbound =
|
||||||
"feishu",
|
OutboundMessage::from_chat_message("feishu", "chat-1", None, &HashMap::new(), &message);
|
||||||
"chat-1",
|
|
||||||
None,
|
|
||||||
&HashMap::new(),
|
|
||||||
&message,
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(outbound.len(), 1);
|
assert_eq!(outbound.len(), 1);
|
||||||
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolPending);
|
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolPending);
|
||||||
|
|||||||
@ -3,18 +3,13 @@ pub mod message;
|
|||||||
|
|
||||||
pub use dispatcher::OutboundDispatcher;
|
pub use dispatcher::OutboundDispatcher;
|
||||||
pub use message::{
|
pub use message::{
|
||||||
ChatMessage,
|
ChatMessage, ContentBlock, InboundMessage, MediaItem, OutboundMessage,
|
||||||
ContentBlock,
|
SYSTEM_CONTEXT_AGENT_PROMPT, SYSTEM_CONTEXT_HISTORY_COMPACTION,
|
||||||
InboundMessage,
|
|
||||||
MediaItem,
|
|
||||||
OutboundMessage,
|
|
||||||
SYSTEM_CONTEXT_AGENT_PROMPT,
|
|
||||||
SYSTEM_CONTEXT_HISTORY_COMPACTION,
|
|
||||||
SYSTEM_CONTEXT_SCHEDULED_PROMPT,
|
SYSTEM_CONTEXT_SCHEDULED_PROMPT,
|
||||||
};
|
};
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::{mpsc, Mutex};
|
use tokio::sync::{Mutex, mpsc};
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// MessageBus - Async message queue for Channel <-> Agent communication
|
// MessageBus - Async message queue for Channel <-> Agent communication
|
||||||
@ -52,7 +47,8 @@ impl MessageBus {
|
|||||||
|
|
||||||
/// Consume an inbound message (Agent -> Bus)
|
/// Consume an inbound message (Agent -> Bus)
|
||||||
pub async fn consume_inbound(&self) -> InboundMessage {
|
pub async fn consume_inbound(&self) -> InboundMessage {
|
||||||
let msg = self.inbound_rx
|
let msg = self
|
||||||
|
.inbound_rx
|
||||||
.lock()
|
.lock()
|
||||||
.await
|
.await
|
||||||
.recv()
|
.recv()
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -2,7 +2,7 @@ use std::collections::HashMap;
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
use crate::bus::{MessageBus, OutboundMessage};
|
use crate::bus::MessageBus;
|
||||||
use crate::channels::base::{Channel, ChannelError};
|
use crate::channels::base::{Channel, ChannelError};
|
||||||
use crate::channels::feishu::FeishuChannel;
|
use crate::channels::feishu::FeishuChannel;
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
@ -28,12 +28,18 @@ impl ChannelManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Initialize all Channel instances from config
|
/// Initialize all Channel instances from config
|
||||||
pub async fn init(&self, config: &Config, _provider_config: crate::config::LLMProviderConfig) -> Result<(), ChannelError> {
|
pub async fn init(
|
||||||
|
&self,
|
||||||
|
config: &Config,
|
||||||
|
_provider_config: crate::config::LLMProviderConfig,
|
||||||
|
) -> Result<(), ChannelError> {
|
||||||
// Initialize Feishu channel if enabled
|
// Initialize Feishu channel if enabled
|
||||||
if let Some(feishu_config) = config.channels.get("feishu") {
|
if let Some(feishu_config) = config.channels.get("feishu") {
|
||||||
if feishu_config.enabled {
|
if feishu_config.enabled {
|
||||||
let channel = FeishuChannel::new(feishu_config.clone(), _provider_config)
|
let channel =
|
||||||
.map_err(|e| ChannelError::Other(format!("Failed to create Feishu channel: {}", e)))?;
|
FeishuChannel::new(feishu_config.clone(), _provider_config).map_err(|e| {
|
||||||
|
ChannelError::Other(format!("Failed to create Feishu channel: {}", e))
|
||||||
|
})?;
|
||||||
|
|
||||||
self.channels
|
self.channels
|
||||||
.write()
|
.write()
|
||||||
@ -75,13 +81,12 @@ impl ChannelManager {
|
|||||||
self.channels.read().await.get(name).cloned()
|
self.channels.read().await.get(name).cloned()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Dispatch an outbound message to the appropriate channel
|
pub async fn channels(&self) -> Vec<(String, Arc<dyn Channel + Send + Sync>)> {
|
||||||
pub async fn dispatch(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
|
self.channels
|
||||||
let channel_name = &msg.channel;
|
.read()
|
||||||
if let Some(channel) = self.get_channel(channel_name).await {
|
.await
|
||||||
channel.send(msg).await
|
.iter()
|
||||||
} else {
|
.map(|(name, channel)| (name.clone(), channel.clone()))
|
||||||
Err(ChannelError::Other(format!("Channel not found: {}", channel_name)))
|
.collect()
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,5 +3,5 @@ pub mod feishu;
|
|||||||
pub mod manager;
|
pub mod manager;
|
||||||
|
|
||||||
pub use base::{Channel, ChannelError};
|
pub use base::{Channel, ChannelError};
|
||||||
pub use manager::ChannelManager;
|
|
||||||
pub use feishu::FeishuChannel;
|
pub use feishu::FeishuChannel;
|
||||||
|
pub use manager::ChannelManager;
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
use tokio::io::{AsyncBufReadExt, BufReader, AsyncWriteExt};
|
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||||
|
|
||||||
pub struct CliChannel {
|
pub struct CliChannel {
|
||||||
read: BufReader<tokio::io::Stdin>,
|
read: BufReader<tokio::io::Stdin>,
|
||||||
|
|||||||
@ -49,18 +49,27 @@ impl InputHandler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn write_output(&mut self, content: &str) -> Result<(), InputError> {
|
pub async fn write_output(&mut self, content: &str) -> Result<(), InputError> {
|
||||||
self.channel.write_line(content).await.map_err(InputError::IoError)
|
self.channel
|
||||||
|
.write_line(content)
|
||||||
|
.await
|
||||||
|
.map_err(InputError::IoError)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn write_response(&mut self, content: &str) -> Result<(), InputError> {
|
pub async fn write_response(&mut self, content: &str) -> Result<(), InputError> {
|
||||||
self.channel.write_response(content).await.map_err(InputError::IoError)
|
self.channel
|
||||||
|
.write_response(content)
|
||||||
|
.await
|
||||||
|
.map_err(InputError::IoError)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_special_commands(&self, line: &str) -> Option<InputCommand> {
|
fn handle_special_commands(&self, line: &str) -> Option<InputCommand> {
|
||||||
let trimmed = line.trim();
|
let trimmed = line.trim();
|
||||||
let mut parts = trimmed.splitn(2, char::is_whitespace);
|
let mut parts = trimmed.splitn(2, char::is_whitespace);
|
||||||
let command = parts.next()?;
|
let command = parts.next()?;
|
||||||
let arg = parts.next().map(str::trim).filter(|value| !value.is_empty());
|
let arg = parts
|
||||||
|
.next()
|
||||||
|
.map(str::trim)
|
||||||
|
.filter(|value| !value.is_empty());
|
||||||
|
|
||||||
match command {
|
match command {
|
||||||
"/quit" | "/exit" | "/q" => Some(InputCommand::Exit),
|
"/quit" | "/exit" | "/q" => Some(InputCommand::Exit),
|
||||||
@ -105,14 +114,26 @@ mod tests {
|
|||||||
fn test_special_command_parsing() {
|
fn test_special_command_parsing() {
|
||||||
let handler = InputHandler::new();
|
let handler = InputHandler::new();
|
||||||
|
|
||||||
assert_eq!(handler.handle_special_commands("/quit"), Some(InputCommand::Exit));
|
assert_eq!(
|
||||||
assert_eq!(handler.handle_special_commands("/clear"), Some(InputCommand::Clear));
|
handler.handle_special_commands("/quit"),
|
||||||
assert_eq!(handler.handle_special_commands("/new"), Some(InputCommand::New(None)));
|
Some(InputCommand::Exit)
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
handler.handle_special_commands("/clear"),
|
||||||
|
Some(InputCommand::Clear)
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
handler.handle_special_commands("/new"),
|
||||||
|
Some(InputCommand::New(None))
|
||||||
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
handler.handle_special_commands("/new planning"),
|
handler.handle_special_commands("/new planning"),
|
||||||
Some(InputCommand::New(Some("planning".to_string())))
|
Some(InputCommand::New(Some("planning".to_string())))
|
||||||
);
|
);
|
||||||
assert_eq!(handler.handle_special_commands("/sessions"), Some(InputCommand::Sessions));
|
assert_eq!(
|
||||||
|
handler.handle_special_commands("/sessions"),
|
||||||
|
Some(InputCommand::Sessions)
|
||||||
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
handler.handle_special_commands("/use abc123"),
|
handler.handle_special_commands("/use abc123"),
|
||||||
Some(InputCommand::Use("abc123".to_string()))
|
Some(InputCommand::Use("abc123".to_string()))
|
||||||
@ -121,8 +142,14 @@ mod tests {
|
|||||||
handler.handle_special_commands("/rename project alpha"),
|
handler.handle_special_commands("/rename project alpha"),
|
||||||
Some(InputCommand::Rename("project alpha".to_string()))
|
Some(InputCommand::Rename("project alpha".to_string()))
|
||||||
);
|
);
|
||||||
assert_eq!(handler.handle_special_commands("/archive"), Some(InputCommand::Archive));
|
assert_eq!(
|
||||||
assert_eq!(handler.handle_special_commands("/delete"), Some(InputCommand::Delete));
|
handler.handle_special_commands("/archive"),
|
||||||
|
Some(InputCommand::Archive)
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
handler.handle_special_commands("/delete"),
|
||||||
|
Some(InputCommand::Delete)
|
||||||
|
);
|
||||||
assert_eq!(handler.handle_special_commands("/unknown"), None);
|
assert_eq!(handler.handle_special_commands("/unknown"), None);
|
||||||
assert_eq!(handler.handle_special_commands("/use"), None);
|
assert_eq!(handler.handle_special_commands("/use"), None);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,7 +5,10 @@ use tokio_tungstenite::{connect_async, tungstenite::Message};
|
|||||||
|
|
||||||
use crate::cli::{InputCommand, InputEvent, InputHandler};
|
use crate::cli::{InputCommand, InputEvent, InputHandler};
|
||||||
|
|
||||||
fn format_session_list(sessions: &[crate::protocol::SessionSummary], current_session_id: Option<&str>) -> String {
|
fn format_session_list(
|
||||||
|
sessions: &[crate::protocol::SessionSummary],
|
||||||
|
current_session_id: Option<&str>,
|
||||||
|
) -> String {
|
||||||
if sessions.is_empty() {
|
if sessions.is_empty() {
|
||||||
return "No sessions found.".to_string();
|
return "No sessions found.".to_string();
|
||||||
}
|
}
|
||||||
@ -25,11 +28,7 @@ fn format_session_list(sessions: &[crate::protocol::SessionSummary], current_ses
|
|||||||
};
|
};
|
||||||
lines.push(format!(
|
lines.push(format!(
|
||||||
"{} {} | {} | {} messages{}",
|
"{} {} | {} | {} messages{}",
|
||||||
marker,
|
marker, session.session_id, session.title, session.message_count, archived,
|
||||||
session.session_id,
|
|
||||||
session.title,
|
|
||||||
session.message_count,
|
|
||||||
archived,
|
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -123,7 +123,9 @@ fn default_allow_from() -> Vec<String> {
|
|||||||
|
|
||||||
fn default_media_dir() -> String {
|
fn default_media_dir() -> String {
|
||||||
let home = dirs::home_dir().unwrap_or_else(|| std::path::PathBuf::from("."));
|
let home = dirs::home_dir().unwrap_or_else(|| std::path::PathBuf::from("."));
|
||||||
home.join(".picobot/media/feishu").to_string_lossy().to_string()
|
home.join(".picobot/media/feishu")
|
||||||
|
.to_string_lossy()
|
||||||
|
.to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_reaction_emoji() -> String {
|
fn default_reaction_emoji() -> String {
|
||||||
@ -157,6 +159,8 @@ pub struct ModelConfig {
|
|||||||
pub temperature: Option<f32>,
|
pub temperature: Option<f32>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub context_window_tokens: Option<u32>,
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
pub extra: HashMap<String, serde_json::Value>,
|
pub extra: HashMap<String, serde_json::Value>,
|
||||||
}
|
}
|
||||||
@ -199,7 +203,10 @@ pub struct GatewayConfig {
|
|||||||
pub show_tool_results: bool,
|
pub show_tool_results: bool,
|
||||||
#[serde(default, rename = "session_ttl_hours")]
|
#[serde(default, rename = "session_ttl_hours")]
|
||||||
pub session_ttl_hours: Option<u64>,
|
pub session_ttl_hours: Option<u64>,
|
||||||
#[serde(default = "default_agent_prompt_reinject_every", rename = "agent_prompt_reinject_every")]
|
#[serde(
|
||||||
|
default = "default_agent_prompt_reinject_every",
|
||||||
|
rename = "agent_prompt_reinject_every"
|
||||||
|
)]
|
||||||
pub agent_prompt_reinject_every: u64,
|
pub agent_prompt_reinject_every: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -388,7 +395,10 @@ impl SchedulerSchedule {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_one_shot(&self) -> bool {
|
pub fn is_one_shot(&self) -> bool {
|
||||||
matches!(self, SchedulerSchedule::Delay { .. } | SchedulerSchedule::At { .. })
|
matches!(
|
||||||
|
self,
|
||||||
|
SchedulerSchedule::Delay { .. } | SchedulerSchedule::At { .. }
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn normalized_for_storage(&self) -> Self {
|
pub fn normalized_for_storage(&self) -> Self {
|
||||||
@ -518,6 +528,7 @@ pub struct LLMProviderConfig {
|
|||||||
pub model_id: String,
|
pub model_id: String,
|
||||||
pub temperature: Option<f32>,
|
pub temperature: Option<f32>,
|
||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
|
pub context_window_tokens: Option<u32>,
|
||||||
pub model_extra: HashMap<String, serde_json::Value>,
|
pub model_extra: HashMap<String, serde_json::Value>,
|
||||||
pub max_tool_iterations: usize,
|
pub max_tool_iterations: usize,
|
||||||
pub tool_result_max_chars: usize,
|
pub tool_result_max_chars: usize,
|
||||||
@ -526,7 +537,7 @@ pub struct LLMProviderConfig {
|
|||||||
|
|
||||||
impl LLMProviderConfig {
|
impl LLMProviderConfig {
|
||||||
pub fn context_window_tokens(&self) -> usize {
|
pub fn context_window_tokens(&self) -> usize {
|
||||||
self.max_tokens
|
self.context_window_tokens
|
||||||
.map(|value| value as usize)
|
.map(|value| value as usize)
|
||||||
.unwrap_or(128_000)
|
.unwrap_or(128_000)
|
||||||
}
|
}
|
||||||
@ -581,13 +592,19 @@ impl Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_provider_config(&self, agent_name: &str) -> Result<LLMProviderConfig, ConfigError> {
|
pub fn get_provider_config(&self, agent_name: &str) -> Result<LLMProviderConfig, ConfigError> {
|
||||||
let agent = self.agents.get(agent_name)
|
let agent = self
|
||||||
|
.agents
|
||||||
|
.get(agent_name)
|
||||||
.ok_or(ConfigError::AgentNotFound(agent_name.to_string()))?;
|
.ok_or(ConfigError::AgentNotFound(agent_name.to_string()))?;
|
||||||
|
|
||||||
let provider = self.providers.get(&agent.provider)
|
let provider = self
|
||||||
|
.providers
|
||||||
|
.get(&agent.provider)
|
||||||
.ok_or(ConfigError::ProviderNotFound(agent.provider.clone()))?;
|
.ok_or(ConfigError::ProviderNotFound(agent.provider.clone()))?;
|
||||||
|
|
||||||
let model = self.models.get(&agent.model)
|
let model = self
|
||||||
|
.models
|
||||||
|
.get(&agent.model)
|
||||||
.ok_or(ConfigError::ModelNotFound(agent.model.clone()))?;
|
.ok_or(ConfigError::ModelNotFound(agent.model.clone()))?;
|
||||||
|
|
||||||
Ok(LLMProviderConfig {
|
Ok(LLMProviderConfig {
|
||||||
@ -600,6 +617,7 @@ impl Config {
|
|||||||
model_id: model.model_id.clone(),
|
model_id: model.model_id.clone(),
|
||||||
temperature: model.temperature,
|
temperature: model.temperature,
|
||||||
max_tokens: model.max_tokens,
|
max_tokens: model.max_tokens,
|
||||||
|
context_window_tokens: model.context_window_tokens,
|
||||||
model_extra: model.extra.clone(),
|
model_extra: model.extra.clone(),
|
||||||
max_tool_iterations: agent.max_tool_iterations,
|
max_tool_iterations: agent.max_tool_iterations,
|
||||||
tool_result_max_chars: agent.tool_result_max_chars,
|
tool_result_max_chars: agent.tool_result_max_chars,
|
||||||
@ -621,11 +639,17 @@ pub enum ConfigError {
|
|||||||
impl std::fmt::Display for ConfigError {
|
impl std::fmt::Display for ConfigError {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
ConfigError::ConfigNotFound(path) => write!(f, "Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json", path),
|
ConfigError::ConfigNotFound(path) => write!(
|
||||||
|
f,
|
||||||
|
"Config file not found: {}. Use CONFIG_PATH env var or place config in ~/.picobot/config.json",
|
||||||
|
path
|
||||||
|
),
|
||||||
ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name),
|
ConfigError::AgentNotFound(name) => write!(f, "Agent not found: {}", name),
|
||||||
ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name),
|
ConfigError::ProviderNotFound(name) => write!(f, "Provider not found: {}", name),
|
||||||
ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name),
|
ConfigError::ModelNotFound(name) => write!(f, "Model not found: {}", name),
|
||||||
ConfigError::InvalidSchedulerJob(message) => write!(f, "Invalid scheduler job: {}", message),
|
ConfigError::InvalidSchedulerJob(message) => {
|
||||||
|
write!(f, "Invalid scheduler job: {}", message)
|
||||||
|
}
|
||||||
ConfigError::InvalidTimezone(message) => write!(f, "Invalid timezone: {}", message),
|
ConfigError::InvalidTimezone(message) => write!(f, "Invalid timezone: {}", message),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -661,18 +685,19 @@ fn resolve_env_placeholders(content: &str) -> String {
|
|||||||
re.replace_all(content, |caps: ®ex::Captures| {
|
re.replace_all(content, |caps: ®ex::Captures| {
|
||||||
let var_name = &caps[1];
|
let var_name = &caps[1];
|
||||||
env::var(var_name).unwrap_or_else(|_| caps[0].to_string())
|
env::var(var_name).unwrap_or_else(|_| caps[0].to_string())
|
||||||
}).to_string()
|
})
|
||||||
|
.to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
fn write_test_config() -> tempfile::NamedTempFile {
|
fn write_test_config() -> tempfile::NamedTempFile {
|
||||||
let file = tempfile::NamedTempFile::new().unwrap();
|
let file = tempfile::NamedTempFile::new().unwrap();
|
||||||
std::fs::write(
|
std::fs::write(
|
||||||
file.path(),
|
file.path(),
|
||||||
r#"{
|
r#"{
|
||||||
"providers": {
|
"providers": {
|
||||||
"aliyun": {
|
"aliyun": {
|
||||||
"type": "openai",
|
"type": "openai",
|
||||||
@ -708,15 +733,15 @@ mod tests {
|
|||||||
"agent_prompt_reinject_every": 120
|
"agent_prompt_reinject_every": 120
|
||||||
}
|
}
|
||||||
}"#,
|
}"#,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
file
|
file
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_config_load() {
|
fn test_config_load() {
|
||||||
let file = write_test_config();
|
let file = write_test_config();
|
||||||
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||||
|
|
||||||
// Check providers
|
// Check providers
|
||||||
assert!(config.providers.contains_key("volcengine"));
|
assert!(config.providers.contains_key("volcengine"));
|
||||||
@ -876,7 +901,10 @@ mod tests {
|
|||||||
|
|
||||||
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||||
assert_eq!(config.time.timezone, "Asia/Shanghai");
|
assert_eq!(config.time.timezone, "Asia/Shanghai");
|
||||||
assert_eq!(config.time.parse_timezone().unwrap(), chrono_tz::Asia::Shanghai);
|
assert_eq!(
|
||||||
|
config.time.parse_timezone().unwrap(),
|
||||||
|
chrono_tz::Asia::Shanghai
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -983,7 +1011,10 @@ mod tests {
|
|||||||
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||||
assert_eq!(config.agents["default"].max_tool_iterations, 100);
|
assert_eq!(config.agents["default"].max_tool_iterations, 100);
|
||||||
assert_eq!(config.agents["default"].tool_result_max_chars, 20_000);
|
assert_eq!(config.agents["default"].tool_result_max_chars, 20_000);
|
||||||
assert_eq!(config.agents["default"].context_tool_result_trim_chars, 2_000);
|
assert_eq!(
|
||||||
|
config.agents["default"].context_tool_result_trim_chars,
|
||||||
|
2_000
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -1029,7 +1060,44 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_provider_config_summary_budget_scales_with_model_max_tokens() {
|
fn test_provider_config_summary_budget_scales_with_context_window_tokens() {
|
||||||
|
let file = tempfile::NamedTempFile::new().unwrap();
|
||||||
|
std::fs::write(
|
||||||
|
file.path(),
|
||||||
|
r#"{
|
||||||
|
"providers": {
|
||||||
|
"aliyun": {
|
||||||
|
"type": "openai",
|
||||||
|
"base_url": "https://example.invalid/v1",
|
||||||
|
"api_key": "test-key",
|
||||||
|
"extra_headers": {}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"models": {
|
||||||
|
"qwen-plus": {
|
||||||
|
"model_id": "qwen-plus",
|
||||||
|
"context_window_tokens": 4096
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"agents": {
|
||||||
|
"default": {
|
||||||
|
"provider": "aliyun",
|
||||||
|
"model": "qwen-plus"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}"#,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||||
|
let provider_config = config.get_provider_config("default").unwrap();
|
||||||
|
|
||||||
|
assert_eq!(provider_config.context_window_tokens(), 4096);
|
||||||
|
assert_eq!(provider_config.context_summary_char_budget(), 1_500);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_provider_config_max_tokens_does_not_change_context_window() {
|
||||||
let file = tempfile::NamedTempFile::new().unwrap();
|
let file = tempfile::NamedTempFile::new().unwrap();
|
||||||
std::fs::write(
|
std::fs::write(
|
||||||
file.path(),
|
file.path(),
|
||||||
@ -1061,8 +1129,9 @@ mod tests {
|
|||||||
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||||
let provider_config = config.get_provider_config("default").unwrap();
|
let provider_config = config.get_provider_config("default").unwrap();
|
||||||
|
|
||||||
assert_eq!(provider_config.context_window_tokens(), 4096);
|
assert_eq!(provider_config.max_tokens, Some(4096));
|
||||||
assert_eq!(provider_config.context_summary_char_budget(), 1_500);
|
assert_eq!(provider_config.context_window_tokens(), 128_000);
|
||||||
|
assert_eq!(provider_config.context_summary_char_budget(), 32_000);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -1159,7 +1228,10 @@ mod tests {
|
|||||||
assert!(config.scheduler.enabled);
|
assert!(config.scheduler.enabled);
|
||||||
assert_eq!(config.scheduler.tick_resolution_ms, 1_000);
|
assert_eq!(config.scheduler.tick_resolution_ms, 1_000);
|
||||||
assert_eq!(config.scheduler.worker_queue_capacity, 64);
|
assert_eq!(config.scheduler.worker_queue_capacity, 64);
|
||||||
assert_eq!(config.scheduler.misfire_policy, SchedulerMisfirePolicy::Skip);
|
assert_eq!(
|
||||||
|
config.scheduler.misfire_policy,
|
||||||
|
SchedulerMisfirePolicy::Skip
|
||||||
|
);
|
||||||
assert!(config.scheduler.jobs.is_empty());
|
assert!(config.scheduler.jobs.is_empty());
|
||||||
|
|
||||||
let effective_jobs = config.scheduler.effective_jobs(&config.time);
|
let effective_jobs = config.scheduler.effective_jobs(&config.time);
|
||||||
@ -1273,7 +1345,10 @@ mod tests {
|
|||||||
assert!(config.scheduler.enabled);
|
assert!(config.scheduler.enabled);
|
||||||
assert_eq!(config.scheduler.tick_resolution_ms, 500);
|
assert_eq!(config.scheduler.tick_resolution_ms, 500);
|
||||||
assert_eq!(config.scheduler.worker_queue_capacity, 8);
|
assert_eq!(config.scheduler.worker_queue_capacity, 8);
|
||||||
assert_eq!(config.scheduler.misfire_policy, SchedulerMisfirePolicy::CatchUp);
|
assert_eq!(
|
||||||
|
config.scheduler.misfire_policy,
|
||||||
|
SchedulerMisfirePolicy::CatchUp
|
||||||
|
);
|
||||||
assert_eq!(config.scheduler.jobs.len(), 1);
|
assert_eq!(config.scheduler.jobs.len(), 1);
|
||||||
|
|
||||||
let job = &config.scheduler.jobs[0];
|
let job = &config.scheduler.jobs[0];
|
||||||
@ -1284,11 +1359,17 @@ mod tests {
|
|||||||
assert_eq!(job.startup_delay_secs, 5);
|
assert_eq!(job.startup_delay_secs, 5);
|
||||||
assert_eq!(job.target.channel.as_deref(), Some("feishu"));
|
assert_eq!(job.target.channel.as_deref(), Some("feishu"));
|
||||||
assert_eq!(job.target.chat_id.as_deref(), Some("oc_demo"));
|
assert_eq!(job.target.chat_id.as_deref(), Some("oc_demo"));
|
||||||
assert_eq!(job.payload.get("content").and_then(|value| value.as_str()), Some("heartbeat"));
|
assert_eq!(
|
||||||
assert_eq!(job.resolved_schedule().unwrap(), SchedulerSchedule::Interval {
|
job.payload.get("content").and_then(|value| value.as_str()),
|
||||||
seconds: 60,
|
Some("heartbeat")
|
||||||
startup_delay_secs: 5,
|
);
|
||||||
});
|
assert_eq!(
|
||||||
|
job.resolved_schedule().unwrap(),
|
||||||
|
SchedulerSchedule::Interval {
|
||||||
|
seconds: 60,
|
||||||
|
startup_delay_secs: 5,
|
||||||
|
}
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -1362,21 +1443,30 @@ mod tests {
|
|||||||
config.scheduler.jobs[0].resolved_schedule().unwrap(),
|
config.scheduler.jobs[0].resolved_schedule().unwrap(),
|
||||||
SchedulerSchedule::Delay { seconds: 30 }
|
SchedulerSchedule::Delay { seconds: 30 }
|
||||||
);
|
);
|
||||||
assert_eq!(config.scheduler.jobs[0].kind, SchedulerJobKind::InternalEvent);
|
assert_eq!(
|
||||||
|
config.scheduler.jobs[0].kind,
|
||||||
|
SchedulerJobKind::InternalEvent
|
||||||
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
config.scheduler.jobs[1].resolved_schedule().unwrap(),
|
config.scheduler.jobs[1].resolved_schedule().unwrap(),
|
||||||
SchedulerSchedule::At {
|
SchedulerSchedule::At {
|
||||||
timestamp: "2026-04-23T09:00:00+00:00".to_string(),
|
timestamp: "2026-04-23T09:00:00+00:00".to_string(),
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
assert_eq!(config.scheduler.jobs[1].kind, SchedulerJobKind::OutboundMessage);
|
assert_eq!(
|
||||||
|
config.scheduler.jobs[1].kind,
|
||||||
|
SchedulerJobKind::OutboundMessage
|
||||||
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
config.scheduler.jobs[2].resolved_schedule().unwrap(),
|
config.scheduler.jobs[2].resolved_schedule().unwrap(),
|
||||||
SchedulerSchedule::Cron {
|
SchedulerSchedule::Cron {
|
||||||
expression: "0 9 * * *".to_string(),
|
expression: "0 9 * * *".to_string(),
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
assert_eq!(config.scheduler.jobs[2].kind, SchedulerJobKind::InternalEvent);
|
assert_eq!(
|
||||||
|
config.scheduler.jobs[2].kind,
|
||||||
|
SchedulerJobKind::InternalEvent
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -1433,7 +1523,10 @@ mod tests {
|
|||||||
assert_eq!(job.kind, SchedulerJobKind::AgentTask);
|
assert_eq!(job.kind, SchedulerJobKind::AgentTask);
|
||||||
assert_eq!(job.target.channel.as_deref(), Some("feishu"));
|
assert_eq!(job.target.channel.as_deref(), Some("feishu"));
|
||||||
assert_eq!(job.target.chat_id.as_deref(), Some("oc_demo"));
|
assert_eq!(job.target.chat_id.as_deref(), Some("oc_demo"));
|
||||||
assert_eq!(job.payload.get("prompt").and_then(|value| value.as_str()), Some("请总结今天待办"));
|
assert_eq!(
|
||||||
|
job.payload.get("prompt").and_then(|value| value.as_str()),
|
||||||
|
Some("请总结今天待办")
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -1495,29 +1588,40 @@ mod tests {
|
|||||||
job.target.session_chat_id.as_deref(),
|
job.target.session_chat_id.as_deref(),
|
||||||
Some("scheduler/agent.daily_summary.background")
|
Some("scheduler/agent.daily_summary.background")
|
||||||
);
|
);
|
||||||
assert_eq!(job.payload.get("prompt").and_then(|value| value.as_str()), Some("请后台总结今天待办"));
|
assert_eq!(
|
||||||
|
job.payload.get("prompt").and_then(|value| value.as_str()),
|
||||||
|
Some("请后台总结今天待办")
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_scheduler_schedule_validation_rejects_invalid_values() {
|
fn test_scheduler_schedule_validation_rejects_invalid_values() {
|
||||||
assert!(SchedulerSchedule::Delay { seconds: 0 }
|
assert!(
|
||||||
.validate("delay.job")
|
SchedulerSchedule::Delay { seconds: 0 }
|
||||||
.is_err());
|
.validate("delay.job")
|
||||||
assert!(SchedulerSchedule::Interval {
|
.is_err()
|
||||||
seconds: 0,
|
);
|
||||||
startup_delay_secs: 0,
|
assert!(
|
||||||
}
|
SchedulerSchedule::Interval {
|
||||||
.validate("interval.job")
|
seconds: 0,
|
||||||
.is_err());
|
startup_delay_secs: 0,
|
||||||
assert!(SchedulerSchedule::At {
|
}
|
||||||
timestamp: "bad timestamp".to_string(),
|
.validate("interval.job")
|
||||||
}
|
.is_err()
|
||||||
.validate("at.job")
|
);
|
||||||
.is_err());
|
assert!(
|
||||||
assert!(SchedulerSchedule::Cron {
|
SchedulerSchedule::At {
|
||||||
expression: "bad cron".to_string(),
|
timestamp: "bad timestamp".to_string(),
|
||||||
}
|
}
|
||||||
.validate("cron.job")
|
.validate("at.job")
|
||||||
.is_err());
|
.is_err()
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
SchedulerSchedule::Cron {
|
||||||
|
expression: "bad cron".to_string(),
|
||||||
|
}
|
||||||
|
.validate("cron.job")
|
||||||
|
.is_err()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
252
src/gateway/execution.rs
Normal file
252
src/gateway/execution.rs
Normal file
@ -0,0 +1,252 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use crate::agent::{AgentError, AgentProcessResult};
|
||||||
|
use crate::bus::message::ToolMessageState;
|
||||||
|
use crate::bus::{ChatMessage, OutboundMessage};
|
||||||
|
use crate::config::LLMProviderConfig;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
|
use super::session::{Session, schedule_background_history_compaction};
|
||||||
|
|
||||||
|
const SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT: &str = "系统说明:当前输入来自一次已经触发的定时任务执行。你现在需要执行任务内容本身,而不是创建、修改、恢复、暂停或查询新的定时任务。除非当前任务内容明确要求管理调度器,否则不要调用任何定时任务管理工具;像“每小时”、“每天”、“cron”、“定时”等词,只应视为任务背景,不应再解释为新的建任务请求。";
|
||||||
|
|
||||||
|
pub(crate) fn compose_scheduled_task_system_prompt(system_prompt: Option<&str>) -> String {
|
||||||
|
match system_prompt
|
||||||
|
.map(str::trim)
|
||||||
|
.filter(|value| !value.is_empty())
|
||||||
|
{
|
||||||
|
Some(system_prompt) => format!(
|
||||||
|
"{}\n\n任务专属要求:{}",
|
||||||
|
SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT, system_prompt
|
||||||
|
),
|
||||||
|
None => SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn select_provider_config(
|
||||||
|
default_provider_config: &LLMProviderConfig,
|
||||||
|
provider_configs: &HashMap<String, LLMProviderConfig>,
|
||||||
|
agent_name: Option<&str>,
|
||||||
|
) -> Result<LLMProviderConfig, AgentError> {
|
||||||
|
match agent_name.map(str::trim).filter(|value| !value.is_empty()) {
|
||||||
|
None | Some("default") => Ok(default_provider_config.clone()),
|
||||||
|
Some(agent_name) => provider_configs.get(agent_name).cloned().ok_or_else(|| {
|
||||||
|
AgentError::Other(format!("Scheduled agent '{}' not found", agent_name))
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) struct AgentExecutionService {
|
||||||
|
show_tool_results: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) struct FinalizeAgentResultRequest<'a> {
|
||||||
|
pub(crate) channel_name: &'a str,
|
||||||
|
pub(crate) chat_id: &'a str,
|
||||||
|
pub(crate) user_message: &'a ChatMessage,
|
||||||
|
pub(crate) result: AgentProcessResult,
|
||||||
|
pub(crate) metadata: &'a HashMap<String, String>,
|
||||||
|
pub(crate) suppress_live_tool_calls: bool,
|
||||||
|
pub(crate) execution_kind: &'a str,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) struct FinalizedAgentResult {
|
||||||
|
pub(crate) outbound_messages: Vec<OutboundMessage>,
|
||||||
|
pub(crate) should_schedule_compaction: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AgentExecutionService {
|
||||||
|
pub(crate) fn new(show_tool_results: bool) -> Self {
|
||||||
|
Self { show_tool_results }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn finalize_result(
|
||||||
|
&self,
|
||||||
|
session: &mut Session,
|
||||||
|
request: FinalizeAgentResultRequest<'_>,
|
||||||
|
) -> Result<FinalizedAgentResult, AgentError> {
|
||||||
|
if !session.matches_current_user_turn(request.chat_id, request.user_message) {
|
||||||
|
let (latest_user_id, latest_user_preview, compression_in_flight, history_len) =
|
||||||
|
session.stale_result_diagnostics(request.chat_id);
|
||||||
|
tracing::warn!(
|
||||||
|
channel = %request.channel_name,
|
||||||
|
chat_id = %request.chat_id,
|
||||||
|
user_message_id = %request.user_message.id,
|
||||||
|
latest_user_id,
|
||||||
|
latest_user_preview,
|
||||||
|
compression_in_flight,
|
||||||
|
history_len,
|
||||||
|
execution_kind = %request.execution_kind,
|
||||||
|
"Skipping stale agent result because a newer user message is already present"
|
||||||
|
);
|
||||||
|
|
||||||
|
return Ok(FinalizedAgentResult {
|
||||||
|
outbound_messages: Vec::new(),
|
||||||
|
should_schedule_compaction: false,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
session
|
||||||
|
.append_persisted_messages(request.chat_id, request.result.emitted_messages.clone())?;
|
||||||
|
|
||||||
|
let outbound_messages = request
|
||||||
|
.result
|
||||||
|
.emitted_messages
|
||||||
|
.iter()
|
||||||
|
.filter(|message| {
|
||||||
|
(!message.is_assistant_tool_call_message() || !request.suppress_live_tool_calls)
|
||||||
|
&& should_display_message_to_user(self.show_tool_results, message)
|
||||||
|
})
|
||||||
|
.flat_map(|message| {
|
||||||
|
OutboundMessage::from_chat_message(
|
||||||
|
request.channel_name,
|
||||||
|
request.chat_id,
|
||||||
|
None,
|
||||||
|
request.metadata,
|
||||||
|
message,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(FinalizedAgentResult {
|
||||||
|
outbound_messages,
|
||||||
|
should_schedule_compaction: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn finalize_result_and_schedule_compaction(
|
||||||
|
&self,
|
||||||
|
session: Arc<Mutex<Session>>,
|
||||||
|
request: FinalizeAgentResultRequest<'_>,
|
||||||
|
) -> Result<Vec<OutboundMessage>, AgentError> {
|
||||||
|
let channel_name = request.channel_name.to_string();
|
||||||
|
let chat_id = request.chat_id.to_string();
|
||||||
|
let execution_kind = request.execution_kind.to_string();
|
||||||
|
|
||||||
|
let finalized_result = {
|
||||||
|
let mut session_guard = session.lock().await;
|
||||||
|
self.finalize_result(&mut session_guard, request)?
|
||||||
|
};
|
||||||
|
|
||||||
|
if finalized_result.should_schedule_compaction {
|
||||||
|
if let Err(error) =
|
||||||
|
schedule_background_history_compaction(session.clone(), chat_id.clone()).await
|
||||||
|
{
|
||||||
|
tracing::warn!(
|
||||||
|
channel = %channel_name,
|
||||||
|
chat_id = %chat_id,
|
||||||
|
execution_kind = %execution_kind,
|
||||||
|
error = %error,
|
||||||
|
"Failed to schedule background history compaction"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(finalized_result.outbound_messages)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn should_display_message_to_user(
|
||||||
|
show_tool_results: bool,
|
||||||
|
message: &ChatMessage,
|
||||||
|
) -> bool {
|
||||||
|
if message.role != "tool" {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
show_tool_results
|
||||||
|
|| matches!(
|
||||||
|
message
|
||||||
|
.tool_state
|
||||||
|
.as_ref()
|
||||||
|
.unwrap_or(&ToolMessageState::Completed),
|
||||||
|
ToolMessageState::PendingUserAction
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::bus::ChatMessage;
|
||||||
|
|
||||||
|
fn test_provider_config_named(name: &str, model_id: &str) -> LLMProviderConfig {
|
||||||
|
LLMProviderConfig {
|
||||||
|
provider_type: "openai".to_string(),
|
||||||
|
name: name.to_string(),
|
||||||
|
base_url: "http://localhost".to_string(),
|
||||||
|
api_key: "test-key".to_string(),
|
||||||
|
extra_headers: HashMap::new(),
|
||||||
|
llm_timeout_secs: 120,
|
||||||
|
model_id: model_id.to_string(),
|
||||||
|
temperature: Some(0.0),
|
||||||
|
max_tokens: Some(32),
|
||||||
|
context_window_tokens: None,
|
||||||
|
model_extra: HashMap::new(),
|
||||||
|
max_tool_iterations: 1,
|
||||||
|
tool_result_max_chars: 20_000,
|
||||||
|
context_tool_result_trim_chars: 20_000,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_select_provider_config_uses_named_agent_override() {
|
||||||
|
let default_provider = test_provider_config_named("default-provider", "default-model");
|
||||||
|
let provider_configs = HashMap::from([(
|
||||||
|
"planner".to_string(),
|
||||||
|
test_provider_config_named("planner-provider", "planner-model"),
|
||||||
|
)]);
|
||||||
|
|
||||||
|
let selected =
|
||||||
|
select_provider_config(&default_provider, &provider_configs, Some("planner")).unwrap();
|
||||||
|
assert_eq!(selected.name, "planner-provider");
|
||||||
|
assert_eq!(selected.model_id, "planner-model");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_select_provider_config_falls_back_to_default() {
|
||||||
|
let default_provider = test_provider_config_named("default-provider", "default-model");
|
||||||
|
let provider_configs = HashMap::new();
|
||||||
|
|
||||||
|
let selected =
|
||||||
|
select_provider_config(&default_provider, &provider_configs, Some("default")).unwrap();
|
||||||
|
assert_eq!(selected.name, "default-provider");
|
||||||
|
assert_eq!(selected.model_id, "default-model");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_compose_scheduled_task_system_prompt_appends_task_specific_prompt() {
|
||||||
|
let prompt = compose_scheduled_task_system_prompt(Some(" 只汇报异常 "));
|
||||||
|
|
||||||
|
assert!(prompt.contains("当前输入来自一次已经触发的定时任务执行"));
|
||||||
|
assert!(prompt.contains("任务专属要求:只汇报异常"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_compose_scheduled_task_system_prompt_ignores_blank_override() {
|
||||||
|
let prompt = compose_scheduled_task_system_prompt(Some(" "));
|
||||||
|
|
||||||
|
assert!(prompt.contains("当前输入来自一次已经触发的定时任务执行"));
|
||||||
|
assert!(!prompt.contains("任务专属要求"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_should_display_message_to_user_keeps_pending_tool_action_visible() {
|
||||||
|
let message = ChatMessage::tool_with_state(
|
||||||
|
"call-1",
|
||||||
|
"approval",
|
||||||
|
"需要用户确认",
|
||||||
|
ToolMessageState::PendingUserAction,
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(should_display_message_to_user(false, &message));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_should_display_message_to_user_hides_completed_tool_when_disabled() {
|
||||||
|
let message = ChatMessage::tool("call-1", "calculator", "2");
|
||||||
|
|
||||||
|
assert!(!should_display_message_to_user(false, &message));
|
||||||
|
assert!(should_display_message_to_user(true, &message));
|
||||||
|
}
|
||||||
|
}
|
||||||
517
src/gateway/memory_maintenance.rs
Normal file
517
src/gateway/memory_maintenance.rs
Normal file
@ -0,0 +1,517 @@
|
|||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::agent::AgentError;
|
||||||
|
use crate::config::LLMProviderConfig;
|
||||||
|
use crate::providers::{ChatCompletionRequest, Message, create_provider};
|
||||||
|
use crate::storage::{MemoryRecord, SessionStore};
|
||||||
|
|
||||||
|
use super::prompt::upsert_managed_agent_memory_summary;
|
||||||
|
|
||||||
|
const MEMORY_MAINTENANCE_SYSTEM_PROMPT: &str = include_str!("memory_maintenance_system_prompt.md");
|
||||||
|
const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000];
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
enum MemoryMaintenanceCategory {
|
||||||
|
UserFacts,
|
||||||
|
Preferences,
|
||||||
|
BehaviorPatterns,
|
||||||
|
Other,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
|
pub(crate) struct MemoryMaintenanceCandidate {
|
||||||
|
pub(crate) id: String,
|
||||||
|
pub(crate) namespace: String,
|
||||||
|
pub(crate) key: String,
|
||||||
|
pub(crate) content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
|
pub(crate) struct MemoryMaintenancePlan {
|
||||||
|
pub(crate) user_facts: Vec<MemoryMaintenanceCandidate>,
|
||||||
|
pub(crate) preferences: Vec<MemoryMaintenanceCandidate>,
|
||||||
|
pub(crate) behavior_patterns: Vec<MemoryMaintenanceCandidate>,
|
||||||
|
pub(crate) others: Vec<MemoryMaintenanceCandidate>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
|
pub(crate) struct MemoryMaintenanceMerge {
|
||||||
|
pub(crate) source_ids: Vec<String>,
|
||||||
|
pub(crate) namespace: String,
|
||||||
|
pub(crate) memory_key: String,
|
||||||
|
pub(crate) content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
|
pub(crate) struct MemoryMaintenanceConflict {
|
||||||
|
pub(crate) source_ids: Vec<String>,
|
||||||
|
pub(crate) note: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
|
pub(crate) struct MemoryMaintenanceModelOutput {
|
||||||
|
pub(crate) user_facts: Vec<String>,
|
||||||
|
pub(crate) preferences: Vec<String>,
|
||||||
|
pub(crate) behavior_patterns: Vec<String>,
|
||||||
|
pub(crate) merges: Vec<MemoryMaintenanceMerge>,
|
||||||
|
pub(crate) conflicts: Vec<MemoryMaintenanceConflict>,
|
||||||
|
pub(crate) low_value_ids: Vec<String>,
|
||||||
|
pub(crate) managed_markdown: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub(crate) struct MemoryMaintenanceScopeResult {
|
||||||
|
pub(crate) scope_key: String,
|
||||||
|
pub(crate) output: MemoryMaintenanceModelOutput,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) struct MemoryMaintenanceService {
|
||||||
|
store: Arc<SessionStore>,
|
||||||
|
provider_config: LLMProviderConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MemoryMaintenanceService {
|
||||||
|
pub(crate) fn new(store: Arc<SessionStore>, provider_config: LLMProviderConfig) -> Self {
|
||||||
|
Self {
|
||||||
|
store,
|
||||||
|
provider_config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn build_plan_for_scope(
|
||||||
|
&self,
|
||||||
|
scope_key: &str,
|
||||||
|
) -> Result<Option<MemoryMaintenancePlan>, AgentError> {
|
||||||
|
let memories = self
|
||||||
|
.store
|
||||||
|
.list_memories_for_scope("user", scope_key)
|
||||||
|
.map_err(|err| AgentError::Other(format!("list memories for scope error: {}", err)))?;
|
||||||
|
|
||||||
|
if memories.is_empty() {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Some(build_memory_maintenance_plan(&memories)))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn summarize_for_scope(
|
||||||
|
&self,
|
||||||
|
scope_key: &str,
|
||||||
|
) -> Result<Option<MemoryMaintenanceModelOutput>, AgentError> {
|
||||||
|
let Some(plan) = self.build_plan_for_scope(scope_key)? else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
|
||||||
|
self.summarize_plan(scope_key, &plan).await.map(Some)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn summarize_plan(
|
||||||
|
&self,
|
||||||
|
scope_key: &str,
|
||||||
|
plan: &MemoryMaintenancePlan,
|
||||||
|
) -> Result<MemoryMaintenanceModelOutput, AgentError> {
|
||||||
|
let provider = create_provider(self.provider_config.clone()).map_err(|err| {
|
||||||
|
AgentError::Other(format!("create maintenance provider error: {}", err))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let request = ChatCompletionRequest {
|
||||||
|
messages: vec![
|
||||||
|
Message::system(MEMORY_MAINTENANCE_SYSTEM_PROMPT),
|
||||||
|
Message::user(
|
||||||
|
serde_json::to_string_pretty(&serde_json::json!({
|
||||||
|
"scope_key": scope_key,
|
||||||
|
"candidates": plan,
|
||||||
|
}))
|
||||||
|
.unwrap_or_else(|_| "{}".to_string()),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
temperature: Some(0.0),
|
||||||
|
max_tokens: Some(1200),
|
||||||
|
tools: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut last_error = None;
|
||||||
|
let mut response = None;
|
||||||
|
|
||||||
|
for (attempt, delay_ms) in MEMORY_MAINTENANCE_RETRY_DELAYS_MS
|
||||||
|
.iter()
|
||||||
|
.copied()
|
||||||
|
.map(Some)
|
||||||
|
.chain(std::iter::once(None))
|
||||||
|
.enumerate()
|
||||||
|
{
|
||||||
|
match provider.chat(request.clone()).await {
|
||||||
|
Ok(success) => {
|
||||||
|
response = Some(success);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
let error_text = err.to_string();
|
||||||
|
let should_retry =
|
||||||
|
delay_ms.is_some() && is_recoverable_maintenance_llm_error(&error_text);
|
||||||
|
last_error = Some(error_text.clone());
|
||||||
|
|
||||||
|
if should_retry {
|
||||||
|
tracing::warn!(
|
||||||
|
scope_key = %scope_key,
|
||||||
|
attempt = attempt + 1,
|
||||||
|
retry_in_ms = delay_ms.unwrap_or_default(),
|
||||||
|
error = %error_text,
|
||||||
|
"Memory maintenance model request failed, retrying"
|
||||||
|
);
|
||||||
|
tokio::time::sleep(Duration::from_millis(delay_ms.unwrap_or_default()))
|
||||||
|
.await;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Err(AgentError::Other(format!(
|
||||||
|
"memory maintenance model error: {}",
|
||||||
|
error_text
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = response.ok_or_else(|| {
|
||||||
|
AgentError::Other(format!(
|
||||||
|
"memory maintenance model error: {}",
|
||||||
|
last_error.unwrap_or_else(|| "unknown provider error".to_string())
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let raw_content = strip_json_code_fence(&response.content);
|
||||||
|
let json_candidate = extract_json_object(raw_content).unwrap_or(raw_content);
|
||||||
|
|
||||||
|
let output: MemoryMaintenanceModelOutput =
|
||||||
|
serde_json::from_str(json_candidate).map_err(|err| {
|
||||||
|
tracing::error!(
|
||||||
|
scope_key = %scope_key,
|
||||||
|
error = %err,
|
||||||
|
raw_len = raw_content.len(),
|
||||||
|
raw_preview = %preview_text(raw_content, 400),
|
||||||
|
json_candidate_len = json_candidate.len(),
|
||||||
|
json_candidate_preview = %preview_text(json_candidate, 400),
|
||||||
|
"Memory maintenance JSON decode failed"
|
||||||
|
);
|
||||||
|
AgentError::Other(format!("memory maintenance JSON decode error: {}", err))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(output)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn run_for_scope(
|
||||||
|
&self,
|
||||||
|
scope_key: &str,
|
||||||
|
) -> Result<Option<MemoryMaintenanceModelOutput>, AgentError> {
|
||||||
|
let Some(plan) = self.build_plan_for_scope(scope_key)? else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
|
||||||
|
let output = self.summarize_plan(scope_key, &plan).await?;
|
||||||
|
apply_memory_maintenance_output(self.store.as_ref(), scope_key, &plan, &output)?;
|
||||||
|
|
||||||
|
Ok(Some(output))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn run_for_all_scopes(
|
||||||
|
&self,
|
||||||
|
updated_since: Option<i64>,
|
||||||
|
) -> Result<Vec<MemoryMaintenanceScopeResult>, AgentError> {
|
||||||
|
let scope_keys = if let Some(cutoff) = updated_since {
|
||||||
|
self.store
|
||||||
|
.list_memory_scope_keys_updated_since("user", cutoff)
|
||||||
|
.map_err(|err| {
|
||||||
|
AgentError::Other(format!(
|
||||||
|
"list memory scope keys updated since error: {}",
|
||||||
|
err
|
||||||
|
))
|
||||||
|
})?
|
||||||
|
} else {
|
||||||
|
self.store.list_memory_scope_keys("user").map_err(|err| {
|
||||||
|
AgentError::Other(format!("list memory scope keys error: {}", err))
|
||||||
|
})?
|
||||||
|
};
|
||||||
|
let mut results = Vec::new();
|
||||||
|
|
||||||
|
for scope_key in scope_keys {
|
||||||
|
let Some(output) = self.run_for_scope(&scope_key).await? else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
results.push(MemoryMaintenanceScopeResult { scope_key, output });
|
||||||
|
}
|
||||||
|
|
||||||
|
let combined_markdown = combine_managed_memory_markdown(
|
||||||
|
&results
|
||||||
|
.iter()
|
||||||
|
.map(|result| result.output.managed_markdown.clone())
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
);
|
||||||
|
|
||||||
|
if !combined_markdown.is_empty() {
|
||||||
|
upsert_managed_agent_memory_summary(&combined_markdown)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(results)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn build_memory_maintenance_plan(memories: &[MemoryRecord]) -> MemoryMaintenancePlan {
|
||||||
|
let mut plan = MemoryMaintenancePlan::default();
|
||||||
|
let mut seen = HashSet::new();
|
||||||
|
|
||||||
|
for memory in memories {
|
||||||
|
let normalized_content = memory.content.trim();
|
||||||
|
if normalized_content.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let dedupe_key = format!(
|
||||||
|
"{}\u{1f}{}\u{1f}{}",
|
||||||
|
memory.namespace.trim().to_ascii_lowercase(),
|
||||||
|
memory.memory_key.trim().to_ascii_lowercase(),
|
||||||
|
normalized_content
|
||||||
|
);
|
||||||
|
if !seen.insert(dedupe_key) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let candidate = MemoryMaintenanceCandidate {
|
||||||
|
id: memory.id.clone(),
|
||||||
|
namespace: memory.namespace.clone(),
|
||||||
|
key: memory.memory_key.clone(),
|
||||||
|
content: normalized_content.to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
match memory_maintenance_category(&memory.namespace) {
|
||||||
|
MemoryMaintenanceCategory::UserFacts => plan.user_facts.push(candidate),
|
||||||
|
MemoryMaintenanceCategory::Preferences => plan.preferences.push(candidate),
|
||||||
|
MemoryMaintenanceCategory::BehaviorPatterns => plan.behavior_patterns.push(candidate),
|
||||||
|
MemoryMaintenanceCategory::Other => plan.others.push(candidate),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
plan
|
||||||
|
}
|
||||||
|
|
||||||
|
fn memory_maintenance_category(namespace: &str) -> MemoryMaintenanceCategory {
|
||||||
|
match namespace.trim().to_ascii_lowercase().as_str() {
|
||||||
|
"profile" | "facts" | "identity" => MemoryMaintenanceCategory::UserFacts,
|
||||||
|
"preferences" | "style" | "likes" => MemoryMaintenanceCategory::Preferences,
|
||||||
|
"patterns" | "behavior" | "habits" | "workflow" => {
|
||||||
|
MemoryMaintenanceCategory::BehaviorPatterns
|
||||||
|
}
|
||||||
|
_ => MemoryMaintenanceCategory::Other,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn is_recoverable_maintenance_llm_error(error: &str) -> bool {
|
||||||
|
let normalized = error.to_ascii_lowercase();
|
||||||
|
normalized.contains("error sending request for url")
|
||||||
|
|| normalized.contains("504")
|
||||||
|
|| normalized.contains("gateway timeout")
|
||||||
|
|| normalized.contains("stream timeout")
|
||||||
|
|| normalized.contains("timed out")
|
||||||
|
|| normalized.contains("timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn strip_json_code_fence(content: &str) -> &str {
|
||||||
|
let trimmed = content.trim();
|
||||||
|
if let Some(rest) = trimmed.strip_prefix("```json") {
|
||||||
|
return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed);
|
||||||
|
}
|
||||||
|
if let Some(rest) = trimmed.strip_prefix("```") {
|
||||||
|
return rest.strip_suffix("```").map(str::trim).unwrap_or(trimmed);
|
||||||
|
}
|
||||||
|
trimmed
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn extract_json_object(content: &str) -> Option<&str> {
|
||||||
|
let mut start = None;
|
||||||
|
let mut depth = 0usize;
|
||||||
|
let mut in_string = false;
|
||||||
|
let mut escaped = false;
|
||||||
|
|
||||||
|
for (index, ch) in content.char_indices() {
|
||||||
|
if in_string {
|
||||||
|
if escaped {
|
||||||
|
escaped = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
match ch {
|
||||||
|
'\\' => escaped = true,
|
||||||
|
'"' => in_string = false,
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
match ch {
|
||||||
|
'"' => in_string = true,
|
||||||
|
'{' => {
|
||||||
|
if start.is_none() {
|
||||||
|
start = Some(index);
|
||||||
|
}
|
||||||
|
depth += 1;
|
||||||
|
}
|
||||||
|
'}' => {
|
||||||
|
if depth == 0 {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
depth -= 1;
|
||||||
|
if depth == 0 {
|
||||||
|
let start = start?;
|
||||||
|
let end = index + ch.len_utf8();
|
||||||
|
return Some(content[start..end].trim());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn combine_managed_memory_markdown(chunks: &[String]) -> String {
|
||||||
|
let normalized_chunks = chunks
|
||||||
|
.iter()
|
||||||
|
.map(|chunk| chunk.trim())
|
||||||
|
.filter(|chunk| !chunk.is_empty())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let mut combined = Vec::new();
|
||||||
|
for (index, chunk) in normalized_chunks.iter().enumerate() {
|
||||||
|
let chunk_lines = chunk
|
||||||
|
.lines()
|
||||||
|
.map(str::trim)
|
||||||
|
.filter(|line| !line.is_empty())
|
||||||
|
.collect::<HashSet<_>>();
|
||||||
|
|
||||||
|
let is_subset_of_other =
|
||||||
|
normalized_chunks
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.any(|(other_index, other)| {
|
||||||
|
if index == other_index {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
let other_lines = other
|
||||||
|
.lines()
|
||||||
|
.map(str::trim)
|
||||||
|
.filter(|line| !line.is_empty())
|
||||||
|
.collect::<HashSet<_>>();
|
||||||
|
|
||||||
|
chunk_lines.len() < other_lines.len() && chunk_lines.is_subset(&other_lines)
|
||||||
|
});
|
||||||
|
|
||||||
|
if !is_subset_of_other && !combined.iter().any(|existing: &String| existing == chunk) {
|
||||||
|
combined.push((*chunk).to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
combined.join("\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn apply_memory_maintenance_output(
|
||||||
|
store: &SessionStore,
|
||||||
|
scope_key: &str,
|
||||||
|
plan: &MemoryMaintenancePlan,
|
||||||
|
output: &MemoryMaintenanceModelOutput,
|
||||||
|
) -> Result<(), AgentError> {
|
||||||
|
let all_candidates = plan
|
||||||
|
.user_facts
|
||||||
|
.iter()
|
||||||
|
.chain(plan.preferences.iter())
|
||||||
|
.chain(plan.behavior_patterns.iter())
|
||||||
|
.chain(plan.others.iter())
|
||||||
|
.cloned()
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let candidates_by_id = all_candidates
|
||||||
|
.iter()
|
||||||
|
.map(|candidate| (candidate.id.as_str(), candidate))
|
||||||
|
.collect::<HashMap<_, _>>();
|
||||||
|
|
||||||
|
let mut deleted_ids = HashSet::new();
|
||||||
|
|
||||||
|
for merge in &output.merges {
|
||||||
|
if merge.source_ids.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let source_candidates = merge
|
||||||
|
.source_ids
|
||||||
|
.iter()
|
||||||
|
.filter_map(|id| candidates_by_id.get(id.as_str()).copied())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
if source_candidates.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let existing_target_id = source_candidates
|
||||||
|
.iter()
|
||||||
|
.find(|candidate| {
|
||||||
|
candidate.namespace == merge.namespace && candidate.key == merge.memory_key
|
||||||
|
})
|
||||||
|
.map(|candidate| candidate.id.clone());
|
||||||
|
|
||||||
|
store
|
||||||
|
.put_memory(&crate::storage::MemoryUpsert {
|
||||||
|
scope_kind: "user".to_string(),
|
||||||
|
scope_key: scope_key.to_string(),
|
||||||
|
namespace: merge.namespace.trim().to_string(),
|
||||||
|
memory_key: merge.memory_key.trim().to_string(),
|
||||||
|
content: merge.content.trim().to_string(),
|
||||||
|
source_type: "memory_maintenance".to_string(),
|
||||||
|
source_session_id: None,
|
||||||
|
source_message_id: None,
|
||||||
|
source_message_seq: None,
|
||||||
|
source_channel_name: None,
|
||||||
|
source_chat_id: None,
|
||||||
|
})
|
||||||
|
.map_err(|err| AgentError::Other(format!("upsert merged memory error: {}", err)))?;
|
||||||
|
|
||||||
|
for candidate in source_candidates {
|
||||||
|
if existing_target_id
|
||||||
|
.as_ref()
|
||||||
|
.is_some_and(|target_id| target_id == &candidate.id)
|
||||||
|
{
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if deleted_ids.insert(candidate.id.clone()) {
|
||||||
|
store
|
||||||
|
.delete_memory("user", scope_key, &candidate.namespace, &candidate.key)
|
||||||
|
.map_err(|err| {
|
||||||
|
AgentError::Other(format!("delete merged source memory error: {}", err))
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for memory_id in &output.low_value_ids {
|
||||||
|
if let Some(candidate) = candidates_by_id.get(memory_id.as_str()) {
|
||||||
|
if deleted_ids.insert(candidate.id.clone()) {
|
||||||
|
store
|
||||||
|
.delete_memory("user", scope_key, &candidate.namespace, &candidate.key)
|
||||||
|
.map_err(|err| {
|
||||||
|
AgentError::Other(format!("delete low value memory error: {}", err))
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn preview_text(content: &str, max_chars: usize) -> String {
|
||||||
|
let mut preview = content.chars().take(max_chars).collect::<String>();
|
||||||
|
if content.chars().count() > max_chars {
|
||||||
|
preview.push_str("...");
|
||||||
|
}
|
||||||
|
preview.replace('\n', "\\n")
|
||||||
|
}
|
||||||
@ -1,10 +1,14 @@
|
|||||||
|
pub mod execution;
|
||||||
pub mod http;
|
pub mod http;
|
||||||
|
pub mod memory_maintenance;
|
||||||
|
pub mod processor;
|
||||||
|
pub mod prompt;
|
||||||
pub mod session;
|
pub mod session;
|
||||||
pub mod ws;
|
pub mod ws;
|
||||||
|
|
||||||
|
use axum::{Router, routing};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use axum::{routing, Router};
|
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
|
|
||||||
use crate::bus::{MessageBus, OutboundDispatcher};
|
use crate::bus::{MessageBus, OutboundDispatcher};
|
||||||
@ -14,7 +18,8 @@ use crate::config::LLMProviderConfig;
|
|||||||
use crate::logging;
|
use crate::logging;
|
||||||
use crate::scheduler::Scheduler;
|
use crate::scheduler::Scheduler;
|
||||||
use crate::skills::SkillRuntime;
|
use crate::skills::SkillRuntime;
|
||||||
use session::{BusToolCallEmitter, SessionManager};
|
use processor::InboundProcessor;
|
||||||
|
use session::SessionManager;
|
||||||
|
|
||||||
pub struct GatewayState {
|
pub struct GatewayState {
|
||||||
pub config: Config,
|
pub config: Config,
|
||||||
@ -61,74 +66,17 @@ impl GatewayState {
|
|||||||
|
|
||||||
/// Start the message processing loops
|
/// Start the message processing loops
|
||||||
pub async fn start_message_processing(&self) {
|
pub async fn start_message_processing(&self) {
|
||||||
let bus_for_inbound = self.bus.clone();
|
|
||||||
let bus_for_outbound = self.bus.clone();
|
let bus_for_outbound = self.bus.clone();
|
||||||
let session_manager = self.session_manager.clone();
|
let inbound_processor =
|
||||||
|
InboundProcessor::new(self.bus.clone(), self.session_manager.clone());
|
||||||
// Spawn inbound message processor
|
tokio::spawn(inbound_processor.run());
|
||||||
// This consumes from bus.inbound, processes via SessionManager, publishes to bus.outbound
|
|
||||||
tokio::spawn(async move {
|
|
||||||
tracing::info!("Inbound processor started");
|
|
||||||
loop {
|
|
||||||
let inbound = bus_for_inbound.consume_inbound().await;
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
tracing::debug!(
|
|
||||||
channel = %inbound.channel,
|
|
||||||
chat_id = %inbound.chat_id,
|
|
||||||
sender = %inbound.sender_id,
|
|
||||||
content = %inbound.content,
|
|
||||||
media_count = %inbound.media.len(),
|
|
||||||
"Processing inbound message"
|
|
||||||
);
|
|
||||||
if !inbound.media.is_empty() {
|
|
||||||
for (i, m) in inbound.media.iter().enumerate() {
|
|
||||||
tracing::debug!(media_index = i, media_type = %m.media_type, path = %m.path, "Media item");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process via session manager
|
|
||||||
let live_emitter = Arc::new(BusToolCallEmitter::new(
|
|
||||||
bus_for_inbound.clone(),
|
|
||||||
inbound.channel.clone(),
|
|
||||||
inbound.chat_id.clone(),
|
|
||||||
inbound.forwarded_metadata.clone(),
|
|
||||||
session_manager.show_tool_results(),
|
|
||||||
));
|
|
||||||
match session_manager.handle_message(
|
|
||||||
&inbound.channel,
|
|
||||||
&inbound.sender_id,
|
|
||||||
&inbound.chat_id,
|
|
||||||
&inbound.content,
|
|
||||||
inbound.media,
|
|
||||||
Some(live_emitter),
|
|
||||||
).await {
|
|
||||||
Ok(outbound_messages) => {
|
|
||||||
// Forward channel-specific metadata from inbound to outbound.
|
|
||||||
// This allows channels to propagate context (e.g. feishu message_id for reaction cleanup)
|
|
||||||
// without gateway needing channel-specific code.
|
|
||||||
for mut outbound in outbound_messages {
|
|
||||||
outbound.metadata.extend(inbound.forwarded_metadata.clone());
|
|
||||||
if let Err(e) = bus_for_inbound.publish_outbound(outbound).await {
|
|
||||||
tracing::error!(error = %e, "Failed to publish outbound");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!(error = %e, "Failed to handle message");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Spawn outbound dispatcher
|
// Spawn outbound dispatcher
|
||||||
let dispatcher = OutboundDispatcher::new(bus_for_outbound);
|
let dispatcher = OutboundDispatcher::new(bus_for_outbound);
|
||||||
let channel_manager = self.channel_manager.clone();
|
let channel_manager = self.channel_manager.clone();
|
||||||
|
|
||||||
// Register channels with dispatcher
|
for (name, channel) in channel_manager.channels().await {
|
||||||
if let Some(channel) = channel_manager.get_channel("feishu").await {
|
dispatcher.register_channel(&name, channel).await;
|
||||||
dispatcher.register_channel("feishu", channel).await;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
@ -138,7 +86,10 @@ impl GatewayState {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn std::error::Error>> {
|
pub async fn run(
|
||||||
|
host: Option<String>,
|
||||||
|
port: Option<u16>,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let config = Config::load_default()?;
|
let config = Config::load_default()?;
|
||||||
let timezone = config.time.parse_timezone()?;
|
let timezone = config.time.parse_timezone()?;
|
||||||
|
|
||||||
@ -152,7 +103,10 @@ pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn
|
|||||||
let provider_config = state.config.get_provider_config("default")?;
|
let provider_config = state.config.get_provider_config("default")?;
|
||||||
|
|
||||||
// Initialize and start channels
|
// Initialize and start channels
|
||||||
state.channel_manager.init(&state.config, provider_config.clone()).await?;
|
state
|
||||||
|
.channel_manager
|
||||||
|
.init(&state.config, provider_config.clone())
|
||||||
|
.await?;
|
||||||
state.channel_manager.start_all().await?;
|
state.channel_manager.start_all().await?;
|
||||||
|
|
||||||
// Start message processing (inbound processor + outbound dispatcher)
|
// Start message processing (inbound processor + outbound dispatcher)
|
||||||
|
|||||||
77
src/gateway/processor.rs
Normal file
77
src/gateway/processor.rs
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use crate::bus::MessageBus;
|
||||||
|
|
||||||
|
use super::session::{BusToolCallEmitter, SessionManager};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct InboundProcessor {
|
||||||
|
bus: Arc<MessageBus>,
|
||||||
|
session_manager: SessionManager,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl InboundProcessor {
|
||||||
|
pub fn new(bus: Arc<MessageBus>, session_manager: SessionManager) -> Self {
|
||||||
|
Self {
|
||||||
|
bus,
|
||||||
|
session_manager,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run(self) {
|
||||||
|
tracing::info!("Inbound processor started");
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let inbound = self.bus.consume_inbound().await;
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
tracing::debug!(
|
||||||
|
channel = %inbound.channel,
|
||||||
|
chat_id = %inbound.chat_id,
|
||||||
|
sender = %inbound.sender_id,
|
||||||
|
content = %inbound.content,
|
||||||
|
media_count = %inbound.media.len(),
|
||||||
|
"Processing inbound message"
|
||||||
|
);
|
||||||
|
if !inbound.media.is_empty() {
|
||||||
|
for (i, media) in inbound.media.iter().enumerate() {
|
||||||
|
tracing::debug!(media_index = i, media_type = %media.media_type, path = %media.path, "Media item");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let live_emitter = Arc::new(BusToolCallEmitter::new(
|
||||||
|
self.bus.clone(),
|
||||||
|
inbound.channel.clone(),
|
||||||
|
inbound.chat_id.clone(),
|
||||||
|
inbound.forwarded_metadata.clone(),
|
||||||
|
self.session_manager.show_tool_results(),
|
||||||
|
));
|
||||||
|
|
||||||
|
match self
|
||||||
|
.session_manager
|
||||||
|
.handle_message(
|
||||||
|
&inbound.channel,
|
||||||
|
&inbound.sender_id,
|
||||||
|
&inbound.chat_id,
|
||||||
|
&inbound.content,
|
||||||
|
inbound.media,
|
||||||
|
Some(live_emitter),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(outbound_messages) => {
|
||||||
|
for mut outbound in outbound_messages {
|
||||||
|
outbound.metadata.extend(inbound.forwarded_metadata.clone());
|
||||||
|
if let Err(error) = self.bus.publish_outbound(outbound).await {
|
||||||
|
tracing::error!(error = %error, "Failed to publish outbound");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(error) => {
|
||||||
|
tracing::error!(error = %error, "Failed to handle message");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
149
src/gateway/prompt.rs
Normal file
149
src/gateway/prompt.rs
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
use std::fs;
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
|
use crate::agent::AgentError;
|
||||||
|
|
||||||
|
pub(crate) const DEFAULT_AGENT_PROMPT: &str = include_str!("default_agent_prompt.md");
|
||||||
|
pub(crate) const MANAGED_AGENT_MEMORY_BLOCK_START: &str = "<!-- PICOBOT_MANAGED_MEMORY:START -->";
|
||||||
|
pub(crate) const MANAGED_AGENT_MEMORY_BLOCK_END: &str = "<!-- PICOBOT_MANAGED_MEMORY:END -->";
|
||||||
|
pub(crate) const MANAGED_AGENT_MEMORY_TITLE: &str = "## 用户记忆摘要";
|
||||||
|
|
||||||
|
pub(crate) fn load_agent_prompt() -> Result<Option<String>, AgentError> {
|
||||||
|
let path = agent_prompt_path()?;
|
||||||
|
if let Some(parent) = path.parent() {
|
||||||
|
fs::create_dir_all(parent)
|
||||||
|
.map_err(|err| AgentError::Other(format!("create agent prompt dir error: {}", err)))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !path.exists() {
|
||||||
|
write_agent_prompt(&path, DEFAULT_AGENT_PROMPT)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let content = fs::read_to_string(&path)
|
||||||
|
.map_err(|err| AgentError::Other(format!("read agent prompt file error: {}", err)))?;
|
||||||
|
let trimmed = content.trim();
|
||||||
|
if trimmed.is_empty() {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Some(trimmed.to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn upsert_managed_agent_memory_summary(markdown_body: &str) -> Result<(), AgentError> {
|
||||||
|
let path = agent_prompt_path()?;
|
||||||
|
let existing = if path.exists() {
|
||||||
|
fs::read_to_string(&path)
|
||||||
|
.map_err(|err| AgentError::Other(format!("read agent prompt file error: {}", err)))?
|
||||||
|
} else {
|
||||||
|
DEFAULT_AGENT_PROMPT.to_string()
|
||||||
|
};
|
||||||
|
let updated = upsert_managed_agent_memory_block(&existing, markdown_body);
|
||||||
|
write_agent_prompt(&path, &updated)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn upsert_managed_agent_memory_block(existing: &str, markdown_body: &str) -> String {
|
||||||
|
let managed_block = render_managed_agent_memory_block(markdown_body);
|
||||||
|
|
||||||
|
if let (Some(start), Some(end)) = (
|
||||||
|
existing.find(MANAGED_AGENT_MEMORY_BLOCK_START),
|
||||||
|
existing.find(MANAGED_AGENT_MEMORY_BLOCK_END),
|
||||||
|
) {
|
||||||
|
let end = end + MANAGED_AGENT_MEMORY_BLOCK_END.len();
|
||||||
|
let mut updated = String::new();
|
||||||
|
updated.push_str(existing[..start].trim_end());
|
||||||
|
updated.push_str("\n\n");
|
||||||
|
updated.push_str(&managed_block);
|
||||||
|
updated.push_str("\n\n");
|
||||||
|
updated.push_str(existing[end..].trim_start());
|
||||||
|
return updated.trim().to_string() + "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(reply_rules_index) = existing.find("## 回复规则") {
|
||||||
|
let mut updated = String::new();
|
||||||
|
updated.push_str(existing[..reply_rules_index].trim_end());
|
||||||
|
updated.push_str("\n\n");
|
||||||
|
updated.push_str(&managed_block);
|
||||||
|
updated.push_str("\n\n");
|
||||||
|
updated.push_str(existing[reply_rules_index..].trim_start());
|
||||||
|
return updated.trim().to_string() + "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut updated = existing.trim_end().to_string();
|
||||||
|
if !updated.is_empty() {
|
||||||
|
updated.push_str("\n\n");
|
||||||
|
}
|
||||||
|
updated.push_str(&managed_block);
|
||||||
|
updated.push('\n');
|
||||||
|
updated
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_managed_agent_memory_block(markdown_body: &str) -> String {
|
||||||
|
format!(
|
||||||
|
"{MANAGED_AGENT_MEMORY_BLOCK_START}\n{MANAGED_AGENT_MEMORY_TITLE}\n\n{}\n{MANAGED_AGENT_MEMORY_BLOCK_END}",
|
||||||
|
markdown_body.trim()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn write_agent_prompt(path: &Path, content: &str) -> Result<(), AgentError> {
|
||||||
|
if let Some(parent) = path.parent() {
|
||||||
|
fs::create_dir_all(parent)
|
||||||
|
.map_err(|err| AgentError::Other(format!("create agent prompt dir error: {}", err)))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let temp_path = path.with_extension("md.tmp");
|
||||||
|
fs::write(&temp_path, content)
|
||||||
|
.map_err(|err| AgentError::Other(format!("write agent prompt temp file error: {}", err)))?;
|
||||||
|
fs::rename(&temp_path, path)
|
||||||
|
.map_err(|err| AgentError::Other(format!("replace agent prompt file error: {}", err)))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn agent_prompt_path() -> Result<PathBuf, AgentError> {
|
||||||
|
let home = dirs::home_dir()
|
||||||
|
.ok_or_else(|| AgentError::Other("home directory not found".to_string()))?;
|
||||||
|
Ok(home.join(".picobot").join("agent").join("AGENT.md"))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_upsert_managed_agent_memory_block_inserts_before_reply_rules() {
|
||||||
|
let original =
|
||||||
|
"# PicoBot 代理配置\n\n## 身份\n- 你是 PicoBot。\n\n## 回复规则\n- 使用中文回复。\n";
|
||||||
|
let updated = upsert_managed_agent_memory_block(
|
||||||
|
original,
|
||||||
|
"### 用户事实\n- 用户在做AI产品\n\n### 用户偏好\n- 偏好简洁表达",
|
||||||
|
);
|
||||||
|
|
||||||
|
let managed_pos = updated.find(MANAGED_AGENT_MEMORY_BLOCK_START).unwrap();
|
||||||
|
let reply_rules_pos = updated.find("## 回复规则").unwrap();
|
||||||
|
assert!(managed_pos < reply_rules_pos);
|
||||||
|
assert!(updated.contains(MANAGED_AGENT_MEMORY_TITLE));
|
||||||
|
assert!(updated.contains("用户在做AI产品"));
|
||||||
|
assert!(updated.contains("偏好简洁表达"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_upsert_managed_agent_memory_block_replaces_existing_block() {
|
||||||
|
let original = format!(
|
||||||
|
"# PicoBot\n\n{MANAGED_AGENT_MEMORY_BLOCK_START}\n{MANAGED_AGENT_MEMORY_TITLE}\n\nold\n{MANAGED_AGENT_MEMORY_BLOCK_END}\n\n## 回复规则\n- 简洁。\n"
|
||||||
|
);
|
||||||
|
|
||||||
|
let updated = upsert_managed_agent_memory_block(&original, "new");
|
||||||
|
|
||||||
|
assert!(updated.contains("new"));
|
||||||
|
assert!(!updated.contains("old"));
|
||||||
|
assert_eq!(updated.matches(MANAGED_AGENT_MEMORY_BLOCK_START).count(), 1);
|
||||||
|
assert_eq!(updated.matches(MANAGED_AGENT_MEMORY_BLOCK_END).count(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_upsert_managed_agent_memory_block_trims_summary_body() {
|
||||||
|
let updated = upsert_managed_agent_memory_block("# PicoBot\n", "\n\nsummary\n\n");
|
||||||
|
|
||||||
|
assert!(updated.contains("\n\nsummary\n"));
|
||||||
|
assert!(!updated.contains("\n\nsummary\n\n\n"));
|
||||||
|
}
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@ -1,15 +1,18 @@
|
|||||||
use std::sync::Arc;
|
use super::{
|
||||||
|
GatewayState,
|
||||||
|
session::{Session, handle_in_chat_command, schedule_background_history_compaction},
|
||||||
|
};
|
||||||
|
use crate::agent::EmittedMessageHandler;
|
||||||
|
use crate::bus::ChatMessage;
|
||||||
|
use crate::bus::message::{ToolMessageState, format_tool_call_content};
|
||||||
|
use crate::protocol::{SessionSummary, WsInbound, WsOutbound, parse_inbound, serialize_outbound};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage};
|
|
||||||
use axum::extract::State;
|
use axum::extract::State;
|
||||||
|
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
|
||||||
use axum::response::Response;
|
use axum::response::Response;
|
||||||
use futures_util::{SinkExt, StreamExt};
|
use futures_util::{SinkExt, StreamExt};
|
||||||
use tokio::sync::{mpsc, Mutex};
|
use std::sync::Arc;
|
||||||
use crate::agent::EmittedMessageHandler;
|
use tokio::sync::{Mutex, mpsc};
|
||||||
use crate::bus::message::{format_tool_call_content, ToolMessageState};
|
|
||||||
use crate::bus::ChatMessage;
|
|
||||||
use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound};
|
|
||||||
use super::{GatewayState, session::{Session, handle_in_chat_command, schedule_background_history_compaction}};
|
|
||||||
|
|
||||||
struct WsToolCallEmitter {
|
struct WsToolCallEmitter {
|
||||||
sender: mpsc::Sender<WsOutbound>,
|
sender: mpsc::Sender<WsOutbound>,
|
||||||
@ -120,7 +123,9 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
|||||||
&runtime_session_id,
|
&runtime_session_id,
|
||||||
&mut current_session_id,
|
&mut current_session_id,
|
||||||
inbound,
|
inbound,
|
||||||
).await {
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
tracing::warn!(error = %e, session_id = %current_session_id, "Failed to handle inbound message");
|
tracing::warn!(error = %e, session_id = %current_session_id, "Failed to handle inbound message");
|
||||||
let _ = session
|
let _ = session
|
||||||
.lock()
|
.lock()
|
||||||
@ -182,17 +187,14 @@ fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
outbound.extend(tool_calls
|
outbound.extend(tool_calls.iter().map(|tool_call| WsOutbound::ToolCall {
|
||||||
.iter()
|
id: message.id.clone(),
|
||||||
.map(|tool_call| WsOutbound::ToolCall {
|
tool_call_id: tool_call.id.clone(),
|
||||||
id: message.id.clone(),
|
tool_name: tool_call.name.clone(),
|
||||||
tool_call_id: tool_call.id.clone(),
|
arguments: tool_call.arguments.clone(),
|
||||||
tool_name: tool_call.name.clone(),
|
content: format_tool_call_content(&tool_call.name, &tool_call.arguments),
|
||||||
arguments: tool_call.arguments.clone(),
|
role: message.role.clone(),
|
||||||
content: format_tool_call_content(&tool_call.name, &tool_call.arguments),
|
}));
|
||||||
role: message.role.clone(),
|
|
||||||
})
|
|
||||||
);
|
|
||||||
outbound
|
outbound
|
||||||
} else {
|
} else {
|
||||||
vec![WsOutbound::AssistantResponse {
|
vec![WsOutbound::AssistantResponse {
|
||||||
@ -202,7 +204,11 @@ fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
|
|||||||
}]
|
}]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"tool" => match message.tool_state.as_ref().unwrap_or(&ToolMessageState::Completed) {
|
"tool" => match message
|
||||||
|
.tool_state
|
||||||
|
.as_ref()
|
||||||
|
.unwrap_or(&ToolMessageState::Completed)
|
||||||
|
{
|
||||||
ToolMessageState::Completed => vec![WsOutbound::ToolResult {
|
ToolMessageState::Completed => vec![WsOutbound::ToolResult {
|
||||||
id: message.id.clone(),
|
id: message.id.clone(),
|
||||||
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
|
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
|
||||||
@ -230,7 +236,10 @@ fn should_display_message_to_user(show_tool_results: bool, message: &ChatMessage
|
|||||||
|
|
||||||
show_tool_results
|
show_tool_results
|
||||||
|| matches!(
|
|| matches!(
|
||||||
message.tool_state.as_ref().unwrap_or(&ToolMessageState::Completed),
|
message
|
||||||
|
.tool_state
|
||||||
|
.as_ref()
|
||||||
|
.unwrap_or(&ToolMessageState::Completed),
|
||||||
ToolMessageState::PendingUserAction
|
ToolMessageState::PendingUserAction
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -243,7 +252,12 @@ async fn handle_inbound(
|
|||||||
inbound: WsInbound,
|
inbound: WsInbound,
|
||||||
) -> Result<(), crate::agent::AgentError> {
|
) -> Result<(), crate::agent::AgentError> {
|
||||||
match inbound {
|
match inbound {
|
||||||
WsInbound::UserInput { content, chat_id, sender_id, .. } => {
|
WsInbound::UserInput {
|
||||||
|
content,
|
||||||
|
chat_id,
|
||||||
|
sender_id,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone());
|
let chat_id = chat_id.unwrap_or_else(|| current_session_id.clone());
|
||||||
let sender_id = resolve_ws_sender_id(sender_id.as_deref(), runtime_session_id);
|
let sender_id = resolve_ws_sender_id(sender_id.as_deref(), runtime_session_id);
|
||||||
let (history, agent, user_tx) = {
|
let (history, agent, user_tx) = {
|
||||||
@ -252,7 +266,9 @@ async fn handle_inbound(
|
|||||||
session_guard.ensure_persistent_session(&chat_id)?;
|
session_guard.ensure_persistent_session(&chat_id)?;
|
||||||
session_guard.ensure_chat_loaded(&chat_id)?;
|
session_guard.ensure_chat_loaded(&chat_id)?;
|
||||||
|
|
||||||
if let Some(command_response) = handle_in_chat_command(&mut session_guard, &chat_id, &content)? {
|
if let Some(command_response) =
|
||||||
|
handle_in_chat_command(&mut session_guard, &chat_id, &content)?
|
||||||
|
{
|
||||||
let _ = session_guard
|
let _ = session_guard
|
||||||
.send(WsOutbound::AssistantResponse {
|
.send(WsOutbound::AssistantResponse {
|
||||||
id: uuid::Uuid::new_v4().to_string(),
|
id: uuid::Uuid::new_v4().to_string(),
|
||||||
@ -286,13 +302,17 @@ async fn handle_inbound(
|
|||||||
match agent.process(history).await {
|
match agent.process(history).await {
|
||||||
Ok(result) => {
|
Ok(result) => {
|
||||||
let mut session_guard = session.lock().await;
|
let mut session_guard = session.lock().await;
|
||||||
session_guard.append_persisted_messages(&chat_id, result.emitted_messages.clone())?;
|
session_guard
|
||||||
|
.append_persisted_messages(&chat_id, result.emitted_messages.clone())?;
|
||||||
for outbound in result
|
for outbound in result
|
||||||
.emitted_messages
|
.emitted_messages
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|message| {
|
.filter(|message| {
|
||||||
!message.is_assistant_tool_call_message()
|
!message.is_assistant_tool_call_message()
|
||||||
&& should_display_message_to_user(state.config.gateway.show_tool_results, message)
|
&& should_display_message_to_user(
|
||||||
|
state.config.gateway.show_tool_results,
|
||||||
|
message,
|
||||||
|
)
|
||||||
})
|
})
|
||||||
.flat_map(ws_outbound_from_chat_message)
|
.flat_map(ws_outbound_from_chat_message)
|
||||||
{
|
{
|
||||||
@ -301,7 +321,10 @@ async fn handle_inbound(
|
|||||||
|
|
||||||
drop(session_guard);
|
drop(session_guard);
|
||||||
|
|
||||||
if let Err(error) = schedule_background_history_compaction(session.clone(), chat_id.clone()).await {
|
if let Err(error) =
|
||||||
|
schedule_background_history_compaction(session.clone(), chat_id.clone())
|
||||||
|
.await
|
||||||
|
{
|
||||||
tracing::warn!(chat_id = %chat_id, error = %error, "Failed to schedule background history compaction for CLI session");
|
tracing::warn!(chat_id = %chat_id, error = %error, "Failed to schedule background history compaction for CLI session");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -318,16 +341,19 @@ async fn handle_inbound(
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
WsInbound::ClearHistory { session_id, chat_id } => {
|
WsInbound::ClearHistory {
|
||||||
let target = session_id.or(chat_id).unwrap_or_else(|| current_session_id.clone());
|
session_id,
|
||||||
|
chat_id,
|
||||||
|
} => {
|
||||||
|
let target = session_id
|
||||||
|
.or(chat_id)
|
||||||
|
.unwrap_or_else(|| current_session_id.clone());
|
||||||
state.session_manager.clear_session_messages(&target)?;
|
state.session_manager.clear_session_messages(&target)?;
|
||||||
|
|
||||||
let mut session_guard = session.lock().await;
|
let mut session_guard = session.lock().await;
|
||||||
session_guard.remove_history(&target);
|
session_guard.remove_history(&target);
|
||||||
let _ = session_guard
|
let _ = session_guard
|
||||||
.send(WsOutbound::HistoryCleared {
|
.send(WsOutbound::HistoryCleared { session_id: target })
|
||||||
session_id: target,
|
|
||||||
})
|
|
||||||
.await;
|
.await;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -452,17 +478,15 @@ fn resolve_ws_sender_id(sender_id: Option<&str>, runtime_session_id: &str) -> St
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::agent::EmittedMessageHandler;
|
|
||||||
use super::{
|
use super::{
|
||||||
WsToolCallEmitter,
|
WsToolCallEmitter, resolve_ws_sender_id, should_display_message_to_user,
|
||||||
resolve_ws_sender_id,
|
|
||||||
should_display_message_to_user,
|
|
||||||
ws_outbound_from_chat_message,
|
ws_outbound_from_chat_message,
|
||||||
};
|
};
|
||||||
|
use crate::agent::EmittedMessageHandler;
|
||||||
use crate::bus::ChatMessage;
|
use crate::bus::ChatMessage;
|
||||||
use crate::bus::message::ToolMessageState;
|
use crate::bus::message::ToolMessageState;
|
||||||
use crate::providers::ToolCall;
|
|
||||||
use crate::protocol::WsOutbound;
|
use crate::protocol::WsOutbound;
|
||||||
|
use crate::providers::ToolCall;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
@ -481,11 +505,17 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(outbound.len(), 1);
|
assert_eq!(outbound.len(), 1);
|
||||||
match &outbound[0] {
|
match &outbound[0] {
|
||||||
WsOutbound::ToolCall { tool_call_id, tool_name, arguments, content, .. } => {
|
WsOutbound::ToolCall {
|
||||||
|
tool_call_id,
|
||||||
|
tool_name,
|
||||||
|
arguments,
|
||||||
|
content,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
assert_eq!(tool_call_id, "call-1");
|
assert_eq!(tool_call_id, "call-1");
|
||||||
assert_eq!(tool_name, "calculator");
|
assert_eq!(tool_name, "calculator");
|
||||||
assert_eq!(arguments["expression"], "1 + 1");
|
assert_eq!(arguments["expression"], "1 + 1");
|
||||||
assert_eq!(content, "### calculator\n- expression: 1 + 1");
|
assert_eq!(content, "calculator\nargs: {\"expression\":\"1 + 1\"}");
|
||||||
}
|
}
|
||||||
other => panic!("unexpected outbound variant: {:?}", other),
|
other => panic!("unexpected outbound variant: {:?}", other),
|
||||||
}
|
}
|
||||||
@ -551,8 +581,14 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_resolve_ws_sender_id_prefers_inbound_sender() {
|
fn test_resolve_ws_sender_id_prefers_inbound_sender() {
|
||||||
assert_eq!(resolve_ws_sender_id(Some("user-42"), "runtime-1"), "user-42");
|
assert_eq!(
|
||||||
assert_eq!(resolve_ws_sender_id(Some(" user-42 "), "runtime-1"), "user-42");
|
resolve_ws_sender_id(Some("user-42"), "runtime-1"),
|
||||||
|
"user-42"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
resolve_ws_sender_id(Some(" user-42 "), "runtime-1"),
|
||||||
|
"user-42"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -573,8 +609,10 @@ mod tests {
|
|||||||
.handle(ChatMessage::tool("call-1", "calculator", "2"))
|
.handle(ChatMessage::tool("call-1", "calculator", "2"))
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
assert!(tokio::time::timeout(std::time::Duration::from_millis(50), receiver.recv())
|
assert!(
|
||||||
.await
|
tokio::time::timeout(std::time::Duration::from_millis(50), receiver.recv())
|
||||||
.is_err());
|
.await
|
||||||
|
.is_err()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
20
src/lib.rs
20
src/lib.rs
@ -1,16 +1,16 @@
|
|||||||
pub mod config;
|
|
||||||
pub mod text;
|
|
||||||
pub mod providers;
|
|
||||||
pub mod bus;
|
|
||||||
pub mod cli;
|
|
||||||
pub mod agent;
|
pub mod agent;
|
||||||
pub mod gateway;
|
pub mod bus;
|
||||||
pub mod client;
|
|
||||||
pub mod protocol;
|
|
||||||
pub mod channels;
|
pub mod channels;
|
||||||
|
pub mod cli;
|
||||||
|
pub mod client;
|
||||||
|
pub mod config;
|
||||||
|
pub mod gateway;
|
||||||
pub mod logging;
|
pub mod logging;
|
||||||
pub mod observability;
|
pub mod observability;
|
||||||
|
pub mod protocol;
|
||||||
|
pub mod providers;
|
||||||
pub mod scheduler;
|
pub mod scheduler;
|
||||||
pub mod storage;
|
|
||||||
pub mod tools;
|
|
||||||
pub mod skills;
|
pub mod skills;
|
||||||
|
pub mod storage;
|
||||||
|
pub mod text;
|
||||||
|
pub mod tools;
|
||||||
|
|||||||
@ -1,13 +1,9 @@
|
|||||||
use std::path::PathBuf;
|
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use chrono_tz::Tz;
|
use chrono_tz::Tz;
|
||||||
|
use std::path::PathBuf;
|
||||||
use tracing_appender::rolling::{RollingFileAppender, Rotation};
|
use tracing_appender::rolling::{RollingFileAppender, Rotation};
|
||||||
use tracing_subscriber::{
|
use tracing_subscriber::{
|
||||||
fmt,
|
EnvFilter, fmt, fmt::time::FormatTime, layer::SubscriberExt, util::SubscriberInitExt,
|
||||||
fmt::time::FormatTime,
|
|
||||||
layer::SubscriberExt,
|
|
||||||
util::SubscriberInitExt,
|
|
||||||
EnvFilter,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug)]
|
#[derive(Clone, Copy, Debug)]
|
||||||
@ -16,8 +12,17 @@ struct ConfiguredTimestamp {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl FormatTime for ConfiguredTimestamp {
|
impl FormatTime for ConfiguredTimestamp {
|
||||||
fn format_time(&self, writer: &mut tracing_subscriber::fmt::format::Writer<'_>) -> std::fmt::Result {
|
fn format_time(
|
||||||
write!(writer, "{}", Utc::now().with_timezone(&self.timezone).to_rfc3339_opts(chrono::SecondsFormat::Millis, true))
|
&self,
|
||||||
|
writer: &mut tracing_subscriber::fmt::format::Writer<'_>,
|
||||||
|
) -> std::fmt::Result {
|
||||||
|
write!(
|
||||||
|
writer,
|
||||||
|
"{}",
|
||||||
|
Utc::now()
|
||||||
|
.with_timezone(&self.timezone)
|
||||||
|
.to_rfc3339_opts(chrono::SecondsFormat::Millis, true)
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -41,20 +46,19 @@ pub fn init_logging(timezone: Tz) {
|
|||||||
// Create log directory if it doesn't exist
|
// Create log directory if it doesn't exist
|
||||||
if !log_dir.exists() {
|
if !log_dir.exists() {
|
||||||
if let Err(e) = std::fs::create_dir_all(&log_dir) {
|
if let Err(e) = std::fs::create_dir_all(&log_dir) {
|
||||||
eprintln!("Warning: Failed to create log directory {}: {}", log_dir.display(), e);
|
eprintln!(
|
||||||
|
"Warning: Failed to create log directory {}: {}",
|
||||||
|
log_dir.display(),
|
||||||
|
e
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create file appender with daily rotation
|
// Create file appender with daily rotation
|
||||||
let file_appender = RollingFileAppender::new(
|
let file_appender = RollingFileAppender::new(Rotation::DAILY, &log_dir, "picobot.log");
|
||||||
Rotation::DAILY,
|
|
||||||
&log_dir,
|
|
||||||
"picobot.log",
|
|
||||||
);
|
|
||||||
|
|
||||||
// Build subscriber with both console and file output
|
// Build subscriber with both console and file output
|
||||||
let env_filter = EnvFilter::try_from_default_env()
|
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
|
||||||
.unwrap_or_else(|_| EnvFilter::new("info"));
|
|
||||||
|
|
||||||
let file_layer = fmt::layer()
|
let file_layer = fmt::layer()
|
||||||
.with_writer(file_appender)
|
.with_writer(file_appender)
|
||||||
@ -80,8 +84,7 @@ pub fn init_logging(timezone: Tz) {
|
|||||||
|
|
||||||
/// Initialize logging without file output (console only)
|
/// Initialize logging without file output (console only)
|
||||||
pub fn init_logging_console_only(timezone: Tz) {
|
pub fn init_logging_console_only(timezone: Tz) {
|
||||||
let env_filter = EnvFilter::try_from_default_env()
|
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
|
||||||
.unwrap_or_else(|_| EnvFilter::new("info"));
|
|
||||||
|
|
||||||
let console_layer = fmt::layer()
|
let console_layer = fmt::layer()
|
||||||
.with_timer(ConfiguredTimestamp { timezone })
|
.with_timer(ConfiguredTimestamp { timezone })
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
use clap::{Parser, CommandFactory};
|
use clap::{CommandFactory, Parser};
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
#[command(name = "picobot")]
|
#[command(name = "picobot")]
|
||||||
|
|||||||
@ -26,10 +26,7 @@ pub enum ObserverEvent {
|
|||||||
success: bool,
|
success: bool,
|
||||||
},
|
},
|
||||||
/// Emitted when the agent starts processing.
|
/// Emitted when the agent starts processing.
|
||||||
AgentStart {
|
AgentStart { provider: String, model: String },
|
||||||
provider: String,
|
|
||||||
model: String,
|
|
||||||
},
|
|
||||||
/// Emitted when the agent finishes processing.
|
/// Emitted when the agent finishes processing.
|
||||||
AgentEnd {
|
AgentEnd {
|
||||||
provider: String,
|
provider: String,
|
||||||
@ -116,7 +113,11 @@ impl ToolExecutionOutcome {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Create a failed outcome with duration.
|
/// Create a failed outcome with duration.
|
||||||
pub fn failure_with_duration(output: String, error_reason: Option<String>, duration: Duration) -> Self {
|
pub fn failure_with_duration(
|
||||||
|
output: String,
|
||||||
|
error_reason: Option<String>,
|
||||||
|
duration: Duration,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
output,
|
output,
|
||||||
success: false,
|
success: false,
|
||||||
|
|||||||
@ -43,9 +43,7 @@ pub enum WsInbound {
|
|||||||
include_archived: bool,
|
include_archived: bool,
|
||||||
},
|
},
|
||||||
#[serde(rename = "load_session")]
|
#[serde(rename = "load_session")]
|
||||||
LoadSession {
|
LoadSession { session_id: String },
|
||||||
session_id: String,
|
|
||||||
},
|
|
||||||
#[serde(rename = "rename_session")]
|
#[serde(rename = "rename_session")]
|
||||||
RenameSession {
|
RenameSession {
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
@ -70,7 +68,11 @@ pub enum WsInbound {
|
|||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
pub enum WsOutbound {
|
pub enum WsOutbound {
|
||||||
#[serde(rename = "assistant_response")]
|
#[serde(rename = "assistant_response")]
|
||||||
AssistantResponse { id: String, content: String, role: String },
|
AssistantResponse {
|
||||||
|
id: String,
|
||||||
|
content: String,
|
||||||
|
role: String,
|
||||||
|
},
|
||||||
#[serde(rename = "tool_call")]
|
#[serde(rename = "tool_call")]
|
||||||
ToolCall {
|
ToolCall {
|
||||||
id: String,
|
id: String,
|
||||||
|
|||||||
@ -4,9 +4,9 @@ use serde::{Deserialize, Serialize};
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use crate::bus::message::ContentBlock;
|
|
||||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
|
||||||
use super::traits::Usage;
|
use super::traits::Usage;
|
||||||
|
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
||||||
|
use crate::bus::message::ContentBlock;
|
||||||
|
|
||||||
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
|
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
|
||||||
let mut details = vec![error.to_string()];
|
let mut details = vec![error.to_string()];
|
||||||
@ -20,7 +20,10 @@ fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
|
|||||||
details.join("\ncaused by: ")
|
details.join("\ncaused by: ")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn serialize_content_blocks<S>(blocks: &[serde_json::Value], serializer: S) -> Result<S::Ok, S::Error>
|
fn serialize_content_blocks<S>(
|
||||||
|
blocks: &[serde_json::Value],
|
||||||
|
serializer: S,
|
||||||
|
) -> Result<S::Ok, S::Error>
|
||||||
where
|
where
|
||||||
S: serde::Serializer,
|
S: serde::Serializer,
|
||||||
{
|
{
|
||||||
@ -28,14 +31,15 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> {
|
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> {
|
||||||
blocks.iter().map(|b| match b {
|
blocks
|
||||||
ContentBlock::Text { text } => {
|
.iter()
|
||||||
serde_json::json!({ "type": "text", "text": text })
|
.map(|b| match b {
|
||||||
}
|
ContentBlock::Text { text } => {
|
||||||
ContentBlock::ImageUrl { image_url } => {
|
serde_json::json!({ "type": "text", "text": text })
|
||||||
convert_image_url_to_anthropic(&image_url.url)
|
}
|
||||||
}
|
ContentBlock::ImageUrl { image_url } => convert_image_url_to_anthropic(&image_url.url),
|
||||||
}).collect()
|
})
|
||||||
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn convert_image_url_to_anthropic(url: &str) -> serde_json::Value {
|
fn convert_image_url_to_anthropic(url: &str) -> serde_json::Value {
|
||||||
@ -147,9 +151,13 @@ struct AnthropicResponse {
|
|||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
#[serde(tag = "type", rename_all = "snake_case")]
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
enum AnthropicContent {
|
enum AnthropicContent {
|
||||||
Text { text: String },
|
Text {
|
||||||
|
text: String,
|
||||||
|
},
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
Thinking { thinking: String },
|
Thinking {
|
||||||
|
thinking: String,
|
||||||
|
},
|
||||||
#[serde(rename = "tool_use")]
|
#[serde(rename = "tool_use")]
|
||||||
ToolUse {
|
ToolUse {
|
||||||
id: String,
|
id: String,
|
||||||
|
|||||||
@ -1,12 +1,15 @@
|
|||||||
pub mod traits;
|
|
||||||
pub mod openai;
|
|
||||||
pub mod anthropic;
|
pub mod anthropic;
|
||||||
|
pub mod openai;
|
||||||
|
pub mod traits;
|
||||||
|
|
||||||
pub use self::openai::OpenAIProvider;
|
|
||||||
pub use self::anthropic::AnthropicProvider;
|
pub use self::anthropic::AnthropicProvider;
|
||||||
|
pub use self::openai::OpenAIProvider;
|
||||||
|
|
||||||
use crate::config::LLMProviderConfig;
|
use crate::config::LLMProviderConfig;
|
||||||
pub use traits::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall, ToolFunction, Usage};
|
pub use traits::{
|
||||||
|
ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolCall,
|
||||||
|
ToolFunction, Usage,
|
||||||
|
};
|
||||||
|
|
||||||
pub fn create_provider(config: LLMProviderConfig) -> Result<Box<dyn LLMProvider>, ProviderError> {
|
pub fn create_provider(config: LLMProviderConfig) -> Result<Box<dyn LLMProvider>, ProviderError> {
|
||||||
match config.provider_type.as_str() {
|
match config.provider_type.as_str() {
|
||||||
|
|||||||
@ -1,18 +1,15 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde_json::{json, Value};
|
use serde_json::{Value, json};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use crate::bus::message::ContentBlock;
|
|
||||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
|
||||||
use super::traits::Usage;
|
use super::traits::Usage;
|
||||||
|
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
||||||
|
use crate::bus::message::ContentBlock;
|
||||||
|
|
||||||
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &[
|
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content"];
|
||||||
"tool_call_arguments_json",
|
|
||||||
"mock_response_content",
|
|
||||||
];
|
|
||||||
|
|
||||||
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
|
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
|
||||||
let mut details = vec![error.to_string()];
|
let mut details = vec![error.to_string()];
|
||||||
@ -32,12 +29,17 @@ fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
|
|||||||
return Value::String(text.clone());
|
return Value::String(text.clone());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Value::Array(blocks.iter().map(|b| match b {
|
Value::Array(
|
||||||
ContentBlock::Text { text } => json!({ "type": "text", "text": text }),
|
blocks
|
||||||
ContentBlock::ImageUrl { image_url } => {
|
.iter()
|
||||||
json!({ "type": "image_url", "image_url": { "url": image_url.url } })
|
.map(|b| match b {
|
||||||
}
|
ContentBlock::Text { text } => json!({ "type": "text", "text": text }),
|
||||||
}).collect())
|
ContentBlock::ImageUrl { image_url } => {
|
||||||
|
json!({ "type": "image_url", "image_url": { "url": image_url.url } })
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct OpenAIProvider {
|
pub struct OpenAIProvider {
|
||||||
@ -122,7 +124,9 @@ impl OpenAIProvider {
|
|||||||
|
|
||||||
fn request_model_extra(&self) -> impl Iterator<Item = (&String, &Value)> {
|
fn request_model_extra(&self) -> impl Iterator<Item = (&String, &Value)> {
|
||||||
self.model_extra.iter().filter(|(key, _)| {
|
self.model_extra.iter().filter(|(key, _)| {
|
||||||
!INTERNAL_MODEL_EXTRA_KEYS.iter().any(|internal| internal == &key.as_str())
|
!INTERNAL_MODEL_EXTRA_KEYS
|
||||||
|
.iter()
|
||||||
|
.any(|internal| internal == &key.as_str())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -265,7 +269,11 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
if let Some(content) = msg.get("content").and_then(|c| c.as_array()) {
|
if let Some(content) = msg.get("content").and_then(|c| c.as_array()) {
|
||||||
for (j, item) in content.iter().enumerate() {
|
for (j, item) in content.iter().enumerate() {
|
||||||
if item.get("type").and_then(|t| t.as_str()) == Some("image_url") {
|
if item.get("type").and_then(|t| t.as_str()) == Some("image_url") {
|
||||||
if let Some(url_str) = item.get("image_url").and_then(|u| u.get("url")).and_then(|v| v.as_str()) {
|
if let Some(url_str) = item
|
||||||
|
.get("image_url")
|
||||||
|
.and_then(|u| u.get("url"))
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
{
|
||||||
let prefix: String = url_str.chars().take(20).collect();
|
let prefix: String = url_str.chars().take(20).collect();
|
||||||
tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)");
|
tracing::debug!(msg_idx = i, item_idx = j, image_prefix = %prefix, image_url_len = %url_str.len(), "Image in LLM request (first 20 bytes shown)");
|
||||||
}
|
}
|
||||||
@ -419,7 +427,10 @@ mod tests {
|
|||||||
assert_eq!(tool_calls[0]["id"], "call_1");
|
assert_eq!(tool_calls[0]["id"], "call_1");
|
||||||
assert_eq!(tool_calls[0]["type"], "function");
|
assert_eq!(tool_calls[0]["type"], "function");
|
||||||
assert_eq!(tool_calls[0]["function"]["name"], "calculator");
|
assert_eq!(tool_calls[0]["function"]["name"], "calculator");
|
||||||
assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}");
|
assert_eq!(
|
||||||
|
tool_calls[0]["function"]["arguments"],
|
||||||
|
"{\"expression\":\"1+1\"}"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -433,10 +444,7 @@ mod tests {
|
|||||||
"gpt-test".to_string(),
|
"gpt-test".to_string(),
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
HashMap::from([(
|
HashMap::from([("tool_call_arguments_json".to_string(), Value::Bool(true))]),
|
||||||
"tool_call_arguments_json".to_string(),
|
|
||||||
Value::Bool(true),
|
|
||||||
)]),
|
|
||||||
);
|
);
|
||||||
|
|
||||||
let request = ChatCompletionRequest {
|
let request = ChatCompletionRequest {
|
||||||
@ -461,7 +469,10 @@ mod tests {
|
|||||||
let messages = body["messages"].as_array().unwrap();
|
let messages = body["messages"].as_array().unwrap();
|
||||||
let tool_calls = messages[0]["tool_calls"].as_array().unwrap();
|
let tool_calls = messages[0]["tool_calls"].as_array().unwrap();
|
||||||
|
|
||||||
assert_eq!(tool_calls[0]["function"]["arguments"], json!({"expression": "1+1"}));
|
assert_eq!(
|
||||||
|
tool_calls[0]["function"]["arguments"],
|
||||||
|
json!({"expression": "1+1"})
|
||||||
|
);
|
||||||
assert!(body.get("tool_call_arguments_json").is_none());
|
assert!(body.get("tool_call_arguments_json").is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -501,7 +512,10 @@ mod tests {
|
|||||||
let messages = body["messages"].as_array().unwrap();
|
let messages = body["messages"].as_array().unwrap();
|
||||||
let tool_calls = messages[0]["tool_calls"].as_array().unwrap();
|
let tool_calls = messages[0]["tool_calls"].as_array().unwrap();
|
||||||
|
|
||||||
assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}");
|
assert_eq!(
|
||||||
|
tool_calls[0]["function"]["arguments"],
|
||||||
|
"{\"expression\":\"1+1\"}"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -517,7 +531,10 @@ mod tests {
|
|||||||
None,
|
None,
|
||||||
HashMap::from([
|
HashMap::from([
|
||||||
("tool_call_arguments_json".to_string(), Value::Bool(true)),
|
("tool_call_arguments_json".to_string(), Value::Bool(true)),
|
||||||
("mock_response_content".to_string(), Value::String("stub".to_string())),
|
(
|
||||||
|
"mock_response_content".to_string(),
|
||||||
|
Value::String("stub".to_string()),
|
||||||
|
),
|
||||||
("parallel_tool_calls".to_string(), Value::Bool(true)),
|
("parallel_tool_calls".to_string(), Value::Bool(true)),
|
||||||
]),
|
]),
|
||||||
);
|
);
|
||||||
@ -590,7 +607,10 @@ mod tests {
|
|||||||
}))
|
}))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(response.choices[0].message.reasoning_content.as_deref(), Some("hidden reasoning"));
|
assert_eq!(
|
||||||
|
response.choices[0].message.reasoning_content.as_deref(),
|
||||||
|
Some("hidden reasoning")
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
|
use crate::bus::message::ContentBlock;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use crate::bus::message::ContentBlock;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct Message {
|
pub struct Message {
|
||||||
@ -61,7 +61,11 @@ impl Message {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self {
|
pub fn tool(
|
||||||
|
tool_call_id: impl Into<String>,
|
||||||
|
tool_name: impl Into<String>,
|
||||||
|
content: impl Into<String>,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
role: "tool".to_string(),
|
role: "tool".to_string(),
|
||||||
content: vec![ContentBlock::text(content)],
|
content: vec![ContentBlock::text(content)],
|
||||||
|
|||||||
@ -8,11 +8,11 @@ use tokio::sync::watch;
|
|||||||
|
|
||||||
use crate::bus::{MessageBus, OutboundMessage};
|
use crate::bus::{MessageBus, OutboundMessage};
|
||||||
use crate::config::{
|
use crate::config::{
|
||||||
SchedulerConfig, SchedulerJobConfig, SchedulerJobKind, SchedulerJobTarget, SchedulerMisfirePolicy,
|
SchedulerConfig, SchedulerJobConfig, SchedulerJobKind, SchedulerJobTarget,
|
||||||
SchedulerSchedule,
|
SchedulerMisfirePolicy, SchedulerSchedule,
|
||||||
};
|
};
|
||||||
use crate::gateway::session::SessionManager;
|
|
||||||
use crate::gateway::session::ScheduledAgentTaskOptions;
|
use crate::gateway::session::ScheduledAgentTaskOptions;
|
||||||
|
use crate::gateway::session::SessionManager;
|
||||||
use crate::storage::{
|
use crate::storage::{
|
||||||
SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus, SchedulerJobUpsert, SessionStore,
|
SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus, SchedulerJobUpsert, SessionStore,
|
||||||
};
|
};
|
||||||
@ -76,8 +76,11 @@ impl Scheduler {
|
|||||||
|
|
||||||
fn sync_config_jobs(&self) -> anyhow::Result<()> {
|
fn sync_config_jobs(&self) -> anyhow::Result<()> {
|
||||||
let now = Utc::now();
|
let now = Utc::now();
|
||||||
for job in self.config.effective_jobs(&crate::config::TimeConfig { timezone: self.timezone.name().to_string() }) {
|
for job in self.config.effective_jobs(&crate::config::TimeConfig {
|
||||||
let runtime = RuntimeJob::from_config(&job, now, self.config.misfire_policy, self.timezone)?;
|
timezone: self.timezone.name().to_string(),
|
||||||
|
}) {
|
||||||
|
let runtime =
|
||||||
|
RuntimeJob::from_config(&job, now, self.config.misfire_policy, self.timezone)?;
|
||||||
self.store.upsert_scheduler_job(&runtime.to_upsert())?;
|
self.store.upsert_scheduler_job(&runtime.to_upsert())?;
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -88,7 +91,9 @@ impl Scheduler {
|
|||||||
let jobs = self.store.list_scheduler_jobs(true)?;
|
let jobs = self.store.list_scheduler_jobs(true)?;
|
||||||
|
|
||||||
for record in jobs {
|
for record in jobs {
|
||||||
let Some(mut job) = RuntimeJob::from_record(&record, self.config.misfire_policy, self.timezone)? else {
|
let Some(mut job) =
|
||||||
|
RuntimeJob::from_record(&record, self.config.misfire_policy, self.timezone)?
|
||||||
|
else {
|
||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -178,8 +183,12 @@ impl Scheduler {
|
|||||||
}
|
}
|
||||||
SchedulerJobKind::SilentAgentTask => {
|
SchedulerJobKind::SilentAgentTask => {
|
||||||
let execution_chat_id = resolve_execution_chat_id(job)?;
|
let execution_chat_id = resolve_execution_chat_id(job)?;
|
||||||
if let Err(error) = execute_agent_task(&self.session_manager, job, &execution_chat_id).await {
|
if let Err(error) =
|
||||||
if let Err(notify_error) = self.notify_silent_agent_task_failure(job, &error).await {
|
execute_agent_task(&self.session_manager, job, &execution_chat_id).await
|
||||||
|
{
|
||||||
|
if let Err(notify_error) =
|
||||||
|
self.notify_silent_agent_task_failure(job, &error).await
|
||||||
|
{
|
||||||
tracing::error!(
|
tracing::error!(
|
||||||
job_id = %job.id,
|
job_id = %job.id,
|
||||||
error = %notify_error,
|
error = %notify_error,
|
||||||
@ -208,10 +217,13 @@ impl Scheduler {
|
|||||||
|
|
||||||
let mut metadata = HashMap::new();
|
let mut metadata = HashMap::new();
|
||||||
metadata.insert("scheduler_job_id".to_string(), job.id.clone());
|
metadata.insert("scheduler_job_id".to_string(), job.id.clone());
|
||||||
metadata.insert("scheduler_job_kind".to_string(), "silent_agent_task".to_string());
|
metadata.insert(
|
||||||
|
"scheduler_job_kind".to_string(),
|
||||||
|
"silent_agent_task".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
self.bus
|
self.bus
|
||||||
.publish_outbound(OutboundMessage::assistant(
|
.publish_outbound(OutboundMessage::error_notification(
|
||||||
channel,
|
channel,
|
||||||
chat_id,
|
chat_id,
|
||||||
format!(
|
format!(
|
||||||
@ -304,14 +316,24 @@ impl RuntimeJob {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let schedule = deserialize_schedule(&record.schedule, record.interval_secs, record.startup_delay_secs)?;
|
let schedule = deserialize_schedule(
|
||||||
|
&record.schedule,
|
||||||
|
record.interval_secs,
|
||||||
|
record.startup_delay_secs,
|
||||||
|
)?;
|
||||||
let now = Utc::now();
|
let now = Utc::now();
|
||||||
let next_fire_at = match (record.enabled, record.state.clone(), record.next_fire_at) {
|
let next_fire_at = match (record.enabled, record.state.clone(), record.next_fire_at) {
|
||||||
(false, _, _) => None,
|
(false, _, _) => None,
|
||||||
(_, SchedulerJobState::Paused, _) => None,
|
(_, SchedulerJobState::Paused, _) => None,
|
||||||
(_, SchedulerJobState::Completed, _) => None,
|
(_, SchedulerJobState::Completed, _) => None,
|
||||||
(_, _, some_next) if some_next.is_some() => some_next,
|
(_, _, some_next) if some_next.is_some() => some_next,
|
||||||
_ => compute_initial_next_fire_at(&schedule, now, record.last_fired_at, misfire_policy, timezone)?,
|
_ => compute_initial_next_fire_at(
|
||||||
|
&schedule,
|
||||||
|
now,
|
||||||
|
record.last_fired_at,
|
||||||
|
misfire_policy,
|
||||||
|
timezone,
|
||||||
|
)?,
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Some(Self {
|
Ok(Some(Self {
|
||||||
@ -338,7 +360,10 @@ impl RuntimeJob {
|
|||||||
fn is_due(&self, now: DateTime<Utc>) -> bool {
|
fn is_due(&self, now: DateTime<Utc>) -> bool {
|
||||||
self.enabled
|
self.enabled
|
||||||
&& self.state == SchedulerJobState::Scheduled
|
&& self.state == SchedulerJobState::Scheduled
|
||||||
&& self.next_fire_at.map(|value| value <= now.timestamp_millis()).unwrap_or(false)
|
&& self
|
||||||
|
.next_fire_at
|
||||||
|
.map(|value| value <= now.timestamp_millis())
|
||||||
|
.unwrap_or(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn after_execution(
|
fn after_execution(
|
||||||
@ -371,7 +396,8 @@ impl RuntimeJob {
|
|||||||
let reference_ms = self.next_fire_at.or(self.last_fired_at);
|
let reference_ms = self.next_fire_at.or(self.last_fired_at);
|
||||||
self.state = SchedulerJobState::Scheduled;
|
self.state = SchedulerJobState::Scheduled;
|
||||||
self.completed_at = None;
|
self.completed_at = None;
|
||||||
self.next_fire_at = compute_next_fire_at(&self.schedule, now, reference_ms, misfire_policy, timezone)?;
|
self.next_fire_at =
|
||||||
|
compute_next_fire_at(&self.schedule, now, reference_ms, misfire_policy, timezone)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -384,7 +410,8 @@ impl RuntimeJob {
|
|||||||
SchedulerJobKind::AgentTask => "agent_task".to_string(),
|
SchedulerJobKind::AgentTask => "agent_task".to_string(),
|
||||||
SchedulerJobKind::SilentAgentTask => "silent_agent_task".to_string(),
|
SchedulerJobKind::SilentAgentTask => "silent_agent_task".to_string(),
|
||||||
},
|
},
|
||||||
schedule: serde_json::to_value(&self.schedule).unwrap_or_else(|_| serde_json::json!({})),
|
schedule: serde_json::to_value(&self.schedule)
|
||||||
|
.unwrap_or_else(|_| serde_json::json!({})),
|
||||||
interval_secs: self.interval_secs,
|
interval_secs: self.interval_secs,
|
||||||
startup_delay_secs: self.startup_delay_secs,
|
startup_delay_secs: self.startup_delay_secs,
|
||||||
target: serde_json::to_value(&self.target).unwrap_or_else(|_| serde_json::json!({})),
|
target: serde_json::to_value(&self.target).unwrap_or_else(|_| serde_json::json!({})),
|
||||||
@ -430,21 +457,36 @@ fn compute_initial_next_fire_at(
|
|||||||
timezone: Tz,
|
timezone: Tz,
|
||||||
) -> anyhow::Result<Option<i64>> {
|
) -> anyhow::Result<Option<i64>> {
|
||||||
match last_fired_at {
|
match last_fired_at {
|
||||||
Some(last_fired_at) => compute_next_fire_at(schedule, now, Some(last_fired_at), misfire_policy, timezone),
|
Some(last_fired_at) => {
|
||||||
|
compute_next_fire_at(schedule, now, Some(last_fired_at), misfire_policy, timezone)
|
||||||
|
}
|
||||||
None => match schedule {
|
None => match schedule {
|
||||||
SchedulerSchedule::Delay { seconds } => Ok(Some((now + ChronoDuration::seconds(*seconds as i64)).timestamp_millis())),
|
SchedulerSchedule::Delay { seconds } => Ok(Some(
|
||||||
|
(now + ChronoDuration::seconds(*seconds as i64)).timestamp_millis(),
|
||||||
|
)),
|
||||||
SchedulerSchedule::Interval {
|
SchedulerSchedule::Interval {
|
||||||
seconds,
|
seconds,
|
||||||
startup_delay_secs,
|
startup_delay_secs,
|
||||||
} => {
|
} => {
|
||||||
let delay = if *startup_delay_secs > 0 { *startup_delay_secs } else { *seconds };
|
let delay = if *startup_delay_secs > 0 {
|
||||||
Ok(Some((now + ChronoDuration::seconds(delay as i64)).timestamp_millis()))
|
*startup_delay_secs
|
||||||
|
} else {
|
||||||
|
*seconds
|
||||||
|
};
|
||||||
|
Ok(Some(
|
||||||
|
(now + ChronoDuration::seconds(delay as i64)).timestamp_millis(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
SchedulerSchedule::At { timestamp } => {
|
||||||
|
Ok(Some(parse_rfc3339_to_utc(timestamp)?.timestamp_millis()))
|
||||||
}
|
}
|
||||||
SchedulerSchedule::At { timestamp } => Ok(Some(parse_rfc3339_to_utc(timestamp)?.timestamp_millis())),
|
|
||||||
SchedulerSchedule::Cron { expression } => {
|
SchedulerSchedule::Cron { expression } => {
|
||||||
let schedule = parse_scheduler_cron(expression)?;
|
let schedule = parse_scheduler_cron(expression)?;
|
||||||
let local_now = now.with_timezone(&timezone);
|
let local_now = now.with_timezone(&timezone);
|
||||||
Ok(schedule.after(&local_now).next().map(|next| next.with_timezone(&Utc).timestamp_millis()))
|
Ok(schedule
|
||||||
|
.after(&local_now)
|
||||||
|
.next()
|
||||||
|
.map(|next| next.with_timezone(&Utc).timestamp_millis()))
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -483,7 +525,10 @@ fn compute_next_fire_at(
|
|||||||
.map(|value| value.with_timezone(&timezone))
|
.map(|value| value.with_timezone(&timezone))
|
||||||
.unwrap_or_else(|| now.with_timezone(&timezone)),
|
.unwrap_or_else(|| now.with_timezone(&timezone)),
|
||||||
};
|
};
|
||||||
Ok(schedule.after(&anchor).next().map(|next| next.with_timezone(&Utc).timestamp_millis()))
|
Ok(schedule
|
||||||
|
.after(&anchor)
|
||||||
|
.next()
|
||||||
|
.map(|next| next.with_timezone(&Utc).timestamp_millis()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -525,12 +570,14 @@ fn build_outbound_message(job: &RuntimeJob) -> anyhow::Result<OutboundMessage> {
|
|||||||
.payload
|
.payload
|
||||||
.get("content")
|
.get("content")
|
||||||
.and_then(|value| value.as_str())
|
.and_then(|value| value.as_str())
|
||||||
.ok_or_else(|| anyhow::anyhow!("outbound scheduler job payload.content must be a string"))?;
|
.ok_or_else(|| {
|
||||||
|
anyhow::anyhow!("outbound scheduler job payload.content must be a string")
|
||||||
|
})?;
|
||||||
|
|
||||||
let mut metadata = HashMap::new();
|
let mut metadata = HashMap::new();
|
||||||
metadata.insert("scheduler_job_id".to_string(), job.id.clone());
|
metadata.insert("scheduler_job_id".to_string(), job.id.clone());
|
||||||
|
|
||||||
Ok(OutboundMessage::assistant(
|
Ok(OutboundMessage::scheduler_notification(
|
||||||
channel,
|
channel,
|
||||||
chat_id,
|
chat_id,
|
||||||
content.to_string(),
|
content.to_string(),
|
||||||
@ -539,7 +586,10 @@ fn build_outbound_message(job: &RuntimeJob) -> anyhow::Result<OutboundMessage> {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn execute_internal_event(session_manager: &SessionManager, job: &RuntimeJob) -> anyhow::Result<()> {
|
async fn execute_internal_event(
|
||||||
|
session_manager: &SessionManager,
|
||||||
|
job: &RuntimeJob,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
let event = job
|
let event = job
|
||||||
.payload
|
.payload
|
||||||
.get("event")
|
.get("event")
|
||||||
@ -599,7 +649,10 @@ async fn execute_agent_task(
|
|||||||
.map_err(|error| anyhow::anyhow!(error.to_string()))
|
.map_err(|error| anyhow::anyhow!(error.to_string()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn required_notification_chat_id<'a>(job: &'a RuntimeJob, kind_name: &str) -> anyhow::Result<&'a str> {
|
fn required_notification_chat_id<'a>(
|
||||||
|
job: &'a RuntimeJob,
|
||||||
|
kind_name: &str,
|
||||||
|
) -> anyhow::Result<&'a str> {
|
||||||
job.target
|
job.target
|
||||||
.chat_id
|
.chat_id
|
||||||
.as_deref()
|
.as_deref()
|
||||||
@ -608,7 +661,9 @@ fn required_notification_chat_id<'a>(job: &'a RuntimeJob, kind_name: &str) -> an
|
|||||||
|
|
||||||
fn resolve_execution_chat_id(job: &RuntimeJob) -> anyhow::Result<String> {
|
fn resolve_execution_chat_id(job: &RuntimeJob) -> anyhow::Result<String> {
|
||||||
match job.kind {
|
match job.kind {
|
||||||
SchedulerJobKind::AgentTask => Ok(required_notification_chat_id(job, "agent_task")?.to_string()),
|
SchedulerJobKind::AgentTask => {
|
||||||
|
Ok(required_notification_chat_id(job, "agent_task")?.to_string())
|
||||||
|
}
|
||||||
SchedulerJobKind::SilentAgentTask => Ok(job
|
SchedulerJobKind::SilentAgentTask => Ok(job
|
||||||
.target
|
.target
|
||||||
.session_chat_id
|
.session_chat_id
|
||||||
@ -633,7 +688,9 @@ fn summarize_scheduler_error(error: &anyhow::Error) -> String {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_scheduled_agent_task_options(job: &RuntimeJob) -> anyhow::Result<ScheduledAgentTaskOptions> {
|
fn parse_scheduled_agent_task_options(
|
||||||
|
job: &RuntimeJob,
|
||||||
|
) -> anyhow::Result<ScheduledAgentTaskOptions> {
|
||||||
let sender_id = job
|
let sender_id = job
|
||||||
.payload
|
.payload
|
||||||
.get("sender_id")
|
.get("sender_id")
|
||||||
@ -665,7 +722,9 @@ fn parse_scheduled_agent_task_options(job: &RuntimeJob) -> anyhow::Result<Schedu
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_metadata_map(value: Option<&serde_json::Value>) -> anyhow::Result<HashMap<String, String>> {
|
fn parse_metadata_map(
|
||||||
|
value: Option<&serde_json::Value>,
|
||||||
|
) -> anyhow::Result<HashMap<String, String>> {
|
||||||
let Some(value) = value else {
|
let Some(value) = value else {
|
||||||
return Ok(HashMap::new());
|
return Ok(HashMap::new());
|
||||||
};
|
};
|
||||||
@ -685,7 +744,7 @@ fn parse_metadata_map(value: Option<&serde_json::Value>) -> anyhow::Result<HashM
|
|||||||
return Err(anyhow::anyhow!(
|
return Err(anyhow::anyhow!(
|
||||||
"agent_task payload.metadata field '{}' must be a string, number, bool, or null",
|
"agent_task payload.metadata field '{}' must be a string, number, bool, or null",
|
||||||
key
|
key
|
||||||
))
|
));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
metadata.insert(key.clone(), stringified);
|
metadata.insert(key.clone(), stringified);
|
||||||
@ -730,12 +789,19 @@ mod agent_task_tests {
|
|||||||
updated_at: 1_700_000_000_000,
|
updated_at: 1_700_000_000_000,
|
||||||
};
|
};
|
||||||
|
|
||||||
let job = RuntimeJob::from_record(&record, SchedulerMisfirePolicy::Skip, chrono_tz::Asia::Shanghai)
|
let job = RuntimeJob::from_record(
|
||||||
.unwrap()
|
&record,
|
||||||
.unwrap();
|
SchedulerMisfirePolicy::Skip,
|
||||||
|
chrono_tz::Asia::Shanghai,
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(job.kind, SchedulerJobKind::AgentTask);
|
assert_eq!(job.kind, SchedulerJobKind::AgentTask);
|
||||||
assert_eq!(job.payload.get("prompt").and_then(|value| value.as_str()), Some("请总结今天待办"));
|
assert_eq!(
|
||||||
|
job.payload.get("prompt").and_then(|value| value.as_str()),
|
||||||
|
Some("请总结今天待办")
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -771,12 +837,19 @@ mod agent_task_tests {
|
|||||||
updated_at: 1_700_000_000_000,
|
updated_at: 1_700_000_000_000,
|
||||||
};
|
};
|
||||||
|
|
||||||
let job = RuntimeJob::from_record(&record, SchedulerMisfirePolicy::Skip, chrono_tz::Asia::Shanghai)
|
let job = RuntimeJob::from_record(
|
||||||
.unwrap()
|
&record,
|
||||||
.unwrap();
|
SchedulerMisfirePolicy::Skip,
|
||||||
|
chrono_tz::Asia::Shanghai,
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(job.kind, SchedulerJobKind::SilentAgentTask);
|
assert_eq!(job.kind, SchedulerJobKind::SilentAgentTask);
|
||||||
assert_eq!(job.target.session_chat_id.as_deref(), Some("scheduler/agent.daily_summary.background"));
|
assert_eq!(
|
||||||
|
job.target.session_chat_id.as_deref(),
|
||||||
|
Some("scheduler/agent.daily_summary.background")
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -825,9 +898,18 @@ mod agent_task_tests {
|
|||||||
assert_eq!(options.sender_id.as_deref(), Some("scheduler-bot"));
|
assert_eq!(options.sender_id.as_deref(), Some("scheduler-bot"));
|
||||||
assert!(options.fresh_session);
|
assert!(options.fresh_session);
|
||||||
assert_eq!(options.system_prompt.as_deref(), Some("你是日报助手"));
|
assert_eq!(options.system_prompt.as_deref(), Some("你是日报助手"));
|
||||||
assert_eq!(options.metadata.get("job_type").map(String::as_str), Some("daily_summary"));
|
assert_eq!(
|
||||||
assert_eq!(options.metadata.get("priority").map(String::as_str), Some("1"));
|
options.metadata.get("job_type").map(String::as_str),
|
||||||
assert_eq!(options.metadata.get("urgent").map(String::as_str), Some("false"));
|
Some("daily_summary")
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
options.metadata.get("priority").map(String::as_str),
|
||||||
|
Some("1")
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
options.metadata.get("urgent").map(String::as_str),
|
||||||
|
Some("false")
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -880,12 +962,12 @@ impl TryFrom<serde_json::Value> for SchedulerJobTarget {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use std::collections::HashMap;
|
|
||||||
use crate::bus::MessageBus;
|
use crate::bus::MessageBus;
|
||||||
use crate::config::{BUILTIN_MEMORY_MAINTENANCE_JOB_ID, LLMProviderConfig};
|
use crate::config::{BUILTIN_MEMORY_MAINTENANCE_JOB_ID, LLMProviderConfig};
|
||||||
use crate::gateway::session::SessionManager;
|
use crate::gateway::session::SessionManager;
|
||||||
use crate::skills::SkillRuntime;
|
use crate::skills::SkillRuntime;
|
||||||
use crate::storage::{SchedulerJobUpsert, SessionStore};
|
use crate::storage::{SchedulerJobUpsert, SessionStore};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
fn test_provider_config() -> LLMProviderConfig {
|
fn test_provider_config() -> LLMProviderConfig {
|
||||||
LLMProviderConfig {
|
LLMProviderConfig {
|
||||||
@ -898,6 +980,7 @@ mod tests {
|
|||||||
model_id: "test-model".to_string(),
|
model_id: "test-model".to_string(),
|
||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
max_tokens: None,
|
max_tokens: None,
|
||||||
|
context_window_tokens: None,
|
||||||
model_extra: HashMap::new(),
|
model_extra: HashMap::new(),
|
||||||
max_tool_iterations: 4,
|
max_tool_iterations: 4,
|
||||||
tool_result_max_chars: 20_000,
|
tool_result_max_chars: 20_000,
|
||||||
@ -921,7 +1004,10 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn runtime_job_skip_policy_advances_from_now() {
|
fn runtime_job_skip_policy_advances_from_now() {
|
||||||
let now = Utc.timestamp_millis_opt(1_700_000_000_000).single().unwrap();
|
let now = Utc
|
||||||
|
.timestamp_millis_opt(1_700_000_000_000)
|
||||||
|
.single()
|
||||||
|
.unwrap();
|
||||||
let next = compute_next_fire_at(
|
let next = compute_next_fire_at(
|
||||||
&SchedulerSchedule::Interval {
|
&SchedulerSchedule::Interval {
|
||||||
seconds: 60,
|
seconds: 60,
|
||||||
@ -940,7 +1026,10 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn runtime_job_catch_up_policy_moves_past_now() {
|
fn runtime_job_catch_up_policy_moves_past_now() {
|
||||||
let now = Utc.timestamp_millis_opt(1_700_000_000_000).single().unwrap();
|
let now = Utc
|
||||||
|
.timestamp_millis_opt(1_700_000_000_000)
|
||||||
|
.single()
|
||||||
|
.unwrap();
|
||||||
let next = compute_next_fire_at(
|
let next = compute_next_fire_at(
|
||||||
&SchedulerSchedule::Interval {
|
&SchedulerSchedule::Interval {
|
||||||
seconds: 60,
|
seconds: 60,
|
||||||
@ -989,14 +1078,21 @@ mod tests {
|
|||||||
updated_at: 1_700_000_000_000,
|
updated_at: 1_700_000_000_000,
|
||||||
};
|
};
|
||||||
|
|
||||||
let job = RuntimeJob::from_record(&record, SchedulerMisfirePolicy::Skip, chrono_tz::Asia::Shanghai)
|
let job = RuntimeJob::from_record(
|
||||||
.unwrap()
|
&record,
|
||||||
.unwrap();
|
SchedulerMisfirePolicy::Skip,
|
||||||
|
chrono_tz::Asia::Shanghai,
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(job.schedule, SchedulerSchedule::Interval {
|
assert_eq!(
|
||||||
seconds: 120,
|
job.schedule,
|
||||||
startup_delay_secs: 10,
|
SchedulerSchedule::Interval {
|
||||||
});
|
seconds: 120,
|
||||||
|
startup_delay_secs: 10,
|
||||||
|
}
|
||||||
|
);
|
||||||
assert_eq!(job.next_fire_at, Some(1_700_000_010_000));
|
assert_eq!(job.next_fire_at, Some(1_700_000_010_000));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1050,7 +1146,10 @@ mod tests {
|
|||||||
|
|
||||||
scheduler.process_tick().await.unwrap();
|
scheduler.process_tick().await.unwrap();
|
||||||
|
|
||||||
let saved = store.get_scheduler_job("massage_reminder").unwrap().unwrap();
|
let saved = store
|
||||||
|
.get_scheduler_job("massage_reminder")
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
assert!(saved.next_fire_at.is_some());
|
assert!(saved.next_fire_at.is_some());
|
||||||
assert_eq!(saved.run_count, 0);
|
assert_eq!(saved.run_count, 0);
|
||||||
assert_eq!(saved.state, SchedulerJobState::Scheduled);
|
assert_eq!(saved.state, SchedulerJobState::Scheduled);
|
||||||
@ -1080,7 +1179,10 @@ mod tests {
|
|||||||
assert_eq!(saved.kind, "internal_event");
|
assert_eq!(saved.kind, "internal_event");
|
||||||
assert!(saved.enabled);
|
assert!(saved.enabled);
|
||||||
assert_eq!(saved.state, SchedulerJobState::Scheduled);
|
assert_eq!(saved.state, SchedulerJobState::Scheduled);
|
||||||
assert_eq!(saved.payload.get("event").and_then(|value| value.as_str()), Some("memory_maintenance"));
|
assert_eq!(
|
||||||
|
saved.payload.get("event").and_then(|value| value.as_str()),
|
||||||
|
Some("memory_maintenance")
|
||||||
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
saved.schedule,
|
saved.schedule,
|
||||||
serde_json::json!({
|
serde_json::json!({
|
||||||
@ -1088,7 +1190,13 @@ mod tests {
|
|||||||
"expression": "0 */4 * * *"
|
"expression": "0 */4 * * *"
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
assert_eq!(saved.payload.get("local_time").and_then(|value| value.as_str()), Some("every_4_hours"));
|
assert_eq!(
|
||||||
|
saved
|
||||||
|
.payload
|
||||||
|
.get("local_time")
|
||||||
|
.and_then(|value| value.as_str()),
|
||||||
|
Some("every_4_hours")
|
||||||
|
);
|
||||||
assert!(saved.next_fire_at.is_some());
|
assert!(saved.next_fire_at.is_some());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1155,7 +1263,10 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn cron_schedule_uses_configured_timezone() {
|
fn cron_schedule_uses_configured_timezone() {
|
||||||
let now = Utc.with_ymd_and_hms(2026, 4, 23, 18, 0, 0).single().unwrap();
|
let now = Utc
|
||||||
|
.with_ymd_and_hms(2026, 4, 23, 18, 0, 0)
|
||||||
|
.single()
|
||||||
|
.unwrap();
|
||||||
let next = compute_next_fire_at(
|
let next = compute_next_fire_at(
|
||||||
&SchedulerSchedule::Cron {
|
&SchedulerSchedule::Cron {
|
||||||
expression: "0 3 * * *".to_string(),
|
expression: "0 3 * * *".to_string(),
|
||||||
@ -1169,6 +1280,11 @@ mod tests {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let next_utc = ts_millis_to_utc(next).unwrap();
|
let next_utc = ts_millis_to_utc(next).unwrap();
|
||||||
assert_eq!(next_utc, Utc.with_ymd_and_hms(2026, 4, 23, 19, 0, 0).single().unwrap());
|
assert_eq!(
|
||||||
|
next_utc,
|
||||||
|
Utc.with_ymd_and_hms(2026, 4, 23, 19, 0, 0)
|
||||||
|
.single()
|
||||||
|
.unwrap()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -89,7 +89,10 @@ impl SkillRuntime {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_empty(&self) -> bool {
|
pub fn is_empty(&self) -> bool {
|
||||||
self.catalog.read().expect("skills rwlock poisoned").is_empty()
|
self.catalog
|
||||||
|
.read()
|
||||||
|
.expect("skills rwlock poisoned")
|
||||||
|
.is_empty()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn len(&self) -> usize {
|
pub fn len(&self) -> usize {
|
||||||
@ -97,31 +100,53 @@ impl SkillRuntime {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn system_index_prompt(&self) -> Option<String> {
|
pub fn system_index_prompt(&self) -> Option<String> {
|
||||||
self.catalog.read().expect("skills rwlock poisoned").system_index_prompt()
|
self.catalog
|
||||||
|
.read()
|
||||||
|
.expect("skills rwlock poisoned")
|
||||||
|
.system_index_prompt()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn discovery_event_payload(&self) -> serde_json::Value {
|
pub fn discovery_event_payload(&self) -> serde_json::Value {
|
||||||
self.catalog.read().expect("skills rwlock poisoned").discovery_event_payload()
|
self.catalog
|
||||||
|
.read()
|
||||||
|
.expect("skills rwlock poisoned")
|
||||||
|
.discovery_event_payload()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn offered_event_payload(&self) -> serde_json::Value {
|
pub fn offered_event_payload(&self) -> serde_json::Value {
|
||||||
self.catalog.read().expect("skills rwlock poisoned").offered_event_payload()
|
self.catalog
|
||||||
|
.read()
|
||||||
|
.expect("skills rwlock poisoned")
|
||||||
|
.offered_event_payload()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn skill_tool_definition(&self) -> Option<Tool> {
|
pub fn skill_tool_definition(&self) -> Option<Tool> {
|
||||||
self.catalog.read().expect("skills rwlock poisoned").skill_tool_definition()
|
self.catalog
|
||||||
|
.read()
|
||||||
|
.expect("skills rwlock poisoned")
|
||||||
|
.skill_tool_definition()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn activation_payload(&self, name: &str) -> Result<String, String> {
|
pub fn activation_payload(&self, name: &str) -> Result<String, String> {
|
||||||
self.catalog.read().expect("skills rwlock poisoned").activation_payload(name)
|
self.catalog
|
||||||
|
.read()
|
||||||
|
.expect("skills rwlock poisoned")
|
||||||
|
.activation_payload(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn activation_event_payload(&self, name: &str) -> Result<serde_json::Value, String> {
|
pub fn activation_event_payload(&self, name: &str) -> Result<serde_json::Value, String> {
|
||||||
self.catalog.read().expect("skills rwlock poisoned").activation_event_payload(name)
|
self.catalog
|
||||||
|
.read()
|
||||||
|
.expect("skills rwlock poisoned")
|
||||||
|
.activation_event_payload(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn list_skills(&self) -> Vec<Skill> {
|
pub fn list_skills(&self) -> Vec<Skill> {
|
||||||
self.catalog.read().expect("skills rwlock poisoned").skills.clone()
|
self.catalog
|
||||||
|
.read()
|
||||||
|
.expect("skills rwlock poisoned")
|
||||||
|
.skills
|
||||||
|
.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_skill(&self, name: &str) -> Option<Skill> {
|
pub fn get_skill(&self, name: &str) -> Option<Skill> {
|
||||||
@ -143,7 +168,11 @@ impl SkillRuntime {
|
|||||||
validate_skill_name(name)?;
|
validate_skill_name(name)?;
|
||||||
let path = skill_file_path(scope, name)?;
|
let path = skill_file_path(scope, name)?;
|
||||||
if path.exists() {
|
if path.exists() {
|
||||||
return Err(format!("skill '{}' already exists at {}", name, path.display()));
|
return Err(format!(
|
||||||
|
"skill '{}' already exists at {}",
|
||||||
|
name,
|
||||||
|
path.display()
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
write_skill_file(&path, name, description, body)?;
|
write_skill_file(&path, name, description, body)?;
|
||||||
@ -180,14 +209,20 @@ impl SkillRuntime {
|
|||||||
Ok(skill)
|
Ok(skill)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn delete_skill(&self, scope: SkillScope, name: &str, reload: bool) -> Result<PathBuf, String> {
|
pub fn delete_skill(
|
||||||
|
&self,
|
||||||
|
scope: SkillScope,
|
||||||
|
name: &str,
|
||||||
|
reload: bool,
|
||||||
|
) -> Result<PathBuf, String> {
|
||||||
validate_skill_name(name)?;
|
validate_skill_name(name)?;
|
||||||
let dir = skill_dir_path(scope, name)?;
|
let dir = skill_dir_path(scope, name)?;
|
||||||
if !dir.exists() {
|
if !dir.exists() {
|
||||||
return Err(format!("skill '{}' not found at {}", name, dir.display()));
|
return Err(format!("skill '{}' not found at {}", name, dir.display()));
|
||||||
}
|
}
|
||||||
|
|
||||||
fs::remove_dir_all(&dir).map_err(|err| format!("failed to delete skill directory: {}", err))?;
|
fs::remove_dir_all(&dir)
|
||||||
|
.map_err(|err| format!("failed to delete skill directory: {}", err))?;
|
||||||
if reload {
|
if reload {
|
||||||
let _ = self.reload()?;
|
let _ = self.reload()?;
|
||||||
}
|
}
|
||||||
@ -439,7 +474,8 @@ fn validate_skill_name(name: &str) -> Result<(), String> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn project_skills_root() -> Result<PathBuf, String> {
|
pub fn project_skills_root() -> Result<PathBuf, String> {
|
||||||
let cwd = std::env::current_dir().map_err(|err| format!("failed to get current dir: {}", err))?;
|
let cwd =
|
||||||
|
std::env::current_dir().map_err(|err| format!("failed to get current dir: {}", err))?;
|
||||||
Ok(cwd.join(".picobot").join("skills"))
|
Ok(cwd.join(".picobot").join("skills"))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -466,7 +502,9 @@ fn source_root(source: SkillSource, cwd: &Path) -> Option<PathBuf> {
|
|||||||
|
|
||||||
fn root_for_scope(scope: SkillScope) -> Result<PathBuf, String> {
|
fn root_for_scope(scope: SkillScope) -> Result<PathBuf, String> {
|
||||||
match scope {
|
match scope {
|
||||||
SkillScope::User => user_skills_root().ok_or_else(|| "failed to resolve home directory".to_string()),
|
SkillScope::User => {
|
||||||
|
user_skills_root().ok_or_else(|| "failed to resolve home directory".to_string())
|
||||||
|
}
|
||||||
SkillScope::Project => project_skills_root(),
|
SkillScope::Project => project_skills_root(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -508,7 +546,8 @@ fn render_skill_file(name: &str, description: &str, body: &str) -> Result<String
|
|||||||
fn write_skill_file(path: &Path, name: &str, description: &str, body: &str) -> Result<(), String> {
|
fn write_skill_file(path: &Path, name: &str, description: &str, body: &str) -> Result<(), String> {
|
||||||
let content = render_skill_file(name, description, body)?;
|
let content = render_skill_file(name, description, body)?;
|
||||||
if let Some(parent) = path.parent() {
|
if let Some(parent) = path.parent() {
|
||||||
fs::create_dir_all(parent).map_err(|err| format!("failed to create skill directory: {}", err))?;
|
fs::create_dir_all(parent)
|
||||||
|
.map_err(|err| format!("failed to create skill directory: {}", err))?;
|
||||||
}
|
}
|
||||||
fs::write(path, content).map_err(|err| format!("failed to write skill file: {}", err))
|
fs::write(path, content).map_err(|err| format!("failed to write skill file: {}", err))
|
||||||
}
|
}
|
||||||
@ -556,11 +595,10 @@ struct SkillFrontmatter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn parse_skill_file(path: &Path, source: SkillSource) -> Result<Skill, String> {
|
fn parse_skill_file(path: &Path, source: SkillSource) -> Result<Skill, String> {
|
||||||
let content = fs::read_to_string(path)
|
let content = fs::read_to_string(path).map_err(|e| format!("failed to read file: {}", e))?;
|
||||||
.map_err(|e| format!("failed to read file: {}", e))?;
|
|
||||||
|
|
||||||
let (frontmatter_raw, body) = split_frontmatter(&content)
|
let (frontmatter_raw, body) =
|
||||||
.ok_or_else(|| "missing YAML frontmatter block".to_string())?;
|
split_frontmatter(&content).ok_or_else(|| "missing YAML frontmatter block".to_string())?;
|
||||||
|
|
||||||
let frontmatter: SkillFrontmatter = serde_yaml::from_str(frontmatter_raw)
|
let frontmatter: SkillFrontmatter = serde_yaml::from_str(frontmatter_raw)
|
||||||
.map_err(|e| format!("invalid YAML frontmatter: {}", e))?;
|
.map_err(|e| format!("invalid YAML frontmatter: {}", e))?;
|
||||||
@ -576,11 +614,7 @@ fn parse_skill_file(path: &Path, source: SkillSource) -> Result<Skill, String> {
|
|||||||
.map(|s| s.to_string_lossy().to_string())
|
.map(|s| s.to_string_lossy().to_string())
|
||||||
.unwrap_or_else(|| "unknown-skill".to_string());
|
.unwrap_or_else(|| "unknown-skill".to_string());
|
||||||
|
|
||||||
let name = frontmatter
|
let name = frontmatter.name.unwrap_or(dir_name).trim().to_string();
|
||||||
.name
|
|
||||||
.unwrap_or(dir_name)
|
|
||||||
.trim()
|
|
||||||
.to_string();
|
|
||||||
|
|
||||||
Ok(Skill {
|
Ok(Skill {
|
||||||
name,
|
name,
|
||||||
@ -656,7 +690,10 @@ mod tests {
|
|||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let skills = load_skills_from_root(&dir.path().join(".picobot").join("skills"), SkillSource::Project);
|
let skills = load_skills_from_root(
|
||||||
|
&dir.path().join(".picobot").join("skills"),
|
||||||
|
SkillSource::Project,
|
||||||
|
);
|
||||||
let catalog = SkillCatalog {
|
let catalog = SkillCatalog {
|
||||||
skills,
|
skills,
|
||||||
max_index_chars: 4000,
|
max_index_chars: 4000,
|
||||||
@ -707,7 +744,13 @@ mod tests {
|
|||||||
assert_eq!(runtime.len(), 0);
|
assert_eq!(runtime.len(), 0);
|
||||||
|
|
||||||
let created = runtime
|
let created = runtime
|
||||||
.create_skill(SkillScope::Project, "demo-skill", "demo desc", "line 1", true)
|
.create_skill(
|
||||||
|
SkillScope::Project,
|
||||||
|
"demo-skill",
|
||||||
|
"demo desc",
|
||||||
|
"line 1",
|
||||||
|
true,
|
||||||
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(created.name, "demo-skill");
|
assert_eq!(created.name, "demo-skill");
|
||||||
assert_eq!(runtime.len(), 1);
|
assert_eq!(runtime.len(), 1);
|
||||||
@ -722,7 +765,12 @@ mod tests {
|
|||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(updated.description, "updated desc");
|
assert_eq!(updated.description, "updated desc");
|
||||||
assert!(runtime.activation_payload("demo-skill").unwrap().contains("line 2"));
|
assert!(
|
||||||
|
runtime
|
||||||
|
.activation_payload("demo-skill")
|
||||||
|
.unwrap()
|
||||||
|
.contains("line 2")
|
||||||
|
);
|
||||||
|
|
||||||
let deleted_path = runtime
|
let deleted_path = runtime
|
||||||
.delete_skill(SkillScope::Project, "demo-skill", true)
|
.delete_skill(SkillScope::Project, "demo-skill", true)
|
||||||
@ -759,7 +807,11 @@ mod tests {
|
|||||||
let temp_dir = tempfile::tempdir().unwrap();
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
let _guard = CurrentDirGuard::enter(temp_dir.path());
|
let _guard = CurrentDirGuard::enter(temp_dir.path());
|
||||||
|
|
||||||
let agent_skill_dir = temp_dir.path().join(".agents").join("skills").join("demo-agent");
|
let agent_skill_dir = temp_dir
|
||||||
|
.path()
|
||||||
|
.join(".agents")
|
||||||
|
.join("skills")
|
||||||
|
.join("demo-agent");
|
||||||
fs::create_dir_all(&agent_skill_dir).unwrap();
|
fs::create_dir_all(&agent_skill_dir).unwrap();
|
||||||
fs::write(
|
fs::write(
|
||||||
agent_skill_dir.join("SKILL.md"),
|
agent_skill_dir.join("SKILL.md"),
|
||||||
|
|||||||
@ -391,7 +391,8 @@ impl SessionStore {
|
|||||||
)?;
|
)?;
|
||||||
|
|
||||||
drop(conn);
|
drop(conn);
|
||||||
self.get_session(&id)?.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
|
self.get_session(&id)?
|
||||||
|
.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn ensure_channel_session(
|
pub fn ensure_channel_session(
|
||||||
@ -419,7 +420,8 @@ impl SessionStore {
|
|||||||
)?;
|
)?;
|
||||||
drop(conn);
|
drop(conn);
|
||||||
|
|
||||||
self.get_session(&session_id)?.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
|
self.get_session(&session_id)?
|
||||||
|
.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_session(&self, session_id: &str) -> Result<Option<SessionRecord>, StorageError> {
|
pub fn get_session(&self, session_id: &str) -> Result<Option<SessionRecord>, StorageError> {
|
||||||
@ -495,7 +497,10 @@ impl SessionStore {
|
|||||||
|
|
||||||
pub fn delete_session(&self, session_id: &str) -> Result<(), StorageError> {
|
pub fn delete_session(&self, session_id: &str) -> Result<(), StorageError> {
|
||||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||||
conn.execute("DELETE FROM messages WHERE session_id = ?1", params![session_id])?;
|
conn.execute(
|
||||||
|
"DELETE FROM messages WHERE session_id = ?1",
|
||||||
|
params![session_id],
|
||||||
|
)?;
|
||||||
conn.execute("DELETE FROM sessions WHERE id = ?1", params![session_id])?;
|
conn.execute("DELETE FROM sessions WHERE id = ?1", params![session_id])?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -503,7 +508,10 @@ impl SessionStore {
|
|||||||
pub fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> {
|
pub fn clear_messages(&self, session_id: &str) -> Result<(), StorageError> {
|
||||||
let now = current_timestamp();
|
let now = current_timestamp();
|
||||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||||
conn.execute("DELETE FROM messages WHERE session_id = ?1", params![session_id])?;
|
conn.execute(
|
||||||
|
"DELETE FROM messages WHERE session_id = ?1",
|
||||||
|
params![session_id],
|
||||||
|
)?;
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"
|
"
|
||||||
UPDATE sessions
|
UPDATE sessions
|
||||||
@ -549,7 +557,11 @@ impl SessionStore {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn append_message(&self, session_id: &str, message: &ChatMessage) -> Result<(), StorageError> {
|
pub fn append_message(
|
||||||
|
&self,
|
||||||
|
session_id: &str,
|
||||||
|
message: &ChatMessage,
|
||||||
|
) -> Result<(), StorageError> {
|
||||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||||
let tx = conn.unchecked_transaction()?;
|
let tx = conn.unchecked_transaction()?;
|
||||||
|
|
||||||
@ -560,7 +572,11 @@ impl SessionStore {
|
|||||||
)?;
|
)?;
|
||||||
|
|
||||||
let media_refs_json = serde_json::to_string(&message.media_refs)?;
|
let media_refs_json = serde_json::to_string(&message.media_refs)?;
|
||||||
let tool_calls_json = message.tool_calls.as_ref().map(serde_json::to_string).transpose()?;
|
let tool_calls_json = message
|
||||||
|
.tool_calls
|
||||||
|
.as_ref()
|
||||||
|
.map(serde_json::to_string)
|
||||||
|
.transpose()?;
|
||||||
tx.execute(
|
tx.execute(
|
||||||
"
|
"
|
||||||
INSERT INTO messages (
|
INSERT INTO messages (
|
||||||
@ -630,7 +646,8 @@ impl SessionStore {
|
|||||||
return Ok(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
let delta_messages = load_messages_between(&tx, session_id, snapshot_end_seq, current_max_seq)?;
|
let delta_messages =
|
||||||
|
load_messages_between(&tx, session_id, snapshot_end_seq, current_max_seq)?;
|
||||||
let mut next_seq = current_max_seq + 1;
|
let mut next_seq = current_max_seq + 1;
|
||||||
let now = current_timestamp();
|
let now = current_timestamp();
|
||||||
let mut inserted_count = 0_i64;
|
let mut inserted_count = 0_i64;
|
||||||
@ -782,8 +799,7 @@ impl SessionStore {
|
|||||||
)
|
)
|
||||||
.optional()?;
|
.optional()?;
|
||||||
|
|
||||||
let (id, created_at) = existing
|
let (id, created_at) = existing.unwrap_or_else(|| (uuid::Uuid::new_v4().to_string(), now));
|
||||||
.unwrap_or_else(|| (uuid::Uuid::new_v4().to_string(), now));
|
|
||||||
|
|
||||||
tx.execute(
|
tx.execute(
|
||||||
"
|
"
|
||||||
@ -881,7 +897,10 @@ impl SessionStore {
|
|||||||
LIMIT ?4
|
LIMIT ?4
|
||||||
",
|
",
|
||||||
)?;
|
)?;
|
||||||
let rows = stmt.query_map(params![scope_kind, scope_key, namespace, limit], map_memory_record)?;
|
let rows = stmt.query_map(
|
||||||
|
params![scope_kind, scope_key, namespace, limit],
|
||||||
|
map_memory_record,
|
||||||
|
)?;
|
||||||
for row in rows {
|
for row in rows {
|
||||||
memories.push(row?);
|
memories.push(row?);
|
||||||
}
|
}
|
||||||
@ -940,7 +959,9 @@ impl SessionStore {
|
|||||||
",
|
",
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let rows = stmt.query_map(params![scope_kind, since_timestamp], |row| row.get::<_, String>(0))?;
|
let rows = stmt.query_map(params![scope_kind, since_timestamp], |row| {
|
||||||
|
row.get::<_, String>(0)
|
||||||
|
})?;
|
||||||
let mut scope_keys = Vec::new();
|
let mut scope_keys = Vec::new();
|
||||||
for row in rows {
|
for row in rows {
|
||||||
scope_keys.push(row?);
|
scope_keys.push(row?);
|
||||||
@ -1010,7 +1031,10 @@ impl SessionStore {
|
|||||||
Ok(changed > 0)
|
Ok(changed > 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn upsert_scheduler_job(&self, input: &SchedulerJobUpsert) -> Result<SchedulerJobRecord, StorageError> {
|
pub fn upsert_scheduler_job(
|
||||||
|
&self,
|
||||||
|
input: &SchedulerJobUpsert,
|
||||||
|
) -> Result<SchedulerJobRecord, StorageError> {
|
||||||
let now = current_timestamp();
|
let now = current_timestamp();
|
||||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||||
conn.execute(
|
conn.execute(
|
||||||
@ -1067,7 +1091,10 @@ impl SessionStore {
|
|||||||
.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
|
.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_scheduler_job(&self, job_id: &str) -> Result<Option<SchedulerJobRecord>, StorageError> {
|
pub fn get_scheduler_job(
|
||||||
|
&self,
|
||||||
|
job_id: &str,
|
||||||
|
) -> Result<Option<SchedulerJobRecord>, StorageError> {
|
||||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||||
let mut stmt = conn.prepare(
|
let mut stmt = conn.prepare(
|
||||||
"
|
"
|
||||||
@ -1085,7 +1112,10 @@ impl SessionStore {
|
|||||||
.map_err(StorageError::from)
|
.map_err(StorageError::from)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn list_scheduler_jobs(&self, enabled_only: bool) -> Result<Vec<SchedulerJobRecord>, StorageError> {
|
pub fn list_scheduler_jobs(
|
||||||
|
&self,
|
||||||
|
enabled_only: bool,
|
||||||
|
) -> Result<Vec<SchedulerJobRecord>, StorageError> {
|
||||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||||
let sql = if enabled_only {
|
let sql = if enabled_only {
|
||||||
"
|
"
|
||||||
@ -1195,7 +1225,10 @@ impl SessionStore {
|
|||||||
LIMIT ?5
|
LIMIT ?5
|
||||||
",
|
",
|
||||||
)?;
|
)?;
|
||||||
let rows = stmt.query_map(params![query, scope_kind, scope_key, namespace, limit], map_memory_record)?;
|
let rows = stmt.query_map(
|
||||||
|
params![query, scope_kind, scope_key, namespace, limit],
|
||||||
|
map_memory_record,
|
||||||
|
)?;
|
||||||
for row in rows {
|
for row in rows {
|
||||||
memories.push(row?);
|
memories.push(row?);
|
||||||
}
|
}
|
||||||
@ -1214,7 +1247,10 @@ impl SessionStore {
|
|||||||
LIMIT ?4
|
LIMIT ?4
|
||||||
",
|
",
|
||||||
)?;
|
)?;
|
||||||
let rows = stmt.query_map(params![query, scope_kind, scope_key, limit], map_memory_record)?;
|
let rows = stmt.query_map(
|
||||||
|
params![query, scope_kind, scope_key, limit],
|
||||||
|
map_memory_record,
|
||||||
|
)?;
|
||||||
for row in rows {
|
for row in rows {
|
||||||
memories.push(row?);
|
memories.push(row?);
|
||||||
}
|
}
|
||||||
@ -1256,7 +1292,10 @@ impl SessionStore {
|
|||||||
LIMIT ?5
|
LIMIT ?5
|
||||||
",
|
",
|
||||||
)?;
|
)?;
|
||||||
let rows = stmt.query_map(params![query, scope_kind, scope_key, namespace, limit], map_memory_record)?;
|
let rows = stmt.query_map(
|
||||||
|
params![query, scope_kind, scope_key, namespace, limit],
|
||||||
|
map_memory_record,
|
||||||
|
)?;
|
||||||
for row in rows {
|
for row in rows {
|
||||||
memories.push(row?);
|
memories.push(row?);
|
||||||
}
|
}
|
||||||
@ -1275,7 +1314,10 @@ impl SessionStore {
|
|||||||
LIMIT ?4
|
LIMIT ?4
|
||||||
",
|
",
|
||||||
)?;
|
)?;
|
||||||
let rows = stmt.query_map(params![query, scope_kind, scope_key, limit], map_memory_record)?;
|
let rows = stmt.query_map(
|
||||||
|
params![query, scope_kind, scope_key, limit],
|
||||||
|
map_memory_record,
|
||||||
|
)?;
|
||||||
for row in rows {
|
for row in rows {
|
||||||
memories.push(row?);
|
memories.push(row?);
|
||||||
}
|
}
|
||||||
@ -1347,11 +1389,7 @@ fn map_session_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<SessionRecord
|
|||||||
fn map_skill_event_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<SkillEventRecord> {
|
fn map_skill_event_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<SkillEventRecord> {
|
||||||
let payload_json: String = row.get(4)?;
|
let payload_json: String = row.get(4)?;
|
||||||
let payload = serde_json::from_str(&payload_json).map_err(|err| {
|
let payload = serde_json::from_str(&payload_json).map_err(|err| {
|
||||||
rusqlite::Error::FromSqlConversionFailure(
|
rusqlite::Error::FromSqlConversionFailure(4, rusqlite::types::Type::Text, Box::new(err))
|
||||||
4,
|
|
||||||
rusqlite::types::Type::Text,
|
|
||||||
Box::new(err),
|
|
||||||
)
|
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
Ok(SkillEventRecord {
|
Ok(SkillEventRecord {
|
||||||
@ -1391,25 +1429,13 @@ fn map_scheduler_job_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<Schedul
|
|||||||
let last_status: Option<String> = row.get(9)?;
|
let last_status: Option<String> = row.get(9)?;
|
||||||
|
|
||||||
let schedule = serde_json::from_str(&schedule_json).map_err(|err| {
|
let schedule = serde_json::from_str(&schedule_json).map_err(|err| {
|
||||||
rusqlite::Error::FromSqlConversionFailure(
|
rusqlite::Error::FromSqlConversionFailure(2, rusqlite::types::Type::Text, Box::new(err))
|
||||||
2,
|
|
||||||
rusqlite::types::Type::Text,
|
|
||||||
Box::new(err),
|
|
||||||
)
|
|
||||||
})?;
|
})?;
|
||||||
let target = serde_json::from_str(&target_json).map_err(|err| {
|
let target = serde_json::from_str(&target_json).map_err(|err| {
|
||||||
rusqlite::Error::FromSqlConversionFailure(
|
rusqlite::Error::FromSqlConversionFailure(5, rusqlite::types::Type::Text, Box::new(err))
|
||||||
5,
|
|
||||||
rusqlite::types::Type::Text,
|
|
||||||
Box::new(err),
|
|
||||||
)
|
|
||||||
})?;
|
})?;
|
||||||
let payload = serde_json::from_str(&payload_json).map_err(|err| {
|
let payload = serde_json::from_str(&payload_json).map_err(|err| {
|
||||||
rusqlite::Error::FromSqlConversionFailure(
|
rusqlite::Error::FromSqlConversionFailure(6, rusqlite::types::Type::Text, Box::new(err))
|
||||||
6,
|
|
||||||
rusqlite::types::Type::Text,
|
|
||||||
Box::new(err),
|
|
||||||
)
|
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
Ok(SchedulerJobRecord {
|
Ok(SchedulerJobRecord {
|
||||||
@ -1472,7 +1498,10 @@ fn ensure_messages_schema(conn: &Connection) -> Result<(), StorageError> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !has_column(conn, "messages", "reasoning_content")? {
|
if !has_column(conn, "messages", "reasoning_content")? {
|
||||||
add_column_if_missing(conn, "ALTER TABLE messages ADD COLUMN reasoning_content TEXT")?;
|
add_column_if_missing(
|
||||||
|
conn,
|
||||||
|
"ALTER TABLE messages ADD COLUMN reasoning_content TEXT",
|
||||||
|
)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -1494,17 +1523,11 @@ fn ensure_scheduler_schema(conn: &Connection) -> Result<(), StorageError> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !has_column(conn, "scheduler_jobs", "last_status")? {
|
if !has_column(conn, "scheduler_jobs", "last_status")? {
|
||||||
conn.execute(
|
conn.execute("ALTER TABLE scheduler_jobs ADD COLUMN last_status TEXT", [])?;
|
||||||
"ALTER TABLE scheduler_jobs ADD COLUMN last_status TEXT",
|
|
||||||
[],
|
|
||||||
)?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !has_column(conn, "scheduler_jobs", "last_error")? {
|
if !has_column(conn, "scheduler_jobs", "last_error")? {
|
||||||
conn.execute(
|
conn.execute("ALTER TABLE scheduler_jobs ADD COLUMN last_error TEXT", [])?;
|
||||||
"ALTER TABLE scheduler_jobs ADD COLUMN last_error TEXT",
|
|
||||||
[],
|
|
||||||
)?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !has_column(conn, "scheduler_jobs", "run_count")? {
|
if !has_column(conn, "scheduler_jobs", "run_count")? {
|
||||||
@ -1515,10 +1538,7 @@ fn ensure_scheduler_schema(conn: &Connection) -> Result<(), StorageError> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !has_column(conn, "scheduler_jobs", "max_runs")? {
|
if !has_column(conn, "scheduler_jobs", "max_runs")? {
|
||||||
conn.execute(
|
conn.execute("ALTER TABLE scheduler_jobs ADD COLUMN max_runs INTEGER", [])?;
|
||||||
"ALTER TABLE scheduler_jobs ADD COLUMN max_runs INTEGER",
|
|
||||||
[],
|
|
||||||
)?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !has_column(conn, "scheduler_jobs", "paused_at")? {
|
if !has_column(conn, "scheduler_jobs", "paused_at")? {
|
||||||
@ -1538,7 +1558,11 @@ fn ensure_scheduler_schema(conn: &Connection) -> Result<(), StorageError> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn has_column(conn: &Connection, table_name: &str, column_name: &str) -> Result<bool, StorageError> {
|
fn has_column(
|
||||||
|
conn: &Connection,
|
||||||
|
table_name: &str,
|
||||||
|
column_name: &str,
|
||||||
|
) -> Result<bool, StorageError> {
|
||||||
let pragma = format!("PRAGMA table_info({})", table_name);
|
let pragma = format!("PRAGMA table_info({})", table_name);
|
||||||
let mut stmt = conn.prepare(&pragma)?;
|
let mut stmt = conn.prepare(&pragma)?;
|
||||||
let mut rows = stmt.query([])?;
|
let mut rows = stmt.query([])?;
|
||||||
@ -1557,7 +1581,10 @@ fn add_column_if_missing(conn: &Connection, sql: &str) -> Result<(), StorageErro
|
|||||||
match conn.execute(sql, []) {
|
match conn.execute(sql, []) {
|
||||||
Ok(_) => Ok(()),
|
Ok(_) => Ok(()),
|
||||||
Err(rusqlite::Error::SqliteFailure(_, Some(message)))
|
Err(rusqlite::Error::SqliteFailure(_, Some(message)))
|
||||||
if message.contains("duplicate column name") => Ok(()),
|
if message.contains("duplicate column name") =>
|
||||||
|
{
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
Err(error) => Err(StorageError::Database(error)),
|
Err(error) => Err(StorageError::Database(error)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1581,7 +1608,11 @@ fn insert_message_with_seq(
|
|||||||
message: &ChatMessage,
|
message: &ChatMessage,
|
||||||
) -> Result<(), StorageError> {
|
) -> Result<(), StorageError> {
|
||||||
let media_refs_json = serde_json::to_string(&message.media_refs)?;
|
let media_refs_json = serde_json::to_string(&message.media_refs)?;
|
||||||
let tool_calls_json = message.tool_calls.as_ref().map(serde_json::to_string).transpose()?;
|
let tool_calls_json = message
|
||||||
|
.tool_calls
|
||||||
|
.as_ref()
|
||||||
|
.map(serde_json::to_string)
|
||||||
|
.transpose()?;
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"
|
"
|
||||||
INSERT INTO messages (
|
INSERT INTO messages (
|
||||||
@ -1638,43 +1669,47 @@ fn load_messages_between(
|
|||||||
",
|
",
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let rows = stmt.query_map(params![session_id, start_seq_exclusive, end_seq_inclusive], |row| {
|
let rows = stmt.query_map(
|
||||||
let media_refs_json: String = row.get(5)?;
|
params![session_id, start_seq_exclusive, end_seq_inclusive],
|
||||||
let media_refs: Vec<String> = serde_json::from_str(&media_refs_json).map_err(|err| {
|
|row| {
|
||||||
rusqlite::Error::FromSqlConversionFailure(
|
let media_refs_json: String = row.get(5)?;
|
||||||
media_refs_json.len(),
|
let media_refs: Vec<String> =
|
||||||
rusqlite::types::Type::Text,
|
serde_json::from_str(&media_refs_json).map_err(|err| {
|
||||||
Box::new(err),
|
rusqlite::Error::FromSqlConversionFailure(
|
||||||
)
|
media_refs_json.len(),
|
||||||
})?;
|
rusqlite::types::Type::Text,
|
||||||
|
Box::new(err),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
let tool_calls_json: Option<String> = row.get(9)?;
|
let tool_calls_json: Option<String> = row.get(9)?;
|
||||||
let tool_calls = tool_calls_json
|
let tool_calls = tool_calls_json
|
||||||
.as_deref()
|
.as_deref()
|
||||||
.map(serde_json::from_str)
|
.map(serde_json::from_str)
|
||||||
.transpose()
|
.transpose()
|
||||||
.map_err(|err| {
|
.map_err(|err| {
|
||||||
rusqlite::Error::FromSqlConversionFailure(
|
rusqlite::Error::FromSqlConversionFailure(
|
||||||
9,
|
9,
|
||||||
rusqlite::types::Type::Text,
|
rusqlite::types::Type::Text,
|
||||||
Box::new(err),
|
Box::new(err),
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
Ok(ChatMessage {
|
Ok(ChatMessage {
|
||||||
id: row.get(0)?,
|
id: row.get(0)?,
|
||||||
role: row.get(1)?,
|
role: row.get(1)?,
|
||||||
content: row.get(2)?,
|
content: row.get(2)?,
|
||||||
system_context: row.get(3)?,
|
system_context: row.get(3)?,
|
||||||
reasoning_content: row.get(4)?,
|
reasoning_content: row.get(4)?,
|
||||||
media_refs,
|
media_refs,
|
||||||
timestamp: row.get(6)?,
|
timestamp: row.get(6)?,
|
||||||
tool_call_id: row.get(7)?,
|
tool_call_id: row.get(7)?,
|
||||||
tool_name: row.get(8)?,
|
tool_name: row.get(8)?,
|
||||||
tool_state: None,
|
tool_state: None,
|
||||||
tool_calls,
|
tool_calls,
|
||||||
})
|
})
|
||||||
})?;
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
let mut messages = Vec::new();
|
let mut messages = Vec::new();
|
||||||
for row in rows {
|
for row in rows {
|
||||||
@ -1866,7 +1901,10 @@ mod tests {
|
|||||||
assert_eq!(messages[0].role, "assistant");
|
assert_eq!(messages[0].role, "assistant");
|
||||||
assert_eq!(messages[0].tool_calls.as_ref().unwrap().len(), 1);
|
assert_eq!(messages[0].tool_calls.as_ref().unwrap().len(), 1);
|
||||||
assert_eq!(messages[0].tool_calls.as_ref().unwrap()[0].id, "call_1");
|
assert_eq!(messages[0].tool_calls.as_ref().unwrap()[0].id, "call_1");
|
||||||
assert_eq!(messages[0].tool_calls.as_ref().unwrap()[0].name, "calculator");
|
assert_eq!(
|
||||||
|
messages[0].tool_calls.as_ref().unwrap()[0].name,
|
||||||
|
"calculator"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -1874,17 +1912,17 @@ mod tests {
|
|||||||
let store = SessionStore::in_memory().unwrap();
|
let store = SessionStore::in_memory().unwrap();
|
||||||
let session = store.create_cli_session(Some("reasoning")).unwrap();
|
let session = store.create_cli_session(Some("reasoning")).unwrap();
|
||||||
|
|
||||||
let assistant = ChatMessage::assistant_with_reasoning(
|
let assistant = ChatMessage::assistant_with_reasoning("final answer", "hidden reasoning");
|
||||||
"final answer",
|
|
||||||
"hidden reasoning",
|
|
||||||
);
|
|
||||||
|
|
||||||
store.append_message(&session.id, &assistant).unwrap();
|
store.append_message(&session.id, &assistant).unwrap();
|
||||||
|
|
||||||
let messages = store.load_messages(&session.id).unwrap();
|
let messages = store.load_messages(&session.id).unwrap();
|
||||||
assert_eq!(messages.len(), 1);
|
assert_eq!(messages.len(), 1);
|
||||||
assert_eq!(messages[0].content, "final answer");
|
assert_eq!(messages[0].content, "final answer");
|
||||||
assert_eq!(messages[0].reasoning_content.as_deref(), Some("hidden reasoning"));
|
assert_eq!(
|
||||||
|
messages[0].reasoning_content.as_deref(),
|
||||||
|
Some("hidden reasoning")
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -1892,8 +1930,12 @@ mod tests {
|
|||||||
let store = SessionStore::in_memory().unwrap();
|
let store = SessionStore::in_memory().unwrap();
|
||||||
let session = store.create_cli_session(Some("reset")).unwrap();
|
let session = store.create_cli_session(Some("reset")).unwrap();
|
||||||
|
|
||||||
store.append_message(&session.id, &ChatMessage::user("before")).unwrap();
|
store
|
||||||
store.append_message(&session.id, &ChatMessage::assistant("context")).unwrap();
|
.append_message(&session.id, &ChatMessage::user("before"))
|
||||||
|
.unwrap();
|
||||||
|
store
|
||||||
|
.append_message(&session.id, &ChatMessage::assistant("context"))
|
||||||
|
.unwrap();
|
||||||
store.reset_session(&session.id).unwrap();
|
store.reset_session(&session.id).unwrap();
|
||||||
|
|
||||||
let stored = store.get_session(&session.id).unwrap().unwrap();
|
let stored = store.get_session(&session.id).unwrap().unwrap();
|
||||||
@ -1909,7 +1951,9 @@ mod tests {
|
|||||||
assert_eq!(all_messages[0].content, "before");
|
assert_eq!(all_messages[0].content, "before");
|
||||||
assert_eq!(all_messages[1].content, "context");
|
assert_eq!(all_messages[1].content, "context");
|
||||||
|
|
||||||
store.append_message(&session.id, &ChatMessage::user("after")).unwrap();
|
store
|
||||||
|
.append_message(&session.id, &ChatMessage::user("after"))
|
||||||
|
.unwrap();
|
||||||
let active_messages = store.load_messages(&session.id).unwrap();
|
let active_messages = store.load_messages(&session.id).unwrap();
|
||||||
assert_eq!(active_messages.len(), 1);
|
assert_eq!(active_messages.len(), 1);
|
||||||
assert_eq!(active_messages[0].content, "after");
|
assert_eq!(active_messages[0].content, "after");
|
||||||
@ -2010,19 +2054,33 @@ mod tests {
|
|||||||
let store = SessionStore::in_memory().unwrap();
|
let store = SessionStore::in_memory().unwrap();
|
||||||
let session = store.create_cli_session(Some("count-users")).unwrap();
|
let session = store.create_cli_session(Some("count-users")).unwrap();
|
||||||
|
|
||||||
store.append_message(&session.id, &ChatMessage::system("agent")).unwrap();
|
store
|
||||||
store.append_message(&session.id, &ChatMessage::user("u1")).unwrap();
|
.append_message(&session.id, &ChatMessage::system("agent"))
|
||||||
store.append_message(&session.id, &ChatMessage::assistant("a1")).unwrap();
|
.unwrap();
|
||||||
store.append_message(&session.id, &ChatMessage::user("u2")).unwrap();
|
store
|
||||||
|
.append_message(&session.id, &ChatMessage::user("u1"))
|
||||||
|
.unwrap();
|
||||||
|
store
|
||||||
|
.append_message(&session.id, &ChatMessage::assistant("a1"))
|
||||||
|
.unwrap();
|
||||||
|
store
|
||||||
|
.append_message(&session.id, &ChatMessage::user("u2"))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 2);
|
assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 2);
|
||||||
|
|
||||||
store.reset_session(&session.id).unwrap();
|
store.reset_session(&session.id).unwrap();
|
||||||
assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 0);
|
assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 0);
|
||||||
|
|
||||||
store.append_message(&session.id, &ChatMessage::system("agent-again")).unwrap();
|
store
|
||||||
store.append_message(&session.id, &ChatMessage::user("u3")).unwrap();
|
.append_message(&session.id, &ChatMessage::system("agent-again"))
|
||||||
store.append_message(&session.id, &ChatMessage::user("u4")).unwrap();
|
.unwrap();
|
||||||
|
store
|
||||||
|
.append_message(&session.id, &ChatMessage::user("u3"))
|
||||||
|
.unwrap();
|
||||||
|
store
|
||||||
|
.append_message(&session.id, &ChatMessage::user("u4"))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 2);
|
assert_eq!(store.count_active_user_messages(&session.id).unwrap(), 2);
|
||||||
}
|
}
|
||||||
@ -2052,12 +2110,20 @@ mod tests {
|
|||||||
store.append_message(&session.id, message).unwrap();
|
store.append_message(&session.id, message).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
let snapshot_end_seq = store.get_session(&session.id).unwrap().unwrap().message_count;
|
let snapshot_end_seq = store
|
||||||
|
.get_session(&session.id)
|
||||||
|
.unwrap()
|
||||||
|
.unwrap()
|
||||||
|
.message_count;
|
||||||
let preserved_messages = store.load_messages(&session.id).unwrap()[3..].to_vec();
|
let preserved_messages = store.load_messages(&session.id).unwrap()[3..].to_vec();
|
||||||
let preserved_system_messages = vec![agent_prompt];
|
let preserved_system_messages = vec![agent_prompt];
|
||||||
|
|
||||||
store.append_message(&session.id, &ChatMessage::user("u5")).unwrap();
|
store
|
||||||
store.append_message(&session.id, &ChatMessage::assistant("a5")).unwrap();
|
.append_message(&session.id, &ChatMessage::user("u5"))
|
||||||
|
.unwrap();
|
||||||
|
store
|
||||||
|
.append_message(&session.id, &ChatMessage::assistant("a5"))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let summary_message = ChatMessage::system("[Compressed History]\n\nsummary");
|
let summary_message = ChatMessage::system("[Compressed History]\n\nsummary");
|
||||||
let compacted = store
|
let compacted = store
|
||||||
@ -2074,16 +2140,22 @@ mod tests {
|
|||||||
assert!(compacted);
|
assert!(compacted);
|
||||||
|
|
||||||
let active_messages = store.load_messages(&session.id).unwrap();
|
let active_messages = store.load_messages(&session.id).unwrap();
|
||||||
assert_eq!(active_messages.len(), 10);
|
assert_eq!(active_messages.len(), 10);
|
||||||
assert_eq!(active_messages[0].role, "system");
|
assert_eq!(active_messages[0].role, "system");
|
||||||
assert_eq!(active_messages[0].content, "agent");
|
assert_eq!(active_messages[0].content, "agent");
|
||||||
assert_eq!(active_messages[0].system_context.as_deref(), Some(SYSTEM_CONTEXT_AGENT_PROMPT));
|
assert_eq!(
|
||||||
assert_eq!(active_messages[1].role, "system");
|
active_messages[0].system_context.as_deref(),
|
||||||
assert_eq!(active_messages[1].content, "[Compressed History]\n\nsummary");
|
Some(SYSTEM_CONTEXT_AGENT_PROMPT)
|
||||||
assert_eq!(active_messages[2].content, "u2");
|
);
|
||||||
assert_eq!(active_messages[3].content, "a2");
|
assert_eq!(active_messages[1].role, "system");
|
||||||
assert_eq!(active_messages[8].content, "u5");
|
assert_eq!(
|
||||||
assert_eq!(active_messages[9].content, "a5");
|
active_messages[1].content,
|
||||||
|
"[Compressed History]\n\nsummary"
|
||||||
|
);
|
||||||
|
assert_eq!(active_messages[2].content, "u2");
|
||||||
|
assert_eq!(active_messages[3].content, "a2");
|
||||||
|
assert_eq!(active_messages[8].content, "u5");
|
||||||
|
assert_eq!(active_messages[9].content, "a5");
|
||||||
|
|
||||||
let stored = store.get_session(&session.id).unwrap().unwrap();
|
let stored = store.get_session(&session.id).unwrap().unwrap();
|
||||||
assert_eq!(stored.reset_cutoff_seq, 11);
|
assert_eq!(stored.reset_cutoff_seq, 11);
|
||||||
@ -2128,12 +2200,7 @@ mod tests {
|
|||||||
let session = store.create_cli_session(Some("skill-events")).unwrap();
|
let session = store.create_cli_session(Some("skill-events")).unwrap();
|
||||||
|
|
||||||
store
|
store
|
||||||
.append_skill_event(
|
.append_skill_event(None, "discovered", None, &serde_json::json!({"count": 2}))
|
||||||
None,
|
|
||||||
"discovered",
|
|
||||||
None,
|
|
||||||
&serde_json::json!({"count": 2}),
|
|
||||||
)
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
store
|
store
|
||||||
.append_skill_event(
|
.append_skill_event(
|
||||||
@ -2383,13 +2450,26 @@ mod tests {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let scope_keys = store.list_memory_scope_keys("user").unwrap();
|
let scope_keys = store.list_memory_scope_keys("user").unwrap();
|
||||||
assert_eq!(scope_keys, vec!["feishu:user-1".to_string(), "feishu:user-2".to_string()]);
|
assert_eq!(
|
||||||
|
scope_keys,
|
||||||
|
vec!["feishu:user-1".to_string(), "feishu:user-2".to_string()]
|
||||||
|
);
|
||||||
|
|
||||||
let full_scope = store.list_memories_for_scope("user", "feishu:user-1").unwrap();
|
let full_scope = store
|
||||||
|
.list_memories_for_scope("user", "feishu:user-1")
|
||||||
|
.unwrap();
|
||||||
assert_eq!(full_scope.len(), 2);
|
assert_eq!(full_scope.len(), 2);
|
||||||
assert!(full_scope.iter().all(|memory| memory.scope_key == "feishu:user-1"));
|
assert!(
|
||||||
|
full_scope
|
||||||
|
.iter()
|
||||||
|
.all(|memory| memory.scope_key == "feishu:user-1")
|
||||||
|
);
|
||||||
assert!(full_scope.iter().any(|memory| memory.memory_key == "work"));
|
assert!(full_scope.iter().any(|memory| memory.memory_key == "work"));
|
||||||
assert!(full_scope.iter().any(|memory| memory.memory_key == "workflow"));
|
assert!(
|
||||||
|
full_scope
|
||||||
|
.iter()
|
||||||
|
.any(|memory| memory.memory_key == "workflow")
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@ -15,7 +15,8 @@ use crate::tools::traits::{Tool, ToolResult};
|
|||||||
const MAX_TIMEOUT_SECS: u64 = 600;
|
const MAX_TIMEOUT_SECS: u64 = 600;
|
||||||
const MAX_OUTPUT_CHARS: usize = 50_000;
|
const MAX_OUTPUT_CHARS: usize = 50_000;
|
||||||
const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
|
const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
|
||||||
const USER_ACTION_HINT: &str = "该命令正在等待你完成外部操作。完成后请告诉我继续,或重新运行后续检查命令。";
|
const USER_ACTION_HINT: &str =
|
||||||
|
"该命令正在等待你完成外部操作。完成后请告诉我继续,或重新运行后续检查命令。";
|
||||||
|
|
||||||
pub struct BashTool {
|
pub struct BashTool {
|
||||||
timeout_secs: u64,
|
timeout_secs: u64,
|
||||||
@ -208,7 +209,10 @@ impl Tool for BashTool {
|
|||||||
.map(|d| Path::new(d))
|
.map(|d| Path::new(d))
|
||||||
.unwrap_or_else(|| Path::new("."));
|
.unwrap_or_else(|| Path::new("."));
|
||||||
|
|
||||||
match self.run_command(command, cwd, timeout_secs, interactive).await {
|
match self
|
||||||
|
.run_command(command, cwd, timeout_secs, interactive)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(output) => Ok(ToolResult {
|
Ok(output) => Ok(ToolResult {
|
||||||
success: true,
|
success: true,
|
||||||
output,
|
output,
|
||||||
@ -366,10 +370,7 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_pwd_command() {
|
async fn test_pwd_command() {
|
||||||
let tool = BashTool::new();
|
let tool = BashTool::new();
|
||||||
let result = tool
|
let result = tool.execute(json!({ "command": "pwd" })).await.unwrap();
|
||||||
.execute(json!({ "command": "pwd" }))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert!(result.success);
|
assert!(result.success);
|
||||||
}
|
}
|
||||||
@ -377,7 +378,10 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_ls_command() {
|
async fn test_ls_command() {
|
||||||
let tool = BashTool::new();
|
let tool = BashTool::new();
|
||||||
let result = tool.execute(json!({ "command": "ls -la /tmp" })).await.unwrap();
|
let result = tool
|
||||||
|
.execute(json!({ "command": "ls -la /tmp" }))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert!(result.success);
|
assert!(result.success);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -659,10 +659,7 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_evaluate_missing_expression() {
|
async fn test_evaluate_missing_expression() {
|
||||||
let tool = CalculatorTool::new();
|
let tool = CalculatorTool::new();
|
||||||
let result = tool
|
let result = tool.execute(json!({"function": "evaluate"})).await.unwrap();
|
||||||
.execute(json!({"function": "evaluate"}))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(!result.success);
|
assert!(!result.success);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -268,8 +268,8 @@ impl Tool for FileEditTool {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use tempfile::NamedTempFile;
|
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
use tempfile::NamedTempFile;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_edit_simple() {
|
async fn test_edit_simple() {
|
||||||
|
|||||||
@ -218,7 +218,7 @@ impl Tool for FileReadTool {
|
|||||||
// Try to read as binary and encode as base64
|
// Try to read as binary and encode as base64
|
||||||
match std::fs::read(&resolved) {
|
match std::fs::read(&resolved) {
|
||||||
Ok(bytes) => {
|
Ok(bytes) => {
|
||||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
use base64::{Engine, engine::general_purpose::STANDARD};
|
||||||
let encoded = STANDARD.encode(&bytes);
|
let encoded = STANDARD.encode(&bytes);
|
||||||
let mime = mime_guess::from_path(&resolved)
|
let mime = mime_guess::from_path(&resolved)
|
||||||
.first_or_octet_stream()
|
.first_or_octet_stream()
|
||||||
@ -248,8 +248,8 @@ impl Tool for FileReadTool {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use tempfile::NamedTempFile;
|
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
use tempfile::NamedTempFile;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_read_simple_file() {
|
async fn test_read_simple_file() {
|
||||||
@ -308,10 +308,7 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_is_directory() {
|
async fn test_is_directory() {
|
||||||
let tool = FileReadTool::new();
|
let tool = FileReadTool::new();
|
||||||
let result = tool
|
let result = tool.execute(json!({ "path": "." })).await.unwrap();
|
||||||
.execute(json!({ "path": "." }))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert!(!result.success);
|
assert!(!result.success);
|
||||||
assert!(result.error.unwrap().contains("Not a file"));
|
assert!(result.error.unwrap().contains("Not a file"));
|
||||||
|
|||||||
@ -195,10 +195,7 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_write_missing_path() {
|
async fn test_write_missing_path() {
|
||||||
let tool = FileWriteTool::new();
|
let tool = FileWriteTool::new();
|
||||||
let result = tool
|
let result = tool.execute(json!({ "content": "Hello" })).await.unwrap();
|
||||||
.execute(json!({ "content": "Hello" }))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert!(!result.success);
|
assert!(!result.success);
|
||||||
assert!(result.error.unwrap().contains("path"));
|
assert!(result.error.unwrap().contains("path"));
|
||||||
|
|||||||
@ -50,10 +50,7 @@ impl HttpRequestTool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !host_matches_allowlist(&host, &self.allowed_domains) {
|
if !host_matches_allowlist(&host, &self.allowed_domains) {
|
||||||
return Err(format!(
|
return Err(format!("Host '{}' is not in allowed_domains", host));
|
||||||
"Host '{}' is not in allowed_domains",
|
|
||||||
host
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(url.to_string())
|
Ok(url.to_string())
|
||||||
@ -80,9 +77,7 @@ impl HttpRequestTool {
|
|||||||
for (key, value) in obj {
|
for (key, value) in obj {
|
||||||
if let Some(str_val) = value.as_str() {
|
if let Some(str_val) = value.as_str() {
|
||||||
if let Ok(name) = reqwest::header::HeaderName::from_bytes(key.as_bytes()) {
|
if let Ok(name) = reqwest::header::HeaderName::from_bytes(key.as_bytes()) {
|
||||||
if let Ok(val) =
|
if let Ok(val) = reqwest::header::HeaderValue::from_str(str_val) {
|
||||||
reqwest::header::HeaderValue::from_str(str_val)
|
|
||||||
{
|
|
||||||
header_map.insert(name, val);
|
header_map.insert(name, val);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -193,7 +188,9 @@ fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool {
|
|||||||
|
|
||||||
allowed_domains.iter().any(|domain| {
|
allowed_domains.iter().any(|domain| {
|
||||||
host == domain
|
host == domain
|
||||||
|| host.strip_suffix(domain).is_some_and(|prefix| prefix.ends_with('.'))
|
|| host
|
||||||
|
.strip_suffix(domain)
|
||||||
|
.is_some_and(|prefix| prefix.ends_with('.'))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -204,7 +201,11 @@ fn is_private_host(host: &str) -> bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check .local TLD
|
// Check .local TLD
|
||||||
if host.rsplit('.').next().is_some_and(|label| label == "local") {
|
if host
|
||||||
|
.rsplit('.')
|
||||||
|
.next()
|
||||||
|
.is_some_and(|label| label == "local")
|
||||||
|
{
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -226,9 +227,7 @@ fn is_private_ip(ip: &std::net::IpAddr) -> bool {
|
|||||||
|| v4.is_broadcast()
|
|| v4.is_broadcast()
|
||||||
|| v4.is_multicast()
|
|| v4.is_multicast()
|
||||||
}
|
}
|
||||||
std::net::IpAddr::V6(v6) => {
|
std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(),
|
||||||
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -280,10 +279,7 @@ impl Tool for HttpRequestTool {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let method_str = args
|
let method_str = args.get("method").and_then(|v| v.as_str()).unwrap_or("GET");
|
||||||
.get("method")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.unwrap_or("GET");
|
|
||||||
|
|
||||||
let headers_val = args.get("headers").cloned().unwrap_or(json!({}));
|
let headers_val = args.get("headers").cloned().unwrap_or(json!({}));
|
||||||
let body = args.get("body").and_then(|v| v.as_str());
|
let body = args.get("body").and_then(|v| v.as_str());
|
||||||
|
|||||||
@ -94,7 +94,7 @@ impl Tool for MemoryManageTool {
|
|||||||
return Ok(error_result(&format!(
|
return Ok(error_result(&format!(
|
||||||
"memory '{}.{}' not found",
|
"memory '{}.{}' not found",
|
||||||
input.namespace, input.memory_key
|
input.namespace, input.memory_key
|
||||||
)))
|
)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -108,9 +108,14 @@ impl Tool for MemoryManageTool {
|
|||||||
None => return Ok(error_result("Missing required parameter: key")),
|
None => return Ok(error_result("Missing required parameter: key")),
|
||||||
};
|
};
|
||||||
|
|
||||||
let deleted = self.store.delete_memory("user", &scope_key, namespace, key)?;
|
let deleted = self
|
||||||
|
.store
|
||||||
|
.delete_memory("user", &scope_key, namespace, key)?;
|
||||||
if !deleted {
|
if !deleted {
|
||||||
return Ok(error_result(&format!("memory '{}.{}' not found", namespace, key)));
|
return Ok(error_result(&format!(
|
||||||
|
"memory '{}.{}' not found",
|
||||||
|
namespace, key
|
||||||
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
json!({
|
json!({
|
||||||
|
|||||||
@ -90,7 +90,9 @@ impl Tool for MemorySearchTool {
|
|||||||
.get("limit")
|
.get("limit")
|
||||||
.and_then(|value| value.as_u64())
|
.and_then(|value| value.as_u64())
|
||||||
.unwrap_or(10) as usize;
|
.unwrap_or(10) as usize;
|
||||||
let memories = self.store.list_memories("user", &scope_key, namespace, limit)?;
|
let memories = self
|
||||||
|
.store
|
||||||
|
.list_memories("user", &scope_key, namespace, limit)?;
|
||||||
json!({
|
json!({
|
||||||
"count": memories.len(),
|
"count": memories.len(),
|
||||||
"memories": memories.into_iter().map(memory_to_json).collect::<Vec<_>>()
|
"memories": memories.into_iter().map(memory_to_json).collect::<Vec<_>>()
|
||||||
@ -135,7 +137,12 @@ impl Tool for MemorySearchTool {
|
|||||||
|
|
||||||
match self.store.get_memory("user", &scope_key, namespace, key)? {
|
match self.store.get_memory("user", &scope_key, namespace, key)? {
|
||||||
Some(memory) => memory_to_json(memory),
|
Some(memory) => memory_to_json(memory),
|
||||||
None => return Ok(error_result(&format!("memory '{}.{}' not found", namespace, key))),
|
None => {
|
||||||
|
return Ok(error_result(&format!(
|
||||||
|
"memory '{}.{}' not found",
|
||||||
|
namespace, key
|
||||||
|
)));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => return Ok(error_result("Unsupported action")),
|
_ => return Ok(error_result("Unsupported action")),
|
||||||
|
|||||||
@ -5,9 +5,7 @@ use async_trait::async_trait;
|
|||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
|
||||||
use crate::config::SchedulerSchedule;
|
use crate::config::SchedulerSchedule;
|
||||||
use crate::storage::{
|
use crate::storage::{SchedulerJobRecord, SchedulerJobState, SchedulerJobUpsert, SessionStore};
|
||||||
SchedulerJobRecord, SchedulerJobState, SchedulerJobUpsert, SessionStore,
|
|
||||||
};
|
|
||||||
use crate::tools::traits::{Tool, ToolResult};
|
use crate::tools::traits::{Tool, ToolResult};
|
||||||
|
|
||||||
pub struct SchedulerManageTool {
|
pub struct SchedulerManageTool {
|
||||||
@ -35,11 +33,7 @@ impl Tool for SchedulerManageTool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn parameters_schema(&self) -> serde_json::Value {
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
let mut allowed_agents = self
|
let mut allowed_agents = self.known_agents.iter().cloned().collect::<Vec<_>>();
|
||||||
.known_agents
|
|
||||||
.iter()
|
|
||||||
.cloned()
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
allowed_agents.sort();
|
allowed_agents.sort();
|
||||||
let agent_hint = if allowed_agents.is_empty() {
|
let agent_hint = if allowed_agents.is_empty() {
|
||||||
"agent_task payload.agent may be omitted or set to 'default'.".to_string()
|
"agent_task payload.agent may be omitted or set to 'default'.".to_string()
|
||||||
@ -225,8 +219,15 @@ fn build_upsert(
|
|||||||
startup_delay_secs,
|
startup_delay_secs,
|
||||||
target,
|
target,
|
||||||
payload,
|
payload,
|
||||||
enabled: args.get("enabled").and_then(|value| value.as_bool()).unwrap_or(true),
|
enabled: args
|
||||||
state: if args.get("enabled").and_then(|value| value.as_bool()).unwrap_or(true) {
|
.get("enabled")
|
||||||
|
.and_then(|value| value.as_bool())
|
||||||
|
.unwrap_or(true),
|
||||||
|
state: if args
|
||||||
|
.get("enabled")
|
||||||
|
.and_then(|value| value.as_bool())
|
||||||
|
.unwrap_or(true)
|
||||||
|
{
|
||||||
SchedulerJobState::Scheduled
|
SchedulerJobState::Scheduled
|
||||||
} else {
|
} else {
|
||||||
SchedulerJobState::Paused
|
SchedulerJobState::Paused
|
||||||
@ -252,14 +253,28 @@ fn enrich_target_from_context(
|
|||||||
};
|
};
|
||||||
|
|
||||||
if !has_non_empty_string(&object, "channel") {
|
if !has_non_empty_string(&object, "channel") {
|
||||||
if let Some(channel_name) = context.channel_name.as_ref().filter(|value| !value.trim().is_empty()) {
|
if let Some(channel_name) = context
|
||||||
object.insert("channel".to_string(), serde_json::Value::String(channel_name.clone()));
|
.channel_name
|
||||||
|
.as_ref()
|
||||||
|
.filter(|value| !value.trim().is_empty())
|
||||||
|
{
|
||||||
|
object.insert(
|
||||||
|
"channel".to_string(),
|
||||||
|
serde_json::Value::String(channel_name.clone()),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !has_non_empty_string(&object, "chat_id") {
|
if !has_non_empty_string(&object, "chat_id") {
|
||||||
if let Some(chat_id) = context.chat_id.as_ref().filter(|value| !value.trim().is_empty()) {
|
if let Some(chat_id) = context
|
||||||
object.insert("chat_id".to_string(), serde_json::Value::String(chat_id.clone()));
|
.chat_id
|
||||||
|
.as_ref()
|
||||||
|
.filter(|value| !value.trim().is_empty())
|
||||||
|
{
|
||||||
|
object.insert(
|
||||||
|
"chat_id".to_string(),
|
||||||
|
serde_json::Value::String(chat_id.clone()),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -274,7 +289,10 @@ fn has_non_empty_string(object: &serde_json::Map<String, serde_json::Value>, fie
|
|||||||
.unwrap_or(false)
|
.unwrap_or(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn validate_agent_task_payload(payload: &serde_json::Value, known_agents: &HashSet<String>) -> anyhow::Result<()> {
|
fn validate_agent_task_payload(
|
||||||
|
payload: &serde_json::Value,
|
||||||
|
known_agents: &HashSet<String>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
let Some(prompt) = payload.get("prompt").and_then(|value| value.as_str()) else {
|
let Some(prompt) = payload.get("prompt").and_then(|value| value.as_str()) else {
|
||||||
anyhow::bail!("agent_task payload.prompt is required and must be a string")
|
anyhow::bail!("agent_task payload.prompt is required and must be a string")
|
||||||
};
|
};
|
||||||
@ -299,7 +317,8 @@ fn unknown_agent_message(agent_name: &str, known_agents: &HashSet<String>) -> St
|
|||||||
configured_agents.sort();
|
configured_agents.sort();
|
||||||
|
|
||||||
let configured_hint = if configured_agents.is_empty() {
|
let configured_hint = if configured_agents.is_empty() {
|
||||||
"No named agents are configured; use payload.agent='default' or omit payload.agent.".to_string()
|
"No named agents are configured; use payload.agent='default' or omit payload.agent."
|
||||||
|
.to_string()
|
||||||
} else {
|
} else {
|
||||||
format!(
|
format!(
|
||||||
"payload.agent must be omitted, set to 'default', or use one of configured agents: default, {}.",
|
"payload.agent must be omitted, set to 'default', or use one of configured agents: default, {}.",
|
||||||
@ -309,9 +328,7 @@ fn unknown_agent_message(agent_name: &str, known_agents: &HashSet<String>) -> St
|
|||||||
|
|
||||||
format!(
|
format!(
|
||||||
"Unknown agent '{}' for agent_task payload.agent. {} '{}' is not an agent. If you mean a skill, do not put it in payload.agent.",
|
"Unknown agent '{}' for agent_task payload.agent. {} '{}' is not an agent. If you mean a skill, do not put it in payload.agent.",
|
||||||
agent_name,
|
agent_name, configured_hint, agent_name,
|
||||||
configured_hint,
|
|
||||||
agent_name,
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -517,7 +534,10 @@ mod tests {
|
|||||||
.unwrap()
|
.unwrap()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(saved.kind, "silent_agent_task");
|
assert_eq!(saved.kind, "silent_agent_task");
|
||||||
assert_eq!(saved.target["session_chat_id"], "scheduler/agent.daily_summary.background");
|
assert_eq!(
|
||||||
|
saved.target["session_chat_id"],
|
||||||
|
"scheduler/agent.daily_summary.background"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@ -654,7 +674,9 @@ mod tests {
|
|||||||
assert!(!result.success);
|
assert!(!result.success);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
result.error.as_deref(),
|
result.error.as_deref(),
|
||||||
Some("Missing required parameters: scheduler_manage expects a JSON object like {\"action\":\"list\"}")
|
Some(
|
||||||
|
"Missing required parameters: scheduler_manage expects a JSON object like {\"action\":\"list\"}"
|
||||||
|
)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -668,7 +690,9 @@ mod tests {
|
|||||||
.as_str()
|
.as_str()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert!(payload_description.contains("avoid repeating schedule phrases or execution times"));
|
assert!(
|
||||||
|
payload_description.contains("avoid repeating schedule phrases or execution times")
|
||||||
|
);
|
||||||
assert!(payload_description.contains("每天9点"));
|
assert!(payload_description.contains("每天9点"));
|
||||||
assert!(payload_description.contains("每小时"));
|
assert!(payload_description.contains("每小时"));
|
||||||
}
|
}
|
||||||
|
|||||||
@ -408,7 +408,10 @@ impl SchemaCleanr {
|
|||||||
|
|
||||||
match non_null.len() {
|
match non_null.len() {
|
||||||
0 => Value::String("null".to_string()),
|
0 => Value::String("null".to_string()),
|
||||||
1 => non_null.into_iter().next().unwrap_or(Value::String("null".to_string())),
|
1 => non_null
|
||||||
|
.into_iter()
|
||||||
|
.next()
|
||||||
|
.unwrap_or(Value::String("null".to_string())),
|
||||||
_ => Value::Array(non_null),
|
_ => Value::Array(non_null),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@ -83,7 +83,11 @@ impl Tool for SkillManageTool {
|
|||||||
let scope = match args.get("scope").and_then(|v| v.as_str()) {
|
let scope = match args.get("scope").and_then(|v| v.as_str()) {
|
||||||
Some(value) => match SkillScope::parse(value) {
|
Some(value) => match SkillScope::parse(value) {
|
||||||
Some(scope) => scope,
|
Some(scope) => scope,
|
||||||
None => return Ok(error_result("scope must be 'project' or 'user'; .agents sources are discovery-only")),
|
None => {
|
||||||
|
return Ok(error_result(
|
||||||
|
"scope must be 'project' or 'user'; .agents sources are discovery-only",
|
||||||
|
));
|
||||||
|
}
|
||||||
},
|
},
|
||||||
None => SkillScope::Project,
|
None => SkillScope::Project,
|
||||||
};
|
};
|
||||||
@ -91,9 +95,7 @@ impl Tool for SkillManageTool {
|
|||||||
let name = args.get("name").and_then(|v| v.as_str());
|
let name = args.get("name").and_then(|v| v.as_str());
|
||||||
|
|
||||||
let result = match action {
|
let result = match action {
|
||||||
"list" => {
|
"list" => list_skills_payload(&self.skills),
|
||||||
list_skills_payload(&self.skills)
|
|
||||||
}
|
|
||||||
"get" => {
|
"get" => {
|
||||||
let name = match name {
|
let name = match name {
|
||||||
Some(name) => name,
|
Some(name) => name,
|
||||||
@ -127,7 +129,10 @@ impl Tool for SkillManageTool {
|
|||||||
};
|
};
|
||||||
let body = args.get("body").and_then(|v| v.as_str()).unwrap_or("");
|
let body = args.get("body").and_then(|v| v.as_str()).unwrap_or("");
|
||||||
|
|
||||||
match self.skills.create_skill(scope, name, description, body, reload) {
|
match self
|
||||||
|
.skills
|
||||||
|
.create_skill(scope, name, description, body, reload)
|
||||||
|
{
|
||||||
Ok(skill) => json!({
|
Ok(skill) => json!({
|
||||||
"status": "created",
|
"status": "created",
|
||||||
"name": skill.name,
|
"name": skill.name,
|
||||||
@ -149,7 +154,10 @@ impl Tool for SkillManageTool {
|
|||||||
return Ok(error_result("update requires description or body"));
|
return Ok(error_result("update requires description or body"));
|
||||||
}
|
}
|
||||||
|
|
||||||
match self.skills.update_skill(scope, name, description, body, reload) {
|
match self
|
||||||
|
.skills
|
||||||
|
.update_skill(scope, name, description, body, reload)
|
||||||
|
{
|
||||||
Ok(skill) => json!({
|
Ok(skill) => json!({
|
||||||
"status": "updated",
|
"status": "updated",
|
||||||
"name": skill.name,
|
"name": skill.name,
|
||||||
|
|||||||
@ -99,9 +99,7 @@ fn execute_time_request(
|
|||||||
.and_then(Value::as_str)
|
.and_then(Value::as_str)
|
||||||
.unwrap_or(default_timezone);
|
.unwrap_or(default_timezone);
|
||||||
let timezone = timezone_name.parse::<chrono_tz::Tz>().map_err(|_| {
|
let timezone = timezone_name.parse::<chrono_tz::Tz>().map_err(|_| {
|
||||||
format!(
|
format!("Invalid timezone: {timezone_name}. Expected an IANA timezone like Asia/Shanghai")
|
||||||
"Invalid timezone: {timezone_name}. Expected an IANA timezone like Asia/Shanghai"
|
|
||||||
)
|
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let now_local = now_utc.with_timezone(&timezone);
|
let now_local = now_utc.with_timezone(&timezone);
|
||||||
@ -168,13 +166,14 @@ fn parse_offset_request(args: &Value) -> Result<Option<OffsetRequest>, String> {
|
|||||||
let direction = direction.ok_or_else(|| {
|
let direction = direction.ok_or_else(|| {
|
||||||
"Missing required parameter: direction when requesting a relative time".to_string()
|
"Missing required parameter: direction when requesting a relative time".to_string()
|
||||||
})?;
|
})?;
|
||||||
let amount = amount
|
let amount = amount.and_then(Value::as_u64).ok_or_else(|| {
|
||||||
.and_then(Value::as_u64)
|
"Missing required parameter: amount when requesting a relative time".to_string()
|
||||||
.ok_or_else(|| "Missing required parameter: amount when requesting a relative time".to_string())?;
|
})?;
|
||||||
let amount = u32::try_from(amount)
|
let amount = u32::try_from(amount)
|
||||||
.map_err(|_| "amount is too large; expected a 32-bit unsigned integer".to_string())?;
|
.map_err(|_| "amount is too large; expected a 32-bit unsigned integer".to_string())?;
|
||||||
let unit = unit
|
let unit = unit.ok_or_else(|| {
|
||||||
.ok_or_else(|| "Missing required parameter: unit when requesting a relative time".to_string())?;
|
"Missing required parameter: unit when requesting a relative time".to_string()
|
||||||
|
})?;
|
||||||
|
|
||||||
Ok(Some(OffsetRequest {
|
Ok(Some(OffsetRequest {
|
||||||
direction: OffsetDirection::parse(direction)?,
|
direction: OffsetDirection::parse(direction)?,
|
||||||
@ -188,10 +187,18 @@ fn apply_offset(
|
|||||||
offset: &OffsetRequest,
|
offset: &OffsetRequest,
|
||||||
) -> Result<DateTime<chrono_tz::Tz>, String> {
|
) -> Result<DateTime<chrono_tz::Tz>, String> {
|
||||||
match (offset.direction, offset.unit) {
|
match (offset.direction, offset.unit) {
|
||||||
(OffsetDirection::Future, TimeUnit::Minute) => Ok(now_local + Duration::minutes(i64::from(offset.amount))),
|
(OffsetDirection::Future, TimeUnit::Minute) => {
|
||||||
(OffsetDirection::Past, TimeUnit::Minute) => Ok(now_local - Duration::minutes(i64::from(offset.amount))),
|
Ok(now_local + Duration::minutes(i64::from(offset.amount)))
|
||||||
(OffsetDirection::Future, TimeUnit::Hour) => Ok(now_local + Duration::hours(i64::from(offset.amount))),
|
}
|
||||||
(OffsetDirection::Past, TimeUnit::Hour) => Ok(now_local - Duration::hours(i64::from(offset.amount))),
|
(OffsetDirection::Past, TimeUnit::Minute) => {
|
||||||
|
Ok(now_local - Duration::minutes(i64::from(offset.amount)))
|
||||||
|
}
|
||||||
|
(OffsetDirection::Future, TimeUnit::Hour) => {
|
||||||
|
Ok(now_local + Duration::hours(i64::from(offset.amount)))
|
||||||
|
}
|
||||||
|
(OffsetDirection::Past, TimeUnit::Hour) => {
|
||||||
|
Ok(now_local - Duration::hours(i64::from(offset.amount)))
|
||||||
|
}
|
||||||
(OffsetDirection::Future, TimeUnit::Day) => now_local
|
(OffsetDirection::Future, TimeUnit::Day) => now_local
|
||||||
.checked_add_days(Days::new(u64::from(offset.amount)))
|
.checked_add_days(Days::new(u64::from(offset.amount)))
|
||||||
.ok_or_else(|| "Failed to add days to the current time".to_string()),
|
.ok_or_else(|| "Failed to add days to the current time".to_string()),
|
||||||
@ -439,8 +446,8 @@ mod tests {
|
|||||||
|
|
||||||
assert!(result.success);
|
assert!(result.success);
|
||||||
let payload: Value = serde_json::from_str(&result.output).unwrap();
|
let payload: Value = serde_json::from_str(&result.output).unwrap();
|
||||||
let result_time = chrono::DateTime::parse_from_rfc3339(payload["result_time"].as_str().unwrap())
|
let result_time =
|
||||||
.unwrap();
|
chrono::DateTime::parse_from_rfc3339(payload["result_time"].as_str().unwrap()).unwrap();
|
||||||
assert_eq!(result_time.hour(), 12);
|
assert_eq!(result_time.hour(), 12);
|
||||||
assert_eq!(result_time.minute(), 30);
|
assert_eq!(result_time.minute(), 30);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -239,7 +239,11 @@ fn is_private_host(host: &str) -> bool {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if host.rsplit('.').next().is_some_and(|label| label == "local") {
|
if host
|
||||||
|
.rsplit('.')
|
||||||
|
.next()
|
||||||
|
.is_some_and(|label| label == "local")
|
||||||
|
{
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -248,7 +252,9 @@ fn is_private_host(host: &str) -> bool {
|
|||||||
std::net::IpAddr::V4(v4) => {
|
std::net::IpAddr::V4(v4) => {
|
||||||
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified()
|
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified()
|
||||||
}
|
}
|
||||||
std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified() || v6.is_multicast(),
|
std::net::IpAddr::V6(v6) => {
|
||||||
|
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast()
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
use picobot::providers::{create_provider, ChatCompletionRequest, Message};
|
|
||||||
use picobot::config::{Config, LLMProviderConfig};
|
use picobot::config::{Config, LLMProviderConfig};
|
||||||
|
use picobot::providers::{ChatCompletionRequest, Message, create_provider};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
fn load_config() -> Option<LLMProviderConfig> {
|
fn load_config() -> Option<LLMProviderConfig> {
|
||||||
dotenv::from_filename("tests/test.env").ok()?;
|
dotenv::from_filename("tests/test.env").ok()?;
|
||||||
@ -23,11 +23,10 @@ fn load_config() -> Option<LLMProviderConfig> {
|
|||||||
model_id: openai_model,
|
model_id: openai_model,
|
||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
max_tokens: Some(100),
|
max_tokens: Some(100),
|
||||||
|
context_window_tokens: None,
|
||||||
model_extra: HashMap::new(),
|
model_extra: HashMap::new(),
|
||||||
max_tool_iterations: 20,
|
max_tool_iterations: 20,
|
||||||
token_limit: 128_000,
|
|
||||||
tool_result_max_chars: 20_000,
|
tool_result_max_chars: 20_000,
|
||||||
context_summary_max_chars: 20_000,
|
|
||||||
context_tool_result_trim_chars: 20_000,
|
context_tool_result_trim_chars: 20_000,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -44,8 +43,7 @@ fn create_request(content: &str) -> ChatCompletionRequest {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
#[ignore]
|
#[ignore]
|
||||||
async fn test_openai_simple_completion() {
|
async fn test_openai_simple_completion() {
|
||||||
let config = load_config()
|
let config = load_config().expect("Please configure tests/test.env with valid API keys");
|
||||||
.expect("Please configure tests/test.env with valid API keys");
|
|
||||||
|
|
||||||
let provider = create_provider(config).expect("Failed to create provider");
|
let provider = create_provider(config).expect("Failed to create provider");
|
||||||
let response = provider.chat(create_request("Say 'ok'")).await.unwrap();
|
let response = provider.chat(create_request("Say 'ok'")).await.unwrap();
|
||||||
@ -59,8 +57,7 @@ async fn test_openai_simple_completion() {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
#[ignore]
|
#[ignore]
|
||||||
async fn test_openai_conversation() {
|
async fn test_openai_conversation() {
|
||||||
let config = load_config()
|
let config = load_config().expect("Please configure tests/test.env with valid API keys");
|
||||||
.expect("Please configure tests/test.env with valid API keys");
|
|
||||||
|
|
||||||
let provider = create_provider(config).expect("Failed to create provider");
|
let provider = create_provider(config).expect("Failed to create provider");
|
||||||
|
|
||||||
@ -84,7 +81,9 @@ async fn test_openai_conversation() {
|
|||||||
async fn test_config_load() {
|
async fn test_config_load() {
|
||||||
// Test that config.json can be loaded and provider config created
|
// Test that config.json can be loaded and provider config created
|
||||||
let config = Config::load("config.json").expect("Failed to load config.json");
|
let config = Config::load("config.json").expect("Failed to load config.json");
|
||||||
let provider_config = config.get_provider_config("default").expect("Failed to get provider config");
|
let provider_config = config
|
||||||
|
.get_provider_config("default")
|
||||||
|
.expect("Failed to get provider config");
|
||||||
|
|
||||||
assert_eq!(provider_config.provider_type, "openai");
|
assert_eq!(provider_config.provider_type, "openai");
|
||||||
assert_eq!(provider_config.name, "aliyun");
|
assert_eq!(provider_config.name, "aliyun");
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
use picobot::providers::{ChatCompletionRequest, Message};
|
|
||||||
use picobot::protocol::{SessionSummary, WsInbound, WsOutbound};
|
use picobot::protocol::{SessionSummary, WsInbound, WsOutbound};
|
||||||
|
use picobot::providers::{ChatCompletionRequest, Message};
|
||||||
|
|
||||||
/// Test that message with special characters is properly escaped
|
/// Test that message with special characters is properly escaped
|
||||||
#[test]
|
#[test]
|
||||||
@ -19,7 +19,9 @@ fn test_message_special_characters() {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_multiline_system_prompt() {
|
fn test_multiline_system_prompt() {
|
||||||
let messages = vec![
|
let messages = vec![
|
||||||
Message::system("You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate"),
|
Message::system(
|
||||||
|
"You are a helpful assistant.\n\nFollow these rules:\n1. Be kind\n2. Be accurate",
|
||||||
|
),
|
||||||
Message::user("Hi"),
|
Message::user("Hi"),
|
||||||
];
|
];
|
||||||
|
|
||||||
@ -33,10 +35,7 @@ fn test_multiline_system_prompt() {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_chat_request_serialization() {
|
fn test_chat_request_serialization() {
|
||||||
let request = ChatCompletionRequest {
|
let request = ChatCompletionRequest {
|
||||||
messages: vec![
|
messages: vec![Message::system("You are helpful"), Message::user("Hello")],
|
||||||
Message::system("You are helpful"),
|
|
||||||
Message::user("Hello"),
|
|
||||||
],
|
|
||||||
temperature: Some(0.7),
|
temperature: Some(0.7),
|
||||||
max_tokens: Some(100),
|
max_tokens: Some(100),
|
||||||
tools: None,
|
tools: None,
|
||||||
@ -136,7 +135,12 @@ fn test_tool_call_outbound_serialization() {
|
|||||||
|
|
||||||
let decoded: WsOutbound = serde_json::from_str(&json).unwrap();
|
let decoded: WsOutbound = serde_json::from_str(&json).unwrap();
|
||||||
match decoded {
|
match decoded {
|
||||||
WsOutbound::ToolCall { tool_call_id, tool_name, arguments, .. } => {
|
WsOutbound::ToolCall {
|
||||||
|
tool_call_id,
|
||||||
|
tool_name,
|
||||||
|
arguments,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
assert_eq!(tool_call_id, "call-1");
|
assert_eq!(tool_call_id, "call-1");
|
||||||
assert_eq!(tool_name, "calculator");
|
assert_eq!(tool_name, "calculator");
|
||||||
assert_eq!(arguments["expression"], "1 + 1");
|
assert_eq!(arguments["expression"], "1 + 1");
|
||||||
@ -161,7 +165,12 @@ fn test_tool_result_outbound_serialization() {
|
|||||||
|
|
||||||
let decoded: WsOutbound = serde_json::from_str(&json).unwrap();
|
let decoded: WsOutbound = serde_json::from_str(&json).unwrap();
|
||||||
match decoded {
|
match decoded {
|
||||||
WsOutbound::ToolResult { tool_call_id, tool_name, content, .. } => {
|
WsOutbound::ToolResult {
|
||||||
|
tool_call_id,
|
||||||
|
tool_name,
|
||||||
|
content,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
assert_eq!(tool_call_id, "call-1");
|
assert_eq!(tool_call_id, "call-1");
|
||||||
assert_eq!(tool_name, "calculator");
|
assert_eq!(tool_name, "calculator");
|
||||||
assert!(content.contains('2'));
|
assert!(content.contains('2'));
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
use picobot::providers::{create_provider, ChatCompletionRequest, Message, Tool, ToolFunction};
|
|
||||||
use picobot::config::LLMProviderConfig;
|
use picobot::config::LLMProviderConfig;
|
||||||
|
use picobot::providers::{ChatCompletionRequest, Message, Tool, ToolFunction, create_provider};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
fn load_openai_config() -> Option<LLMProviderConfig> {
|
fn load_openai_config() -> Option<LLMProviderConfig> {
|
||||||
dotenv::from_filename("tests/test.env").ok()?;
|
dotenv::from_filename("tests/test.env").ok()?;
|
||||||
@ -23,11 +23,10 @@ fn load_openai_config() -> Option<LLMProviderConfig> {
|
|||||||
model_id: openai_model,
|
model_id: openai_model,
|
||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
max_tokens: Some(100),
|
max_tokens: Some(100),
|
||||||
|
context_window_tokens: None,
|
||||||
model_extra: HashMap::new(),
|
model_extra: HashMap::new(),
|
||||||
max_tool_iterations: 20,
|
max_tool_iterations: 20,
|
||||||
token_limit: 128_000,
|
|
||||||
tool_result_max_chars: 20_000,
|
tool_result_max_chars: 20_000,
|
||||||
context_summary_max_chars: 20_000,
|
|
||||||
context_tool_result_trim_chars: 20_000,
|
context_tool_result_trim_chars: 20_000,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -55,8 +54,7 @@ fn make_weather_tool() -> Tool {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
#[ignore]
|
#[ignore]
|
||||||
async fn test_openai_tool_call() {
|
async fn test_openai_tool_call() {
|
||||||
let config = load_openai_config()
|
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
|
||||||
.expect("Please configure tests/test.env with valid API keys");
|
|
||||||
|
|
||||||
let provider = create_provider(config).expect("Failed to create provider");
|
let provider = create_provider(config).expect("Failed to create provider");
|
||||||
|
|
||||||
@ -70,7 +68,11 @@ async fn test_openai_tool_call() {
|
|||||||
let response = provider.chat(request).await.unwrap();
|
let response = provider.chat(request).await.unwrap();
|
||||||
|
|
||||||
// Should have tool calls
|
// Should have tool calls
|
||||||
assert!(!response.tool_calls.is_empty(), "Expected tool call, got: {}", response.content);
|
assert!(
|
||||||
|
!response.tool_calls.is_empty(),
|
||||||
|
"Expected tool call, got: {}",
|
||||||
|
response.content
|
||||||
|
);
|
||||||
|
|
||||||
let tool_call = &response.tool_calls[0];
|
let tool_call = &response.tool_calls[0];
|
||||||
assert_eq!(tool_call.name, "get_weather");
|
assert_eq!(tool_call.name, "get_weather");
|
||||||
@ -80,8 +82,7 @@ async fn test_openai_tool_call() {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
#[ignore]
|
#[ignore]
|
||||||
async fn test_openai_tool_call_with_manual_execution() {
|
async fn test_openai_tool_call_with_manual_execution() {
|
||||||
let config = load_openai_config()
|
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
|
||||||
.expect("Please configure tests/test.env with valid API keys");
|
|
||||||
|
|
||||||
let provider = create_provider(config).expect("Failed to create provider");
|
let provider = create_provider(config).expect("Failed to create provider");
|
||||||
|
|
||||||
@ -94,8 +95,7 @@ async fn test_openai_tool_call_with_manual_execution() {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let response1 = provider.chat(request1).await.unwrap();
|
let response1 = provider.chat(request1).await.unwrap();
|
||||||
let tool_call = response1.tool_calls.first()
|
let tool_call = response1.tool_calls.first().expect("Expected tool call");
|
||||||
.expect("Expected tool call");
|
|
||||||
assert_eq!(tool_call.name, "get_weather");
|
assert_eq!(tool_call.name, "get_weather");
|
||||||
|
|
||||||
// Second request with tool result
|
// Second request with tool result
|
||||||
@ -118,8 +118,7 @@ async fn test_openai_tool_call_with_manual_execution() {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
#[ignore]
|
#[ignore]
|
||||||
async fn test_openai_no_tool_when_not_provided() {
|
async fn test_openai_no_tool_when_not_provided() {
|
||||||
let config = load_openai_config()
|
let config = load_openai_config().expect("Please configure tests/test.env with valid API keys");
|
||||||
.expect("Please configure tests/test.env with valid API keys");
|
|
||||||
|
|
||||||
let provider = create_provider(config).expect("Failed to create provider");
|
let provider = create_provider(config).expect("Failed to create provider");
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user