feat: add llm_timeout_secs to provider configuration and implement timeout handling
- Introduced llm_timeout_secs in ProviderConfig and LLMProviderConfig to specify timeout for LLM requests. - Updated OpenAIProvider and AnthropicProvider to utilize the timeout setting when creating HTTP clients. - Enhanced error handling for API responses to include timeout information. - Modified SessionManager to support agent-specific provider configurations, allowing for more flexible agent management. - Added tests to verify the correct behavior of timeout settings and agent task validation.
This commit is contained in:
parent
1ffdcab585
commit
f3f369b329
@ -126,6 +126,7 @@ Config example:
|
|||||||
},
|
},
|
||||||
"payload": {
|
"payload": {
|
||||||
"prompt": "请总结今天的项目进展,并列出明天的优先事项",
|
"prompt": "请总结今天的项目进展,并列出明天的优先事项",
|
||||||
|
"agent": "default",
|
||||||
"fresh_session": true,
|
"fresh_session": true,
|
||||||
"system_prompt": "你是日报助手,输出时先给摘要,再给待办。",
|
"system_prompt": "你是日报助手,输出时先给摘要,再给待办。",
|
||||||
"sender_id": "scheduler-daily-summary",
|
"sender_id": "scheduler-daily-summary",
|
||||||
@ -161,6 +162,7 @@ Runtime management:
|
|||||||
- agent_task reuses the normal agent pipeline: it creates a synthetic user turn from payload.prompt and runs tools, persistence, and outbound rendering through SessionManager.
|
- agent_task reuses the normal agent pipeline: it creates a synthetic user turn from payload.prompt and runs tools, persistence, and outbound rendering through SessionManager.
|
||||||
- agent_task payload fields:
|
- agent_task payload fields:
|
||||||
- prompt: required, synthetic user input.
|
- prompt: required, synthetic user input.
|
||||||
|
- agent: optional, choose which configured agent definition to use. default or any configured agent name.
|
||||||
- fresh_session: optional, when true reset the active chat segment before running.
|
- fresh_session: optional, when true reset the active chat segment before running.
|
||||||
- system_prompt: optional, append a task-specific system message before the synthetic user turn.
|
- system_prompt: optional, append a task-specific system message before the synthetic user turn.
|
||||||
- sender_id: optional, overrides the synthetic sender id used for tool context and memory scoping.
|
- sender_id: optional, overrides the synthetic sender id used for tool context and memory scoping.
|
||||||
|
|||||||
@ -185,6 +185,7 @@ PicoBot 使用 SQLite 持久化会话和消息历史,当前只有一份数据
|
|||||||
- 内置 `internal_event` 当前包含 `session_cleanup`,用于回收超时的内存 session 缓存。
|
- 内置 `internal_event` 当前包含 `session_cleanup`,用于回收超时的内存 session 缓存。
|
||||||
- `agent_task` 会把 `payload.prompt` 作为一次合成用户输入,交给 `SessionManager::run_scheduled_agent_task()` 执行,因此会复用持久化历史、工具调用和渠道下发链路。
|
- `agent_task` 会把 `payload.prompt` 作为一次合成用户输入,交给 `SessionManager::run_scheduled_agent_task()` 执行,因此会复用持久化历史、工具调用和渠道下发链路。
|
||||||
- `payload.fresh_session = true` 时,会先对目标 chat 执行一次逻辑 reset,再开始本次任务运行。
|
- `payload.fresh_session = true` 时,会先对目标 chat 执行一次逻辑 reset,再开始本次任务运行。
|
||||||
|
- `payload.agent` 可指定本次任务使用哪一个已配置 agent;未指定时仍使用 `default`。
|
||||||
- `payload.system_prompt` 会作为额外 system 消息写入本次任务上下文。
|
- `payload.system_prompt` 会作为额外 system 消息写入本次任务上下文。
|
||||||
- `payload.sender_id` 会覆盖默认的 `scheduler` 发送者标识。
|
- `payload.sender_id` 会覆盖默认的 `scheduler` 发送者标识。
|
||||||
- `payload.metadata` 会映射到 outbound metadata,便于渠道侧做追踪或特殊处理。
|
- `payload.metadata` 会映射到 outbound metadata,便于渠道侧做追踪或特殊处理。
|
||||||
|
|||||||
@ -24,6 +24,7 @@ const TRUNCATION_SUFFIX_LEN: usize = 200;
|
|||||||
const MEMORY_TOOL_USAGE_SYSTEM_PROMPT: &str = "在绝大多数请求开始时,你都应先使用长期记忆检索工具 memory_search 来召回相关上下文,然后再决定如何回答或是否需要写入记忆。默认流程是:先用 memory_search(action='search');只有在你已经明确知道 namespace 和 key 时才改用 get;只有在需要浏览最近几条记忆时才用 list。即使用户没有明确提到“记忆”或“偏好”,只要请求可能与用户长期偏好、稳定事实、历史决策、持续任务或项目上下文有关,就应先搜记忆。仅以下少数情况可跳过记忆搜索:纯寒暄、一次性简单计算、完全不依赖用户历史的直接事实问答。写入或修改记忆时,再使用 memory_manage。仅在遇到高价值且未来仍有用的信息时写入记忆:用户长期偏好、稳定事实、用户对你的纠正、持续任务/项目上下文、明确决策。不要保存一次性工具结果、临时列表、敏感凭证或不确定推测。写入时优先使用规范 namespace:preferences、profile、tasks、decisions,并优先调用 memory_manage(action='put');同一 namespace/key 可直接覆盖更新。检索时应提供 queries 数组,尽量同时放入中文关键词、英文别名,以及可能的 snake_case memory_key 词,例如 queries=['email', '邮件', 'email_folder_preference']。如果你决定跳过记忆搜索,应先确认当前请求确实属于上述少数例外,而不是因为你忘了检索。";
|
const MEMORY_TOOL_USAGE_SYSTEM_PROMPT: &str = "在绝大多数请求开始时,你都应先使用长期记忆检索工具 memory_search 来召回相关上下文,然后再决定如何回答或是否需要写入记忆。默认流程是:先用 memory_search(action='search');只有在你已经明确知道 namespace 和 key 时才改用 get;只有在需要浏览最近几条记忆时才用 list。即使用户没有明确提到“记忆”或“偏好”,只要请求可能与用户长期偏好、稳定事实、历史决策、持续任务或项目上下文有关,就应先搜记忆。仅以下少数情况可跳过记忆搜索:纯寒暄、一次性简单计算、完全不依赖用户历史的直接事实问答。写入或修改记忆时,再使用 memory_manage。仅在遇到高价值且未来仍有用的信息时写入记忆:用户长期偏好、稳定事实、用户对你的纠正、持续任务/项目上下文、明确决策。不要保存一次性工具结果、临时列表、敏感凭证或不确定推测。写入时优先使用规范 namespace:preferences、profile、tasks、decisions,并优先调用 memory_manage(action='put');同一 namespace/key 可直接覆盖更新。检索时应提供 queries 数组,尽量同时放入中文关键词、英文别名,以及可能的 snake_case memory_key 词,例如 queries=['email', '邮件', 'email_folder_preference']。如果你决定跳过记忆搜索,应先确认当前请求确实属于上述少数例外,而不是因为你忘了检索。";
|
||||||
const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
|
const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
|
||||||
const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str = "工具已经启动并进入等待用户操作的状态。请先完成外部操作,完成后直接告诉我继续。";
|
const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str = "工具已经启动并进入等待用户操作的状态。请先完成外部操作,完成后直接告诉我继续。";
|
||||||
|
const RECOVERABLE_LLM_ERROR_MESSAGE: &str = "模型服务暂时不可用或响应超时。请稍后重试。";
|
||||||
|
|
||||||
/// Build content blocks from text and media paths
|
/// Build content blocks from text and media paths
|
||||||
fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock> {
|
fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock> {
|
||||||
@ -99,6 +100,23 @@ fn parse_pending_tool_output(output: &str) -> Option<String> {
|
|||||||
output.strip_prefix(PENDING_USER_ACTION_MARKER).map(|rest| rest.trim().to_string())
|
output.strip_prefix(PENDING_USER_ACTION_MARKER).map(|rest| rest.trim().to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn is_recoverable_llm_error(error: &str) -> bool {
|
||||||
|
let normalized = error.to_ascii_lowercase();
|
||||||
|
normalized.contains("504")
|
||||||
|
|| normalized.contains("gateway timeout")
|
||||||
|
|| normalized.contains("stream timeout")
|
||||||
|
|| normalized.contains("timed out")
|
||||||
|
|| normalized.contains("timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn recoverable_llm_message(error: &str) -> String {
|
||||||
|
if is_recoverable_llm_error(error) {
|
||||||
|
RECOVERABLE_LLM_ERROR_MESSAGE.to_string()
|
||||||
|
} else {
|
||||||
|
format!("模型请求失败:{}", error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Loop detection result.
|
/// Loop detection result.
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
enum LoopDetectionResult {
|
enum LoopDetectionResult {
|
||||||
@ -386,11 +404,18 @@ impl AgentLoop {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Call LLM
|
// Call LLM
|
||||||
let response = (*self.provider).chat(request).await
|
let response = match (*self.provider).chat(request).await {
|
||||||
.map_err(|e| {
|
Ok(response) => response,
|
||||||
|
Err(e) => {
|
||||||
tracing::error!(error = %e, "LLM request failed");
|
tracing::error!(error = %e, "LLM request failed");
|
||||||
AgentError::LlmError(e.to_string())
|
let assistant_message = ChatMessage::assistant(recoverable_llm_message(&e.to_string()));
|
||||||
})?;
|
emitted_messages.push(assistant_message.clone());
|
||||||
|
return Ok(AgentProcessResult {
|
||||||
|
final_response: assistant_message,
|
||||||
|
emitted_messages,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
@ -539,11 +564,8 @@ impl AgentLoop {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
// Fallback if summary call fails
|
|
||||||
tracing::error!(error = %e, "Failed to get summary from LLM");
|
tracing::error!(error = %e, "Failed to get summary from LLM");
|
||||||
let final_message = ChatMessage::assistant(
|
let final_message = ChatMessage::assistant(recoverable_llm_message(&e.to_string()));
|
||||||
format!("I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.", self.max_iterations)
|
|
||||||
);
|
|
||||||
emitted_messages.push(final_message.clone());
|
emitted_messages.push(final_message.clone());
|
||||||
Ok(AgentProcessResult {
|
Ok(AgentProcessResult {
|
||||||
final_response: final_message,
|
final_response: final_message,
|
||||||
|
|||||||
@ -41,6 +41,7 @@ impl Default for ContextCompressionConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Context compressor that reduces message history when it exceeds token limits.
|
/// Context compressor that reduces message history when it exceeds token limits.
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct ContextCompressor {
|
pub struct ContextCompressor {
|
||||||
config: ContextCompressionConfig,
|
config: ContextCompressionConfig,
|
||||||
context_window: usize,
|
context_window: usize,
|
||||||
|
|||||||
@ -101,6 +101,8 @@ pub struct ProviderConfig {
|
|||||||
pub api_key: String,
|
pub api_key: String,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub extra_headers: HashMap<String, String>,
|
pub extra_headers: HashMap<String, String>,
|
||||||
|
#[serde(default = "default_llm_timeout_secs")]
|
||||||
|
pub llm_timeout_secs: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
@ -132,6 +134,10 @@ fn default_token_limit() -> usize {
|
|||||||
128_000
|
128_000
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn default_llm_timeout_secs() -> u64 {
|
||||||
|
120
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct GatewayConfig {
|
pub struct GatewayConfig {
|
||||||
#[serde(default = "default_gateway_host")]
|
#[serde(default = "default_gateway_host")]
|
||||||
@ -400,6 +406,7 @@ pub struct LLMProviderConfig {
|
|||||||
pub base_url: String,
|
pub base_url: String,
|
||||||
pub api_key: String,
|
pub api_key: String,
|
||||||
pub extra_headers: HashMap<String, String>,
|
pub extra_headers: HashMap<String, String>,
|
||||||
|
pub llm_timeout_secs: u64,
|
||||||
pub model_id: String,
|
pub model_id: String,
|
||||||
pub temperature: Option<f32>,
|
pub temperature: Option<f32>,
|
||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
@ -461,6 +468,7 @@ impl Config {
|
|||||||
base_url: provider.base_url.clone(),
|
base_url: provider.base_url.clone(),
|
||||||
api_key: provider.api_key.clone(),
|
api_key: provider.api_key.clone(),
|
||||||
extra_headers: provider.extra_headers.clone(),
|
extra_headers: provider.extra_headers.clone(),
|
||||||
|
llm_timeout_secs: provider.llm_timeout_secs,
|
||||||
model_id: model.model_id.clone(),
|
model_id: model.model_id.clone(),
|
||||||
temperature: model.temperature,
|
temperature: model.temperature,
|
||||||
max_tokens: model.max_tokens,
|
max_tokens: model.max_tokens,
|
||||||
@ -601,6 +609,43 @@ mod tests {
|
|||||||
assert_eq!(provider_config.name, "aliyun");
|
assert_eq!(provider_config.name, "aliyun");
|
||||||
assert_eq!(provider_config.model_id, "qwen-plus");
|
assert_eq!(provider_config.model_id, "qwen-plus");
|
||||||
assert_eq!(provider_config.temperature, Some(0.0));
|
assert_eq!(provider_config.temperature, Some(0.0));
|
||||||
|
assert_eq!(provider_config.llm_timeout_secs, 120);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_provider_config_loads_custom_llm_timeout() {
|
||||||
|
let file = tempfile::NamedTempFile::new().unwrap();
|
||||||
|
std::fs::write(
|
||||||
|
file.path(),
|
||||||
|
r#"{
|
||||||
|
"providers": {
|
||||||
|
"aliyun": {
|
||||||
|
"type": "openai",
|
||||||
|
"base_url": "https://example.invalid/v1",
|
||||||
|
"api_key": "test-key",
|
||||||
|
"extra_headers": {},
|
||||||
|
"llm_timeout_secs": 400
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"models": {
|
||||||
|
"qwen-plus": {
|
||||||
|
"model_id": "qwen-plus"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"agents": {
|
||||||
|
"default": {
|
||||||
|
"provider": "aliyun",
|
||||||
|
"model": "qwen-plus"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}"#,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let config = Config::load(file.path().to_str().unwrap()).unwrap();
|
||||||
|
let provider_config = config.get_provider_config("default").unwrap();
|
||||||
|
|
||||||
|
assert_eq!(provider_config.llm_timeout_secs, 400);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@ -2,6 +2,7 @@ pub mod http;
|
|||||||
pub mod session;
|
pub mod session;
|
||||||
pub mod ws;
|
pub mod ws;
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use axum::{routing, Router};
|
use axum::{routing, Router};
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
@ -9,6 +10,7 @@ use tokio::net::TcpListener;
|
|||||||
use crate::bus::{MessageBus, OutboundDispatcher};
|
use crate::bus::{MessageBus, OutboundDispatcher};
|
||||||
use crate::channels::ChannelManager;
|
use crate::channels::ChannelManager;
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
|
use crate::config::LLMProviderConfig;
|
||||||
use crate::logging;
|
use crate::logging;
|
||||||
use crate::scheduler::Scheduler;
|
use crate::scheduler::Scheduler;
|
||||||
use crate::skills::SkillRuntime;
|
use crate::skills::SkillRuntime;
|
||||||
@ -27,6 +29,10 @@ impl GatewayState {
|
|||||||
|
|
||||||
// Get provider config for SessionManager
|
// Get provider config for SessionManager
|
||||||
let provider_config = config.get_provider_config("default")?;
|
let provider_config = config.get_provider_config("default")?;
|
||||||
|
let mut provider_configs = HashMap::<String, LLMProviderConfig>::new();
|
||||||
|
for agent_name in config.agents.keys() {
|
||||||
|
provider_configs.insert(agent_name.clone(), config.get_provider_config(agent_name)?);
|
||||||
|
}
|
||||||
|
|
||||||
// Session TTL from config (default 4 hours)
|
// Session TTL from config (default 4 hours)
|
||||||
let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4);
|
let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4);
|
||||||
@ -40,6 +46,7 @@ impl GatewayState {
|
|||||||
agent_prompt_reinject_every,
|
agent_prompt_reinject_every,
|
||||||
show_tool_results,
|
show_tool_results,
|
||||||
provider_config,
|
provider_config,
|
||||||
|
provider_configs,
|
||||||
skills,
|
skills,
|
||||||
)?;
|
)?;
|
||||||
let channel_manager = ChannelManager::new();
|
let channel_manager = ChannelManager::new();
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
@ -49,6 +49,7 @@ pub struct ScheduledAgentTaskOptions {
|
|||||||
pub fresh_session: bool,
|
pub fresh_session: bool,
|
||||||
pub system_prompt: Option<String>,
|
pub system_prompt: Option<String>,
|
||||||
pub metadata: HashMap<String, String>,
|
pub metadata: HashMap<String, String>,
|
||||||
|
pub agent: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BusToolCallEmitter {
|
impl BusToolCallEmitter {
|
||||||
@ -244,6 +245,23 @@ impl Session {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn latest_user_message_id(&self, chat_id: &str) -> Option<&str> {
|
||||||
|
self.get_history(chat_id)
|
||||||
|
.and_then(|history| {
|
||||||
|
history
|
||||||
|
.iter()
|
||||||
|
.rev()
|
||||||
|
.find(|message| message.role == "user")
|
||||||
|
.map(|message| message.id.as_str())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_latest_user_message(&self, chat_id: &str, message_id: &str) -> bool {
|
||||||
|
self.latest_user_message_id(chat_id)
|
||||||
|
.map(|current_id| current_id == message_id)
|
||||||
|
.unwrap_or(false)
|
||||||
|
}
|
||||||
|
|
||||||
/// 清除所有历史
|
/// 清除所有历史
|
||||||
pub fn clear_all_history(&mut self) -> Result<(), AgentError> {
|
pub fn clear_all_history(&mut self) -> Result<(), AgentError> {
|
||||||
let chat_ids: Vec<String> = self.chat_histories.keys().cloned().collect();
|
let chat_ids: Vec<String> = self.chat_histories.keys().cloned().collect();
|
||||||
@ -296,10 +314,20 @@ impl Session {
|
|||||||
chat_id: &str,
|
chat_id: &str,
|
||||||
sender_id: Option<&str>,
|
sender_id: Option<&str>,
|
||||||
message_id: Option<&str>,
|
message_id: Option<&str>,
|
||||||
|
) -> Result<AgentLoop, AgentError> {
|
||||||
|
self.create_agent_with_provider_config(chat_id, sender_id, message_id, self.provider_config.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn create_agent_with_provider_config(
|
||||||
|
&self,
|
||||||
|
chat_id: &str,
|
||||||
|
sender_id: Option<&str>,
|
||||||
|
message_id: Option<&str>,
|
||||||
|
provider_config: LLMProviderConfig,
|
||||||
) -> Result<AgentLoop, AgentError> {
|
) -> Result<AgentLoop, AgentError> {
|
||||||
let session_id = self.persistent_session_id(chat_id);
|
let session_id = self.persistent_session_id(chat_id);
|
||||||
AgentLoop::with_tools_and_skills(
|
AgentLoop::with_tools_and_skills(
|
||||||
self.provider_config.clone(),
|
provider_config,
|
||||||
self.tools.clone(),
|
self.tools.clone(),
|
||||||
self.skills.clone(),
|
self.skills.clone(),
|
||||||
)
|
)
|
||||||
@ -368,6 +396,7 @@ fn agent_prompt_path() -> Result<std::path::PathBuf, AgentError> {
|
|||||||
pub struct SessionManager {
|
pub struct SessionManager {
|
||||||
inner: Arc<Mutex<SessionManagerInner>>,
|
inner: Arc<Mutex<SessionManagerInner>>,
|
||||||
provider_config: LLMProviderConfig,
|
provider_config: LLMProviderConfig,
|
||||||
|
provider_configs: Arc<HashMap<String, LLMProviderConfig>>,
|
||||||
tools: Arc<ToolRegistry>,
|
tools: Arc<ToolRegistry>,
|
||||||
skills: Arc<SkillRuntime>,
|
skills: Arc<SkillRuntime>,
|
||||||
store: Arc<SessionStore>,
|
store: Arc<SessionStore>,
|
||||||
@ -381,7 +410,7 @@ struct SessionManagerInner {
|
|||||||
session_ttl: Duration,
|
session_ttl: Duration,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_tools(skills: Arc<SkillRuntime>, store: Arc<SessionStore>) -> ToolRegistry {
|
fn default_tools(skills: Arc<SkillRuntime>, store: Arc<SessionStore>, known_agents: HashSet<String>) -> ToolRegistry {
|
||||||
let mut registry = ToolRegistry::new();
|
let mut registry = ToolRegistry::new();
|
||||||
registry.register(CalculatorTool::new());
|
registry.register(CalculatorTool::new());
|
||||||
registry.register(FileReadTool::new());
|
registry.register(FileReadTool::new());
|
||||||
@ -389,7 +418,7 @@ fn default_tools(skills: Arc<SkillRuntime>, store: Arc<SessionStore>) -> ToolReg
|
|||||||
registry.register(FileEditTool::new());
|
registry.register(FileEditTool::new());
|
||||||
registry.register(MemorySearchTool::new(store.clone()));
|
registry.register(MemorySearchTool::new(store.clone()));
|
||||||
registry.register(MemoryManageTool::new(store.clone()));
|
registry.register(MemoryManageTool::new(store.clone()));
|
||||||
registry.register(SchedulerManageTool::new(store));
|
registry.register(SchedulerManageTool::new(store, known_agents));
|
||||||
registry.register(SkillListTool::new(skills.clone()));
|
registry.register(SkillListTool::new(skills.clone()));
|
||||||
registry.register(SkillManageTool::new(skills));
|
registry.register(SkillManageTool::new(skills));
|
||||||
registry.register(BashTool::new());
|
registry.register(BashTool::new());
|
||||||
@ -435,12 +464,14 @@ impl SessionManager {
|
|||||||
agent_prompt_reinject_every: u64,
|
agent_prompt_reinject_every: u64,
|
||||||
show_tool_results: bool,
|
show_tool_results: bool,
|
||||||
provider_config: LLMProviderConfig,
|
provider_config: LLMProviderConfig,
|
||||||
|
provider_configs: HashMap<String, LLMProviderConfig>,
|
||||||
skills: Arc<SkillRuntime>,
|
skills: Arc<SkillRuntime>,
|
||||||
) -> Result<Self, AgentError> {
|
) -> Result<Self, AgentError> {
|
||||||
let store = Arc::new(
|
let store = Arc::new(
|
||||||
SessionStore::new()
|
SessionStore::new()
|
||||||
.map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?,
|
.map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?,
|
||||||
);
|
);
|
||||||
|
let known_agents = provider_configs.keys().cloned().collect::<HashSet<_>>();
|
||||||
|
|
||||||
if let Err(err) = store.append_skill_event(None, "discovered", None, &skills.discovery_event_payload()) {
|
if let Err(err) = store.append_skill_event(None, "discovered", None, &skills.discovery_event_payload()) {
|
||||||
tracing::warn!(error = %err, "Failed to record skill discovery event");
|
tracing::warn!(error = %err, "Failed to record skill discovery event");
|
||||||
@ -453,7 +484,8 @@ impl SessionManager {
|
|||||||
session_ttl: Duration::from_secs(session_ttl_hours * 3600),
|
session_ttl: Duration::from_secs(session_ttl_hours * 3600),
|
||||||
})),
|
})),
|
||||||
provider_config,
|
provider_config,
|
||||||
tools: Arc::new(default_tools(skills.clone(), store.clone())),
|
provider_configs: Arc::new(provider_configs),
|
||||||
|
tools: Arc::new(default_tools(skills.clone(), store.clone(), known_agents)),
|
||||||
skills,
|
skills,
|
||||||
store,
|
store,
|
||||||
agent_prompt_reinject_every,
|
agent_prompt_reinject_every,
|
||||||
@ -477,6 +509,10 @@ impl SessionManager {
|
|||||||
self.skills.clone()
|
self.skills.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn provider_config_for_agent(&self, agent_name: Option<&str>) -> Result<LLMProviderConfig, AgentError> {
|
||||||
|
select_provider_config(&self.provider_config, &self.provider_configs, agent_name)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn create_cli_session(&self, title: Option<&str>) -> Result<SessionRecord, AgentError> {
|
pub fn create_cli_session(&self, title: Option<&str>) -> Result<SessionRecord, AgentError> {
|
||||||
self.store
|
self.store
|
||||||
.create_cli_session(title)
|
.create_cli_session(title)
|
||||||
@ -640,7 +676,7 @@ impl SessionManager {
|
|||||||
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
|
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
|
||||||
|
|
||||||
// 处理消息
|
// 处理消息
|
||||||
let response = {
|
let (history, compressor, provider_config, agent, user_message_id) = {
|
||||||
let mut session_guard = session.lock().await;
|
let mut session_guard = session.lock().await;
|
||||||
|
|
||||||
session_guard.ensure_persistent_session(chat_id)?;
|
session_guard.ensure_persistent_session(chat_id)?;
|
||||||
@ -669,11 +705,8 @@ impl SessionManager {
|
|||||||
session_guard.append_persisted_message(chat_id, user_message)?;
|
session_guard.append_persisted_message(chat_id, user_message)?;
|
||||||
|
|
||||||
let history = session_guard.get_or_create_history(chat_id).clone();
|
let history = session_guard.get_or_create_history(chat_id).clone();
|
||||||
|
let compressor = session_guard.compressor().clone();
|
||||||
// 压缩历史(如果需要)
|
let provider_config = session_guard.provider_config().clone();
|
||||||
let history = session_guard.compressor
|
|
||||||
.compress_if_needed(history, &session_guard.provider_config)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
session_guard.record_skill_offer(chat_id)?;
|
session_guard.record_skill_offer(chat_id)?;
|
||||||
|
|
||||||
@ -682,28 +715,48 @@ impl SessionManager {
|
|||||||
if let Some(handler) = live_emitter.clone() {
|
if let Some(handler) = live_emitter.clone() {
|
||||||
agent = agent.with_emitted_message_handler(handler);
|
agent = agent.with_emitted_message_handler(handler);
|
||||||
}
|
}
|
||||||
let result = agent.process(history).await?;
|
|
||||||
|
|
||||||
// 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复
|
(history, compressor, provider_config, agent, user_message_id)
|
||||||
session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?;
|
};
|
||||||
|
|
||||||
result
|
let history = compressor
|
||||||
.emitted_messages
|
.compress_if_needed(history, &provider_config)
|
||||||
.iter()
|
.await?;
|
||||||
.filter(|message| {
|
let result = agent.process(history).await?;
|
||||||
(!message.is_assistant_tool_call_message() || live_emitter.is_none())
|
|
||||||
&& should_display_message_to_user(self.show_tool_results, message)
|
let response = {
|
||||||
})
|
let mut session_guard = session.lock().await;
|
||||||
.flat_map(|message| {
|
|
||||||
OutboundMessage::from_chat_message(
|
if !session_guard.is_latest_user_message(chat_id, &user_message_id) {
|
||||||
channel_name,
|
tracing::warn!(
|
||||||
chat_id,
|
channel = %channel_name,
|
||||||
None,
|
chat_id = %chat_id,
|
||||||
&HashMap::new(),
|
user_message_id = %user_message_id,
|
||||||
message,
|
"Skipping stale agent result because a newer user message is already present"
|
||||||
)
|
);
|
||||||
})
|
Vec::new()
|
||||||
.collect::<Vec<_>>()
|
} else {
|
||||||
|
// 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复
|
||||||
|
session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?;
|
||||||
|
|
||||||
|
result
|
||||||
|
.emitted_messages
|
||||||
|
.iter()
|
||||||
|
.filter(|message| {
|
||||||
|
(!message.is_assistant_tool_call_message() || live_emitter.is_none())
|
||||||
|
&& should_display_message_to_user(self.show_tool_results, message)
|
||||||
|
})
|
||||||
|
.flat_map(|message| {
|
||||||
|
OutboundMessage::from_chat_message(
|
||||||
|
channel_name,
|
||||||
|
chat_id,
|
||||||
|
None,
|
||||||
|
&HashMap::new(),
|
||||||
|
message,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
@ -736,8 +789,9 @@ impl SessionManager {
|
|||||||
.sender_id
|
.sender_id
|
||||||
.clone()
|
.clone()
|
||||||
.unwrap_or_else(|| "scheduler".to_string());
|
.unwrap_or_else(|| "scheduler".to_string());
|
||||||
|
let provider_config = self.provider_config_for_agent(options.agent.as_deref())?;
|
||||||
|
|
||||||
let response = {
|
let (history, compressor, agent, user_message_id) = {
|
||||||
let mut session_guard = session.lock().await;
|
let mut session_guard = session.lock().await;
|
||||||
|
|
||||||
session_guard.ensure_persistent_session(chat_id)?;
|
session_guard.ensure_persistent_session(chat_id)?;
|
||||||
@ -758,31 +812,54 @@ impl SessionManager {
|
|||||||
session_guard.append_persisted_message(chat_id, user_message)?;
|
session_guard.append_persisted_message(chat_id, user_message)?;
|
||||||
|
|
||||||
let history = session_guard.get_or_create_history(chat_id).clone();
|
let history = session_guard.get_or_create_history(chat_id).clone();
|
||||||
let history = session_guard
|
let compressor = session_guard.compressor().clone();
|
||||||
.compressor
|
|
||||||
.compress_if_needed(history, &session_guard.provider_config)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
session_guard.record_skill_offer(chat_id)?;
|
session_guard.record_skill_offer(chat_id)?;
|
||||||
|
|
||||||
let agent = session_guard.create_agent(chat_id, Some(&sender_id), Some(&user_message_id))?;
|
let agent = session_guard.create_agent_with_provider_config(
|
||||||
let result = agent.process(history).await?;
|
chat_id,
|
||||||
session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?;
|
Some(&sender_id),
|
||||||
|
Some(&user_message_id),
|
||||||
|
provider_config.clone(),
|
||||||
|
)?;
|
||||||
|
|
||||||
result
|
(history, compressor, agent, user_message_id)
|
||||||
.emitted_messages
|
};
|
||||||
.iter()
|
|
||||||
.filter(|message| should_display_message_to_user(self.show_tool_results, message))
|
let history = compressor
|
||||||
.flat_map(|message| {
|
.compress_if_needed(history, &provider_config)
|
||||||
OutboundMessage::from_chat_message(
|
.await?;
|
||||||
channel_name,
|
let result = agent.process(history).await?;
|
||||||
chat_id,
|
|
||||||
None,
|
let response = {
|
||||||
&options.metadata,
|
let mut session_guard = session.lock().await;
|
||||||
message,
|
|
||||||
)
|
if !session_guard.is_latest_user_message(chat_id, &user_message_id) {
|
||||||
})
|
tracing::warn!(
|
||||||
.collect::<Vec<_>>()
|
channel = %channel_name,
|
||||||
|
chat_id = %chat_id,
|
||||||
|
user_message_id = %user_message_id,
|
||||||
|
"Skipping stale scheduled agent result because a newer user message is already present"
|
||||||
|
);
|
||||||
|
Vec::new()
|
||||||
|
} else {
|
||||||
|
session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?;
|
||||||
|
|
||||||
|
result
|
||||||
|
.emitted_messages
|
||||||
|
.iter()
|
||||||
|
.filter(|message| should_display_message_to_user(self.show_tool_results, message))
|
||||||
|
.flat_map(|message| {
|
||||||
|
OutboundMessage::from_chat_message(
|
||||||
|
channel_name,
|
||||||
|
chat_id,
|
||||||
|
None,
|
||||||
|
&options.metadata,
|
||||||
|
message,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(response)
|
Ok(response)
|
||||||
@ -810,11 +887,29 @@ fn should_display_message_to_user(show_tool_results: bool, message: &ChatMessage
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn select_provider_config(
|
||||||
|
default_provider_config: &LLMProviderConfig,
|
||||||
|
provider_configs: &HashMap<String, LLMProviderConfig>,
|
||||||
|
agent_name: Option<&str>,
|
||||||
|
) -> Result<LLMProviderConfig, AgentError> {
|
||||||
|
match agent_name.map(str::trim).filter(|value| !value.is_empty()) {
|
||||||
|
None | Some("default") => Ok(default_provider_config.clone()),
|
||||||
|
Some(agent_name) => provider_configs
|
||||||
|
.get(agent_name)
|
||||||
|
.cloned()
|
||||||
|
.ok_or_else(|| AgentError::Other(format!("Scheduled agent '{}' not found", agent_name))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use axum::http::StatusCode;
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use axum::{Json, Router, routing::post};
|
||||||
use crate::bus::MessageBus;
|
use crate::bus::MessageBus;
|
||||||
|
use serde_json::{Value, json};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use tokio::net::TcpListener;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
fn test_provider_config() -> LLMProviderConfig {
|
fn test_provider_config() -> LLMProviderConfig {
|
||||||
@ -824,6 +919,7 @@ mod tests {
|
|||||||
base_url: "http://localhost".to_string(),
|
base_url: "http://localhost".to_string(),
|
||||||
api_key: "test-key".to_string(),
|
api_key: "test-key".to_string(),
|
||||||
extra_headers: HashMap::new(),
|
extra_headers: HashMap::new(),
|
||||||
|
llm_timeout_secs: 120,
|
||||||
model_id: "test-model".to_string(),
|
model_id: "test-model".to_string(),
|
||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
max_tokens: Some(32),
|
max_tokens: Some(32),
|
||||||
@ -833,6 +929,235 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn test_provider_config_named(name: &str, model_id: &str) -> LLMProviderConfig {
|
||||||
|
LLMProviderConfig {
|
||||||
|
provider_type: "openai".to_string(),
|
||||||
|
name: name.to_string(),
|
||||||
|
base_url: "http://localhost".to_string(),
|
||||||
|
api_key: "test-key".to_string(),
|
||||||
|
extra_headers: HashMap::new(),
|
||||||
|
llm_timeout_secs: 120,
|
||||||
|
model_id: model_id.to_string(),
|
||||||
|
temperature: Some(0.0),
|
||||||
|
max_tokens: Some(32),
|
||||||
|
model_extra: HashMap::new(),
|
||||||
|
max_tool_iterations: 1,
|
||||||
|
token_limit: 4096,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_select_provider_config_uses_named_agent_override() {
|
||||||
|
let default_provider = test_provider_config_named("default-provider", "default-model");
|
||||||
|
let provider_configs = HashMap::from([(
|
||||||
|
"planner".to_string(),
|
||||||
|
test_provider_config_named("planner-provider", "planner-model"),
|
||||||
|
)]);
|
||||||
|
|
||||||
|
let selected = select_provider_config(&default_provider, &provider_configs, Some("planner")).unwrap();
|
||||||
|
assert_eq!(selected.name, "planner-provider");
|
||||||
|
assert_eq!(selected.model_id, "planner-model");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_latest_user_message_guard_tracks_current_turn() {
|
||||||
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
|
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||||
|
let skills = Arc::new(SkillRuntime::default());
|
||||||
|
let tools = Arc::new(default_tools(skills.clone(), store.clone(), HashSet::new()));
|
||||||
|
let mut session = Session::new(
|
||||||
|
"feishu".to_string(),
|
||||||
|
test_provider_config(),
|
||||||
|
user_tx,
|
||||||
|
tools,
|
||||||
|
skills,
|
||||||
|
store,
|
||||||
|
100,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
session.ensure_persistent_session("chat-1").unwrap();
|
||||||
|
session.ensure_chat_loaded("chat-1").unwrap();
|
||||||
|
|
||||||
|
let first = session.create_user_message("first", Vec::new());
|
||||||
|
let first_id = first.id.clone();
|
||||||
|
session.append_persisted_message("chat-1", first).unwrap();
|
||||||
|
assert!(session.is_latest_user_message("chat-1", &first_id));
|
||||||
|
|
||||||
|
let second = session.create_user_message("second", Vec::new());
|
||||||
|
let second_id = second.id.clone();
|
||||||
|
session.append_persisted_message("chat-1", second).unwrap();
|
||||||
|
|
||||||
|
assert!(!session.is_latest_user_message("chat-1", &first_id));
|
||||||
|
assert!(session.is_latest_user_message("chat-1", &second_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_select_provider_config_falls_back_to_default() {
|
||||||
|
let default_provider = test_provider_config_named("default-provider", "default-model");
|
||||||
|
let provider_configs = HashMap::new();
|
||||||
|
|
||||||
|
let selected = select_provider_config(&default_provider, &provider_configs, Some("default")).unwrap();
|
||||||
|
assert_eq!(selected.name, "default-provider");
|
||||||
|
assert_eq!(selected.model_id, "default-model");
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn start_mock_openai_server() -> String {
|
||||||
|
async fn handle(Json(body): Json<Value>) -> Json<Value> {
|
||||||
|
let model = body
|
||||||
|
.get("model")
|
||||||
|
.and_then(|value| value.as_str())
|
||||||
|
.unwrap_or("unknown-model");
|
||||||
|
|
||||||
|
Json(json!({
|
||||||
|
"id": "mock-response",
|
||||||
|
"model": model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"message": {
|
||||||
|
"content": format!("reply from {}", model),
|
||||||
|
"tool_calls": []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 1,
|
||||||
|
"completion_tokens": 1,
|
||||||
|
"total_tokens": 2
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
let app = Router::new().route("/chat/completions", post(handle));
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let address = listener.local_addr().unwrap();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
axum::serve(listener, app).await.unwrap();
|
||||||
|
});
|
||||||
|
format!("http://{}", address)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn start_mock_openai_504_server() -> String {
|
||||||
|
async fn handle() -> (StatusCode, &'static str) {
|
||||||
|
(StatusCode::GATEWAY_TIMEOUT, "stream timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
let app = Router::new().route("/chat/completions", post(handle));
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let address = listener.local_addr().unwrap();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
axum::serve(listener, app).await.unwrap();
|
||||||
|
});
|
||||||
|
format!("http://{}", address)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_handle_message_returns_recoverable_reply_on_llm_504() {
|
||||||
|
let base_url = start_mock_openai_504_server().await;
|
||||||
|
let provider_config = LLMProviderConfig {
|
||||||
|
provider_type: "openai".to_string(),
|
||||||
|
name: "timeout-provider".to_string(),
|
||||||
|
base_url: base_url.clone(),
|
||||||
|
api_key: "test-key".to_string(),
|
||||||
|
extra_headers: HashMap::new(),
|
||||||
|
model_id: "timeout-model".to_string(),
|
||||||
|
temperature: Some(0.0),
|
||||||
|
max_tokens: Some(32),
|
||||||
|
model_extra: HashMap::new(),
|
||||||
|
max_tool_iterations: 1,
|
||||||
|
token_limit: 4096,
|
||||||
|
llm_timeout_secs: 30,
|
||||||
|
};
|
||||||
|
|
||||||
|
let session_manager = SessionManager::new(
|
||||||
|
4,
|
||||||
|
100,
|
||||||
|
false,
|
||||||
|
provider_config.clone(),
|
||||||
|
HashMap::from([("default".to_string(), provider_config)]),
|
||||||
|
Arc::new(SkillRuntime::default()),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let outbound = session_manager
|
||||||
|
.handle_message("feishu", "user-1", "chat-1", "hello", Vec::new(), None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(outbound.len(), 1);
|
||||||
|
assert!(outbound[0].content.contains("模型服务暂时不可用或响应超时"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_run_scheduled_agent_task_uses_task_specific_agent_provider() {
|
||||||
|
let base_url = start_mock_openai_server().await;
|
||||||
|
let default_provider = LLMProviderConfig {
|
||||||
|
provider_type: "openai".to_string(),
|
||||||
|
name: "default-provider".to_string(),
|
||||||
|
base_url: base_url.clone(),
|
||||||
|
api_key: "test-key".to_string(),
|
||||||
|
extra_headers: HashMap::new(),
|
||||||
|
model_id: "default-model".to_string(),
|
||||||
|
temperature: Some(0.0),
|
||||||
|
max_tokens: Some(32),
|
||||||
|
model_extra: HashMap::new(),
|
||||||
|
max_tool_iterations: 1,
|
||||||
|
token_limit: 4096,
|
||||||
|
llm_timeout_secs: 30,
|
||||||
|
};
|
||||||
|
let planner_provider = LLMProviderConfig {
|
||||||
|
model_id: "planner-model".to_string(),
|
||||||
|
name: "planner-provider".to_string(),
|
||||||
|
..default_provider.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
let session_manager = SessionManager::new(
|
||||||
|
4,
|
||||||
|
100,
|
||||||
|
false,
|
||||||
|
default_provider.clone(),
|
||||||
|
HashMap::from([
|
||||||
|
("default".to_string(), default_provider),
|
||||||
|
("planner".to_string(), planner_provider),
|
||||||
|
]),
|
||||||
|
Arc::new(SkillRuntime::default()),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let planner_outbound = session_manager
|
||||||
|
.run_scheduled_agent_task(
|
||||||
|
"feishu",
|
||||||
|
"chat-planner",
|
||||||
|
"请规划今天工作",
|
||||||
|
ScheduledAgentTaskOptions {
|
||||||
|
agent: Some("planner".to_string()),
|
||||||
|
fresh_session: true,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(planner_outbound.len(), 1);
|
||||||
|
assert!(planner_outbound[0].content.contains("planner-model"));
|
||||||
|
|
||||||
|
let default_outbound = session_manager
|
||||||
|
.run_scheduled_agent_task(
|
||||||
|
"feishu",
|
||||||
|
"chat-default",
|
||||||
|
"请规划今天工作",
|
||||||
|
ScheduledAgentTaskOptions {
|
||||||
|
agent: Some("default".to_string()),
|
||||||
|
fresh_session: true,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(default_outbound.len(), 1);
|
||||||
|
assert!(default_outbound[0].content.contains("default-model"));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_should_display_message_to_user_hides_completed_tool_results_by_default() {
|
fn test_should_display_message_to_user_hides_completed_tool_results_by_default() {
|
||||||
let completed = ChatMessage::tool("call-1", "calculator", "2");
|
let completed = ChatMessage::tool("call-1", "calculator", "2");
|
||||||
@ -881,7 +1206,7 @@ mod tests {
|
|||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||||
let skills = Arc::new(SkillRuntime::default());
|
let skills = Arc::new(SkillRuntime::default());
|
||||||
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
|
let tools = Arc::new(default_tools(skills.clone(), store.clone(), HashSet::new()));
|
||||||
let mut session = Session::new(
|
let mut session = Session::new(
|
||||||
"feishu".to_string(),
|
"feishu".to_string(),
|
||||||
test_provider_config(),
|
test_provider_config(),
|
||||||
@ -929,7 +1254,7 @@ mod tests {
|
|||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||||
let skills = Arc::new(SkillRuntime::default());
|
let skills = Arc::new(SkillRuntime::default());
|
||||||
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
|
let tools = Arc::new(default_tools(skills.clone(), store.clone(), HashSet::new()));
|
||||||
let mut session = Session::new(
|
let mut session = Session::new(
|
||||||
"feishu".to_string(),
|
"feishu".to_string(),
|
||||||
test_provider_config(),
|
test_provider_config(),
|
||||||
@ -956,7 +1281,7 @@ mod tests {
|
|||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||||
let skills = Arc::new(SkillRuntime::default());
|
let skills = Arc::new(SkillRuntime::default());
|
||||||
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
|
let tools = Arc::new(default_tools(skills.clone(), store.clone(), HashSet::new()));
|
||||||
let mut session = Session::new(
|
let mut session = Session::new(
|
||||||
"feishu".to_string(),
|
"feishu".to_string(),
|
||||||
test_provider_config(),
|
test_provider_config(),
|
||||||
@ -1001,7 +1326,7 @@ mod tests {
|
|||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||||
let skills = Arc::new(SkillRuntime::default());
|
let skills = Arc::new(SkillRuntime::default());
|
||||||
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
|
let tools = Arc::new(default_tools(skills.clone(), store.clone(), HashSet::new()));
|
||||||
let mut session = Session::new(
|
let mut session = Session::new(
|
||||||
"feishu".to_string(),
|
"feishu".to_string(),
|
||||||
test_provider_config(),
|
test_provider_config(),
|
||||||
@ -1035,7 +1360,7 @@ mod tests {
|
|||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||||
let skills = Arc::new(SkillRuntime::default());
|
let skills = Arc::new(SkillRuntime::default());
|
||||||
let tools = Arc::new(default_tools(skills.clone(), store.clone()));
|
let tools = Arc::new(default_tools(skills.clone(), store.clone(), HashSet::new()));
|
||||||
let mut session = Session::new(
|
let mut session = Session::new(
|
||||||
"feishu".to_string(),
|
"feishu".to_string(),
|
||||||
test_provider_config(),
|
test_provider_config(),
|
||||||
|
|||||||
@ -2,6 +2,7 @@ use async_trait::async_trait;
|
|||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
use crate::bus::message::ContentBlock;
|
use crate::bus::message::ContentBlock;
|
||||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
||||||
@ -58,6 +59,7 @@ pub struct AnthropicProvider {
|
|||||||
api_key: String,
|
api_key: String,
|
||||||
base_url: String,
|
base_url: String,
|
||||||
extra_headers: HashMap<String, String>,
|
extra_headers: HashMap<String, String>,
|
||||||
|
llm_timeout_secs: u64,
|
||||||
model_id: String,
|
model_id: String,
|
||||||
temperature: Option<f32>,
|
temperature: Option<f32>,
|
||||||
max_tokens: Option<u32>,
|
max_tokens: Option<u32>,
|
||||||
@ -70,17 +72,24 @@ impl AnthropicProvider {
|
|||||||
api_key: String,
|
api_key: String,
|
||||||
base_url: String,
|
base_url: String,
|
||||||
extra_headers: HashMap<String, String>,
|
extra_headers: HashMap<String, String>,
|
||||||
|
llm_timeout_secs: u64,
|
||||||
model_id: String,
|
model_id: String,
|
||||||
temperature: Option<f32>,
|
temperature: Option<f32>,
|
||||||
max_tokens: Option<u32>,
|
max_tokens: Option<u32>,
|
||||||
model_extra: HashMap<String, serde_json::Value>,
|
model_extra: HashMap<String, serde_json::Value>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
let client = Client::builder()
|
||||||
|
.timeout(Duration::from_secs(llm_timeout_secs))
|
||||||
|
.build()
|
||||||
|
.unwrap_or_else(|_| Client::new());
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
client: Client::new(),
|
client,
|
||||||
name,
|
name,
|
||||||
api_key,
|
api_key,
|
||||||
base_url,
|
base_url,
|
||||||
extra_headers,
|
extra_headers,
|
||||||
|
llm_timeout_secs,
|
||||||
model_id,
|
model_id,
|
||||||
temperature,
|
temperature,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
@ -190,8 +199,21 @@ impl LLMProvider for AnthropicProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let resp = req_builder.json(&body).send().await?;
|
let resp = req_builder.json(&body).send().await?;
|
||||||
|
let status = resp.status();
|
||||||
|
let text = resp.text().await?;
|
||||||
|
|
||||||
let anthropic_resp: AnthropicResponse = resp.json().await?;
|
if !status.is_success() {
|
||||||
|
return Err(format!("API error {}: {}", status, text).into());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
let resp_preview: String = text.chars().take(100).collect();
|
||||||
|
tracing::debug!(status = %status, response_preview = %resp_preview, response_len = %text.len(), timeout_secs = self.llm_timeout_secs, "Anthropic response (first 100 chars shown)");
|
||||||
|
}
|
||||||
|
|
||||||
|
let anthropic_resp: AnthropicResponse = serde_json::from_str(&text)
|
||||||
|
.map_err(|e| format!("decode error: {} | body: {}", e, &text))?;
|
||||||
|
|
||||||
let mut content = String::new();
|
let mut content = String::new();
|
||||||
let mut tool_calls = Vec::new();
|
let mut tool_calls = Vec::new();
|
||||||
|
|||||||
@ -15,6 +15,7 @@ pub fn create_provider(config: LLMProviderConfig) -> Result<Box<dyn LLMProvider>
|
|||||||
config.api_key,
|
config.api_key,
|
||||||
config.base_url,
|
config.base_url,
|
||||||
config.extra_headers,
|
config.extra_headers,
|
||||||
|
config.llm_timeout_secs,
|
||||||
config.model_id,
|
config.model_id,
|
||||||
config.temperature,
|
config.temperature,
|
||||||
config.max_tokens,
|
config.max_tokens,
|
||||||
@ -25,6 +26,7 @@ pub fn create_provider(config: LLMProviderConfig) -> Result<Box<dyn LLMProvider>
|
|||||||
config.api_key,
|
config.api_key,
|
||||||
config.base_url,
|
config.base_url,
|
||||||
config.extra_headers,
|
config.extra_headers,
|
||||||
|
config.llm_timeout_secs,
|
||||||
config.model_id,
|
config.model_id,
|
||||||
config.temperature,
|
config.temperature,
|
||||||
config.max_tokens,
|
config.max_tokens,
|
||||||
|
|||||||
@ -3,6 +3,7 @@ use reqwest::Client;
|
|||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde_json::{json, Value};
|
use serde_json::{json, Value};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
use crate::bus::message::ContentBlock;
|
use crate::bus::message::ContentBlock;
|
||||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
||||||
@ -28,6 +29,7 @@ pub struct OpenAIProvider {
|
|||||||
api_key: String,
|
api_key: String,
|
||||||
base_url: String,
|
base_url: String,
|
||||||
extra_headers: HashMap<String, String>,
|
extra_headers: HashMap<String, String>,
|
||||||
|
llm_timeout_secs: u64,
|
||||||
model_id: String,
|
model_id: String,
|
||||||
temperature: Option<f32>,
|
temperature: Option<f32>,
|
||||||
max_tokens: Option<u32>,
|
max_tokens: Option<u32>,
|
||||||
@ -40,17 +42,24 @@ impl OpenAIProvider {
|
|||||||
api_key: String,
|
api_key: String,
|
||||||
base_url: String,
|
base_url: String,
|
||||||
extra_headers: HashMap<String, String>,
|
extra_headers: HashMap<String, String>,
|
||||||
|
llm_timeout_secs: u64,
|
||||||
model_id: String,
|
model_id: String,
|
||||||
temperature: Option<f32>,
|
temperature: Option<f32>,
|
||||||
max_tokens: Option<u32>,
|
max_tokens: Option<u32>,
|
||||||
model_extra: HashMap<String, serde_json::Value>,
|
model_extra: HashMap<String, serde_json::Value>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
let client = Client::builder()
|
||||||
|
.timeout(Duration::from_secs(llm_timeout_secs))
|
||||||
|
.build()
|
||||||
|
.unwrap_or_else(|_| Client::new());
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
client: Client::new(),
|
client,
|
||||||
name,
|
name,
|
||||||
api_key,
|
api_key,
|
||||||
base_url,
|
base_url,
|
||||||
extra_headers,
|
extra_headers,
|
||||||
|
llm_timeout_secs,
|
||||||
model_id,
|
model_id,
|
||||||
temperature,
|
temperature,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
@ -209,7 +218,7 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
let resp_preview: String = text.chars().take(100).collect();
|
let resp_preview: String = text.chars().take(100).collect();
|
||||||
tracing::debug!(status = %status, response_preview = %resp_preview, response_len = %text.len(), "LLM response (first 100 chars shown)");
|
tracing::debug!(status = %status, response_preview = %resp_preview, response_len = %text.len(), timeout_secs = self.llm_timeout_secs, "LLM response (first 100 chars shown)");
|
||||||
}
|
}
|
||||||
|
|
||||||
if !status.is_success() {
|
if !status.is_success() {
|
||||||
@ -275,6 +284,7 @@ mod tests {
|
|||||||
"key".to_string(),
|
"key".to_string(),
|
||||||
"https://example.com/v1".to_string(),
|
"https://example.com/v1".to_string(),
|
||||||
HashMap::new(),
|
HashMap::new(),
|
||||||
|
120,
|
||||||
"gpt-test".to_string(),
|
"gpt-test".to_string(),
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
|||||||
@ -88,6 +88,20 @@ impl Scheduler {
|
|||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if record.next_fire_at.is_none() && job.next_fire_at.is_some() {
|
||||||
|
self.store.update_scheduler_job_runtime(
|
||||||
|
&job.id,
|
||||||
|
job.state.clone(),
|
||||||
|
job.last_status.clone(),
|
||||||
|
job.last_error.as_deref(),
|
||||||
|
job.run_count,
|
||||||
|
job.last_fired_at,
|
||||||
|
job.next_fire_at,
|
||||||
|
job.paused_at,
|
||||||
|
job.completed_at,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
|
||||||
if !job.is_due(now) {
|
if !job.is_due(now) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -517,6 +531,11 @@ fn parse_scheduled_agent_task_options(job: &RuntimeJob) -> anyhow::Result<Schedu
|
|||||||
.get("system_prompt")
|
.get("system_prompt")
|
||||||
.and_then(|value| value.as_str())
|
.and_then(|value| value.as_str())
|
||||||
.map(ToString::to_string);
|
.map(ToString::to_string);
|
||||||
|
let agent = job
|
||||||
|
.payload
|
||||||
|
.get("agent")
|
||||||
|
.and_then(|value| value.as_str())
|
||||||
|
.map(ToString::to_string);
|
||||||
let metadata = parse_metadata_map(job.payload.get("metadata"))?;
|
let metadata = parse_metadata_map(job.payload.get("metadata"))?;
|
||||||
|
|
||||||
Ok(ScheduledAgentTaskOptions {
|
Ok(ScheduledAgentTaskOptions {
|
||||||
@ -524,6 +543,7 @@ fn parse_scheduled_agent_task_options(job: &RuntimeJob) -> anyhow::Result<Schedu
|
|||||||
fresh_session,
|
fresh_session,
|
||||||
system_prompt,
|
system_prompt,
|
||||||
metadata,
|
metadata,
|
||||||
|
agent,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -616,6 +636,7 @@ mod agent_task_tests {
|
|||||||
},
|
},
|
||||||
payload: serde_json::json!({
|
payload: serde_json::json!({
|
||||||
"prompt": "请总结今天待办",
|
"prompt": "请总结今天待办",
|
||||||
|
"agent": "planner",
|
||||||
"sender_id": "scheduler-bot",
|
"sender_id": "scheduler-bot",
|
||||||
"fresh_session": true,
|
"fresh_session": true,
|
||||||
"system_prompt": "你是日报助手",
|
"system_prompt": "你是日报助手",
|
||||||
@ -640,6 +661,7 @@ mod agent_task_tests {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let options = parse_scheduled_agent_task_options(&job).unwrap();
|
let options = parse_scheduled_agent_task_options(&job).unwrap();
|
||||||
|
assert_eq!(options.agent.as_deref(), Some("planner"));
|
||||||
assert_eq!(options.sender_id.as_deref(), Some("scheduler-bot"));
|
assert_eq!(options.sender_id.as_deref(), Some("scheduler-bot"));
|
||||||
assert!(options.fresh_session);
|
assert!(options.fresh_session);
|
||||||
assert_eq!(options.system_prompt.as_deref(), Some("你是日报助手"));
|
assert_eq!(options.system_prompt.as_deref(), Some("你是日报助手"));
|
||||||
@ -660,6 +682,12 @@ impl TryFrom<serde_json::Value> for SchedulerJobTarget {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use crate::bus::MessageBus;
|
||||||
|
use crate::config::LLMProviderConfig;
|
||||||
|
use crate::gateway::session::SessionManager;
|
||||||
|
use crate::skills::SkillRuntime;
|
||||||
|
use crate::storage::{SchedulerJobUpsert, SessionStore};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn runtime_job_skip_policy_advances_from_now() {
|
fn runtime_job_skip_policy_advances_from_now() {
|
||||||
@ -739,4 +767,81 @@ mod tests {
|
|||||||
});
|
});
|
||||||
assert_eq!(job.next_fire_at, Some(1_700_000_010_000));
|
assert_eq!(job.next_fire_at, Some(1_700_000_010_000));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn process_tick_persists_initial_next_fire_at_for_db_created_jobs() {
|
||||||
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
|
store
|
||||||
|
.upsert_scheduler_job(&SchedulerJobUpsert {
|
||||||
|
id: "massage_reminder".to_string(),
|
||||||
|
kind: "outbound_message".to_string(),
|
||||||
|
schedule: serde_json::json!({
|
||||||
|
"type": "interval",
|
||||||
|
"seconds": 60
|
||||||
|
}),
|
||||||
|
interval_secs: 60,
|
||||||
|
startup_delay_secs: 0,
|
||||||
|
target: serde_json::json!({
|
||||||
|
"channel": "feishu",
|
||||||
|
"chat_id": "oc_demo"
|
||||||
|
}),
|
||||||
|
payload: serde_json::json!({
|
||||||
|
"content": "ping"
|
||||||
|
}),
|
||||||
|
enabled: true,
|
||||||
|
state: SchedulerJobState::Scheduled,
|
||||||
|
last_status: None,
|
||||||
|
last_error: None,
|
||||||
|
run_count: 0,
|
||||||
|
max_runs: Some(1),
|
||||||
|
last_fired_at: None,
|
||||||
|
next_fire_at: None,
|
||||||
|
paused_at: None,
|
||||||
|
completed_at: None,
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let provider_config = LLMProviderConfig {
|
||||||
|
provider_type: "openai".to_string(),
|
||||||
|
name: "default".to_string(),
|
||||||
|
base_url: "http://localhost".to_string(),
|
||||||
|
api_key: "test-key".to_string(),
|
||||||
|
extra_headers: HashMap::new(),
|
||||||
|
llm_timeout_secs: 30,
|
||||||
|
model_id: "test-model".to_string(),
|
||||||
|
temperature: Some(0.0),
|
||||||
|
max_tokens: None,
|
||||||
|
model_extra: HashMap::new(),
|
||||||
|
token_limit: 4096,
|
||||||
|
max_tool_iterations: 4,
|
||||||
|
};
|
||||||
|
let session_manager = SessionManager::new(
|
||||||
|
4,
|
||||||
|
100,
|
||||||
|
false,
|
||||||
|
provider_config.clone(),
|
||||||
|
HashMap::from([("default".to_string(), provider_config)]),
|
||||||
|
Arc::new(SkillRuntime::default()),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let scheduler = Scheduler::new(
|
||||||
|
MessageBus::new(8),
|
||||||
|
SchedulerConfig {
|
||||||
|
enabled: true,
|
||||||
|
tick_resolution_ms: 1000,
|
||||||
|
worker_queue_capacity: 64,
|
||||||
|
misfire_policy: SchedulerMisfirePolicy::Skip,
|
||||||
|
jobs: Vec::new(),
|
||||||
|
},
|
||||||
|
store.clone(),
|
||||||
|
session_manager,
|
||||||
|
);
|
||||||
|
|
||||||
|
scheduler.process_tick().await.unwrap();
|
||||||
|
|
||||||
|
let saved = store.get_scheduler_job("massage_reminder").unwrap().unwrap();
|
||||||
|
assert!(saved.next_fire_at.is_some());
|
||||||
|
assert_eq!(saved.run_count, 0);
|
||||||
|
assert_eq!(saved.state, SchedulerJobState::Scheduled);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@ -1,3 +1,4 @@
|
|||||||
|
use std::collections::HashSet;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
@ -11,11 +12,15 @@ use crate::tools::traits::{Tool, ToolResult};
|
|||||||
|
|
||||||
pub struct SchedulerManageTool {
|
pub struct SchedulerManageTool {
|
||||||
store: Arc<SessionStore>,
|
store: Arc<SessionStore>,
|
||||||
|
known_agents: Arc<HashSet<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SchedulerManageTool {
|
impl SchedulerManageTool {
|
||||||
pub fn new(store: Arc<SessionStore>) -> Self {
|
pub fn new(store: Arc<SessionStore>, known_agents: HashSet<String>) -> Self {
|
||||||
Self { store }
|
Self {
|
||||||
|
store,
|
||||||
|
known_agents: Arc::new(known_agents),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -58,7 +63,7 @@ impl Tool for SchedulerManageTool {
|
|||||||
},
|
},
|
||||||
"payload": {
|
"payload": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"description": "Job payload. agent_task supports prompt, fresh_session, system_prompt, sender_id, metadata. outbound_message expects content. internal_event expects event."
|
"description": "Job payload. agent_task supports prompt, agent, fresh_session, system_prompt, sender_id, metadata. outbound_message expects content. internal_event expects event."
|
||||||
},
|
},
|
||||||
"max_runs": {
|
"max_runs": {
|
||||||
"type": ["integer", "null"]
|
"type": ["integer", "null"]
|
||||||
@ -91,7 +96,7 @@ impl Tool for SchedulerManageTool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
"put" => {
|
"put" => {
|
||||||
let input = build_upsert(&args)?;
|
let input = build_upsert(&args, &self.known_agents)?;
|
||||||
let record = self.store.upsert_scheduler_job(&input)?;
|
let record = self.store.upsert_scheduler_job(&input)?;
|
||||||
record_to_json(&record)
|
record_to_json(&record)
|
||||||
}
|
}
|
||||||
@ -140,7 +145,7 @@ impl Tool for SchedulerManageTool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_upsert(args: &serde_json::Value) -> anyhow::Result<SchedulerJobUpsert> {
|
fn build_upsert(args: &serde_json::Value, known_agents: &HashSet<String>) -> anyhow::Result<SchedulerJobUpsert> {
|
||||||
let id = require_str(args, "id")?.to_string();
|
let id = require_str(args, "id")?.to_string();
|
||||||
let kind = require_str(args, "kind")?.to_string();
|
let kind = require_str(args, "kind")?.to_string();
|
||||||
let schedule_value = args
|
let schedule_value = args
|
||||||
@ -158,14 +163,24 @@ fn build_upsert(args: &serde_json::Value) -> anyhow::Result<SchedulerJobUpsert>
|
|||||||
_ => (0, 0),
|
_ => (0, 0),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let payload = args.get("payload").cloned().unwrap_or_else(|| json!({}));
|
||||||
|
let target = args.get("target").cloned().unwrap_or_else(|| json!({}));
|
||||||
|
if kind == "agent_task" {
|
||||||
|
validate_agent_task_payload(&payload, known_agents)?;
|
||||||
|
validate_target_fields(&target, &["channel", "chat_id"], "agent_task")?;
|
||||||
|
} else if kind == "outbound_message" {
|
||||||
|
validate_outbound_message_payload(&payload)?;
|
||||||
|
validate_target_fields(&target, &["channel", "chat_id"], "outbound_message")?;
|
||||||
|
}
|
||||||
|
|
||||||
Ok(SchedulerJobUpsert {
|
Ok(SchedulerJobUpsert {
|
||||||
id,
|
id,
|
||||||
kind,
|
kind,
|
||||||
schedule: serde_json::to_value(schedule)?,
|
schedule: serde_json::to_value(schedule)?,
|
||||||
interval_secs,
|
interval_secs,
|
||||||
startup_delay_secs,
|
startup_delay_secs,
|
||||||
target: args.get("target").cloned().unwrap_or_else(|| json!({})),
|
target,
|
||||||
payload: args.get("payload").cloned().unwrap_or_else(|| json!({})),
|
payload,
|
||||||
enabled: args.get("enabled").and_then(|value| value.as_bool()).unwrap_or(true),
|
enabled: args.get("enabled").and_then(|value| value.as_bool()).unwrap_or(true),
|
||||||
state: if args.get("enabled").and_then(|value| value.as_bool()).unwrap_or(true) {
|
state: if args.get("enabled").and_then(|value| value.as_bool()).unwrap_or(true) {
|
||||||
SchedulerJobState::Scheduled
|
SchedulerJobState::Scheduled
|
||||||
@ -183,6 +198,57 @@ fn build_upsert(args: &serde_json::Value) -> anyhow::Result<SchedulerJobUpsert>
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn validate_agent_task_payload(payload: &serde_json::Value, known_agents: &HashSet<String>) -> anyhow::Result<()> {
|
||||||
|
let Some(prompt) = payload.get("prompt").and_then(|value| value.as_str()) else {
|
||||||
|
anyhow::bail!("agent_task payload.prompt is required and must be a string")
|
||||||
|
};
|
||||||
|
if prompt.trim().is_empty() {
|
||||||
|
anyhow::bail!("agent_task payload.prompt cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(agent_name) = payload.get("agent").and_then(|value| value.as_str()) else {
|
||||||
|
return Ok(());
|
||||||
|
};
|
||||||
|
|
||||||
|
let normalized = agent_name.trim();
|
||||||
|
if normalized.is_empty() || normalized == "default" || known_agents.contains(normalized) {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
anyhow::bail!("Unknown agent '{}' for agent_task payload.agent", normalized)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate_outbound_message_payload(payload: &serde_json::Value) -> anyhow::Result<()> {
|
||||||
|
let Some(content) = payload.get("content").and_then(|value| value.as_str()) else {
|
||||||
|
anyhow::bail!("outbound_message payload.content is required and must be a string")
|
||||||
|
};
|
||||||
|
if content.trim().is_empty() {
|
||||||
|
anyhow::bail!("outbound_message payload.content cannot be empty")
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate_target_fields(
|
||||||
|
target: &serde_json::Value,
|
||||||
|
required_fields: &[&str],
|
||||||
|
kind: &str,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
let object = target
|
||||||
|
.as_object()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("{} target must be an object", kind))?;
|
||||||
|
|
||||||
|
for field in required_fields {
|
||||||
|
let Some(value) = object.get(*field).and_then(|value| value.as_str()) else {
|
||||||
|
anyhow::bail!("{} target.{} is required and must be a string", kind, field)
|
||||||
|
};
|
||||||
|
if value.trim().is_empty() {
|
||||||
|
anyhow::bail!("{} target.{} cannot be empty", kind, field)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn record_to_json(record: &SchedulerJobRecord) -> serde_json::Value {
|
fn record_to_json(record: &SchedulerJobRecord) -> serde_json::Value {
|
||||||
json!({
|
json!({
|
||||||
"id": record.id,
|
"id": record.id,
|
||||||
@ -255,7 +321,7 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_scheduler_manage_put_and_get() {
|
async fn test_scheduler_manage_put_and_get() {
|
||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
let tool = SchedulerManageTool::new(store);
|
let tool = SchedulerManageTool::new(store, HashSet::new());
|
||||||
|
|
||||||
let put_result = tool
|
let put_result = tool
|
||||||
.execute(json!({
|
.execute(json!({
|
||||||
@ -293,7 +359,7 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_scheduler_manage_put_agent_task() {
|
async fn test_scheduler_manage_put_agent_task() {
|
||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
let tool = SchedulerManageTool::new(store);
|
let tool = SchedulerManageTool::new(store, HashSet::from(["planner".to_string()]));
|
||||||
|
|
||||||
let put_result = tool
|
let put_result = tool
|
||||||
.execute(json!({
|
.execute(json!({
|
||||||
@ -309,7 +375,8 @@ mod tests {
|
|||||||
"chat_id": "oc_demo"
|
"chat_id": "oc_demo"
|
||||||
},
|
},
|
||||||
"payload": {
|
"payload": {
|
||||||
"prompt": "请总结今天待办"
|
"prompt": "请总结今天待办",
|
||||||
|
"agent": "planner"
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
.await
|
.await
|
||||||
@ -318,4 +385,59 @@ mod tests {
|
|||||||
assert!(put_result.success);
|
assert!(put_result.success);
|
||||||
assert!(put_result.output.contains("agent_task"));
|
assert!(put_result.output.contains("agent_task"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_scheduler_manage_rejects_outbound_message_without_target() {
|
||||||
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
|
let tool = SchedulerManageTool::new(store, HashSet::new());
|
||||||
|
|
||||||
|
let put_result = tool
|
||||||
|
.execute(json!({
|
||||||
|
"action": "put",
|
||||||
|
"id": "massage_reminder",
|
||||||
|
"kind": "outbound_message",
|
||||||
|
"schedule": {
|
||||||
|
"type": "interval",
|
||||||
|
"seconds": 60
|
||||||
|
},
|
||||||
|
"payload": {
|
||||||
|
"content": "⏰ 时间到了!该去按摩了!💆"
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(put_result.is_err());
|
||||||
|
let error = put_result.err().unwrap().to_string();
|
||||||
|
assert!(error.contains("outbound_message target.channel is required"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_scheduler_manage_rejects_unknown_agent_task_agent() {
|
||||||
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
|
let tool = SchedulerManageTool::new(store, HashSet::from(["planner".to_string()]));
|
||||||
|
|
||||||
|
let put_result = tool
|
||||||
|
.execute(json!({
|
||||||
|
"action": "put",
|
||||||
|
"id": "agent.daily_summary",
|
||||||
|
"kind": "agent_task",
|
||||||
|
"schedule": {
|
||||||
|
"type": "cron",
|
||||||
|
"expression": "0 9 * * *"
|
||||||
|
},
|
||||||
|
"target": {
|
||||||
|
"channel": "feishu",
|
||||||
|
"chat_id": "oc_demo"
|
||||||
|
},
|
||||||
|
"payload": {
|
||||||
|
"prompt": "请总结今天待办",
|
||||||
|
"agent": "missing-agent"
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(put_result.is_err());
|
||||||
|
let error = put_result.err().unwrap().to_string();
|
||||||
|
assert!(error.contains("Unknown agent 'missing-agent'"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@ -19,6 +19,7 @@ fn load_config() -> Option<LLMProviderConfig> {
|
|||||||
base_url: openai_base_url,
|
base_url: openai_base_url,
|
||||||
api_key: openai_api_key,
|
api_key: openai_api_key,
|
||||||
extra_headers: HashMap::new(),
|
extra_headers: HashMap::new(),
|
||||||
|
llm_timeout_secs: 120,
|
||||||
model_id: openai_model,
|
model_id: openai_model,
|
||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
max_tokens: Some(100),
|
max_tokens: Some(100),
|
||||||
|
|||||||
@ -19,6 +19,7 @@ fn load_openai_config() -> Option<LLMProviderConfig> {
|
|||||||
base_url: openai_base_url,
|
base_url: openai_base_url,
|
||||||
api_key: openai_api_key,
|
api_key: openai_api_key,
|
||||||
extra_headers: HashMap::new(),
|
extra_headers: HashMap::new(),
|
||||||
|
llm_timeout_secs: 120,
|
||||||
model_id: openai_model,
|
model_id: openai_model,
|
||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
max_tokens: Some(100),
|
max_tokens: Some(100),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user