feat: add llm_timeout_secs to provider configuration and implement timeout handling

- Introduced llm_timeout_secs in ProviderConfig and LLMProviderConfig to specify timeout for LLM requests.
- Updated OpenAIProvider and AnthropicProvider to utilize the timeout setting when creating HTTP clients.
- Enhanced error handling for API responses to include timeout information.
- Modified SessionManager to support agent-specific provider configurations, allowing for more flexible agent management.
- Added tests to verify the correct behavior of timeout settings and agent task validation.
This commit is contained in:
ooodc 2026-04-23 09:23:15 +08:00
parent 1ffdcab585
commit f3f369b329
14 changed files with 746 additions and 80 deletions

View File

@ -126,6 +126,7 @@ Config example:
},
"payload": {
"prompt": "请总结今天的项目进展,并列出明天的优先事项",
"agent": "default",
"fresh_session": true,
"system_prompt": "你是日报助手,输出时先给摘要,再给待办。",
"sender_id": "scheduler-daily-summary",
@ -161,6 +162,7 @@ Runtime management:
- agent_task reuses the normal agent pipeline: it creates a synthetic user turn from payload.prompt and runs tools, persistence, and outbound rendering through SessionManager.
- agent_task payload fields:
- prompt: required, synthetic user input.
- agent: optional, choose which configured agent definition to use. default or any configured agent name.
- fresh_session: optional, when true reset the active chat segment before running.
- system_prompt: optional, append a task-specific system message before the synthetic user turn.
- sender_id: optional, overrides the synthetic sender id used for tool context and memory scoping.

View File

@ -185,6 +185,7 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
- 内置 `internal_event` 当前包含 `session_cleanup`,用于回收超时的内存 session 缓存。
- `agent_task` 会把 `payload.prompt` 作为一次合成用户输入,交给 `SessionManager::run_scheduled_agent_task()` 执行,因此会复用持久化历史、工具调用和渠道下发链路。
- `payload.fresh_session = true` 时,会先对目标 chat 执行一次逻辑 reset再开始本次任务运行。
- `payload.agent` 可指定本次任务使用哪一个已配置 agent未指定时仍使用 `default`
- `payload.system_prompt` 会作为额外 system 消息写入本次任务上下文。
- `payload.sender_id` 会覆盖默认的 `scheduler` 发送者标识。
- `payload.metadata` 会映射到 outbound metadata便于渠道侧做追踪或特殊处理。

View File

@ -24,6 +24,7 @@ const TRUNCATION_SUFFIX_LEN: usize = 200;
const MEMORY_TOOL_USAGE_SYSTEM_PROMPT: &str = "在绝大多数请求开始时,你都应先使用长期记忆检索工具 memory_search 来召回相关上下文,然后再决定如何回答或是否需要写入记忆。默认流程是:先用 memory_search(action='search');只有在你已经明确知道 namespace 和 key 时才改用 get只有在需要浏览最近几条记忆时才用 list。即使用户没有明确提到“记忆”或“偏好”只要请求可能与用户长期偏好、稳定事实、历史决策、持续任务或项目上下文有关就应先搜记忆。仅以下少数情况可跳过记忆搜索纯寒暄、一次性简单计算、完全不依赖用户历史的直接事实问答。写入或修改记忆时再使用 memory_manage。仅在遇到高价值且未来仍有用的信息时写入记忆用户长期偏好、稳定事实、用户对你的纠正、持续任务/项目上下文、明确决策。不要保存一次性工具结果、临时列表、敏感凭证或不确定推测。写入时优先使用规范 namespacepreferences、profile、tasks、decisions并优先调用 memory_manage(action='put');同一 namespace/key 可直接覆盖更新。检索时应提供 queries 数组,尽量同时放入中文关键词、英文别名,以及可能的 snake_case memory_key 词,例如 queries=['email', '邮件', 'email_folder_preference']。如果你决定跳过记忆搜索,应先确认当前请求确实属于上述少数例外,而不是因为你忘了检索。";
const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str = "工具已经启动并进入等待用户操作的状态。请先完成外部操作,完成后直接告诉我继续。";
const RECOVERABLE_LLM_ERROR_MESSAGE: &str = "模型服务暂时不可用或响应超时。请稍后重试。";
/// Build content blocks from text and media paths
fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock> {
@ -99,6 +100,23 @@ fn parse_pending_tool_output(output: &str) -> Option<String> {
output.strip_prefix(PENDING_USER_ACTION_MARKER).map(|rest| rest.trim().to_string())
}
fn is_recoverable_llm_error(error: &str) -> bool {
let normalized = error.to_ascii_lowercase();
normalized.contains("504")
|| normalized.contains("gateway timeout")
|| normalized.contains("stream timeout")
|| normalized.contains("timed out")
|| normalized.contains("timeout")
}
fn recoverable_llm_message(error: &str) -> String {
if is_recoverable_llm_error(error) {
RECOVERABLE_LLM_ERROR_MESSAGE.to_string()
} else {
format!("模型请求失败:{}", error)
}
}
/// Loop detection result.
#[derive(Debug, Clone, PartialEq, Eq)]
enum LoopDetectionResult {
@ -386,11 +404,18 @@ impl AgentLoop {
};
// Call LLM
let response = (*self.provider).chat(request).await
.map_err(|e| {
let response = match (*self.provider).chat(request).await {
Ok(response) => response,
Err(e) => {
tracing::error!(error = %e, "LLM request failed");
AgentError::LlmError(e.to_string())
})?;
let assistant_message = ChatMessage::assistant(recoverable_llm_message(&e.to_string()));
emitted_messages.push(assistant_message.clone());
return Ok(AgentProcessResult {
final_response: assistant_message,
emitted_messages,
});
}
};
#[cfg(debug_assertions)]
tracing::debug!(
@ -539,11 +564,8 @@ impl AgentLoop {
})
}
Err(e) => {
// Fallback if summary call fails
tracing::error!(error = %e, "Failed to get summary from LLM");
let final_message = ChatMessage::assistant(
format!("I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.", self.max_iterations)
);
let final_message = ChatMessage::assistant(recoverable_llm_message(&e.to_string()));
emitted_messages.push(final_message.clone());
Ok(AgentProcessResult {
final_response: final_message,

View File

@ -41,6 +41,7 @@ impl Default for ContextCompressionConfig {
}
/// Context compressor that reduces message history when it exceeds token limits.
#[derive(Clone)]
pub struct ContextCompressor {
config: ContextCompressionConfig,
context_window: usize,

View File

@ -101,6 +101,8 @@ pub struct ProviderConfig {
pub api_key: String,
#[serde(default)]
pub extra_headers: HashMap<String, String>,
#[serde(default = "default_llm_timeout_secs")]
pub llm_timeout_secs: u64,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
@ -132,6 +134,10 @@ fn default_token_limit() -> usize {
128_000
}
fn default_llm_timeout_secs() -> u64 {
120
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct GatewayConfig {
#[serde(default = "default_gateway_host")]
@ -400,6 +406,7 @@ pub struct LLMProviderConfig {
pub base_url: String,
pub api_key: String,
pub extra_headers: HashMap<String, String>,
pub llm_timeout_secs: u64,
pub model_id: String,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
@ -461,6 +468,7 @@ impl Config {
base_url: provider.base_url.clone(),
api_key: provider.api_key.clone(),
extra_headers: provider.extra_headers.clone(),
llm_timeout_secs: provider.llm_timeout_secs,
model_id: model.model_id.clone(),
temperature: model.temperature,
max_tokens: model.max_tokens,
@ -601,6 +609,43 @@ mod tests {
assert_eq!(provider_config.name, "aliyun");
assert_eq!(provider_config.model_id, "qwen-plus");
assert_eq!(provider_config.temperature, Some(0.0));
assert_eq!(provider_config.llm_timeout_secs, 120);
}
#[test]
fn test_provider_config_loads_custom_llm_timeout() {
let file = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
file.path(),
r#"{
"providers": {
"aliyun": {
"type": "openai",
"base_url": "https://example.invalid/v1",
"api_key": "test-key",
"extra_headers": {},
"llm_timeout_secs": 400
}
},
"models": {
"qwen-plus": {
"model_id": "qwen-plus"
}
},
"agents": {
"default": {
"provider": "aliyun",
"model": "qwen-plus"
}
}
}"#,
)
.unwrap();
let config = Config::load(file.path().to_str().unwrap()).unwrap();
let provider_config = config.get_provider_config("default").unwrap();
assert_eq!(provider_config.llm_timeout_secs, 400);
}
#[test]

View File

@ -2,6 +2,7 @@ pub mod http;
pub mod session;
pub mod ws;
use std::collections::HashMap;
use std::sync::Arc;
use axum::{routing, Router};
use tokio::net::TcpListener;
@ -9,6 +10,7 @@ use tokio::net::TcpListener;
use crate::bus::{MessageBus, OutboundDispatcher};
use crate::channels::ChannelManager;
use crate::config::Config;
use crate::config::LLMProviderConfig;
use crate::logging;
use crate::scheduler::Scheduler;
use crate::skills::SkillRuntime;
@ -27,6 +29,10 @@ impl GatewayState {
// Get provider config for SessionManager
let provider_config = config.get_provider_config("default")?;
let mut provider_configs = HashMap::<String, LLMProviderConfig>::new();
for agent_name in config.agents.keys() {
provider_configs.insert(agent_name.clone(), config.get_provider_config(agent_name)?);
}
// Session TTL from config (default 4 hours)
let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4);
@ -40,6 +46,7 @@ impl GatewayState {
agent_prompt_reinject_every,
show_tool_results,
provider_config,
provider_configs,
skills,
)?;
let channel_manager = ChannelManager::new();

View File

@ -1,4 +1,4 @@
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::fs;
use std::sync::Arc;
use std::time::{Duration, Instant};
@ -49,6 +49,7 @@ pub struct ScheduledAgentTaskOptions {
pub fresh_session: bool,
pub system_prompt: Option<String>,
pub metadata: HashMap<String, String>,
pub agent: Option<String>,
}
impl BusToolCallEmitter {
@ -244,6 +245,23 @@ impl Session {
}
}
fn latest_user_message_id(&self, chat_id: &str) -> Option<&str> {
self.get_history(chat_id)
.and_then(|history| {
history
.iter()
.rev()
.find(|message| message.role == "user")
.map(|message| message.id.as_str())
})
}
fn is_latest_user_message(&self, chat_id: &str, message_id: &str) -> bool {
self.latest_user_message_id(chat_id)
.map(|current_id| current_id == message_id)
.unwrap_or(false)
}
/// 清除所有历史
pub fn clear_all_history(&mut self) -> Result<(), AgentError> {
let chat_ids: Vec<String> = self.chat_histories.keys().cloned().collect();
@ -296,10 +314,20 @@ impl Session {
chat_id: &str,
sender_id: Option<&str>,
message_id: Option<&str>,
) -> Result<AgentLoop, AgentError> {
self.create_agent_with_provider_config(chat_id, sender_id, message_id, self.provider_config.clone())
}
pub fn create_agent_with_provider_config(
&self,
chat_id: &str,
sender_id: Option<&str>,
message_id: Option<&str>,
provider_config: LLMProviderConfig,
) -> Result<AgentLoop, AgentError> {
let session_id = self.persistent_session_id(chat_id);
AgentLoop::with_tools_and_skills(
self.provider_config.clone(),
provider_config,
self.tools.clone(),
self.skills.clone(),
)
@ -368,6 +396,7 @@ fn agent_prompt_path() -> Result<std::path::PathBuf, AgentError> {
pub struct SessionManager {
inner: Arc<Mutex<SessionManagerInner>>,
provider_config: LLMProviderConfig,
provider_configs: Arc<HashMap<String, LLMProviderConfig>>,
tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>,
store: Arc<SessionStore>,
@ -381,7 +410,7 @@ struct SessionManagerInner {
session_ttl: Duration,
}
fn default_tools(skills: Arc<SkillRuntime>, store: Arc<SessionStore>) -> ToolRegistry {
fn default_tools(skills: Arc<SkillRuntime>, store: Arc<SessionStore>, known_agents: HashSet<String>) -> ToolRegistry {
let mut registry = ToolRegistry::new();
registry.register(CalculatorTool::new());
registry.register(FileReadTool::new());
@ -389,7 +418,7 @@ fn default_tools(skills: Arc<SkillRuntime>, store: Arc<SessionStore>) -> ToolReg
registry.register(FileEditTool::new());
registry.register(MemorySearchTool::new(store.clone()));
registry.register(MemoryManageTool::new(store.clone()));
registry.register(SchedulerManageTool::new(store));
registry.register(SchedulerManageTool::new(store, known_agents));
registry.register(SkillListTool::new(skills.clone()));
registry.register(SkillManageTool::new(skills));
registry.register(BashTool::new());
@ -435,12 +464,14 @@ impl SessionManager {
agent_prompt_reinject_every: u64,
show_tool_results: bool,
provider_config: LLMProviderConfig,
provider_configs: HashMap<String, LLMProviderConfig>,
skills: Arc<SkillRuntime>,
) -> Result<Self, AgentError> {
let store = Arc::new(
SessionStore::new()
.map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?,
);
let known_agents = provider_configs.keys().cloned().collect::<HashSet<_>>();
if let Err(err) = store.append_skill_event(None, "discovered", None, &skills.discovery_event_payload()) {
tracing::warn!(error = %err, "Failed to record skill discovery event");
@ -453,7 +484,8 @@ impl SessionManager {
session_ttl: Duration::from_secs(session_ttl_hours * 3600),
})),
provider_config,
tools: Arc::new(default_tools(skills.clone(), store.clone())),
provider_configs: Arc::new(provider_configs),
tools: Arc::new(default_tools(skills.clone(), store.clone(), known_agents)),
skills,
store,
agent_prompt_reinject_every,
@ -477,6 +509,10 @@ impl SessionManager {
self.skills.clone()
}
pub fn provider_config_for_agent(&self, agent_name: Option<&str>) -> Result<LLMProviderConfig, AgentError> {
select_provider_config(&self.provider_config, &self.provider_configs, agent_name)
}
pub fn create_cli_session(&self, title: Option<&str>) -> Result<SessionRecord, AgentError> {
self.store
.create_cli_session(title)
@ -640,7 +676,7 @@ impl SessionManager {
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
// 处理消息
let response = {
let (history, compressor, provider_config, agent, user_message_id) = {
let mut session_guard = session.lock().await;
session_guard.ensure_persistent_session(chat_id)?;
@ -669,11 +705,8 @@ impl SessionManager {
session_guard.append_persisted_message(chat_id, user_message)?;
let history = session_guard.get_or_create_history(chat_id).clone();
// 压缩历史(如果需要)
let history = session_guard.compressor
.compress_if_needed(history, &session_guard.provider_config)
.await?;
let compressor = session_guard.compressor().clone();
let provider_config = session_guard.provider_config().clone();
session_guard.record_skill_offer(chat_id)?;
@ -682,8 +715,27 @@ impl SessionManager {
if let Some(handler) = live_emitter.clone() {
agent = agent.with_emitted_message_handler(handler);
}
(history, compressor, provider_config, agent, user_message_id)
};
let history = compressor
.compress_if_needed(history, &provider_config)
.await?;
let result = agent.process(history).await?;
let response = {
let mut session_guard = session.lock().await;
if !session_guard.is_latest_user_message(chat_id, &user_message_id) {
tracing::warn!(
channel = %channel_name,
chat_id = %chat_id,
user_message_id = %user_message_id,
"Skipping stale agent result because a newer user message is already present"
);
Vec::new()
} else {
// 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复
session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?;
@ -704,6 +756,7 @@ impl SessionManager {
)
})
.collect::<Vec<_>>()
}
};
#[cfg(debug_assertions)]
@ -736,8 +789,9 @@ impl SessionManager {
.sender_id
.clone()
.unwrap_or_else(|| "scheduler".to_string());
let provider_config = self.provider_config_for_agent(options.agent.as_deref())?;
let response = {
let (history, compressor, agent, user_message_id) = {
let mut session_guard = session.lock().await;
session_guard.ensure_persistent_session(chat_id)?;
@ -758,15 +812,37 @@ impl SessionManager {
session_guard.append_persisted_message(chat_id, user_message)?;
let history = session_guard.get_or_create_history(chat_id).clone();
let history = session_guard
.compressor
.compress_if_needed(history, &session_guard.provider_config)
.await?;
let compressor = session_guard.compressor().clone();
session_guard.record_skill_offer(chat_id)?;
let agent = session_guard.create_agent(chat_id, Some(&sender_id), Some(&user_message_id))?;
let agent = session_guard.create_agent_with_provider_config(
chat_id,
Some(&sender_id),
Some(&user_message_id),
provider_config.clone(),
)?;
(history, compressor, agent, user_message_id)
};
let history = compressor
.compress_if_needed(history, &provider_config)
.await?;
let result = agent.process(history).await?;
let response = {
let mut session_guard = session.lock().await;
if !session_guard.is_latest_user_message(chat_id, &user_message_id) {
tracing::warn!(
channel = %channel_name,
chat_id = %chat_id,
user_message_id = %user_message_id,
"Skipping stale scheduled agent result because a newer user message is already present"
);
Vec::new()
} else {
session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?;
result
@ -783,6 +859,7 @@ impl SessionManager {
)
})
.collect::<Vec<_>>()
}
};
Ok(response)
@ -810,11 +887,29 @@ fn should_display_message_to_user(show_tool_results: bool, message: &ChatMessage
)
}
fn select_provider_config(
default_provider_config: &LLMProviderConfig,
provider_configs: &HashMap<String, LLMProviderConfig>,
agent_name: Option<&str>,
) -> Result<LLMProviderConfig, AgentError> {
match agent_name.map(str::trim).filter(|value| !value.is_empty()) {
None | Some("default") => Ok(default_provider_config.clone()),
Some(agent_name) => provider_configs
.get(agent_name)
.cloned()
.ok_or_else(|| AgentError::Other(format!("Scheduled agent '{}' not found", agent_name))),
}
}
#[cfg(test)]
mod tests {
use axum::http::StatusCode;
use super::*;
use axum::{Json, Router, routing::post};
use crate::bus::MessageBus;
use serde_json::{Value, json};
use std::collections::HashMap;
use tokio::net::TcpListener;
use tokio::sync::mpsc;
fn test_provider_config() -> LLMProviderConfig {
@ -824,6 +919,7 @@ mod tests {
base_url: "http://localhost".to_string(),
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
llm_timeout_secs: 120,
model_id: "test-model".to_string(),
temperature: Some(0.0),
max_tokens: Some(32),
@ -833,6 +929,235 @@ mod tests {
}
}
fn test_provider_config_named(name: &str, model_id: &str) -> LLMProviderConfig {
LLMProviderConfig {
provider_type: "openai".to_string(),
name: name.to_string(),
base_url: "http://localhost".to_string(),
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
llm_timeout_secs: 120,
model_id: model_id.to_string(),
temperature: Some(0.0),
max_tokens: Some(32),
model_extra: HashMap::new(),
max_tool_iterations: 1,
token_limit: 4096,
}
}
#[test]
fn test_select_provider_config_uses_named_agent_override() {
let default_provider = test_provider_config_named("default-provider", "default-model");
let provider_configs = HashMap::from([(
"planner".to_string(),
test_provider_config_named("planner-provider", "planner-model"),
)]);
let selected = select_provider_config(&default_provider, &provider_configs, Some("planner")).unwrap();
assert_eq!(selected.name, "planner-provider");
assert_eq!(selected.model_id, "planner-model");
}
#[tokio::test]
async fn test_latest_user_message_guard_tracks_current_turn() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(default_tools(skills.clone(), store.clone(), HashSet::new()));
let mut session = Session::new(
"feishu".to_string(),
test_provider_config(),
user_tx,
tools,
skills,
store,
100,
)
.await
.unwrap();
session.ensure_persistent_session("chat-1").unwrap();
session.ensure_chat_loaded("chat-1").unwrap();
let first = session.create_user_message("first", Vec::new());
let first_id = first.id.clone();
session.append_persisted_message("chat-1", first).unwrap();
assert!(session.is_latest_user_message("chat-1", &first_id));
let second = session.create_user_message("second", Vec::new());
let second_id = second.id.clone();
session.append_persisted_message("chat-1", second).unwrap();
assert!(!session.is_latest_user_message("chat-1", &first_id));
assert!(session.is_latest_user_message("chat-1", &second_id));
}
#[test]
fn test_select_provider_config_falls_back_to_default() {
let default_provider = test_provider_config_named("default-provider", "default-model");
let provider_configs = HashMap::new();
let selected = select_provider_config(&default_provider, &provider_configs, Some("default")).unwrap();
assert_eq!(selected.name, "default-provider");
assert_eq!(selected.model_id, "default-model");
}
async fn start_mock_openai_server() -> String {
async fn handle(Json(body): Json<Value>) -> Json<Value> {
let model = body
.get("model")
.and_then(|value| value.as_str())
.unwrap_or("unknown-model");
Json(json!({
"id": "mock-response",
"model": model,
"choices": [
{
"message": {
"content": format!("reply from {}", model),
"tool_calls": []
}
}
],
"usage": {
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2
}
}))
}
let app = Router::new().route("/chat/completions", post(handle));
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let address = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
format!("http://{}", address)
}
async fn start_mock_openai_504_server() -> String {
async fn handle() -> (StatusCode, &'static str) {
(StatusCode::GATEWAY_TIMEOUT, "stream timeout")
}
let app = Router::new().route("/chat/completions", post(handle));
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let address = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
format!("http://{}", address)
}
#[tokio::test]
async fn test_handle_message_returns_recoverable_reply_on_llm_504() {
let base_url = start_mock_openai_504_server().await;
let provider_config = LLMProviderConfig {
provider_type: "openai".to_string(),
name: "timeout-provider".to_string(),
base_url: base_url.clone(),
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
model_id: "timeout-model".to_string(),
temperature: Some(0.0),
max_tokens: Some(32),
model_extra: HashMap::new(),
max_tool_iterations: 1,
token_limit: 4096,
llm_timeout_secs: 30,
};
let session_manager = SessionManager::new(
4,
100,
false,
provider_config.clone(),
HashMap::from([("default".to_string(), provider_config)]),
Arc::new(SkillRuntime::default()),
)
.unwrap();
let outbound = session_manager
.handle_message("feishu", "user-1", "chat-1", "hello", Vec::new(), None)
.await
.unwrap();
assert_eq!(outbound.len(), 1);
assert!(outbound[0].content.contains("模型服务暂时不可用或响应超时"));
}
#[tokio::test]
async fn test_run_scheduled_agent_task_uses_task_specific_agent_provider() {
let base_url = start_mock_openai_server().await;
let default_provider = LLMProviderConfig {
provider_type: "openai".to_string(),
name: "default-provider".to_string(),
base_url: base_url.clone(),
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
model_id: "default-model".to_string(),
temperature: Some(0.0),
max_tokens: Some(32),
model_extra: HashMap::new(),
max_tool_iterations: 1,
token_limit: 4096,
llm_timeout_secs: 30,
};
let planner_provider = LLMProviderConfig {
model_id: "planner-model".to_string(),
name: "planner-provider".to_string(),
..default_provider.clone()
};
let session_manager = SessionManager::new(
4,
100,
false,
default_provider.clone(),
HashMap::from([
("default".to_string(), default_provider),
("planner".to_string(), planner_provider),
]),
Arc::new(SkillRuntime::default()),
)
.unwrap();
let planner_outbound = session_manager
.run_scheduled_agent_task(
"feishu",
"chat-planner",
"请规划今天工作",
ScheduledAgentTaskOptions {
agent: Some("planner".to_string()),
fresh_session: true,
..Default::default()
},
)
.await
.unwrap();
assert_eq!(planner_outbound.len(), 1);
assert!(planner_outbound[0].content.contains("planner-model"));
let default_outbound = session_manager
.run_scheduled_agent_task(
"feishu",
"chat-default",
"请规划今天工作",
ScheduledAgentTaskOptions {
agent: Some("default".to_string()),
fresh_session: true,
..Default::default()
},
)
.await
.unwrap();
assert_eq!(default_outbound.len(), 1);
assert!(default_outbound[0].content.contains("default-model"));
}
#[test]
fn test_should_display_message_to_user_hides_completed_tool_results_by_default() {
let completed = ChatMessage::tool("call-1", "calculator", "2");
@ -881,7 +1206,7 @@ mod tests {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
let tools = Arc::new(default_tools(skills.clone(), store.clone(), HashSet::new()));
let mut session = Session::new(
"feishu".to_string(),
test_provider_config(),
@ -929,7 +1254,7 @@ mod tests {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
let tools = Arc::new(default_tools(skills.clone(), store.clone(), HashSet::new()));
let mut session = Session::new(
"feishu".to_string(),
test_provider_config(),
@ -956,7 +1281,7 @@ mod tests {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
let tools = Arc::new(default_tools(skills.clone(), store.clone(), HashSet::new()));
let mut session = Session::new(
"feishu".to_string(),
test_provider_config(),
@ -1001,7 +1326,7 @@ mod tests {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
let tools = Arc::new(default_tools(skills.clone(), store.clone(), HashSet::new()));
let mut session = Session::new(
"feishu".to_string(),
test_provider_config(),
@ -1035,7 +1360,7 @@ mod tests {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
let tools = Arc::new(default_tools(skills.clone(), store.clone(), HashSet::new()));
let mut session = Session::new(
"feishu".to_string(),
test_provider_config(),

View File

@ -2,6 +2,7 @@ use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
use crate::bus::message::ContentBlock;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
@ -58,6 +59,7 @@ pub struct AnthropicProvider {
api_key: String,
base_url: String,
extra_headers: HashMap<String, String>,
llm_timeout_secs: u64,
model_id: String,
temperature: Option<f32>,
max_tokens: Option<u32>,
@ -70,17 +72,24 @@ impl AnthropicProvider {
api_key: String,
base_url: String,
extra_headers: HashMap<String, String>,
llm_timeout_secs: u64,
model_id: String,
temperature: Option<f32>,
max_tokens: Option<u32>,
model_extra: HashMap<String, serde_json::Value>,
) -> Self {
let client = Client::builder()
.timeout(Duration::from_secs(llm_timeout_secs))
.build()
.unwrap_or_else(|_| Client::new());
Self {
client: Client::new(),
client,
name,
api_key,
base_url,
extra_headers,
llm_timeout_secs,
model_id,
temperature,
max_tokens,
@ -190,8 +199,21 @@ impl LLMProvider for AnthropicProvider {
}
let resp = req_builder.json(&body).send().await?;
let status = resp.status();
let text = resp.text().await?;
let anthropic_resp: AnthropicResponse = resp.json().await?;
if !status.is_success() {
return Err(format!("API error {}: {}", status, text).into());
}
#[cfg(debug_assertions)]
{
let resp_preview: String = text.chars().take(100).collect();
tracing::debug!(status = %status, response_preview = %resp_preview, response_len = %text.len(), timeout_secs = self.llm_timeout_secs, "Anthropic response (first 100 chars shown)");
}
let anthropic_resp: AnthropicResponse = serde_json::from_str(&text)
.map_err(|e| format!("decode error: {} | body: {}", e, &text))?;
let mut content = String::new();
let mut tool_calls = Vec::new();

View File

@ -15,6 +15,7 @@ pub fn create_provider(config: LLMProviderConfig) -> Result<Box<dyn LLMProvider>
config.api_key,
config.base_url,
config.extra_headers,
config.llm_timeout_secs,
config.model_id,
config.temperature,
config.max_tokens,
@ -25,6 +26,7 @@ pub fn create_provider(config: LLMProviderConfig) -> Result<Box<dyn LLMProvider>
config.api_key,
config.base_url,
config.extra_headers,
config.llm_timeout_secs,
config.model_id,
config.temperature,
config.max_tokens,

View File

@ -3,6 +3,7 @@ use reqwest::Client;
use serde::Deserialize;
use serde_json::{json, Value};
use std::collections::HashMap;
use std::time::Duration;
use crate::bus::message::ContentBlock;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
@ -28,6 +29,7 @@ pub struct OpenAIProvider {
api_key: String,
base_url: String,
extra_headers: HashMap<String, String>,
llm_timeout_secs: u64,
model_id: String,
temperature: Option<f32>,
max_tokens: Option<u32>,
@ -40,17 +42,24 @@ impl OpenAIProvider {
api_key: String,
base_url: String,
extra_headers: HashMap<String, String>,
llm_timeout_secs: u64,
model_id: String,
temperature: Option<f32>,
max_tokens: Option<u32>,
model_extra: HashMap<String, serde_json::Value>,
) -> Self {
let client = Client::builder()
.timeout(Duration::from_secs(llm_timeout_secs))
.build()
.unwrap_or_else(|_| Client::new());
Self {
client: Client::new(),
client,
name,
api_key,
base_url,
extra_headers,
llm_timeout_secs,
model_id,
temperature,
max_tokens,
@ -209,7 +218,7 @@ impl LLMProvider for OpenAIProvider {
#[cfg(debug_assertions)]
{
let resp_preview: String = text.chars().take(100).collect();
tracing::debug!(status = %status, response_preview = %resp_preview, response_len = %text.len(), "LLM response (first 100 chars shown)");
tracing::debug!(status = %status, response_preview = %resp_preview, response_len = %text.len(), timeout_secs = self.llm_timeout_secs, "LLM response (first 100 chars shown)");
}
if !status.is_success() {
@ -275,6 +284,7 @@ mod tests {
"key".to_string(),
"https://example.com/v1".to_string(),
HashMap::new(),
120,
"gpt-test".to_string(),
None,
None,

View File

@ -88,6 +88,20 @@ impl Scheduler {
continue;
};
if record.next_fire_at.is_none() && job.next_fire_at.is_some() {
self.store.update_scheduler_job_runtime(
&job.id,
job.state.clone(),
job.last_status.clone(),
job.last_error.as_deref(),
job.run_count,
job.last_fired_at,
job.next_fire_at,
job.paused_at,
job.completed_at,
)?;
}
if !job.is_due(now) {
continue;
}
@ -517,6 +531,11 @@ fn parse_scheduled_agent_task_options(job: &RuntimeJob) -> anyhow::Result<Schedu
.get("system_prompt")
.and_then(|value| value.as_str())
.map(ToString::to_string);
let agent = job
.payload
.get("agent")
.and_then(|value| value.as_str())
.map(ToString::to_string);
let metadata = parse_metadata_map(job.payload.get("metadata"))?;
Ok(ScheduledAgentTaskOptions {
@ -524,6 +543,7 @@ fn parse_scheduled_agent_task_options(job: &RuntimeJob) -> anyhow::Result<Schedu
fresh_session,
system_prompt,
metadata,
agent,
})
}
@ -616,6 +636,7 @@ mod agent_task_tests {
},
payload: serde_json::json!({
"prompt": "请总结今天待办",
"agent": "planner",
"sender_id": "scheduler-bot",
"fresh_session": true,
"system_prompt": "你是日报助手",
@ -640,6 +661,7 @@ mod agent_task_tests {
};
let options = parse_scheduled_agent_task_options(&job).unwrap();
assert_eq!(options.agent.as_deref(), Some("planner"));
assert_eq!(options.sender_id.as_deref(), Some("scheduler-bot"));
assert!(options.fresh_session);
assert_eq!(options.system_prompt.as_deref(), Some("你是日报助手"));
@ -660,6 +682,12 @@ impl TryFrom<serde_json::Value> for SchedulerJobTarget {
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use crate::bus::MessageBus;
use crate::config::LLMProviderConfig;
use crate::gateway::session::SessionManager;
use crate::skills::SkillRuntime;
use crate::storage::{SchedulerJobUpsert, SessionStore};
#[test]
fn runtime_job_skip_policy_advances_from_now() {
@ -739,4 +767,81 @@ mod tests {
});
assert_eq!(job.next_fire_at, Some(1_700_000_010_000));
}
#[tokio::test]
async fn process_tick_persists_initial_next_fire_at_for_db_created_jobs() {
let store = Arc::new(SessionStore::in_memory().unwrap());
store
.upsert_scheduler_job(&SchedulerJobUpsert {
id: "massage_reminder".to_string(),
kind: "outbound_message".to_string(),
schedule: serde_json::json!({
"type": "interval",
"seconds": 60
}),
interval_secs: 60,
startup_delay_secs: 0,
target: serde_json::json!({
"channel": "feishu",
"chat_id": "oc_demo"
}),
payload: serde_json::json!({
"content": "ping"
}),
enabled: true,
state: SchedulerJobState::Scheduled,
last_status: None,
last_error: None,
run_count: 0,
max_runs: Some(1),
last_fired_at: None,
next_fire_at: None,
paused_at: None,
completed_at: None,
})
.unwrap();
let provider_config = LLMProviderConfig {
provider_type: "openai".to_string(),
name: "default".to_string(),
base_url: "http://localhost".to_string(),
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
llm_timeout_secs: 30,
model_id: "test-model".to_string(),
temperature: Some(0.0),
max_tokens: None,
model_extra: HashMap::new(),
token_limit: 4096,
max_tool_iterations: 4,
};
let session_manager = SessionManager::new(
4,
100,
false,
provider_config.clone(),
HashMap::from([("default".to_string(), provider_config)]),
Arc::new(SkillRuntime::default()),
)
.unwrap();
let scheduler = Scheduler::new(
MessageBus::new(8),
SchedulerConfig {
enabled: true,
tick_resolution_ms: 1000,
worker_queue_capacity: 64,
misfire_policy: SchedulerMisfirePolicy::Skip,
jobs: Vec::new(),
},
store.clone(),
session_manager,
);
scheduler.process_tick().await.unwrap();
let saved = store.get_scheduler_job("massage_reminder").unwrap().unwrap();
assert!(saved.next_fire_at.is_some());
assert_eq!(saved.run_count, 0);
assert_eq!(saved.state, SchedulerJobState::Scheduled);
}
}

View File

@ -1,3 +1,4 @@
use std::collections::HashSet;
use std::sync::Arc;
use async_trait::async_trait;
@ -11,11 +12,15 @@ use crate::tools::traits::{Tool, ToolResult};
pub struct SchedulerManageTool {
store: Arc<SessionStore>,
known_agents: Arc<HashSet<String>>,
}
impl SchedulerManageTool {
pub fn new(store: Arc<SessionStore>) -> Self {
Self { store }
pub fn new(store: Arc<SessionStore>, known_agents: HashSet<String>) -> Self {
Self {
store,
known_agents: Arc::new(known_agents),
}
}
}
@ -58,7 +63,7 @@ impl Tool for SchedulerManageTool {
},
"payload": {
"type": "object",
"description": "Job payload. agent_task supports prompt, fresh_session, system_prompt, sender_id, metadata. outbound_message expects content. internal_event expects event."
"description": "Job payload. agent_task supports prompt, agent, fresh_session, system_prompt, sender_id, metadata. outbound_message expects content. internal_event expects event."
},
"max_runs": {
"type": ["integer", "null"]
@ -91,7 +96,7 @@ impl Tool for SchedulerManageTool {
}
}
"put" => {
let input = build_upsert(&args)?;
let input = build_upsert(&args, &self.known_agents)?;
let record = self.store.upsert_scheduler_job(&input)?;
record_to_json(&record)
}
@ -140,7 +145,7 @@ impl Tool for SchedulerManageTool {
}
}
fn build_upsert(args: &serde_json::Value) -> anyhow::Result<SchedulerJobUpsert> {
fn build_upsert(args: &serde_json::Value, known_agents: &HashSet<String>) -> anyhow::Result<SchedulerJobUpsert> {
let id = require_str(args, "id")?.to_string();
let kind = require_str(args, "kind")?.to_string();
let schedule_value = args
@ -158,14 +163,24 @@ fn build_upsert(args: &serde_json::Value) -> anyhow::Result<SchedulerJobUpsert>
_ => (0, 0),
};
let payload = args.get("payload").cloned().unwrap_or_else(|| json!({}));
let target = args.get("target").cloned().unwrap_or_else(|| json!({}));
if kind == "agent_task" {
validate_agent_task_payload(&payload, known_agents)?;
validate_target_fields(&target, &["channel", "chat_id"], "agent_task")?;
} else if kind == "outbound_message" {
validate_outbound_message_payload(&payload)?;
validate_target_fields(&target, &["channel", "chat_id"], "outbound_message")?;
}
Ok(SchedulerJobUpsert {
id,
kind,
schedule: serde_json::to_value(schedule)?,
interval_secs,
startup_delay_secs,
target: args.get("target").cloned().unwrap_or_else(|| json!({})),
payload: args.get("payload").cloned().unwrap_or_else(|| json!({})),
target,
payload,
enabled: args.get("enabled").and_then(|value| value.as_bool()).unwrap_or(true),
state: if args.get("enabled").and_then(|value| value.as_bool()).unwrap_or(true) {
SchedulerJobState::Scheduled
@ -183,6 +198,57 @@ fn build_upsert(args: &serde_json::Value) -> anyhow::Result<SchedulerJobUpsert>
})
}
fn validate_agent_task_payload(payload: &serde_json::Value, known_agents: &HashSet<String>) -> anyhow::Result<()> {
let Some(prompt) = payload.get("prompt").and_then(|value| value.as_str()) else {
anyhow::bail!("agent_task payload.prompt is required and must be a string")
};
if prompt.trim().is_empty() {
anyhow::bail!("agent_task payload.prompt cannot be empty")
}
let Some(agent_name) = payload.get("agent").and_then(|value| value.as_str()) else {
return Ok(());
};
let normalized = agent_name.trim();
if normalized.is_empty() || normalized == "default" || known_agents.contains(normalized) {
return Ok(());
}
anyhow::bail!("Unknown agent '{}' for agent_task payload.agent", normalized)
}
fn validate_outbound_message_payload(payload: &serde_json::Value) -> anyhow::Result<()> {
let Some(content) = payload.get("content").and_then(|value| value.as_str()) else {
anyhow::bail!("outbound_message payload.content is required and must be a string")
};
if content.trim().is_empty() {
anyhow::bail!("outbound_message payload.content cannot be empty")
}
Ok(())
}
fn validate_target_fields(
target: &serde_json::Value,
required_fields: &[&str],
kind: &str,
) -> anyhow::Result<()> {
let object = target
.as_object()
.ok_or_else(|| anyhow::anyhow!("{} target must be an object", kind))?;
for field in required_fields {
let Some(value) = object.get(*field).and_then(|value| value.as_str()) else {
anyhow::bail!("{} target.{} is required and must be a string", kind, field)
};
if value.trim().is_empty() {
anyhow::bail!("{} target.{} cannot be empty", kind, field)
}
}
Ok(())
}
fn record_to_json(record: &SchedulerJobRecord) -> serde_json::Value {
json!({
"id": record.id,
@ -255,7 +321,7 @@ mod tests {
#[tokio::test]
async fn test_scheduler_manage_put_and_get() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let tool = SchedulerManageTool::new(store);
let tool = SchedulerManageTool::new(store, HashSet::new());
let put_result = tool
.execute(json!({
@ -293,7 +359,7 @@ mod tests {
#[tokio::test]
async fn test_scheduler_manage_put_agent_task() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let tool = SchedulerManageTool::new(store);
let tool = SchedulerManageTool::new(store, HashSet::from(["planner".to_string()]));
let put_result = tool
.execute(json!({
@ -309,7 +375,8 @@ mod tests {
"chat_id": "oc_demo"
},
"payload": {
"prompt": "请总结今天待办"
"prompt": "请总结今天待办",
"agent": "planner"
}
}))
.await
@ -318,4 +385,59 @@ mod tests {
assert!(put_result.success);
assert!(put_result.output.contains("agent_task"));
}
#[tokio::test]
async fn test_scheduler_manage_rejects_outbound_message_without_target() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let tool = SchedulerManageTool::new(store, HashSet::new());
let put_result = tool
.execute(json!({
"action": "put",
"id": "massage_reminder",
"kind": "outbound_message",
"schedule": {
"type": "interval",
"seconds": 60
},
"payload": {
"content": "⏰ 时间到了!该去按摩了!💆"
}
}))
.await;
assert!(put_result.is_err());
let error = put_result.err().unwrap().to_string();
assert!(error.contains("outbound_message target.channel is required"));
}
#[tokio::test]
async fn test_scheduler_manage_rejects_unknown_agent_task_agent() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let tool = SchedulerManageTool::new(store, HashSet::from(["planner".to_string()]));
let put_result = tool
.execute(json!({
"action": "put",
"id": "agent.daily_summary",
"kind": "agent_task",
"schedule": {
"type": "cron",
"expression": "0 9 * * *"
},
"target": {
"channel": "feishu",
"chat_id": "oc_demo"
},
"payload": {
"prompt": "请总结今天待办",
"agent": "missing-agent"
}
}))
.await;
assert!(put_result.is_err());
let error = put_result.err().unwrap().to_string();
assert!(error.contains("Unknown agent 'missing-agent'"));
}
}

View File

@ -19,6 +19,7 @@ fn load_config() -> Option<LLMProviderConfig> {
base_url: openai_base_url,
api_key: openai_api_key,
extra_headers: HashMap::new(),
llm_timeout_secs: 120,
model_id: openai_model,
temperature: Some(0.0),
max_tokens: Some(100),

View File

@ -19,6 +19,7 @@ fn load_openai_config() -> Option<LLMProviderConfig> {
base_url: openai_base_url,
api_key: openai_api_key,
extra_headers: HashMap::new(),
llm_timeout_secs: 120,
model_id: openai_model,
temperature: Some(0.0),
max_tokens: Some(100),