diff --git a/README.md b/README.md index 77ba850..73b45c2 100644 --- a/README.md +++ b/README.md @@ -126,6 +126,7 @@ Config example: }, "payload": { "prompt": "请总结今天的项目进展,并列出明天的优先事项", + "agent": "default", "fresh_session": true, "system_prompt": "你是日报助手,输出时先给摘要,再给待办。", "sender_id": "scheduler-daily-summary", @@ -161,6 +162,7 @@ Runtime management: - agent_task reuses the normal agent pipeline: it creates a synthetic user turn from payload.prompt and runs tools, persistence, and outbound rendering through SessionManager. - agent_task payload fields: - prompt: required, synthetic user input. + - agent: optional, choose which configured agent definition to use. default or any configured agent name. - fresh_session: optional, when true reset the active chat segment before running. - system_prompt: optional, append a task-specific system message before the synthetic user turn. - sender_id: optional, overrides the synthetic sender id used for tool context and memory scoping. diff --git a/docs/PERSISTENCE.md b/docs/PERSISTENCE.md index c1603db..61bbc9f 100644 --- a/docs/PERSISTENCE.md +++ b/docs/PERSISTENCE.md @@ -185,6 +185,7 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据 - 内置 `internal_event` 当前包含 `session_cleanup`,用于回收超时的内存 session 缓存。 - `agent_task` 会把 `payload.prompt` 作为一次合成用户输入,交给 `SessionManager::run_scheduled_agent_task()` 执行,因此会复用持久化历史、工具调用和渠道下发链路。 - `payload.fresh_session = true` 时,会先对目标 chat 执行一次逻辑 reset,再开始本次任务运行。 +- `payload.agent` 可指定本次任务使用哪一个已配置 agent;未指定时仍使用 `default`。 - `payload.system_prompt` 会作为额外 system 消息写入本次任务上下文。 - `payload.sender_id` 会覆盖默认的 `scheduler` 发送者标识。 - `payload.metadata` 会映射到 outbound metadata,便于渠道侧做追踪或特殊处理。 diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index 248d94d..e1e779e 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -24,6 +24,7 @@ const TRUNCATION_SUFFIX_LEN: usize = 200; const MEMORY_TOOL_USAGE_SYSTEM_PROMPT: &str = "在绝大多数请求开始时,你都应先使用长期记忆检索工具 memory_search 来召回相关上下文,然后再决定如何回答或是否需要写入记忆。默认流程是:先用 memory_search(action='search');只有在你已经明确知道 namespace 和 key 时才改用 get;只有在需要浏览最近几条记忆时才用 list。即使用户没有明确提到“记忆”或“偏好”,只要请求可能与用户长期偏好、稳定事实、历史决策、持续任务或项目上下文有关,就应先搜记忆。仅以下少数情况可跳过记忆搜索:纯寒暄、一次性简单计算、完全不依赖用户历史的直接事实问答。写入或修改记忆时,再使用 memory_manage。仅在遇到高价值且未来仍有用的信息时写入记忆:用户长期偏好、稳定事实、用户对你的纠正、持续任务/项目上下文、明确决策。不要保存一次性工具结果、临时列表、敏感凭证或不确定推测。写入时优先使用规范 namespace:preferences、profile、tasks、decisions,并优先调用 memory_manage(action='put');同一 namespace/key 可直接覆盖更新。检索时应提供 queries 数组,尽量同时放入中文关键词、英文别名,以及可能的 snake_case memory_key 词,例如 queries=['email', '邮件', 'email_folder_preference']。如果你决定跳过记忆搜索,应先确认当前请求确实属于上述少数例外,而不是因为你忘了检索。"; const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__"; const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str = "工具已经启动并进入等待用户操作的状态。请先完成外部操作,完成后直接告诉我继续。"; +const RECOVERABLE_LLM_ERROR_MESSAGE: &str = "模型服务暂时不可用或响应超时。请稍后重试。"; /// Build content blocks from text and media paths fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec { @@ -99,6 +100,23 @@ fn parse_pending_tool_output(output: &str) -> Option { output.strip_prefix(PENDING_USER_ACTION_MARKER).map(|rest| rest.trim().to_string()) } +fn is_recoverable_llm_error(error: &str) -> bool { + let normalized = error.to_ascii_lowercase(); + normalized.contains("504") + || normalized.contains("gateway timeout") + || normalized.contains("stream timeout") + || normalized.contains("timed out") + || normalized.contains("timeout") +} + +fn recoverable_llm_message(error: &str) -> String { + if is_recoverable_llm_error(error) { + RECOVERABLE_LLM_ERROR_MESSAGE.to_string() + } else { + format!("模型请求失败:{}", error) + } +} + /// Loop detection result. #[derive(Debug, Clone, PartialEq, Eq)] enum LoopDetectionResult { @@ -386,11 +404,18 @@ impl AgentLoop { }; // Call LLM - let response = (*self.provider).chat(request).await - .map_err(|e| { + let response = match (*self.provider).chat(request).await { + Ok(response) => response, + Err(e) => { tracing::error!(error = %e, "LLM request failed"); - AgentError::LlmError(e.to_string()) - })?; + let assistant_message = ChatMessage::assistant(recoverable_llm_message(&e.to_string())); + emitted_messages.push(assistant_message.clone()); + return Ok(AgentProcessResult { + final_response: assistant_message, + emitted_messages, + }); + } + }; #[cfg(debug_assertions)] tracing::debug!( @@ -539,11 +564,8 @@ impl AgentLoop { }) } Err(e) => { - // Fallback if summary call fails tracing::error!(error = %e, "Failed to get summary from LLM"); - let final_message = ChatMessage::assistant( - format!("I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.", self.max_iterations) - ); + let final_message = ChatMessage::assistant(recoverable_llm_message(&e.to_string())); emitted_messages.push(final_message.clone()); Ok(AgentProcessResult { final_response: final_message, diff --git a/src/agent/context_compressor.rs b/src/agent/context_compressor.rs index d53224c..389415a 100644 --- a/src/agent/context_compressor.rs +++ b/src/agent/context_compressor.rs @@ -41,6 +41,7 @@ impl Default for ContextCompressionConfig { } /// Context compressor that reduces message history when it exceeds token limits. +#[derive(Clone)] pub struct ContextCompressor { config: ContextCompressionConfig, context_window: usize, diff --git a/src/config/mod.rs b/src/config/mod.rs index 74cbd8a..aaed652 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -101,6 +101,8 @@ pub struct ProviderConfig { pub api_key: String, #[serde(default)] pub extra_headers: HashMap, + #[serde(default = "default_llm_timeout_secs")] + pub llm_timeout_secs: u64, } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -132,6 +134,10 @@ fn default_token_limit() -> usize { 128_000 } +fn default_llm_timeout_secs() -> u64 { + 120 +} + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct GatewayConfig { #[serde(default = "default_gateway_host")] @@ -400,6 +406,7 @@ pub struct LLMProviderConfig { pub base_url: String, pub api_key: String, pub extra_headers: HashMap, + pub llm_timeout_secs: u64, pub model_id: String, pub temperature: Option, pub max_tokens: Option, @@ -461,6 +468,7 @@ impl Config { base_url: provider.base_url.clone(), api_key: provider.api_key.clone(), extra_headers: provider.extra_headers.clone(), + llm_timeout_secs: provider.llm_timeout_secs, model_id: model.model_id.clone(), temperature: model.temperature, max_tokens: model.max_tokens, @@ -601,6 +609,43 @@ mod tests { assert_eq!(provider_config.name, "aliyun"); assert_eq!(provider_config.model_id, "qwen-plus"); assert_eq!(provider_config.temperature, Some(0.0)); + assert_eq!(provider_config.llm_timeout_secs, 120); + } + + #[test] + fn test_provider_config_loads_custom_llm_timeout() { + let file = tempfile::NamedTempFile::new().unwrap(); + std::fs::write( + file.path(), + r#"{ + "providers": { + "aliyun": { + "type": "openai", + "base_url": "https://example.invalid/v1", + "api_key": "test-key", + "extra_headers": {}, + "llm_timeout_secs": 400 + } + }, + "models": { + "qwen-plus": { + "model_id": "qwen-plus" + } + }, + "agents": { + "default": { + "provider": "aliyun", + "model": "qwen-plus" + } + } +}"#, + ) + .unwrap(); + + let config = Config::load(file.path().to_str().unwrap()).unwrap(); + let provider_config = config.get_provider_config("default").unwrap(); + + assert_eq!(provider_config.llm_timeout_secs, 400); } #[test] diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 2f711ef..3e55cfc 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -2,6 +2,7 @@ pub mod http; pub mod session; pub mod ws; +use std::collections::HashMap; use std::sync::Arc; use axum::{routing, Router}; use tokio::net::TcpListener; @@ -9,6 +10,7 @@ use tokio::net::TcpListener; use crate::bus::{MessageBus, OutboundDispatcher}; use crate::channels::ChannelManager; use crate::config::Config; +use crate::config::LLMProviderConfig; use crate::logging; use crate::scheduler::Scheduler; use crate::skills::SkillRuntime; @@ -27,6 +29,10 @@ impl GatewayState { // Get provider config for SessionManager let provider_config = config.get_provider_config("default")?; + let mut provider_configs = HashMap::::new(); + for agent_name in config.agents.keys() { + provider_configs.insert(agent_name.clone(), config.get_provider_config(agent_name)?); + } // Session TTL from config (default 4 hours) let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4); @@ -40,6 +46,7 @@ impl GatewayState { agent_prompt_reinject_every, show_tool_results, provider_config, + provider_configs, skills, )?; let channel_manager = ChannelManager::new(); diff --git a/src/gateway/session.rs b/src/gateway/session.rs index ceffeda..a7f009a 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::fs; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -49,6 +49,7 @@ pub struct ScheduledAgentTaskOptions { pub fresh_session: bool, pub system_prompt: Option, pub metadata: HashMap, + pub agent: Option, } impl BusToolCallEmitter { @@ -244,6 +245,23 @@ impl Session { } } + fn latest_user_message_id(&self, chat_id: &str) -> Option<&str> { + self.get_history(chat_id) + .and_then(|history| { + history + .iter() + .rev() + .find(|message| message.role == "user") + .map(|message| message.id.as_str()) + }) + } + + fn is_latest_user_message(&self, chat_id: &str, message_id: &str) -> bool { + self.latest_user_message_id(chat_id) + .map(|current_id| current_id == message_id) + .unwrap_or(false) + } + /// 清除所有历史 pub fn clear_all_history(&mut self) -> Result<(), AgentError> { let chat_ids: Vec = self.chat_histories.keys().cloned().collect(); @@ -296,10 +314,20 @@ impl Session { chat_id: &str, sender_id: Option<&str>, message_id: Option<&str>, + ) -> Result { + self.create_agent_with_provider_config(chat_id, sender_id, message_id, self.provider_config.clone()) + } + + pub fn create_agent_with_provider_config( + &self, + chat_id: &str, + sender_id: Option<&str>, + message_id: Option<&str>, + provider_config: LLMProviderConfig, ) -> Result { let session_id = self.persistent_session_id(chat_id); AgentLoop::with_tools_and_skills( - self.provider_config.clone(), + provider_config, self.tools.clone(), self.skills.clone(), ) @@ -368,6 +396,7 @@ fn agent_prompt_path() -> Result { pub struct SessionManager { inner: Arc>, provider_config: LLMProviderConfig, + provider_configs: Arc>, tools: Arc, skills: Arc, store: Arc, @@ -381,7 +410,7 @@ struct SessionManagerInner { session_ttl: Duration, } -fn default_tools(skills: Arc, store: Arc) -> ToolRegistry { +fn default_tools(skills: Arc, store: Arc, known_agents: HashSet) -> ToolRegistry { let mut registry = ToolRegistry::new(); registry.register(CalculatorTool::new()); registry.register(FileReadTool::new()); @@ -389,7 +418,7 @@ fn default_tools(skills: Arc, store: Arc) -> ToolReg registry.register(FileEditTool::new()); registry.register(MemorySearchTool::new(store.clone())); registry.register(MemoryManageTool::new(store.clone())); - registry.register(SchedulerManageTool::new(store)); + registry.register(SchedulerManageTool::new(store, known_agents)); registry.register(SkillListTool::new(skills.clone())); registry.register(SkillManageTool::new(skills)); registry.register(BashTool::new()); @@ -435,12 +464,14 @@ impl SessionManager { agent_prompt_reinject_every: u64, show_tool_results: bool, provider_config: LLMProviderConfig, + provider_configs: HashMap, skills: Arc, ) -> Result { let store = Arc::new( SessionStore::new() .map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?, ); + let known_agents = provider_configs.keys().cloned().collect::>(); if let Err(err) = store.append_skill_event(None, "discovered", None, &skills.discovery_event_payload()) { tracing::warn!(error = %err, "Failed to record skill discovery event"); @@ -453,7 +484,8 @@ impl SessionManager { session_ttl: Duration::from_secs(session_ttl_hours * 3600), })), provider_config, - tools: Arc::new(default_tools(skills.clone(), store.clone())), + provider_configs: Arc::new(provider_configs), + tools: Arc::new(default_tools(skills.clone(), store.clone(), known_agents)), skills, store, agent_prompt_reinject_every, @@ -477,6 +509,10 @@ impl SessionManager { self.skills.clone() } + pub fn provider_config_for_agent(&self, agent_name: Option<&str>) -> Result { + select_provider_config(&self.provider_config, &self.provider_configs, agent_name) + } + pub fn create_cli_session(&self, title: Option<&str>) -> Result { self.store .create_cli_session(title) @@ -640,7 +676,7 @@ impl SessionManager { .ok_or_else(|| AgentError::Other("Session not found".to_string()))?; // 处理消息 - let response = { + let (history, compressor, provider_config, agent, user_message_id) = { let mut session_guard = session.lock().await; session_guard.ensure_persistent_session(chat_id)?; @@ -669,11 +705,8 @@ impl SessionManager { session_guard.append_persisted_message(chat_id, user_message)?; let history = session_guard.get_or_create_history(chat_id).clone(); - - // 压缩历史(如果需要) - let history = session_guard.compressor - .compress_if_needed(history, &session_guard.provider_config) - .await?; + let compressor = session_guard.compressor().clone(); + let provider_config = session_guard.provider_config().clone(); session_guard.record_skill_offer(chat_id)?; @@ -682,28 +715,48 @@ impl SessionManager { if let Some(handler) = live_emitter.clone() { agent = agent.with_emitted_message_handler(handler); } - let result = agent.process(history).await?; - // 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复 - session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?; + (history, compressor, provider_config, agent, user_message_id) + }; - result - .emitted_messages - .iter() - .filter(|message| { - (!message.is_assistant_tool_call_message() || live_emitter.is_none()) - && should_display_message_to_user(self.show_tool_results, message) - }) - .flat_map(|message| { - OutboundMessage::from_chat_message( - channel_name, - chat_id, - None, - &HashMap::new(), - message, - ) - }) - .collect::>() + let history = compressor + .compress_if_needed(history, &provider_config) + .await?; + let result = agent.process(history).await?; + + let response = { + let mut session_guard = session.lock().await; + + if !session_guard.is_latest_user_message(chat_id, &user_message_id) { + tracing::warn!( + channel = %channel_name, + chat_id = %chat_id, + user_message_id = %user_message_id, + "Skipping stale agent result because a newer user message is already present" + ); + Vec::new() + } else { + // 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复 + session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?; + + result + .emitted_messages + .iter() + .filter(|message| { + (!message.is_assistant_tool_call_message() || live_emitter.is_none()) + && should_display_message_to_user(self.show_tool_results, message) + }) + .flat_map(|message| { + OutboundMessage::from_chat_message( + channel_name, + chat_id, + None, + &HashMap::new(), + message, + ) + }) + .collect::>() + } }; #[cfg(debug_assertions)] @@ -736,8 +789,9 @@ impl SessionManager { .sender_id .clone() .unwrap_or_else(|| "scheduler".to_string()); + let provider_config = self.provider_config_for_agent(options.agent.as_deref())?; - let response = { + let (history, compressor, agent, user_message_id) = { let mut session_guard = session.lock().await; session_guard.ensure_persistent_session(chat_id)?; @@ -758,31 +812,54 @@ impl SessionManager { session_guard.append_persisted_message(chat_id, user_message)?; let history = session_guard.get_or_create_history(chat_id).clone(); - let history = session_guard - .compressor - .compress_if_needed(history, &session_guard.provider_config) - .await?; + let compressor = session_guard.compressor().clone(); session_guard.record_skill_offer(chat_id)?; - let agent = session_guard.create_agent(chat_id, Some(&sender_id), Some(&user_message_id))?; - let result = agent.process(history).await?; - session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?; + let agent = session_guard.create_agent_with_provider_config( + chat_id, + Some(&sender_id), + Some(&user_message_id), + provider_config.clone(), + )?; - result - .emitted_messages - .iter() - .filter(|message| should_display_message_to_user(self.show_tool_results, message)) - .flat_map(|message| { - OutboundMessage::from_chat_message( - channel_name, - chat_id, - None, - &options.metadata, - message, - ) - }) - .collect::>() + (history, compressor, agent, user_message_id) + }; + + let history = compressor + .compress_if_needed(history, &provider_config) + .await?; + let result = agent.process(history).await?; + + let response = { + let mut session_guard = session.lock().await; + + if !session_guard.is_latest_user_message(chat_id, &user_message_id) { + tracing::warn!( + channel = %channel_name, + chat_id = %chat_id, + user_message_id = %user_message_id, + "Skipping stale scheduled agent result because a newer user message is already present" + ); + Vec::new() + } else { + session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?; + + result + .emitted_messages + .iter() + .filter(|message| should_display_message_to_user(self.show_tool_results, message)) + .flat_map(|message| { + OutboundMessage::from_chat_message( + channel_name, + chat_id, + None, + &options.metadata, + message, + ) + }) + .collect::>() + } }; Ok(response) @@ -810,11 +887,29 @@ fn should_display_message_to_user(show_tool_results: bool, message: &ChatMessage ) } +fn select_provider_config( + default_provider_config: &LLMProviderConfig, + provider_configs: &HashMap, + agent_name: Option<&str>, +) -> Result { + match agent_name.map(str::trim).filter(|value| !value.is_empty()) { + None | Some("default") => Ok(default_provider_config.clone()), + Some(agent_name) => provider_configs + .get(agent_name) + .cloned() + .ok_or_else(|| AgentError::Other(format!("Scheduled agent '{}' not found", agent_name))), + } +} + #[cfg(test)] mod tests { + use axum::http::StatusCode; use super::*; + use axum::{Json, Router, routing::post}; use crate::bus::MessageBus; + use serde_json::{Value, json}; use std::collections::HashMap; + use tokio::net::TcpListener; use tokio::sync::mpsc; fn test_provider_config() -> LLMProviderConfig { @@ -824,6 +919,7 @@ mod tests { base_url: "http://localhost".to_string(), api_key: "test-key".to_string(), extra_headers: HashMap::new(), + llm_timeout_secs: 120, model_id: "test-model".to_string(), temperature: Some(0.0), max_tokens: Some(32), @@ -833,6 +929,235 @@ mod tests { } } + fn test_provider_config_named(name: &str, model_id: &str) -> LLMProviderConfig { + LLMProviderConfig { + provider_type: "openai".to_string(), + name: name.to_string(), + base_url: "http://localhost".to_string(), + api_key: "test-key".to_string(), + extra_headers: HashMap::new(), + llm_timeout_secs: 120, + model_id: model_id.to_string(), + temperature: Some(0.0), + max_tokens: Some(32), + model_extra: HashMap::new(), + max_tool_iterations: 1, + token_limit: 4096, + } + } + + #[test] + fn test_select_provider_config_uses_named_agent_override() { + let default_provider = test_provider_config_named("default-provider", "default-model"); + let provider_configs = HashMap::from([( + "planner".to_string(), + test_provider_config_named("planner-provider", "planner-model"), + )]); + + let selected = select_provider_config(&default_provider, &provider_configs, Some("planner")).unwrap(); + assert_eq!(selected.name, "planner-provider"); + assert_eq!(selected.model_id, "planner-model"); + } + + #[tokio::test] + async fn test_latest_user_message_guard_tracks_current_turn() { + let store = Arc::new(SessionStore::in_memory().unwrap()); + let (user_tx, _user_rx) = mpsc::channel(4); + let skills = Arc::new(SkillRuntime::default()); + let tools = Arc::new(default_tools(skills.clone(), store.clone(), HashSet::new())); + let mut session = Session::new( + "feishu".to_string(), + test_provider_config(), + user_tx, + tools, + skills, + store, + 100, + ) + .await + .unwrap(); + + session.ensure_persistent_session("chat-1").unwrap(); + session.ensure_chat_loaded("chat-1").unwrap(); + + let first = session.create_user_message("first", Vec::new()); + let first_id = first.id.clone(); + session.append_persisted_message("chat-1", first).unwrap(); + assert!(session.is_latest_user_message("chat-1", &first_id)); + + let second = session.create_user_message("second", Vec::new()); + let second_id = second.id.clone(); + session.append_persisted_message("chat-1", second).unwrap(); + + assert!(!session.is_latest_user_message("chat-1", &first_id)); + assert!(session.is_latest_user_message("chat-1", &second_id)); + } + + #[test] + fn test_select_provider_config_falls_back_to_default() { + let default_provider = test_provider_config_named("default-provider", "default-model"); + let provider_configs = HashMap::new(); + + let selected = select_provider_config(&default_provider, &provider_configs, Some("default")).unwrap(); + assert_eq!(selected.name, "default-provider"); + assert_eq!(selected.model_id, "default-model"); + } + + async fn start_mock_openai_server() -> String { + async fn handle(Json(body): Json) -> Json { + let model = body + .get("model") + .and_then(|value| value.as_str()) + .unwrap_or("unknown-model"); + + Json(json!({ + "id": "mock-response", + "model": model, + "choices": [ + { + "message": { + "content": format!("reply from {}", model), + "tool_calls": [] + } + } + ], + "usage": { + "prompt_tokens": 1, + "completion_tokens": 1, + "total_tokens": 2 + } + })) + } + + let app = Router::new().route("/chat/completions", post(handle)); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + format!("http://{}", address) + } + + async fn start_mock_openai_504_server() -> String { + async fn handle() -> (StatusCode, &'static str) { + (StatusCode::GATEWAY_TIMEOUT, "stream timeout") + } + + let app = Router::new().route("/chat/completions", post(handle)); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + format!("http://{}", address) + } + + #[tokio::test] + async fn test_handle_message_returns_recoverable_reply_on_llm_504() { + let base_url = start_mock_openai_504_server().await; + let provider_config = LLMProviderConfig { + provider_type: "openai".to_string(), + name: "timeout-provider".to_string(), + base_url: base_url.clone(), + api_key: "test-key".to_string(), + extra_headers: HashMap::new(), + model_id: "timeout-model".to_string(), + temperature: Some(0.0), + max_tokens: Some(32), + model_extra: HashMap::new(), + max_tool_iterations: 1, + token_limit: 4096, + llm_timeout_secs: 30, + }; + + let session_manager = SessionManager::new( + 4, + 100, + false, + provider_config.clone(), + HashMap::from([("default".to_string(), provider_config)]), + Arc::new(SkillRuntime::default()), + ) + .unwrap(); + + let outbound = session_manager + .handle_message("feishu", "user-1", "chat-1", "hello", Vec::new(), None) + .await + .unwrap(); + + assert_eq!(outbound.len(), 1); + assert!(outbound[0].content.contains("模型服务暂时不可用或响应超时")); + } + + #[tokio::test] + async fn test_run_scheduled_agent_task_uses_task_specific_agent_provider() { + let base_url = start_mock_openai_server().await; + let default_provider = LLMProviderConfig { + provider_type: "openai".to_string(), + name: "default-provider".to_string(), + base_url: base_url.clone(), + api_key: "test-key".to_string(), + extra_headers: HashMap::new(), + model_id: "default-model".to_string(), + temperature: Some(0.0), + max_tokens: Some(32), + model_extra: HashMap::new(), + max_tool_iterations: 1, + token_limit: 4096, + llm_timeout_secs: 30, + }; + let planner_provider = LLMProviderConfig { + model_id: "planner-model".to_string(), + name: "planner-provider".to_string(), + ..default_provider.clone() + }; + + let session_manager = SessionManager::new( + 4, + 100, + false, + default_provider.clone(), + HashMap::from([ + ("default".to_string(), default_provider), + ("planner".to_string(), planner_provider), + ]), + Arc::new(SkillRuntime::default()), + ) + .unwrap(); + + let planner_outbound = session_manager + .run_scheduled_agent_task( + "feishu", + "chat-planner", + "请规划今天工作", + ScheduledAgentTaskOptions { + agent: Some("planner".to_string()), + fresh_session: true, + ..Default::default() + }, + ) + .await + .unwrap(); + assert_eq!(planner_outbound.len(), 1); + assert!(planner_outbound[0].content.contains("planner-model")); + + let default_outbound = session_manager + .run_scheduled_agent_task( + "feishu", + "chat-default", + "请规划今天工作", + ScheduledAgentTaskOptions { + agent: Some("default".to_string()), + fresh_session: true, + ..Default::default() + }, + ) + .await + .unwrap(); + assert_eq!(default_outbound.len(), 1); + assert!(default_outbound[0].content.contains("default-model")); + } + #[test] fn test_should_display_message_to_user_hides_completed_tool_results_by_default() { let completed = ChatMessage::tool("call-1", "calculator", "2"); @@ -881,7 +1206,7 @@ mod tests { let store = Arc::new(SessionStore::in_memory().unwrap()); let (user_tx, _user_rx) = mpsc::channel(4); let skills = Arc::new(SkillRuntime::default()); - let tools = Arc::new(default_tools(skills.clone(), store.clone())); + let tools = Arc::new(default_tools(skills.clone(), store.clone(), HashSet::new())); let mut session = Session::new( "feishu".to_string(), test_provider_config(), @@ -929,7 +1254,7 @@ mod tests { let store = Arc::new(SessionStore::in_memory().unwrap()); let (user_tx, _user_rx) = mpsc::channel(4); let skills = Arc::new(SkillRuntime::default()); - let tools = Arc::new(default_tools(skills.clone(), store.clone())); + let tools = Arc::new(default_tools(skills.clone(), store.clone(), HashSet::new())); let mut session = Session::new( "feishu".to_string(), test_provider_config(), @@ -956,7 +1281,7 @@ mod tests { let store = Arc::new(SessionStore::in_memory().unwrap()); let (user_tx, _user_rx) = mpsc::channel(4); let skills = Arc::new(SkillRuntime::default()); - let tools = Arc::new(default_tools(skills.clone(), store.clone())); + let tools = Arc::new(default_tools(skills.clone(), store.clone(), HashSet::new())); let mut session = Session::new( "feishu".to_string(), test_provider_config(), @@ -1001,7 +1326,7 @@ mod tests { let store = Arc::new(SessionStore::in_memory().unwrap()); let (user_tx, _user_rx) = mpsc::channel(4); let skills = Arc::new(SkillRuntime::default()); - let tools = Arc::new(default_tools(skills.clone(), store.clone())); + let tools = Arc::new(default_tools(skills.clone(), store.clone(), HashSet::new())); let mut session = Session::new( "feishu".to_string(), test_provider_config(), @@ -1035,7 +1360,7 @@ mod tests { let store = Arc::new(SessionStore::in_memory().unwrap()); let (user_tx, _user_rx) = mpsc::channel(4); let skills = Arc::new(SkillRuntime::default()); - let tools = Arc::new(default_tools(skills.clone(), store.clone())); + let tools = Arc::new(default_tools(skills.clone(), store.clone(), HashSet::new())); let mut session = Session::new( "feishu".to_string(), test_provider_config(), diff --git a/src/providers/anthropic.rs b/src/providers/anthropic.rs index 8b1cefa..bcbd50a 100644 --- a/src/providers/anthropic.rs +++ b/src/providers/anthropic.rs @@ -2,6 +2,7 @@ use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; use std::collections::HashMap; +use std::time::Duration; use crate::bus::message::ContentBlock; use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall}; @@ -58,6 +59,7 @@ pub struct AnthropicProvider { api_key: String, base_url: String, extra_headers: HashMap, + llm_timeout_secs: u64, model_id: String, temperature: Option, max_tokens: Option, @@ -70,17 +72,24 @@ impl AnthropicProvider { api_key: String, base_url: String, extra_headers: HashMap, + llm_timeout_secs: u64, model_id: String, temperature: Option, max_tokens: Option, model_extra: HashMap, ) -> Self { + let client = Client::builder() + .timeout(Duration::from_secs(llm_timeout_secs)) + .build() + .unwrap_or_else(|_| Client::new()); + Self { - client: Client::new(), + client, name, api_key, base_url, extra_headers, + llm_timeout_secs, model_id, temperature, max_tokens, @@ -190,8 +199,21 @@ impl LLMProvider for AnthropicProvider { } let resp = req_builder.json(&body).send().await?; + let status = resp.status(); + let text = resp.text().await?; - let anthropic_resp: AnthropicResponse = resp.json().await?; + if !status.is_success() { + return Err(format!("API error {}: {}", status, text).into()); + } + + #[cfg(debug_assertions)] + { + let resp_preview: String = text.chars().take(100).collect(); + tracing::debug!(status = %status, response_preview = %resp_preview, response_len = %text.len(), timeout_secs = self.llm_timeout_secs, "Anthropic response (first 100 chars shown)"); + } + + let anthropic_resp: AnthropicResponse = serde_json::from_str(&text) + .map_err(|e| format!("decode error: {} | body: {}", e, &text))?; let mut content = String::new(); let mut tool_calls = Vec::new(); diff --git a/src/providers/mod.rs b/src/providers/mod.rs index eedab44..e4eae27 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -15,6 +15,7 @@ pub fn create_provider(config: LLMProviderConfig) -> Result config.api_key, config.base_url, config.extra_headers, + config.llm_timeout_secs, config.model_id, config.temperature, config.max_tokens, @@ -25,6 +26,7 @@ pub fn create_provider(config: LLMProviderConfig) -> Result config.api_key, config.base_url, config.extra_headers, + config.llm_timeout_secs, config.model_id, config.temperature, config.max_tokens, diff --git a/src/providers/openai.rs b/src/providers/openai.rs index b55ed8e..e13bd22 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -3,6 +3,7 @@ use reqwest::Client; use serde::Deserialize; use serde_json::{json, Value}; use std::collections::HashMap; +use std::time::Duration; use crate::bus::message::ContentBlock; use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall}; @@ -28,6 +29,7 @@ pub struct OpenAIProvider { api_key: String, base_url: String, extra_headers: HashMap, + llm_timeout_secs: u64, model_id: String, temperature: Option, max_tokens: Option, @@ -40,17 +42,24 @@ impl OpenAIProvider { api_key: String, base_url: String, extra_headers: HashMap, + llm_timeout_secs: u64, model_id: String, temperature: Option, max_tokens: Option, model_extra: HashMap, ) -> Self { + let client = Client::builder() + .timeout(Duration::from_secs(llm_timeout_secs)) + .build() + .unwrap_or_else(|_| Client::new()); + Self { - client: Client::new(), + client, name, api_key, base_url, extra_headers, + llm_timeout_secs, model_id, temperature, max_tokens, @@ -209,7 +218,7 @@ impl LLMProvider for OpenAIProvider { #[cfg(debug_assertions)] { let resp_preview: String = text.chars().take(100).collect(); - tracing::debug!(status = %status, response_preview = %resp_preview, response_len = %text.len(), "LLM response (first 100 chars shown)"); + tracing::debug!(status = %status, response_preview = %resp_preview, response_len = %text.len(), timeout_secs = self.llm_timeout_secs, "LLM response (first 100 chars shown)"); } if !status.is_success() { @@ -275,6 +284,7 @@ mod tests { "key".to_string(), "https://example.com/v1".to_string(), HashMap::new(), + 120, "gpt-test".to_string(), None, None, diff --git a/src/scheduler/mod.rs b/src/scheduler/mod.rs index eb8e5aa..5292751 100644 --- a/src/scheduler/mod.rs +++ b/src/scheduler/mod.rs @@ -88,6 +88,20 @@ impl Scheduler { continue; }; + if record.next_fire_at.is_none() && job.next_fire_at.is_some() { + self.store.update_scheduler_job_runtime( + &job.id, + job.state.clone(), + job.last_status.clone(), + job.last_error.as_deref(), + job.run_count, + job.last_fired_at, + job.next_fire_at, + job.paused_at, + job.completed_at, + )?; + } + if !job.is_due(now) { continue; } @@ -517,6 +531,11 @@ fn parse_scheduled_agent_task_options(job: &RuntimeJob) -> anyhow::Result anyhow::Result for SchedulerJobTarget { #[cfg(test)] mod tests { use super::*; + use std::collections::HashMap; + use crate::bus::MessageBus; + use crate::config::LLMProviderConfig; + use crate::gateway::session::SessionManager; + use crate::skills::SkillRuntime; + use crate::storage::{SchedulerJobUpsert, SessionStore}; #[test] fn runtime_job_skip_policy_advances_from_now() { @@ -739,4 +767,81 @@ mod tests { }); assert_eq!(job.next_fire_at, Some(1_700_000_010_000)); } + + #[tokio::test] + async fn process_tick_persists_initial_next_fire_at_for_db_created_jobs() { + let store = Arc::new(SessionStore::in_memory().unwrap()); + store + .upsert_scheduler_job(&SchedulerJobUpsert { + id: "massage_reminder".to_string(), + kind: "outbound_message".to_string(), + schedule: serde_json::json!({ + "type": "interval", + "seconds": 60 + }), + interval_secs: 60, + startup_delay_secs: 0, + target: serde_json::json!({ + "channel": "feishu", + "chat_id": "oc_demo" + }), + payload: serde_json::json!({ + "content": "ping" + }), + enabled: true, + state: SchedulerJobState::Scheduled, + last_status: None, + last_error: None, + run_count: 0, + max_runs: Some(1), + last_fired_at: None, + next_fire_at: None, + paused_at: None, + completed_at: None, + }) + .unwrap(); + + let provider_config = LLMProviderConfig { + provider_type: "openai".to_string(), + name: "default".to_string(), + base_url: "http://localhost".to_string(), + api_key: "test-key".to_string(), + extra_headers: HashMap::new(), + llm_timeout_secs: 30, + model_id: "test-model".to_string(), + temperature: Some(0.0), + max_tokens: None, + model_extra: HashMap::new(), + token_limit: 4096, + max_tool_iterations: 4, + }; + let session_manager = SessionManager::new( + 4, + 100, + false, + provider_config.clone(), + HashMap::from([("default".to_string(), provider_config)]), + Arc::new(SkillRuntime::default()), + ) + .unwrap(); + let scheduler = Scheduler::new( + MessageBus::new(8), + SchedulerConfig { + enabled: true, + tick_resolution_ms: 1000, + worker_queue_capacity: 64, + misfire_policy: SchedulerMisfirePolicy::Skip, + jobs: Vec::new(), + }, + store.clone(), + session_manager, + ); + + scheduler.process_tick().await.unwrap(); + + let saved = store.get_scheduler_job("massage_reminder").unwrap().unwrap(); + assert!(saved.next_fire_at.is_some()); + assert_eq!(saved.run_count, 0); + assert_eq!(saved.state, SchedulerJobState::Scheduled); + } } \ No newline at end of file diff --git a/src/tools/scheduler_manage.rs b/src/tools/scheduler_manage.rs index 144fd5d..1e434a3 100644 --- a/src/tools/scheduler_manage.rs +++ b/src/tools/scheduler_manage.rs @@ -1,3 +1,4 @@ +use std::collections::HashSet; use std::sync::Arc; use async_trait::async_trait; @@ -11,11 +12,15 @@ use crate::tools::traits::{Tool, ToolResult}; pub struct SchedulerManageTool { store: Arc, + known_agents: Arc>, } impl SchedulerManageTool { - pub fn new(store: Arc) -> Self { - Self { store } + pub fn new(store: Arc, known_agents: HashSet) -> Self { + Self { + store, + known_agents: Arc::new(known_agents), + } } } @@ -58,7 +63,7 @@ impl Tool for SchedulerManageTool { }, "payload": { "type": "object", - "description": "Job payload. agent_task supports prompt, fresh_session, system_prompt, sender_id, metadata. outbound_message expects content. internal_event expects event." + "description": "Job payload. agent_task supports prompt, agent, fresh_session, system_prompt, sender_id, metadata. outbound_message expects content. internal_event expects event." }, "max_runs": { "type": ["integer", "null"] @@ -91,7 +96,7 @@ impl Tool for SchedulerManageTool { } } "put" => { - let input = build_upsert(&args)?; + let input = build_upsert(&args, &self.known_agents)?; let record = self.store.upsert_scheduler_job(&input)?; record_to_json(&record) } @@ -140,7 +145,7 @@ impl Tool for SchedulerManageTool { } } -fn build_upsert(args: &serde_json::Value) -> anyhow::Result { +fn build_upsert(args: &serde_json::Value, known_agents: &HashSet) -> anyhow::Result { let id = require_str(args, "id")?.to_string(); let kind = require_str(args, "kind")?.to_string(); let schedule_value = args @@ -158,14 +163,24 @@ fn build_upsert(args: &serde_json::Value) -> anyhow::Result _ => (0, 0), }; + let payload = args.get("payload").cloned().unwrap_or_else(|| json!({})); + let target = args.get("target").cloned().unwrap_or_else(|| json!({})); + if kind == "agent_task" { + validate_agent_task_payload(&payload, known_agents)?; + validate_target_fields(&target, &["channel", "chat_id"], "agent_task")?; + } else if kind == "outbound_message" { + validate_outbound_message_payload(&payload)?; + validate_target_fields(&target, &["channel", "chat_id"], "outbound_message")?; + } + Ok(SchedulerJobUpsert { id, kind, schedule: serde_json::to_value(schedule)?, interval_secs, startup_delay_secs, - target: args.get("target").cloned().unwrap_or_else(|| json!({})), - payload: args.get("payload").cloned().unwrap_or_else(|| json!({})), + target, + payload, enabled: args.get("enabled").and_then(|value| value.as_bool()).unwrap_or(true), state: if args.get("enabled").and_then(|value| value.as_bool()).unwrap_or(true) { SchedulerJobState::Scheduled @@ -183,6 +198,57 @@ fn build_upsert(args: &serde_json::Value) -> anyhow::Result }) } +fn validate_agent_task_payload(payload: &serde_json::Value, known_agents: &HashSet) -> anyhow::Result<()> { + let Some(prompt) = payload.get("prompt").and_then(|value| value.as_str()) else { + anyhow::bail!("agent_task payload.prompt is required and must be a string") + }; + if prompt.trim().is_empty() { + anyhow::bail!("agent_task payload.prompt cannot be empty") + } + + let Some(agent_name) = payload.get("agent").and_then(|value| value.as_str()) else { + return Ok(()); + }; + + let normalized = agent_name.trim(); + if normalized.is_empty() || normalized == "default" || known_agents.contains(normalized) { + return Ok(()); + } + + anyhow::bail!("Unknown agent '{}' for agent_task payload.agent", normalized) +} + +fn validate_outbound_message_payload(payload: &serde_json::Value) -> anyhow::Result<()> { + let Some(content) = payload.get("content").and_then(|value| value.as_str()) else { + anyhow::bail!("outbound_message payload.content is required and must be a string") + }; + if content.trim().is_empty() { + anyhow::bail!("outbound_message payload.content cannot be empty") + } + Ok(()) +} + +fn validate_target_fields( + target: &serde_json::Value, + required_fields: &[&str], + kind: &str, +) -> anyhow::Result<()> { + let object = target + .as_object() + .ok_or_else(|| anyhow::anyhow!("{} target must be an object", kind))?; + + for field in required_fields { + let Some(value) = object.get(*field).and_then(|value| value.as_str()) else { + anyhow::bail!("{} target.{} is required and must be a string", kind, field) + }; + if value.trim().is_empty() { + anyhow::bail!("{} target.{} cannot be empty", kind, field) + } + } + + Ok(()) +} + fn record_to_json(record: &SchedulerJobRecord) -> serde_json::Value { json!({ "id": record.id, @@ -255,7 +321,7 @@ mod tests { #[tokio::test] async fn test_scheduler_manage_put_and_get() { let store = Arc::new(SessionStore::in_memory().unwrap()); - let tool = SchedulerManageTool::new(store); + let tool = SchedulerManageTool::new(store, HashSet::new()); let put_result = tool .execute(json!({ @@ -293,7 +359,7 @@ mod tests { #[tokio::test] async fn test_scheduler_manage_put_agent_task() { let store = Arc::new(SessionStore::in_memory().unwrap()); - let tool = SchedulerManageTool::new(store); + let tool = SchedulerManageTool::new(store, HashSet::from(["planner".to_string()])); let put_result = tool .execute(json!({ @@ -309,7 +375,8 @@ mod tests { "chat_id": "oc_demo" }, "payload": { - "prompt": "请总结今天待办" + "prompt": "请总结今天待办", + "agent": "planner" } })) .await @@ -318,4 +385,59 @@ mod tests { assert!(put_result.success); assert!(put_result.output.contains("agent_task")); } + + #[tokio::test] + async fn test_scheduler_manage_rejects_outbound_message_without_target() { + let store = Arc::new(SessionStore::in_memory().unwrap()); + let tool = SchedulerManageTool::new(store, HashSet::new()); + + let put_result = tool + .execute(json!({ + "action": "put", + "id": "massage_reminder", + "kind": "outbound_message", + "schedule": { + "type": "interval", + "seconds": 60 + }, + "payload": { + "content": "⏰ 时间到了!该去按摩了!💆" + } + })) + .await; + + assert!(put_result.is_err()); + let error = put_result.err().unwrap().to_string(); + assert!(error.contains("outbound_message target.channel is required")); + } + + #[tokio::test] + async fn test_scheduler_manage_rejects_unknown_agent_task_agent() { + let store = Arc::new(SessionStore::in_memory().unwrap()); + let tool = SchedulerManageTool::new(store, HashSet::from(["planner".to_string()])); + + let put_result = tool + .execute(json!({ + "action": "put", + "id": "agent.daily_summary", + "kind": "agent_task", + "schedule": { + "type": "cron", + "expression": "0 9 * * *" + }, + "target": { + "channel": "feishu", + "chat_id": "oc_demo" + }, + "payload": { + "prompt": "请总结今天待办", + "agent": "missing-agent" + } + })) + .await; + + assert!(put_result.is_err()); + let error = put_result.err().unwrap().to_string(); + assert!(error.contains("Unknown agent 'missing-agent'")); + } } \ No newline at end of file diff --git a/tests/test_integration.rs b/tests/test_integration.rs index 09f705e..d4203a4 100644 --- a/tests/test_integration.rs +++ b/tests/test_integration.rs @@ -19,6 +19,7 @@ fn load_config() -> Option { base_url: openai_base_url, api_key: openai_api_key, extra_headers: HashMap::new(), + llm_timeout_secs: 120, model_id: openai_model, temperature: Some(0.0), max_tokens: Some(100), diff --git a/tests/test_tool_calling.rs b/tests/test_tool_calling.rs index 39ead1d..7eaff89 100644 --- a/tests/test_tool_calling.rs +++ b/tests/test_tool_calling.rs @@ -19,6 +19,7 @@ fn load_openai_config() -> Option { base_url: openai_base_url, api_key: openai_api_key, extra_headers: HashMap::new(), + llm_timeout_secs: 120, model_id: openai_model, temperature: Some(0.0), max_tokens: Some(100),