feat: add llm_timeout_secs to provider configuration and implement timeout handling
- Introduced llm_timeout_secs in ProviderConfig and LLMProviderConfig to specify timeout for LLM requests. - Updated OpenAIProvider and AnthropicProvider to utilize the timeout setting when creating HTTP clients. - Enhanced error handling for API responses to include timeout information. - Modified SessionManager to support agent-specific provider configurations, allowing for more flexible agent management. - Added tests to verify the correct behavior of timeout settings and agent task validation.
This commit is contained in:
parent
1ffdcab585
commit
f3f369b329
@ -126,6 +126,7 @@ Config example:
|
||||
},
|
||||
"payload": {
|
||||
"prompt": "请总结今天的项目进展,并列出明天的优先事项",
|
||||
"agent": "default",
|
||||
"fresh_session": true,
|
||||
"system_prompt": "你是日报助手,输出时先给摘要,再给待办。",
|
||||
"sender_id": "scheduler-daily-summary",
|
||||
@ -161,6 +162,7 @@ Runtime management:
|
||||
- agent_task reuses the normal agent pipeline: it creates a synthetic user turn from payload.prompt and runs tools, persistence, and outbound rendering through SessionManager.
|
||||
- agent_task payload fields:
|
||||
- prompt: required, synthetic user input.
|
||||
- agent: optional, choose which configured agent definition to use. default or any configured agent name.
|
||||
- fresh_session: optional, when true reset the active chat segment before running.
|
||||
- system_prompt: optional, append a task-specific system message before the synthetic user turn.
|
||||
- sender_id: optional, overrides the synthetic sender id used for tool context and memory scoping.
|
||||
|
||||
@ -185,6 +185,7 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
|
||||
- 内置 `internal_event` 当前包含 `session_cleanup`,用于回收超时的内存 session 缓存。
|
||||
- `agent_task` 会把 `payload.prompt` 作为一次合成用户输入,交给 `SessionManager::run_scheduled_agent_task()` 执行,因此会复用持久化历史、工具调用和渠道下发链路。
|
||||
- `payload.fresh_session = true` 时,会先对目标 chat 执行一次逻辑 reset,再开始本次任务运行。
|
||||
- `payload.agent` 可指定本次任务使用哪一个已配置 agent;未指定时仍使用 `default`。
|
||||
- `payload.system_prompt` 会作为额外 system 消息写入本次任务上下文。
|
||||
- `payload.sender_id` 会覆盖默认的 `scheduler` 发送者标识。
|
||||
- `payload.metadata` 会映射到 outbound metadata,便于渠道侧做追踪或特殊处理。
|
||||
|
||||
@ -24,6 +24,7 @@ const TRUNCATION_SUFFIX_LEN: usize = 200;
|
||||
const MEMORY_TOOL_USAGE_SYSTEM_PROMPT: &str = "在绝大多数请求开始时,你都应先使用长期记忆检索工具 memory_search 来召回相关上下文,然后再决定如何回答或是否需要写入记忆。默认流程是:先用 memory_search(action='search');只有在你已经明确知道 namespace 和 key 时才改用 get;只有在需要浏览最近几条记忆时才用 list。即使用户没有明确提到“记忆”或“偏好”,只要请求可能与用户长期偏好、稳定事实、历史决策、持续任务或项目上下文有关,就应先搜记忆。仅以下少数情况可跳过记忆搜索:纯寒暄、一次性简单计算、完全不依赖用户历史的直接事实问答。写入或修改记忆时,再使用 memory_manage。仅在遇到高价值且未来仍有用的信息时写入记忆:用户长期偏好、稳定事实、用户对你的纠正、持续任务/项目上下文、明确决策。不要保存一次性工具结果、临时列表、敏感凭证或不确定推测。写入时优先使用规范 namespace:preferences、profile、tasks、decisions,并优先调用 memory_manage(action='put');同一 namespace/key 可直接覆盖更新。检索时应提供 queries 数组,尽量同时放入中文关键词、英文别名,以及可能的 snake_case memory_key 词,例如 queries=['email', '邮件', 'email_folder_preference']。如果你决定跳过记忆搜索,应先确认当前请求确实属于上述少数例外,而不是因为你忘了检索。";
|
||||
const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
|
||||
const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str = "工具已经启动并进入等待用户操作的状态。请先完成外部操作,完成后直接告诉我继续。";
|
||||
const RECOVERABLE_LLM_ERROR_MESSAGE: &str = "模型服务暂时不可用或响应超时。请稍后重试。";
|
||||
|
||||
/// Build content blocks from text and media paths
|
||||
fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock> {
|
||||
@ -99,6 +100,23 @@ fn parse_pending_tool_output(output: &str) -> Option<String> {
|
||||
output.strip_prefix(PENDING_USER_ACTION_MARKER).map(|rest| rest.trim().to_string())
|
||||
}
|
||||
|
||||
fn is_recoverable_llm_error(error: &str) -> bool {
|
||||
let normalized = error.to_ascii_lowercase();
|
||||
normalized.contains("504")
|
||||
|| normalized.contains("gateway timeout")
|
||||
|| normalized.contains("stream timeout")
|
||||
|| normalized.contains("timed out")
|
||||
|| normalized.contains("timeout")
|
||||
}
|
||||
|
||||
fn recoverable_llm_message(error: &str) -> String {
|
||||
if is_recoverable_llm_error(error) {
|
||||
RECOVERABLE_LLM_ERROR_MESSAGE.to_string()
|
||||
} else {
|
||||
format!("模型请求失败:{}", error)
|
||||
}
|
||||
}
|
||||
|
||||
/// Loop detection result.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
enum LoopDetectionResult {
|
||||
@ -386,11 +404,18 @@ impl AgentLoop {
|
||||
};
|
||||
|
||||
// Call LLM
|
||||
let response = (*self.provider).chat(request).await
|
||||
.map_err(|e| {
|
||||
let response = match (*self.provider).chat(request).await {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "LLM request failed");
|
||||
AgentError::LlmError(e.to_string())
|
||||
})?;
|
||||
let assistant_message = ChatMessage::assistant(recoverable_llm_message(&e.to_string()));
|
||||
emitted_messages.push(assistant_message.clone());
|
||||
return Ok(AgentProcessResult {
|
||||
final_response: assistant_message,
|
||||
emitted_messages,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(
|
||||
@ -539,11 +564,8 @@ impl AgentLoop {
|
||||
})
|
||||
}
|
||||
Err(e) => {
|
||||
// Fallback if summary call fails
|
||||
tracing::error!(error = %e, "Failed to get summary from LLM");
|
||||
let final_message = ChatMessage::assistant(
|
||||
format!("I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.", self.max_iterations)
|
||||
);
|
||||
let final_message = ChatMessage::assistant(recoverable_llm_message(&e.to_string()));
|
||||
emitted_messages.push(final_message.clone());
|
||||
Ok(AgentProcessResult {
|
||||
final_response: final_message,
|
||||
|
||||
@ -41,6 +41,7 @@ impl Default for ContextCompressionConfig {
|
||||
}
|
||||
|
||||
/// Context compressor that reduces message history when it exceeds token limits.
|
||||
#[derive(Clone)]
|
||||
pub struct ContextCompressor {
|
||||
config: ContextCompressionConfig,
|
||||
context_window: usize,
|
||||
|
||||
@ -101,6 +101,8 @@ pub struct ProviderConfig {
|
||||
pub api_key: String,
|
||||
#[serde(default)]
|
||||
pub extra_headers: HashMap<String, String>,
|
||||
#[serde(default = "default_llm_timeout_secs")]
|
||||
pub llm_timeout_secs: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
@ -132,6 +134,10 @@ fn default_token_limit() -> usize {
|
||||
128_000
|
||||
}
|
||||
|
||||
fn default_llm_timeout_secs() -> u64 {
|
||||
120
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct GatewayConfig {
|
||||
#[serde(default = "default_gateway_host")]
|
||||
@ -400,6 +406,7 @@ pub struct LLMProviderConfig {
|
||||
pub base_url: String,
|
||||
pub api_key: String,
|
||||
pub extra_headers: HashMap<String, String>,
|
||||
pub llm_timeout_secs: u64,
|
||||
pub model_id: String,
|
||||
pub temperature: Option<f32>,
|
||||
pub max_tokens: Option<u32>,
|
||||
@ -461,6 +468,7 @@ impl Config {
|
||||
base_url: provider.base_url.clone(),
|
||||
api_key: provider.api_key.clone(),
|
||||
extra_headers: provider.extra_headers.clone(),
|
||||
llm_timeout_secs: provider.llm_timeout_secs,
|
||||
model_id: model.model_id.clone(),
|
||||
temperature: model.temperature,
|
||||
max_tokens: model.max_tokens,
|
||||
@ -601,6 +609,43 @@ mod tests {
|
||||
assert_eq!(provider_config.name, "aliyun");
|
||||
assert_eq!(provider_config.model_id, "qwen-plus");
|
||||
assert_eq!(provider_config.temperature, Some(0.0));
|
||||
assert_eq!(provider_config.llm_timeout_secs, 120);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_config_loads_custom_llm_timeout() {
|
||||
let file = tempfile::NamedTempFile::new().unwrap();
|
||||
std::fs::write(
|
||||
file.path(),
|
||||
r#"{
|
||||
"providers": {
|
||||
"aliyun": {
|
||||
"type": "openai",
|
||||
"base_url": "https://example.invalid/v1",
|
||||
"api_key": "test-key",
|
||||
"extra_headers": {},
|
||||
"llm_timeout_secs": 400
|
||||
}
|
||||
},
|
||||
"models": {
|
||||
"qwen-plus": {
|
||||
"model_id": "qwen-plus"
|
||||
}
|
||||
},
|
||||
"agents": {
|
||||
"default": {
|
||||
"provider": "aliyun",
|
||||
"model": "qwen-plus"
|
||||
}
|
||||
}
|
||||
}"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||
let provider_config = config.get_provider_config("default").unwrap();
|
||||
|
||||
assert_eq!(provider_config.llm_timeout_secs, 400);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@ -2,6 +2,7 @@ pub mod http;
|
||||
pub mod session;
|
||||
pub mod ws;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use axum::{routing, Router};
|
||||
use tokio::net::TcpListener;
|
||||
@ -9,6 +10,7 @@ use tokio::net::TcpListener;
|
||||
use crate::bus::{MessageBus, OutboundDispatcher};
|
||||
use crate::channels::ChannelManager;
|
||||
use crate::config::Config;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::logging;
|
||||
use crate::scheduler::Scheduler;
|
||||
use crate::skills::SkillRuntime;
|
||||
@ -27,6 +29,10 @@ impl GatewayState {
|
||||
|
||||
// Get provider config for SessionManager
|
||||
let provider_config = config.get_provider_config("default")?;
|
||||
let mut provider_configs = HashMap::<String, LLMProviderConfig>::new();
|
||||
for agent_name in config.agents.keys() {
|
||||
provider_configs.insert(agent_name.clone(), config.get_provider_config(agent_name)?);
|
||||
}
|
||||
|
||||
// Session TTL from config (default 4 hours)
|
||||
let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4);
|
||||
@ -40,6 +46,7 @@ impl GatewayState {
|
||||
agent_prompt_reinject_every,
|
||||
show_tool_results,
|
||||
provider_config,
|
||||
provider_configs,
|
||||
skills,
|
||||
)?;
|
||||
let channel_manager = ChannelManager::new();
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::fs;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
@ -49,6 +49,7 @@ pub struct ScheduledAgentTaskOptions {
|
||||
pub fresh_session: bool,
|
||||
pub system_prompt: Option<String>,
|
||||
pub metadata: HashMap<String, String>,
|
||||
pub agent: Option<String>,
|
||||
}
|
||||
|
||||
impl BusToolCallEmitter {
|
||||
@ -244,6 +245,23 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
fn latest_user_message_id(&self, chat_id: &str) -> Option<&str> {
|
||||
self.get_history(chat_id)
|
||||
.and_then(|history| {
|
||||
history
|
||||
.iter()
|
||||
.rev()
|
||||
.find(|message| message.role == "user")
|
||||
.map(|message| message.id.as_str())
|
||||
})
|
||||
}
|
||||
|
||||
fn is_latest_user_message(&self, chat_id: &str, message_id: &str) -> bool {
|
||||
self.latest_user_message_id(chat_id)
|
||||
.map(|current_id| current_id == message_id)
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// 清除所有历史
|
||||
pub fn clear_all_history(&mut self) -> Result<(), AgentError> {
|
||||
let chat_ids: Vec<String> = self.chat_histories.keys().cloned().collect();
|
||||
@ -296,10 +314,20 @@ impl Session {
|
||||
chat_id: &str,
|
||||
sender_id: Option<&str>,
|
||||
message_id: Option<&str>,
|
||||
) -> Result<AgentLoop, AgentError> {
|
||||
self.create_agent_with_provider_config(chat_id, sender_id, message_id, self.provider_config.clone())
|
||||
}
|
||||
|
||||
pub fn create_agent_with_provider_config(
|
||||
&self,
|
||||
chat_id: &str,
|
||||
sender_id: Option<&str>,
|
||||
message_id: Option<&str>,
|
||||
provider_config: LLMProviderConfig,
|
||||
) -> Result<AgentLoop, AgentError> {
|
||||
let session_id = self.persistent_session_id(chat_id);
|
||||
AgentLoop::with_tools_and_skills(
|
||||
self.provider_config.clone(),
|
||||
provider_config,
|
||||
self.tools.clone(),
|
||||
self.skills.clone(),
|
||||
)
|
||||
@ -368,6 +396,7 @@ fn agent_prompt_path() -> Result<std::path::PathBuf, AgentError> {
|
||||
pub struct SessionManager {
|
||||
inner: Arc<Mutex<SessionManagerInner>>,
|
||||
provider_config: LLMProviderConfig,
|
||||
provider_configs: Arc<HashMap<String, LLMProviderConfig>>,
|
||||
tools: Arc<ToolRegistry>,
|
||||
skills: Arc<SkillRuntime>,
|
||||
store: Arc<SessionStore>,
|
||||
@ -381,7 +410,7 @@ struct SessionManagerInner {
|
||||
session_ttl: Duration,
|
||||
}
|
||||
|
||||
fn default_tools(skills: Arc<SkillRuntime>, store: Arc<SessionStore>) -> ToolRegistry {
|
||||
fn default_tools(skills: Arc<SkillRuntime>, store: Arc<SessionStore>, known_agents: HashSet<String>) -> ToolRegistry {
|
||||
let mut registry = ToolRegistry::new();
|
||||
registry.register(CalculatorTool::new());
|
||||
registry.register(FileReadTool::new());
|
||||
@ -389,7 +418,7 @@ fn default_tools(skills: Arc<SkillRuntime>, store: Arc<SessionStore>) -> ToolReg
|
||||
registry.register(FileEditTool::new());
|
||||
registry.register(MemorySearchTool::new(store.clone()));
|
||||
registry.register(MemoryManageTool::new(store.clone()));
|
||||
registry.register(SchedulerManageTool::new(store));
|
||||
registry.register(SchedulerManageTool::new(store, known_agents));
|
||||
registry.register(SkillListTool::new(skills.clone()));
|
||||
registry.register(SkillManageTool::new(skills));
|
||||
registry.register(BashTool::new());
|
||||
@ -435,12 +464,14 @@ impl SessionManager {
|
||||
agent_prompt_reinject_every: u64,
|
||||
show_tool_results: bool,
|
||||
provider_config: LLMProviderConfig,
|
||||
provider_configs: HashMap<String, LLMProviderConfig>,
|
||||
skills: Arc<SkillRuntime>,
|
||||
) -> Result<Self, AgentError> {
|
||||
let store = Arc::new(
|
||||
SessionStore::new()
|
||||
.map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?,
|
||||
);
|
||||
let known_agents = provider_configs.keys().cloned().collect::<HashSet<_>>();
|
||||
|
||||
if let Err(err) = store.append_skill_event(None, "discovered", None, &skills.discovery_event_payload()) {
|
||||
tracing::warn!(error = %err, "Failed to record skill discovery event");
|
||||
@ -453,7 +484,8 @@ impl SessionManager {
|
||||
session_ttl: Duration::from_secs(session_ttl_hours * 3600),
|
||||
})),
|
||||
provider_config,
|
||||
tools: Arc::new(default_tools(skills.clone(), store.clone())),
|
||||
provider_configs: Arc::new(provider_configs),
|
||||
tools: Arc::new(default_tools(skills.clone(), store.clone(), known_agents)),
|
||||
skills,
|
||||
store,
|
||||
agent_prompt_reinject_every,
|
||||
@ -477,6 +509,10 @@ impl SessionManager {
|
||||
self.skills.clone()
|
||||
}
|
||||
|
||||
pub fn provider_config_for_agent(&self, agent_name: Option<&str>) -> Result<LLMProviderConfig, AgentError> {
|
||||
select_provider_config(&self.provider_config, &self.provider_configs, agent_name)
|
||||
}
|
||||
|
||||
pub fn create_cli_session(&self, title: Option<&str>) -> Result<SessionRecord, AgentError> {
|
||||
self.store
|
||||
.create_cli_session(title)
|
||||
@ -640,7 +676,7 @@ impl SessionManager {
|
||||
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
|
||||
|
||||
// 处理消息
|
||||
let response = {
|
||||
let (history, compressor, provider_config, agent, user_message_id) = {
|
||||
let mut session_guard = session.lock().await;
|
||||
|
||||
session_guard.ensure_persistent_session(chat_id)?;
|
||||
@ -669,11 +705,8 @@ impl SessionManager {
|
||||
session_guard.append_persisted_message(chat_id, user_message)?;
|
||||
|
||||
let history = session_guard.get_or_create_history(chat_id).clone();
|
||||
|
||||
// 压缩历史(如果需要)
|
||||
let history = session_guard.compressor
|
||||
.compress_if_needed(history, &session_guard.provider_config)
|
||||
.await?;
|
||||
let compressor = session_guard.compressor().clone();
|
||||
let provider_config = session_guard.provider_config().clone();
|
||||
|
||||
session_guard.record_skill_offer(chat_id)?;
|
||||
|
||||
@ -682,8 +715,27 @@ impl SessionManager {
|
||||
if let Some(handler) = live_emitter.clone() {
|
||||
agent = agent.with_emitted_message_handler(handler);
|
||||
}
|
||||
|
||||
(history, compressor, provider_config, agent, user_message_id)
|
||||
};
|
||||
|
||||
let history = compressor
|
||||
.compress_if_needed(history, &provider_config)
|
||||
.await?;
|
||||
let result = agent.process(history).await?;
|
||||
|
||||
let response = {
|
||||
let mut session_guard = session.lock().await;
|
||||
|
||||
if !session_guard.is_latest_user_message(chat_id, &user_message_id) {
|
||||
tracing::warn!(
|
||||
channel = %channel_name,
|
||||
chat_id = %chat_id,
|
||||
user_message_id = %user_message_id,
|
||||
"Skipping stale agent result because a newer user message is already present"
|
||||
);
|
||||
Vec::new()
|
||||
} else {
|
||||
// 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复
|
||||
session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?;
|
||||
|
||||
@ -704,6 +756,7 @@ impl SessionManager {
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
};
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
@ -736,8 +789,9 @@ impl SessionManager {
|
||||
.sender_id
|
||||
.clone()
|
||||
.unwrap_or_else(|| "scheduler".to_string());
|
||||
let provider_config = self.provider_config_for_agent(options.agent.as_deref())?;
|
||||
|
||||
let response = {
|
||||
let (history, compressor, agent, user_message_id) = {
|
||||
let mut session_guard = session.lock().await;
|
||||
|
||||
session_guard.ensure_persistent_session(chat_id)?;
|
||||
@ -758,15 +812,37 @@ impl SessionManager {
|
||||
session_guard.append_persisted_message(chat_id, user_message)?;
|
||||
|
||||
let history = session_guard.get_or_create_history(chat_id).clone();
|
||||
let history = session_guard
|
||||
.compressor
|
||||
.compress_if_needed(history, &session_guard.provider_config)
|
||||
.await?;
|
||||
let compressor = session_guard.compressor().clone();
|
||||
|
||||
session_guard.record_skill_offer(chat_id)?;
|
||||
|
||||
let agent = session_guard.create_agent(chat_id, Some(&sender_id), Some(&user_message_id))?;
|
||||
let agent = session_guard.create_agent_with_provider_config(
|
||||
chat_id,
|
||||
Some(&sender_id),
|
||||
Some(&user_message_id),
|
||||
provider_config.clone(),
|
||||
)?;
|
||||
|
||||
(history, compressor, agent, user_message_id)
|
||||
};
|
||||
|
||||
let history = compressor
|
||||
.compress_if_needed(history, &provider_config)
|
||||
.await?;
|
||||
let result = agent.process(history).await?;
|
||||
|
||||
let response = {
|
||||
let mut session_guard = session.lock().await;
|
||||
|
||||
if !session_guard.is_latest_user_message(chat_id, &user_message_id) {
|
||||
tracing::warn!(
|
||||
channel = %channel_name,
|
||||
chat_id = %chat_id,
|
||||
user_message_id = %user_message_id,
|
||||
"Skipping stale scheduled agent result because a newer user message is already present"
|
||||
);
|
||||
Vec::new()
|
||||
} else {
|
||||
session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?;
|
||||
|
||||
result
|
||||
@ -783,6 +859,7 @@ impl SessionManager {
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
};
|
||||
|
||||
Ok(response)
|
||||
@ -810,11 +887,29 @@ fn should_display_message_to_user(show_tool_results: bool, message: &ChatMessage
|
||||
)
|
||||
}
|
||||
|
||||
fn select_provider_config(
|
||||
default_provider_config: &LLMProviderConfig,
|
||||
provider_configs: &HashMap<String, LLMProviderConfig>,
|
||||
agent_name: Option<&str>,
|
||||
) -> Result<LLMProviderConfig, AgentError> {
|
||||
match agent_name.map(str::trim).filter(|value| !value.is_empty()) {
|
||||
None | Some("default") => Ok(default_provider_config.clone()),
|
||||
Some(agent_name) => provider_configs
|
||||
.get(agent_name)
|
||||
.cloned()
|
||||
.ok_or_else(|| AgentError::Other(format!("Scheduled agent '{}' not found", agent_name))),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use axum::http::StatusCode;
|
||||
use super::*;
|
||||
use axum::{Json, Router, routing::post};
|
||||
use crate::bus::MessageBus;
|
||||
use serde_json::{Value, json};
|
||||
use std::collections::HashMap;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
fn test_provider_config() -> LLMProviderConfig {
|
||||
@ -824,6 +919,7 @@ mod tests {
|
||||
base_url: "http://localhost".to_string(),
|
||||
api_key: "test-key".to_string(),
|
||||
extra_headers: HashMap::new(),
|
||||
llm_timeout_secs: 120,
|
||||
model_id: "test-model".to_string(),
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(32),
|
||||
@ -833,6 +929,235 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
fn test_provider_config_named(name: &str, model_id: &str) -> LLMProviderConfig {
|
||||
LLMProviderConfig {
|
||||
provider_type: "openai".to_string(),
|
||||
name: name.to_string(),
|
||||
base_url: "http://localhost".to_string(),
|
||||
api_key: "test-key".to_string(),
|
||||
extra_headers: HashMap::new(),
|
||||
llm_timeout_secs: 120,
|
||||
model_id: model_id.to_string(),
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(32),
|
||||
model_extra: HashMap::new(),
|
||||
max_tool_iterations: 1,
|
||||
token_limit: 4096,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_provider_config_uses_named_agent_override() {
|
||||
let default_provider = test_provider_config_named("default-provider", "default-model");
|
||||
let provider_configs = HashMap::from([(
|
||||
"planner".to_string(),
|
||||
test_provider_config_named("planner-provider", "planner-model"),
|
||||
)]);
|
||||
|
||||
let selected = select_provider_config(&default_provider, &provider_configs, Some("planner")).unwrap();
|
||||
assert_eq!(selected.name, "planner-provider");
|
||||
assert_eq!(selected.model_id, "planner-model");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_latest_user_message_guard_tracks_current_turn() {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||
let skills = Arc::new(SkillRuntime::default());
|
||||
let tools = Arc::new(default_tools(skills.clone(), store.clone(), HashSet::new()));
|
||||
let mut session = Session::new(
|
||||
"feishu".to_string(),
|
||||
test_provider_config(),
|
||||
user_tx,
|
||||
tools,
|
||||
skills,
|
||||
store,
|
||||
100,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
session.ensure_persistent_session("chat-1").unwrap();
|
||||
session.ensure_chat_loaded("chat-1").unwrap();
|
||||
|
||||
let first = session.create_user_message("first", Vec::new());
|
||||
let first_id = first.id.clone();
|
||||
session.append_persisted_message("chat-1", first).unwrap();
|
||||
assert!(session.is_latest_user_message("chat-1", &first_id));
|
||||
|
||||
let second = session.create_user_message("second", Vec::new());
|
||||
let second_id = second.id.clone();
|
||||
session.append_persisted_message("chat-1", second).unwrap();
|
||||
|
||||
assert!(!session.is_latest_user_message("chat-1", &first_id));
|
||||
assert!(session.is_latest_user_message("chat-1", &second_id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_provider_config_falls_back_to_default() {
|
||||
let default_provider = test_provider_config_named("default-provider", "default-model");
|
||||
let provider_configs = HashMap::new();
|
||||
|
||||
let selected = select_provider_config(&default_provider, &provider_configs, Some("default")).unwrap();
|
||||
assert_eq!(selected.name, "default-provider");
|
||||
assert_eq!(selected.model_id, "default-model");
|
||||
}
|
||||
|
||||
async fn start_mock_openai_server() -> String {
|
||||
async fn handle(Json(body): Json<Value>) -> Json<Value> {
|
||||
let model = body
|
||||
.get("model")
|
||||
.and_then(|value| value.as_str())
|
||||
.unwrap_or("unknown-model");
|
||||
|
||||
Json(json!({
|
||||
"id": "mock-response",
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": format!("reply from {}", model),
|
||||
"tool_calls": []
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 2
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
let app = Router::new().route("/chat/completions", post(handle));
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let address = listener.local_addr().unwrap();
|
||||
tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
});
|
||||
format!("http://{}", address)
|
||||
}
|
||||
|
||||
async fn start_mock_openai_504_server() -> String {
|
||||
async fn handle() -> (StatusCode, &'static str) {
|
||||
(StatusCode::GATEWAY_TIMEOUT, "stream timeout")
|
||||
}
|
||||
|
||||
let app = Router::new().route("/chat/completions", post(handle));
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let address = listener.local_addr().unwrap();
|
||||
tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
});
|
||||
format!("http://{}", address)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_handle_message_returns_recoverable_reply_on_llm_504() {
|
||||
let base_url = start_mock_openai_504_server().await;
|
||||
let provider_config = LLMProviderConfig {
|
||||
provider_type: "openai".to_string(),
|
||||
name: "timeout-provider".to_string(),
|
||||
base_url: base_url.clone(),
|
||||
api_key: "test-key".to_string(),
|
||||
extra_headers: HashMap::new(),
|
||||
model_id: "timeout-model".to_string(),
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(32),
|
||||
model_extra: HashMap::new(),
|
||||
max_tool_iterations: 1,
|
||||
token_limit: 4096,
|
||||
llm_timeout_secs: 30,
|
||||
};
|
||||
|
||||
let session_manager = SessionManager::new(
|
||||
4,
|
||||
100,
|
||||
false,
|
||||
provider_config.clone(),
|
||||
HashMap::from([("default".to_string(), provider_config)]),
|
||||
Arc::new(SkillRuntime::default()),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let outbound = session_manager
|
||||
.handle_message("feishu", "user-1", "chat-1", "hello", Vec::new(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(outbound.len(), 1);
|
||||
assert!(outbound[0].content.contains("模型服务暂时不可用或响应超时"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_run_scheduled_agent_task_uses_task_specific_agent_provider() {
|
||||
let base_url = start_mock_openai_server().await;
|
||||
let default_provider = LLMProviderConfig {
|
||||
provider_type: "openai".to_string(),
|
||||
name: "default-provider".to_string(),
|
||||
base_url: base_url.clone(),
|
||||
api_key: "test-key".to_string(),
|
||||
extra_headers: HashMap::new(),
|
||||
model_id: "default-model".to_string(),
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(32),
|
||||
model_extra: HashMap::new(),
|
||||
max_tool_iterations: 1,
|
||||
token_limit: 4096,
|
||||
llm_timeout_secs: 30,
|
||||
};
|
||||
let planner_provider = LLMProviderConfig {
|
||||
model_id: "planner-model".to_string(),
|
||||
name: "planner-provider".to_string(),
|
||||
..default_provider.clone()
|
||||
};
|
||||
|
||||
let session_manager = SessionManager::new(
|
||||
4,
|
||||
100,
|
||||
false,
|
||||
default_provider.clone(),
|
||||
HashMap::from([
|
||||
("default".to_string(), default_provider),
|
||||
("planner".to_string(), planner_provider),
|
||||
]),
|
||||
Arc::new(SkillRuntime::default()),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let planner_outbound = session_manager
|
||||
.run_scheduled_agent_task(
|
||||
"feishu",
|
||||
"chat-planner",
|
||||
"请规划今天工作",
|
||||
ScheduledAgentTaskOptions {
|
||||
agent: Some("planner".to_string()),
|
||||
fresh_session: true,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(planner_outbound.len(), 1);
|
||||
assert!(planner_outbound[0].content.contains("planner-model"));
|
||||
|
||||
let default_outbound = session_manager
|
||||
.run_scheduled_agent_task(
|
||||
"feishu",
|
||||
"chat-default",
|
||||
"请规划今天工作",
|
||||
ScheduledAgentTaskOptions {
|
||||
agent: Some("default".to_string()),
|
||||
fresh_session: true,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(default_outbound.len(), 1);
|
||||
assert!(default_outbound[0].content.contains("default-model"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_should_display_message_to_user_hides_completed_tool_results_by_default() {
|
||||
let completed = ChatMessage::tool("call-1", "calculator", "2");
|
||||
@ -881,7 +1206,7 @@ mod tests {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||
let skills = Arc::new(SkillRuntime::default());
|
||||
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
|
||||
let tools = Arc::new(default_tools(skills.clone(), store.clone(), HashSet::new()));
|
||||
let mut session = Session::new(
|
||||
"feishu".to_string(),
|
||||
test_provider_config(),
|
||||
@ -929,7 +1254,7 @@ mod tests {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||
let skills = Arc::new(SkillRuntime::default());
|
||||
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
|
||||
let tools = Arc::new(default_tools(skills.clone(), store.clone(), HashSet::new()));
|
||||
let mut session = Session::new(
|
||||
"feishu".to_string(),
|
||||
test_provider_config(),
|
||||
@ -956,7 +1281,7 @@ mod tests {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||
let skills = Arc::new(SkillRuntime::default());
|
||||
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
|
||||
let tools = Arc::new(default_tools(skills.clone(), store.clone(), HashSet::new()));
|
||||
let mut session = Session::new(
|
||||
"feishu".to_string(),
|
||||
test_provider_config(),
|
||||
@ -1001,7 +1326,7 @@ mod tests {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||
let skills = Arc::new(SkillRuntime::default());
|
||||
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
|
||||
let tools = Arc::new(default_tools(skills.clone(), store.clone(), HashSet::new()));
|
||||
let mut session = Session::new(
|
||||
"feishu".to_string(),
|
||||
test_provider_config(),
|
||||
@ -1035,7 +1360,7 @@ mod tests {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||
let skills = Arc::new(SkillRuntime::default());
|
||||
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
|
||||
let tools = Arc::new(default_tools(skills.clone(), store.clone(), HashSet::new()));
|
||||
let mut session = Session::new(
|
||||
"feishu".to_string(),
|
||||
test_provider_config(),
|
||||
|
||||
@ -2,6 +2,7 @@ use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::bus::message::ContentBlock;
|
||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
||||
@ -58,6 +59,7 @@ pub struct AnthropicProvider {
|
||||
api_key: String,
|
||||
base_url: String,
|
||||
extra_headers: HashMap<String, String>,
|
||||
llm_timeout_secs: u64,
|
||||
model_id: String,
|
||||
temperature: Option<f32>,
|
||||
max_tokens: Option<u32>,
|
||||
@ -70,17 +72,24 @@ impl AnthropicProvider {
|
||||
api_key: String,
|
||||
base_url: String,
|
||||
extra_headers: HashMap<String, String>,
|
||||
llm_timeout_secs: u64,
|
||||
model_id: String,
|
||||
temperature: Option<f32>,
|
||||
max_tokens: Option<u32>,
|
||||
model_extra: HashMap<String, serde_json::Value>,
|
||||
) -> Self {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(llm_timeout_secs))
|
||||
.build()
|
||||
.unwrap_or_else(|_| Client::new());
|
||||
|
||||
Self {
|
||||
client: Client::new(),
|
||||
client,
|
||||
name,
|
||||
api_key,
|
||||
base_url,
|
||||
extra_headers,
|
||||
llm_timeout_secs,
|
||||
model_id,
|
||||
temperature,
|
||||
max_tokens,
|
||||
@ -190,8 +199,21 @@ impl LLMProvider for AnthropicProvider {
|
||||
}
|
||||
|
||||
let resp = req_builder.json(&body).send().await?;
|
||||
let status = resp.status();
|
||||
let text = resp.text().await?;
|
||||
|
||||
let anthropic_resp: AnthropicResponse = resp.json().await?;
|
||||
if !status.is_success() {
|
||||
return Err(format!("API error {}: {}", status, text).into());
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
let resp_preview: String = text.chars().take(100).collect();
|
||||
tracing::debug!(status = %status, response_preview = %resp_preview, response_len = %text.len(), timeout_secs = self.llm_timeout_secs, "Anthropic response (first 100 chars shown)");
|
||||
}
|
||||
|
||||
let anthropic_resp: AnthropicResponse = serde_json::from_str(&text)
|
||||
.map_err(|e| format!("decode error: {} | body: {}", e, &text))?;
|
||||
|
||||
let mut content = String::new();
|
||||
let mut tool_calls = Vec::new();
|
||||
|
||||
@ -15,6 +15,7 @@ pub fn create_provider(config: LLMProviderConfig) -> Result<Box<dyn LLMProvider>
|
||||
config.api_key,
|
||||
config.base_url,
|
||||
config.extra_headers,
|
||||
config.llm_timeout_secs,
|
||||
config.model_id,
|
||||
config.temperature,
|
||||
config.max_tokens,
|
||||
@ -25,6 +26,7 @@ pub fn create_provider(config: LLMProviderConfig) -> Result<Box<dyn LLMProvider>
|
||||
config.api_key,
|
||||
config.base_url,
|
||||
config.extra_headers,
|
||||
config.llm_timeout_secs,
|
||||
config.model_id,
|
||||
config.temperature,
|
||||
config.max_tokens,
|
||||
|
||||
@ -3,6 +3,7 @@ use reqwest::Client;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::bus::message::ContentBlock;
|
||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
||||
@ -28,6 +29,7 @@ pub struct OpenAIProvider {
|
||||
api_key: String,
|
||||
base_url: String,
|
||||
extra_headers: HashMap<String, String>,
|
||||
llm_timeout_secs: u64,
|
||||
model_id: String,
|
||||
temperature: Option<f32>,
|
||||
max_tokens: Option<u32>,
|
||||
@ -40,17 +42,24 @@ impl OpenAIProvider {
|
||||
api_key: String,
|
||||
base_url: String,
|
||||
extra_headers: HashMap<String, String>,
|
||||
llm_timeout_secs: u64,
|
||||
model_id: String,
|
||||
temperature: Option<f32>,
|
||||
max_tokens: Option<u32>,
|
||||
model_extra: HashMap<String, serde_json::Value>,
|
||||
) -> Self {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(llm_timeout_secs))
|
||||
.build()
|
||||
.unwrap_or_else(|_| Client::new());
|
||||
|
||||
Self {
|
||||
client: Client::new(),
|
||||
client,
|
||||
name,
|
||||
api_key,
|
||||
base_url,
|
||||
extra_headers,
|
||||
llm_timeout_secs,
|
||||
model_id,
|
||||
temperature,
|
||||
max_tokens,
|
||||
@ -209,7 +218,7 @@ impl LLMProvider for OpenAIProvider {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
let resp_preview: String = text.chars().take(100).collect();
|
||||
tracing::debug!(status = %status, response_preview = %resp_preview, response_len = %text.len(), "LLM response (first 100 chars shown)");
|
||||
tracing::debug!(status = %status, response_preview = %resp_preview, response_len = %text.len(), timeout_secs = self.llm_timeout_secs, "LLM response (first 100 chars shown)");
|
||||
}
|
||||
|
||||
if !status.is_success() {
|
||||
@ -275,6 +284,7 @@ mod tests {
|
||||
"key".to_string(),
|
||||
"https://example.com/v1".to_string(),
|
||||
HashMap::new(),
|
||||
120,
|
||||
"gpt-test".to_string(),
|
||||
None,
|
||||
None,
|
||||
|
||||
@ -88,6 +88,20 @@ impl Scheduler {
|
||||
continue;
|
||||
};
|
||||
|
||||
if record.next_fire_at.is_none() && job.next_fire_at.is_some() {
|
||||
self.store.update_scheduler_job_runtime(
|
||||
&job.id,
|
||||
job.state.clone(),
|
||||
job.last_status.clone(),
|
||||
job.last_error.as_deref(),
|
||||
job.run_count,
|
||||
job.last_fired_at,
|
||||
job.next_fire_at,
|
||||
job.paused_at,
|
||||
job.completed_at,
|
||||
)?;
|
||||
}
|
||||
|
||||
if !job.is_due(now) {
|
||||
continue;
|
||||
}
|
||||
@ -517,6 +531,11 @@ fn parse_scheduled_agent_task_options(job: &RuntimeJob) -> anyhow::Result<Schedu
|
||||
.get("system_prompt")
|
||||
.and_then(|value| value.as_str())
|
||||
.map(ToString::to_string);
|
||||
let agent = job
|
||||
.payload
|
||||
.get("agent")
|
||||
.and_then(|value| value.as_str())
|
||||
.map(ToString::to_string);
|
||||
let metadata = parse_metadata_map(job.payload.get("metadata"))?;
|
||||
|
||||
Ok(ScheduledAgentTaskOptions {
|
||||
@ -524,6 +543,7 @@ fn parse_scheduled_agent_task_options(job: &RuntimeJob) -> anyhow::Result<Schedu
|
||||
fresh_session,
|
||||
system_prompt,
|
||||
metadata,
|
||||
agent,
|
||||
})
|
||||
}
|
||||
|
||||
@ -616,6 +636,7 @@ mod agent_task_tests {
|
||||
},
|
||||
payload: serde_json::json!({
|
||||
"prompt": "请总结今天待办",
|
||||
"agent": "planner",
|
||||
"sender_id": "scheduler-bot",
|
||||
"fresh_session": true,
|
||||
"system_prompt": "你是日报助手",
|
||||
@ -640,6 +661,7 @@ mod agent_task_tests {
|
||||
};
|
||||
|
||||
let options = parse_scheduled_agent_task_options(&job).unwrap();
|
||||
assert_eq!(options.agent.as_deref(), Some("planner"));
|
||||
assert_eq!(options.sender_id.as_deref(), Some("scheduler-bot"));
|
||||
assert!(options.fresh_session);
|
||||
assert_eq!(options.system_prompt.as_deref(), Some("你是日报助手"));
|
||||
@ -660,6 +682,12 @@ impl TryFrom<serde_json::Value> for SchedulerJobTarget {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
use crate::bus::MessageBus;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::gateway::session::SessionManager;
|
||||
use crate::skills::SkillRuntime;
|
||||
use crate::storage::{SchedulerJobUpsert, SessionStore};
|
||||
|
||||
#[test]
|
||||
fn runtime_job_skip_policy_advances_from_now() {
|
||||
@ -739,4 +767,81 @@ mod tests {
|
||||
});
|
||||
assert_eq!(job.next_fire_at, Some(1_700_000_010_000));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_tick_persists_initial_next_fire_at_for_db_created_jobs() {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
store
|
||||
.upsert_scheduler_job(&SchedulerJobUpsert {
|
||||
id: "massage_reminder".to_string(),
|
||||
kind: "outbound_message".to_string(),
|
||||
schedule: serde_json::json!({
|
||||
"type": "interval",
|
||||
"seconds": 60
|
||||
}),
|
||||
interval_secs: 60,
|
||||
startup_delay_secs: 0,
|
||||
target: serde_json::json!({
|
||||
"channel": "feishu",
|
||||
"chat_id": "oc_demo"
|
||||
}),
|
||||
payload: serde_json::json!({
|
||||
"content": "ping"
|
||||
}),
|
||||
enabled: true,
|
||||
state: SchedulerJobState::Scheduled,
|
||||
last_status: None,
|
||||
last_error: None,
|
||||
run_count: 0,
|
||||
max_runs: Some(1),
|
||||
last_fired_at: None,
|
||||
next_fire_at: None,
|
||||
paused_at: None,
|
||||
completed_at: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let provider_config = LLMProviderConfig {
|
||||
provider_type: "openai".to_string(),
|
||||
name: "default".to_string(),
|
||||
base_url: "http://localhost".to_string(),
|
||||
api_key: "test-key".to_string(),
|
||||
extra_headers: HashMap::new(),
|
||||
llm_timeout_secs: 30,
|
||||
model_id: "test-model".to_string(),
|
||||
temperature: Some(0.0),
|
||||
max_tokens: None,
|
||||
model_extra: HashMap::new(),
|
||||
token_limit: 4096,
|
||||
max_tool_iterations: 4,
|
||||
};
|
||||
let session_manager = SessionManager::new(
|
||||
4,
|
||||
100,
|
||||
false,
|
||||
provider_config.clone(),
|
||||
HashMap::from([("default".to_string(), provider_config)]),
|
||||
Arc::new(SkillRuntime::default()),
|
||||
)
|
||||
.unwrap();
|
||||
let scheduler = Scheduler::new(
|
||||
MessageBus::new(8),
|
||||
SchedulerConfig {
|
||||
enabled: true,
|
||||
tick_resolution_ms: 1000,
|
||||
worker_queue_capacity: 64,
|
||||
misfire_policy: SchedulerMisfirePolicy::Skip,
|
||||
jobs: Vec::new(),
|
||||
},
|
||||
store.clone(),
|
||||
session_manager,
|
||||
);
|
||||
|
||||
scheduler.process_tick().await.unwrap();
|
||||
|
||||
let saved = store.get_scheduler_job("massage_reminder").unwrap().unwrap();
|
||||
assert!(saved.next_fire_at.is_some());
|
||||
assert_eq!(saved.run_count, 0);
|
||||
assert_eq!(saved.state, SchedulerJobState::Scheduled);
|
||||
}
|
||||
}
|
||||
@ -1,3 +1,4 @@
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
@ -11,11 +12,15 @@ use crate::tools::traits::{Tool, ToolResult};
|
||||
|
||||
pub struct SchedulerManageTool {
|
||||
store: Arc<SessionStore>,
|
||||
known_agents: Arc<HashSet<String>>,
|
||||
}
|
||||
|
||||
impl SchedulerManageTool {
|
||||
pub fn new(store: Arc<SessionStore>) -> Self {
|
||||
Self { store }
|
||||
pub fn new(store: Arc<SessionStore>, known_agents: HashSet<String>) -> Self {
|
||||
Self {
|
||||
store,
|
||||
known_agents: Arc::new(known_agents),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -58,7 +63,7 @@ impl Tool for SchedulerManageTool {
|
||||
},
|
||||
"payload": {
|
||||
"type": "object",
|
||||
"description": "Job payload. agent_task supports prompt, fresh_session, system_prompt, sender_id, metadata. outbound_message expects content. internal_event expects event."
|
||||
"description": "Job payload. agent_task supports prompt, agent, fresh_session, system_prompt, sender_id, metadata. outbound_message expects content. internal_event expects event."
|
||||
},
|
||||
"max_runs": {
|
||||
"type": ["integer", "null"]
|
||||
@ -91,7 +96,7 @@ impl Tool for SchedulerManageTool {
|
||||
}
|
||||
}
|
||||
"put" => {
|
||||
let input = build_upsert(&args)?;
|
||||
let input = build_upsert(&args, &self.known_agents)?;
|
||||
let record = self.store.upsert_scheduler_job(&input)?;
|
||||
record_to_json(&record)
|
||||
}
|
||||
@ -140,7 +145,7 @@ impl Tool for SchedulerManageTool {
|
||||
}
|
||||
}
|
||||
|
||||
fn build_upsert(args: &serde_json::Value) -> anyhow::Result<SchedulerJobUpsert> {
|
||||
fn build_upsert(args: &serde_json::Value, known_agents: &HashSet<String>) -> anyhow::Result<SchedulerJobUpsert> {
|
||||
let id = require_str(args, "id")?.to_string();
|
||||
let kind = require_str(args, "kind")?.to_string();
|
||||
let schedule_value = args
|
||||
@ -158,14 +163,24 @@ fn build_upsert(args: &serde_json::Value) -> anyhow::Result<SchedulerJobUpsert>
|
||||
_ => (0, 0),
|
||||
};
|
||||
|
||||
let payload = args.get("payload").cloned().unwrap_or_else(|| json!({}));
|
||||
let target = args.get("target").cloned().unwrap_or_else(|| json!({}));
|
||||
if kind == "agent_task" {
|
||||
validate_agent_task_payload(&payload, known_agents)?;
|
||||
validate_target_fields(&target, &["channel", "chat_id"], "agent_task")?;
|
||||
} else if kind == "outbound_message" {
|
||||
validate_outbound_message_payload(&payload)?;
|
||||
validate_target_fields(&target, &["channel", "chat_id"], "outbound_message")?;
|
||||
}
|
||||
|
||||
Ok(SchedulerJobUpsert {
|
||||
id,
|
||||
kind,
|
||||
schedule: serde_json::to_value(schedule)?,
|
||||
interval_secs,
|
||||
startup_delay_secs,
|
||||
target: args.get("target").cloned().unwrap_or_else(|| json!({})),
|
||||
payload: args.get("payload").cloned().unwrap_or_else(|| json!({})),
|
||||
target,
|
||||
payload,
|
||||
enabled: args.get("enabled").and_then(|value| value.as_bool()).unwrap_or(true),
|
||||
state: if args.get("enabled").and_then(|value| value.as_bool()).unwrap_or(true) {
|
||||
SchedulerJobState::Scheduled
|
||||
@ -183,6 +198,57 @@ fn build_upsert(args: &serde_json::Value) -> anyhow::Result<SchedulerJobUpsert>
|
||||
})
|
||||
}
|
||||
|
||||
fn validate_agent_task_payload(payload: &serde_json::Value, known_agents: &HashSet<String>) -> anyhow::Result<()> {
|
||||
let Some(prompt) = payload.get("prompt").and_then(|value| value.as_str()) else {
|
||||
anyhow::bail!("agent_task payload.prompt is required and must be a string")
|
||||
};
|
||||
if prompt.trim().is_empty() {
|
||||
anyhow::bail!("agent_task payload.prompt cannot be empty")
|
||||
}
|
||||
|
||||
let Some(agent_name) = payload.get("agent").and_then(|value| value.as_str()) else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let normalized = agent_name.trim();
|
||||
if normalized.is_empty() || normalized == "default" || known_agents.contains(normalized) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
anyhow::bail!("Unknown agent '{}' for agent_task payload.agent", normalized)
|
||||
}
|
||||
|
||||
fn validate_outbound_message_payload(payload: &serde_json::Value) -> anyhow::Result<()> {
|
||||
let Some(content) = payload.get("content").and_then(|value| value.as_str()) else {
|
||||
anyhow::bail!("outbound_message payload.content is required and must be a string")
|
||||
};
|
||||
if content.trim().is_empty() {
|
||||
anyhow::bail!("outbound_message payload.content cannot be empty")
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_target_fields(
|
||||
target: &serde_json::Value,
|
||||
required_fields: &[&str],
|
||||
kind: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let object = target
|
||||
.as_object()
|
||||
.ok_or_else(|| anyhow::anyhow!("{} target must be an object", kind))?;
|
||||
|
||||
for field in required_fields {
|
||||
let Some(value) = object.get(*field).and_then(|value| value.as_str()) else {
|
||||
anyhow::bail!("{} target.{} is required and must be a string", kind, field)
|
||||
};
|
||||
if value.trim().is_empty() {
|
||||
anyhow::bail!("{} target.{} cannot be empty", kind, field)
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn record_to_json(record: &SchedulerJobRecord) -> serde_json::Value {
|
||||
json!({
|
||||
"id": record.id,
|
||||
@ -255,7 +321,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_scheduler_manage_put_and_get() {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let tool = SchedulerManageTool::new(store);
|
||||
let tool = SchedulerManageTool::new(store, HashSet::new());
|
||||
|
||||
let put_result = tool
|
||||
.execute(json!({
|
||||
@ -293,7 +359,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_scheduler_manage_put_agent_task() {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let tool = SchedulerManageTool::new(store);
|
||||
let tool = SchedulerManageTool::new(store, HashSet::from(["planner".to_string()]));
|
||||
|
||||
let put_result = tool
|
||||
.execute(json!({
|
||||
@ -309,7 +375,8 @@ mod tests {
|
||||
"chat_id": "oc_demo"
|
||||
},
|
||||
"payload": {
|
||||
"prompt": "请总结今天待办"
|
||||
"prompt": "请总结今天待办",
|
||||
"agent": "planner"
|
||||
}
|
||||
}))
|
||||
.await
|
||||
@ -318,4 +385,59 @@ mod tests {
|
||||
assert!(put_result.success);
|
||||
assert!(put_result.output.contains("agent_task"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_scheduler_manage_rejects_outbound_message_without_target() {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let tool = SchedulerManageTool::new(store, HashSet::new());
|
||||
|
||||
let put_result = tool
|
||||
.execute(json!({
|
||||
"action": "put",
|
||||
"id": "massage_reminder",
|
||||
"kind": "outbound_message",
|
||||
"schedule": {
|
||||
"type": "interval",
|
||||
"seconds": 60
|
||||
},
|
||||
"payload": {
|
||||
"content": "⏰ 时间到了!该去按摩了!💆"
|
||||
}
|
||||
}))
|
||||
.await;
|
||||
|
||||
assert!(put_result.is_err());
|
||||
let error = put_result.err().unwrap().to_string();
|
||||
assert!(error.contains("outbound_message target.channel is required"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_scheduler_manage_rejects_unknown_agent_task_agent() {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let tool = SchedulerManageTool::new(store, HashSet::from(["planner".to_string()]));
|
||||
|
||||
let put_result = tool
|
||||
.execute(json!({
|
||||
"action": "put",
|
||||
"id": "agent.daily_summary",
|
||||
"kind": "agent_task",
|
||||
"schedule": {
|
||||
"type": "cron",
|
||||
"expression": "0 9 * * *"
|
||||
},
|
||||
"target": {
|
||||
"channel": "feishu",
|
||||
"chat_id": "oc_demo"
|
||||
},
|
||||
"payload": {
|
||||
"prompt": "请总结今天待办",
|
||||
"agent": "missing-agent"
|
||||
}
|
||||
}))
|
||||
.await;
|
||||
|
||||
assert!(put_result.is_err());
|
||||
let error = put_result.err().unwrap().to_string();
|
||||
assert!(error.contains("Unknown agent 'missing-agent'"));
|
||||
}
|
||||
}
|
||||
@ -19,6 +19,7 @@ fn load_config() -> Option<LLMProviderConfig> {
|
||||
base_url: openai_base_url,
|
||||
api_key: openai_api_key,
|
||||
extra_headers: HashMap::new(),
|
||||
llm_timeout_secs: 120,
|
||||
model_id: openai_model,
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(100),
|
||||
|
||||
@ -19,6 +19,7 @@ fn load_openai_config() -> Option<LLMProviderConfig> {
|
||||
base_url: openai_base_url,
|
||||
api_key: openai_api_key,
|
||||
extra_headers: HashMap::new(),
|
||||
llm_timeout_secs: 120,
|
||||
model_id: openai_model,
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(100),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user