feat: 添加执行服务和提示管理功能,重构相关模块以优化代码结构

This commit is contained in:
ooodc 2026-04-28 10:51:54 +08:00
parent 73dab09bfe
commit 33f5a4cbd2
4 changed files with 410 additions and 289 deletions

218
src/gateway/execution.rs Normal file
View File

@ -0,0 +1,218 @@
use std::collections::HashMap;
use crate::agent::{AgentError, AgentProcessResult};
use crate::bus::message::ToolMessageState;
use crate::bus::{ChatMessage, OutboundMessage};
use crate::config::LLMProviderConfig;
use super::session::Session;
const SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT: &str = "系统说明当前输入来自一次已经触发的定时任务执行。你现在需要执行任务内容本身而不是创建、修改、恢复、暂停或查询新的定时任务。除非当前任务内容明确要求管理调度器否则不要调用任何定时任务管理工具像“每小时”、“每天”、“cron”、“定时”等词只应视为任务背景不应再解释为新的建任务请求。";
pub(crate) fn compose_scheduled_task_system_prompt(system_prompt: Option<&str>) -> String {
match system_prompt
.map(str::trim)
.filter(|value| !value.is_empty())
{
Some(system_prompt) => format!(
"{}\n\n任务专属要求:{}",
SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT, system_prompt
),
None => SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT.to_string(),
}
}
pub(crate) fn select_provider_config(
default_provider_config: &LLMProviderConfig,
provider_configs: &HashMap<String, LLMProviderConfig>,
agent_name: Option<&str>,
) -> Result<LLMProviderConfig, AgentError> {
match agent_name.map(str::trim).filter(|value| !value.is_empty()) {
None | Some("default") => Ok(default_provider_config.clone()),
Some(agent_name) => provider_configs.get(agent_name).cloned().ok_or_else(|| {
AgentError::Other(format!("Scheduled agent '{}' not found", agent_name))
}),
}
}
pub(crate) struct AgentExecutionService {
show_tool_results: bool,
}
pub(crate) struct FinalizeAgentResultRequest<'a> {
pub(crate) channel_name: &'a str,
pub(crate) chat_id: &'a str,
pub(crate) user_message: &'a ChatMessage,
pub(crate) result: AgentProcessResult,
pub(crate) metadata: &'a HashMap<String, String>,
pub(crate) suppress_live_tool_calls: bool,
pub(crate) execution_kind: &'a str,
}
pub(crate) struct FinalizedAgentResult {
pub(crate) outbound_messages: Vec<OutboundMessage>,
pub(crate) should_schedule_compaction: bool,
}
impl AgentExecutionService {
pub(crate) fn new(show_tool_results: bool) -> Self {
Self { show_tool_results }
}
pub(crate) fn finalize_result(
&self,
session: &mut Session,
request: FinalizeAgentResultRequest<'_>,
) -> Result<FinalizedAgentResult, AgentError> {
if !session.matches_current_user_turn(request.chat_id, request.user_message) {
let (latest_user_id, latest_user_preview, compression_in_flight, history_len) =
session.stale_result_diagnostics(request.chat_id);
tracing::warn!(
channel = %request.channel_name,
chat_id = %request.chat_id,
user_message_id = %request.user_message.id,
latest_user_id,
latest_user_preview,
compression_in_flight,
history_len,
execution_kind = %request.execution_kind,
"Skipping stale agent result because a newer user message is already present"
);
return Ok(FinalizedAgentResult {
outbound_messages: Vec::new(),
should_schedule_compaction: false,
});
}
session
.append_persisted_messages(request.chat_id, request.result.emitted_messages.clone())?;
let outbound_messages = request
.result
.emitted_messages
.iter()
.filter(|message| {
(!message.is_assistant_tool_call_message() || !request.suppress_live_tool_calls)
&& should_display_message_to_user(self.show_tool_results, message)
})
.flat_map(|message| {
OutboundMessage::from_chat_message(
request.channel_name,
request.chat_id,
None,
request.metadata,
message,
)
})
.collect();
Ok(FinalizedAgentResult {
outbound_messages,
should_schedule_compaction: true,
})
}
}
pub(crate) fn should_display_message_to_user(
show_tool_results: bool,
message: &ChatMessage,
) -> bool {
if message.role != "tool" {
return true;
}
show_tool_results
|| matches!(
message
.tool_state
.as_ref()
.unwrap_or(&ToolMessageState::Completed),
ToolMessageState::PendingUserAction
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bus::ChatMessage;
fn test_provider_config_named(name: &str, model_id: &str) -> LLMProviderConfig {
LLMProviderConfig {
provider_type: "openai".to_string(),
name: name.to_string(),
base_url: "http://localhost".to_string(),
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
llm_timeout_secs: 120,
model_id: model_id.to_string(),
temperature: Some(0.0),
max_tokens: Some(32),
model_extra: HashMap::new(),
max_tool_iterations: 1,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
}
}
#[test]
fn test_select_provider_config_uses_named_agent_override() {
let default_provider = test_provider_config_named("default-provider", "default-model");
let provider_configs = HashMap::from([(
"planner".to_string(),
test_provider_config_named("planner-provider", "planner-model"),
)]);
let selected =
select_provider_config(&default_provider, &provider_configs, Some("planner")).unwrap();
assert_eq!(selected.name, "planner-provider");
assert_eq!(selected.model_id, "planner-model");
}
#[test]
fn test_select_provider_config_falls_back_to_default() {
let default_provider = test_provider_config_named("default-provider", "default-model");
let provider_configs = HashMap::new();
let selected =
select_provider_config(&default_provider, &provider_configs, Some("default")).unwrap();
assert_eq!(selected.name, "default-provider");
assert_eq!(selected.model_id, "default-model");
}
#[test]
fn test_compose_scheduled_task_system_prompt_appends_task_specific_prompt() {
let prompt = compose_scheduled_task_system_prompt(Some(" 只汇报异常 "));
assert!(prompt.contains("当前输入来自一次已经触发的定时任务执行"));
assert!(prompt.contains("任务专属要求:只汇报异常"));
}
#[test]
fn test_compose_scheduled_task_system_prompt_ignores_blank_override() {
let prompt = compose_scheduled_task_system_prompt(Some(" "));
assert!(prompt.contains("当前输入来自一次已经触发的定时任务执行"));
assert!(!prompt.contains("任务专属要求"));
}
#[test]
fn test_should_display_message_to_user_keeps_pending_tool_action_visible() {
let message = ChatMessage::tool_with_state(
"call-1",
"approval",
"需要用户确认",
ToolMessageState::PendingUserAction,
);
assert!(should_display_message_to_user(false, &message));
}
#[test]
fn test_should_display_message_to_user_hides_completed_tool_when_disabled() {
let message = ChatMessage::tool("call-1", "calculator", "2");
assert!(!should_display_message_to_user(false, &message));
assert!(should_display_message_to_user(true, &message));
}
}

View File

@ -1,5 +1,7 @@
pub mod execution;
pub mod http;
pub mod processor;
pub mod prompt;
pub mod session;
pub mod ws;

149
src/gateway/prompt.rs Normal file
View File

@ -0,0 +1,149 @@
use std::fs;
use std::path::{Path, PathBuf};
use crate::agent::AgentError;
pub(crate) const DEFAULT_AGENT_PROMPT: &str = include_str!("default_agent_prompt.md");
pub(crate) const MANAGED_AGENT_MEMORY_BLOCK_START: &str = "<!-- PICOBOT_MANAGED_MEMORY:START -->";
pub(crate) const MANAGED_AGENT_MEMORY_BLOCK_END: &str = "<!-- PICOBOT_MANAGED_MEMORY:END -->";
pub(crate) const MANAGED_AGENT_MEMORY_TITLE: &str = "## 用户记忆摘要";
pub(crate) fn load_agent_prompt() -> Result<Option<String>, AgentError> {
let path = agent_prompt_path()?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.map_err(|err| AgentError::Other(format!("create agent prompt dir error: {}", err)))?;
}
if !path.exists() {
write_agent_prompt(&path, DEFAULT_AGENT_PROMPT)?;
}
let content = fs::read_to_string(&path)
.map_err(|err| AgentError::Other(format!("read agent prompt file error: {}", err)))?;
let trimmed = content.trim();
if trimmed.is_empty() {
return Ok(None);
}
Ok(Some(trimmed.to_string()))
}
pub(crate) fn upsert_managed_agent_memory_summary(markdown_body: &str) -> Result<(), AgentError> {
let path = agent_prompt_path()?;
let existing = if path.exists() {
fs::read_to_string(&path)
.map_err(|err| AgentError::Other(format!("read agent prompt file error: {}", err)))?
} else {
DEFAULT_AGENT_PROMPT.to_string()
};
let updated = upsert_managed_agent_memory_block(&existing, markdown_body);
write_agent_prompt(&path, &updated)
}
pub(crate) fn upsert_managed_agent_memory_block(existing: &str, markdown_body: &str) -> String {
let managed_block = render_managed_agent_memory_block(markdown_body);
if let (Some(start), Some(end)) = (
existing.find(MANAGED_AGENT_MEMORY_BLOCK_START),
existing.find(MANAGED_AGENT_MEMORY_BLOCK_END),
) {
let end = end + MANAGED_AGENT_MEMORY_BLOCK_END.len();
let mut updated = String::new();
updated.push_str(existing[..start].trim_end());
updated.push_str("\n\n");
updated.push_str(&managed_block);
updated.push_str("\n\n");
updated.push_str(existing[end..].trim_start());
return updated.trim().to_string() + "\n";
}
if let Some(reply_rules_index) = existing.find("## 回复规则") {
let mut updated = String::new();
updated.push_str(existing[..reply_rules_index].trim_end());
updated.push_str("\n\n");
updated.push_str(&managed_block);
updated.push_str("\n\n");
updated.push_str(existing[reply_rules_index..].trim_start());
return updated.trim().to_string() + "\n";
}
let mut updated = existing.trim_end().to_string();
if !updated.is_empty() {
updated.push_str("\n\n");
}
updated.push_str(&managed_block);
updated.push('\n');
updated
}
fn render_managed_agent_memory_block(markdown_body: &str) -> String {
format!(
"{MANAGED_AGENT_MEMORY_BLOCK_START}\n{MANAGED_AGENT_MEMORY_TITLE}\n\n{}\n{MANAGED_AGENT_MEMORY_BLOCK_END}",
markdown_body.trim()
)
}
fn write_agent_prompt(path: &Path, content: &str) -> Result<(), AgentError> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.map_err(|err| AgentError::Other(format!("create agent prompt dir error: {}", err)))?;
}
let temp_path = path.with_extension("md.tmp");
fs::write(&temp_path, content)
.map_err(|err| AgentError::Other(format!("write agent prompt temp file error: {}", err)))?;
fs::rename(&temp_path, path)
.map_err(|err| AgentError::Other(format!("replace agent prompt file error: {}", err)))?;
Ok(())
}
fn agent_prompt_path() -> Result<PathBuf, AgentError> {
let home = dirs::home_dir()
.ok_or_else(|| AgentError::Other("home directory not found".to_string()))?;
Ok(home.join(".picobot").join("agent").join("AGENT.md"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_upsert_managed_agent_memory_block_inserts_before_reply_rules() {
let original =
"# PicoBot 代理配置\n\n## 身份\n- 你是 PicoBot。\n\n## 回复规则\n- 使用中文回复。\n";
let updated = upsert_managed_agent_memory_block(
original,
"### 用户事实\n- 用户在做AI产品\n\n### 用户偏好\n- 偏好简洁表达",
);
let managed_pos = updated.find(MANAGED_AGENT_MEMORY_BLOCK_START).unwrap();
let reply_rules_pos = updated.find("## 回复规则").unwrap();
assert!(managed_pos < reply_rules_pos);
assert!(updated.contains(MANAGED_AGENT_MEMORY_TITLE));
assert!(updated.contains("用户在做AI产品"));
assert!(updated.contains("偏好简洁表达"));
}
#[test]
fn test_upsert_managed_agent_memory_block_replaces_existing_block() {
let original = format!(
"# PicoBot\n\n{MANAGED_AGENT_MEMORY_BLOCK_START}\n{MANAGED_AGENT_MEMORY_TITLE}\n\nold\n{MANAGED_AGENT_MEMORY_BLOCK_END}\n\n## 回复规则\n- 简洁。\n"
);
let updated = upsert_managed_agent_memory_block(&original, "new");
assert!(updated.contains("new"));
assert!(!updated.contains("old"));
assert_eq!(updated.matches(MANAGED_AGENT_MEMORY_BLOCK_START).count(), 1);
assert_eq!(updated.matches(MANAGED_AGENT_MEMORY_BLOCK_END).count(), 1);
}
#[test]
fn test_upsert_managed_agent_memory_block_trims_summary_body() {
let updated = upsert_managed_agent_memory_block("# PicoBot\n", "\n\nsummary\n\n");
assert!(updated.contains("\n\nsummary\n"));
assert!(!updated.contains("\n\nsummary\n\n\n"));
}
}

View File

@ -16,20 +16,19 @@ use crate::tools::{
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::fs;
use std::path::Path;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, mpsc};
use uuid::Uuid;
const DEFAULT_AGENT_PROMPT: &str = include_str!("default_agent_prompt.md");
const MANAGED_AGENT_MEMORY_BLOCK_START: &str = "<!-- PICOBOT_MANAGED_MEMORY:START -->";
const MANAGED_AGENT_MEMORY_BLOCK_END: &str = "<!-- PICOBOT_MANAGED_MEMORY:END -->";
const MANAGED_AGENT_MEMORY_TITLE: &str = "## 用户记忆摘要";
use super::execution::{
AgentExecutionService, FinalizeAgentResultRequest, compose_scheduled_task_system_prompt,
select_provider_config, should_display_message_to_user,
};
use super::prompt::{load_agent_prompt, upsert_managed_agent_memory_summary};
const MEMORY_MAINTENANCE_SYSTEM_PROMPT: &str = include_str!("memory_maintenance_system_prompt.md");
const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000];
const SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT: &str = "系统说明当前输入来自一次已经触发的定时任务执行。你现在需要执行任务内容本身而不是创建、修改、恢复、暂停或查询新的定时任务。除非当前任务内容明确要求管理调度器否则不要调用任何定时任务管理工具像“每小时”、“每天”、“cron”、“定时”等词只应视为任务背景不应再解释为新的建任务请求。";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum MemoryMaintenanceCategory {
@ -147,63 +146,6 @@ fn is_recoverable_maintenance_llm_error(error: &str) -> bool {
|| normalized.contains("timeout")
}
fn render_managed_agent_memory_block(markdown_body: &str) -> String {
format!(
"{MANAGED_AGENT_MEMORY_BLOCK_START}\n{MANAGED_AGENT_MEMORY_TITLE}\n\n{}\n{MANAGED_AGENT_MEMORY_BLOCK_END}",
markdown_body.trim()
)
}
fn upsert_managed_agent_memory_block(existing: &str, markdown_body: &str) -> String {
let managed_block = render_managed_agent_memory_block(markdown_body);
if let (Some(start), Some(end)) = (
existing.find(MANAGED_AGENT_MEMORY_BLOCK_START),
existing.find(MANAGED_AGENT_MEMORY_BLOCK_END),
) {
let end = end + MANAGED_AGENT_MEMORY_BLOCK_END.len();
let mut updated = String::new();
updated.push_str(existing[..start].trim_end());
updated.push_str("\n\n");
updated.push_str(&managed_block);
updated.push_str("\n\n");
updated.push_str(existing[end..].trim_start());
return updated.trim().to_string() + "\n";
}
if let Some(reply_rules_index) = existing.find("## 回复规则") {
let mut updated = String::new();
updated.push_str(existing[..reply_rules_index].trim_end());
updated.push_str("\n\n");
updated.push_str(&managed_block);
updated.push_str("\n\n");
updated.push_str(existing[reply_rules_index..].trim_start());
return updated.trim().to_string() + "\n";
}
let mut updated = existing.trim_end().to_string();
if !updated.is_empty() {
updated.push_str("\n\n");
}
updated.push_str(&managed_block);
updated.push('\n');
updated
}
fn write_agent_prompt(path: &Path, content: &str) -> Result<(), AgentError> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.map_err(|err| AgentError::Other(format!("create agent prompt dir error: {}", err)))?;
}
let temp_path = path.with_extension("md.tmp");
fs::write(&temp_path, content)
.map_err(|err| AgentError::Other(format!("write agent prompt temp file error: {}", err)))?;
fs::rename(&temp_path, path)
.map_err(|err| AgentError::Other(format!("replace agent prompt file error: {}", err)))?;
Ok(())
}
fn strip_json_code_fence(content: &str) -> &str {
let trimmed = content.trim();
if let Some(rest) = trimmed.strip_prefix("```json") {
@ -683,7 +625,7 @@ impl Session {
.unwrap_or(false)
}
fn matches_current_user_turn(&self, chat_id: &str, message: &ChatMessage) -> bool {
pub(crate) fn matches_current_user_turn(&self, chat_id: &str, message: &ChatMessage) -> bool {
self.latest_user_message(chat_id)
.map(|current| {
current.id == message.id
@ -694,7 +636,7 @@ impl Session {
.unwrap_or(false)
}
fn stale_result_diagnostics(
pub(crate) fn stale_result_diagnostics(
&self,
chat_id: &str,
) -> (Option<&str>, Option<String>, bool, usize) {
@ -843,33 +785,6 @@ impl Session {
}
}
fn load_agent_prompt() -> Result<Option<String>, AgentError> {
let path = agent_prompt_path()?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.map_err(|err| AgentError::Other(format!("create agent prompt dir error: {}", err)))?;
}
if !path.exists() {
write_agent_prompt(&path, DEFAULT_AGENT_PROMPT)?;
}
let content = fs::read_to_string(&path)
.map_err(|err| AgentError::Other(format!("read agent prompt file error: {}", err)))?;
let trimmed = content.trim();
if trimmed.is_empty() {
return Ok(None);
}
Ok(Some(trimmed.to_string()))
}
fn agent_prompt_path() -> Result<std::path::PathBuf, AgentError> {
let home = dirs::home_dir()
.ok_or_else(|| AgentError::Other("home directory not found".to_string()))?;
Ok(home.join(".picobot").join("agent").join("AGENT.md"))
}
/// SessionManager 管理所有 Session按 channel_name 路由
#[derive(Clone)]
pub struct SessionManager {
@ -1119,16 +1034,7 @@ impl SessionManager {
&self,
markdown_body: &str,
) -> Result<(), AgentError> {
let path = agent_prompt_path()?;
let existing = if path.exists() {
fs::read_to_string(&path).map_err(|err| {
AgentError::Other(format!("read agent prompt file error: {}", err))
})?
} else {
DEFAULT_AGENT_PROMPT.to_string()
};
let updated = upsert_managed_agent_memory_block(&existing, markdown_body);
write_agent_prompt(&path, &updated)
upsert_managed_agent_memory_summary(markdown_body)
}
#[cfg_attr(not(test), allow(dead_code))]
@ -1524,51 +1430,24 @@ impl SessionManager {
let result = agent.process(history).await?;
let mut should_schedule_compaction = false;
let response = {
let finalized_result = {
let mut session_guard = session.lock().await;
if !session_guard.matches_current_user_turn(chat_id, &user_message) {
let (latest_user_id, latest_user_preview, compression_in_flight, history_len) =
session_guard.stale_result_diagnostics(chat_id);
tracing::warn!(
channel = %channel_name,
chat_id = %chat_id,
user_message_id = %user_message.id,
latest_user_id,
latest_user_preview,
compression_in_flight,
history_len,
"Skipping stale agent result because a newer user message is already present"
);
Vec::new()
} else {
// 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复
session_guard
.append_persisted_messages(chat_id, result.emitted_messages.clone())?;
should_schedule_compaction = true;
result
.emitted_messages
.iter()
.filter(|message| {
(!message.is_assistant_tool_call_message() || live_emitter.is_none())
&& should_display_message_to_user(self.show_tool_results, message)
})
.flat_map(|message| {
OutboundMessage::from_chat_message(
channel_name,
chat_id,
None,
&HashMap::new(),
message,
)
})
.collect::<Vec<_>>()
}
let metadata = HashMap::new();
AgentExecutionService::new(self.show_tool_results).finalize_result(
&mut session_guard,
FinalizeAgentResultRequest {
channel_name,
chat_id,
user_message: &user_message,
result,
metadata: &metadata,
suppress_live_tool_calls: live_emitter.is_some(),
execution_kind: "message",
},
)?
};
if should_schedule_compaction {
if finalized_result.should_schedule_compaction {
if let Err(error) =
schedule_background_history_compaction(session.clone(), chat_id.to_string()).await
{
@ -1580,11 +1459,11 @@ impl SessionManager {
tracing::debug!(
channel = %channel_name,
chat_id = %chat_id,
outbound_count = response.len(),
outbound_count = finalized_result.outbound_messages.len(),
"Agent response sequence received"
);
Ok(response)
Ok(finalized_result.outbound_messages)
}
pub async fn run_scheduled_agent_task(
@ -1649,49 +1528,23 @@ impl SessionManager {
let result = agent.process(history).await?;
let mut should_schedule_compaction = false;
let response = {
let finalized_result = {
let mut session_guard = session.lock().await;
if !session_guard.matches_current_user_turn(chat_id, &user_message) {
let (latest_user_id, latest_user_preview, compression_in_flight, history_len) =
session_guard.stale_result_diagnostics(chat_id);
tracing::warn!(
channel = %channel_name,
chat_id = %chat_id,
user_message_id = %user_message.id,
latest_user_id,
latest_user_preview,
compression_in_flight,
history_len,
"Skipping stale scheduled agent result because a newer user message is already present"
);
Vec::new()
} else {
session_guard
.append_persisted_messages(chat_id, result.emitted_messages.clone())?;
should_schedule_compaction = true;
result
.emitted_messages
.iter()
.filter(|message| {
should_display_message_to_user(self.show_tool_results, message)
})
.flat_map(|message| {
OutboundMessage::from_chat_message(
channel_name,
chat_id,
None,
&options.metadata,
message,
)
})
.collect::<Vec<_>>()
}
AgentExecutionService::new(self.show_tool_results).finalize_result(
&mut session_guard,
FinalizeAgentResultRequest {
channel_name,
chat_id,
user_message: &user_message,
result,
metadata: &options.metadata,
suppress_live_tool_calls: false,
execution_kind: "scheduled_task",
},
)?
};
if should_schedule_compaction {
if finalized_result.should_schedule_compaction {
if let Err(error) =
schedule_background_history_compaction(session.clone(), chat_id.to_string()).await
{
@ -1699,7 +1552,7 @@ impl SessionManager {
}
}
Ok(response)
Ok(finalized_result.outbound_messages)
}
/// 清除指定 session 的所有历史
@ -1712,47 +1565,6 @@ impl SessionManager {
}
}
fn should_display_message_to_user(show_tool_results: bool, message: &ChatMessage) -> bool {
if message.role != "tool" {
return true;
}
show_tool_results
|| matches!(
message
.tool_state
.as_ref()
.unwrap_or(&crate::bus::message::ToolMessageState::Completed),
crate::bus::message::ToolMessageState::PendingUserAction
)
}
fn compose_scheduled_task_system_prompt(system_prompt: Option<&str>) -> String {
match system_prompt
.map(str::trim)
.filter(|value| !value.is_empty())
{
Some(system_prompt) => format!(
"{}\n\n任务专属要求:{}",
SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT, system_prompt
),
None => SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT.to_string(),
}
}
fn select_provider_config(
default_provider_config: &LLMProviderConfig,
provider_configs: &HashMap<String, LLMProviderConfig>,
agent_name: Option<&str>,
) -> Result<LLMProviderConfig, AgentError> {
match agent_name.map(str::trim).filter(|value| !value.is_empty()) {
None | Some("default") => Ok(default_provider_config.clone()),
Some(agent_name) => provider_configs.get(agent_name).cloned().ok_or_else(|| {
AgentError::Other(format!("Scheduled agent '{}' not found", agent_name))
}),
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -1787,38 +1599,6 @@ mod tests {
}
}
fn test_provider_config_named(name: &str, model_id: &str) -> LLMProviderConfig {
LLMProviderConfig {
provider_type: "openai".to_string(),
name: name.to_string(),
base_url: "http://localhost".to_string(),
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
llm_timeout_secs: 120,
model_id: model_id.to_string(),
temperature: Some(0.0),
max_tokens: Some(32),
model_extra: HashMap::new(),
max_tool_iterations: 1,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
}
}
#[test]
fn test_select_provider_config_uses_named_agent_override() {
let default_provider = test_provider_config_named("default-provider", "default-model");
let provider_configs = HashMap::from([(
"planner".to_string(),
test_provider_config_named("planner-provider", "planner-model"),
)]);
let selected =
select_provider_config(&default_provider, &provider_configs, Some("planner")).unwrap();
assert_eq!(selected.name, "planner-provider");
assert_eq!(selected.model_id, "planner-model");
}
#[test]
fn test_enrich_user_content_with_media_refs_appends_tagged_json() {
let media_refs = vec!["/tmp/a.png".to_string(), "/tmp/b.pdf".to_string()];
@ -1944,17 +1724,6 @@ mod tests {
assert!(session.matches_current_user_turn("chat-1", &second));
}
#[test]
fn test_select_provider_config_falls_back_to_default() {
let default_provider = test_provider_config_named("default-provider", "default-model");
let provider_configs = HashMap::new();
let selected =
select_provider_config(&default_provider, &provider_configs, Some("default")).unwrap();
assert_eq!(selected.name, "default-provider");
assert_eq!(selected.model_id, "default-model");
}
async fn start_mock_openai_server() -> String {
async fn handle(Json(body): Json<Value>) -> Json<Value> {
let model = body
@ -2983,21 +2752,4 @@ mod tests {
assert_eq!(plan.preferences[0].content, "偏好简洁表达");
assert_eq!(plan.behavior_patterns[0].content, "习惯先问方案再要代码");
}
#[test]
fn test_upsert_managed_agent_memory_block_inserts_before_reply_rules() {
let original =
"# PicoBot 代理配置\n\n## 身份\n- 你是 PicoBot。\n\n## 回复规则\n- 使用中文回复。\n";
let updated = upsert_managed_agent_memory_block(
original,
"### 用户事实\n- 用户在做AI产品\n\n### 用户偏好\n- 偏好简洁表达",
);
let managed_pos = updated.find(MANAGED_AGENT_MEMORY_BLOCK_START).unwrap();
let reply_rules_pos = updated.find("## 回复规则").unwrap();
assert!(managed_pos < reply_rules_pos);
assert!(updated.contains(MANAGED_AGENT_MEMORY_TITLE));
assert!(updated.contains("用户在做AI产品"));
assert!(updated.contains("偏好简洁表达"));
}
}