feat: 添加执行服务和提示管理功能,重构相关模块以优化代码结构
This commit is contained in:
parent
73dab09bfe
commit
33f5a4cbd2
218
src/gateway/execution.rs
Normal file
218
src/gateway/execution.rs
Normal file
@ -0,0 +1,218 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::agent::{AgentError, AgentProcessResult};
|
||||
use crate::bus::message::ToolMessageState;
|
||||
use crate::bus::{ChatMessage, OutboundMessage};
|
||||
use crate::config::LLMProviderConfig;
|
||||
|
||||
use super::session::Session;
|
||||
|
||||
const SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT: &str = "系统说明:当前输入来自一次已经触发的定时任务执行。你现在需要执行任务内容本身,而不是创建、修改、恢复、暂停或查询新的定时任务。除非当前任务内容明确要求管理调度器,否则不要调用任何定时任务管理工具;像“每小时”、“每天”、“cron”、“定时”等词,只应视为任务背景,不应再解释为新的建任务请求。";
|
||||
|
||||
pub(crate) fn compose_scheduled_task_system_prompt(system_prompt: Option<&str>) -> String {
|
||||
match system_prompt
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
{
|
||||
Some(system_prompt) => format!(
|
||||
"{}\n\n任务专属要求:{}",
|
||||
SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT, system_prompt
|
||||
),
|
||||
None => SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn select_provider_config(
|
||||
default_provider_config: &LLMProviderConfig,
|
||||
provider_configs: &HashMap<String, LLMProviderConfig>,
|
||||
agent_name: Option<&str>,
|
||||
) -> Result<LLMProviderConfig, AgentError> {
|
||||
match agent_name.map(str::trim).filter(|value| !value.is_empty()) {
|
||||
None | Some("default") => Ok(default_provider_config.clone()),
|
||||
Some(agent_name) => provider_configs.get(agent_name).cloned().ok_or_else(|| {
|
||||
AgentError::Other(format!("Scheduled agent '{}' not found", agent_name))
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct AgentExecutionService {
|
||||
show_tool_results: bool,
|
||||
}
|
||||
|
||||
pub(crate) struct FinalizeAgentResultRequest<'a> {
|
||||
pub(crate) channel_name: &'a str,
|
||||
pub(crate) chat_id: &'a str,
|
||||
pub(crate) user_message: &'a ChatMessage,
|
||||
pub(crate) result: AgentProcessResult,
|
||||
pub(crate) metadata: &'a HashMap<String, String>,
|
||||
pub(crate) suppress_live_tool_calls: bool,
|
||||
pub(crate) execution_kind: &'a str,
|
||||
}
|
||||
|
||||
pub(crate) struct FinalizedAgentResult {
|
||||
pub(crate) outbound_messages: Vec<OutboundMessage>,
|
||||
pub(crate) should_schedule_compaction: bool,
|
||||
}
|
||||
|
||||
impl AgentExecutionService {
|
||||
pub(crate) fn new(show_tool_results: bool) -> Self {
|
||||
Self { show_tool_results }
|
||||
}
|
||||
|
||||
pub(crate) fn finalize_result(
|
||||
&self,
|
||||
session: &mut Session,
|
||||
request: FinalizeAgentResultRequest<'_>,
|
||||
) -> Result<FinalizedAgentResult, AgentError> {
|
||||
if !session.matches_current_user_turn(request.chat_id, request.user_message) {
|
||||
let (latest_user_id, latest_user_preview, compression_in_flight, history_len) =
|
||||
session.stale_result_diagnostics(request.chat_id);
|
||||
tracing::warn!(
|
||||
channel = %request.channel_name,
|
||||
chat_id = %request.chat_id,
|
||||
user_message_id = %request.user_message.id,
|
||||
latest_user_id,
|
||||
latest_user_preview,
|
||||
compression_in_flight,
|
||||
history_len,
|
||||
execution_kind = %request.execution_kind,
|
||||
"Skipping stale agent result because a newer user message is already present"
|
||||
);
|
||||
|
||||
return Ok(FinalizedAgentResult {
|
||||
outbound_messages: Vec::new(),
|
||||
should_schedule_compaction: false,
|
||||
});
|
||||
}
|
||||
|
||||
session
|
||||
.append_persisted_messages(request.chat_id, request.result.emitted_messages.clone())?;
|
||||
|
||||
let outbound_messages = request
|
||||
.result
|
||||
.emitted_messages
|
||||
.iter()
|
||||
.filter(|message| {
|
||||
(!message.is_assistant_tool_call_message() || !request.suppress_live_tool_calls)
|
||||
&& should_display_message_to_user(self.show_tool_results, message)
|
||||
})
|
||||
.flat_map(|message| {
|
||||
OutboundMessage::from_chat_message(
|
||||
request.channel_name,
|
||||
request.chat_id,
|
||||
None,
|
||||
request.metadata,
|
||||
message,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(FinalizedAgentResult {
|
||||
outbound_messages,
|
||||
should_schedule_compaction: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn should_display_message_to_user(
|
||||
show_tool_results: bool,
|
||||
message: &ChatMessage,
|
||||
) -> bool {
|
||||
if message.role != "tool" {
|
||||
return true;
|
||||
}
|
||||
|
||||
show_tool_results
|
||||
|| matches!(
|
||||
message
|
||||
.tool_state
|
||||
.as_ref()
|
||||
.unwrap_or(&ToolMessageState::Completed),
|
||||
ToolMessageState::PendingUserAction
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::bus::ChatMessage;
|
||||
|
||||
fn test_provider_config_named(name: &str, model_id: &str) -> LLMProviderConfig {
|
||||
LLMProviderConfig {
|
||||
provider_type: "openai".to_string(),
|
||||
name: name.to_string(),
|
||||
base_url: "http://localhost".to_string(),
|
||||
api_key: "test-key".to_string(),
|
||||
extra_headers: HashMap::new(),
|
||||
llm_timeout_secs: 120,
|
||||
model_id: model_id.to_string(),
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(32),
|
||||
model_extra: HashMap::new(),
|
||||
max_tool_iterations: 1,
|
||||
tool_result_max_chars: 20_000,
|
||||
context_tool_result_trim_chars: 20_000,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_provider_config_uses_named_agent_override() {
|
||||
let default_provider = test_provider_config_named("default-provider", "default-model");
|
||||
let provider_configs = HashMap::from([(
|
||||
"planner".to_string(),
|
||||
test_provider_config_named("planner-provider", "planner-model"),
|
||||
)]);
|
||||
|
||||
let selected =
|
||||
select_provider_config(&default_provider, &provider_configs, Some("planner")).unwrap();
|
||||
assert_eq!(selected.name, "planner-provider");
|
||||
assert_eq!(selected.model_id, "planner-model");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_provider_config_falls_back_to_default() {
|
||||
let default_provider = test_provider_config_named("default-provider", "default-model");
|
||||
let provider_configs = HashMap::new();
|
||||
|
||||
let selected =
|
||||
select_provider_config(&default_provider, &provider_configs, Some("default")).unwrap();
|
||||
assert_eq!(selected.name, "default-provider");
|
||||
assert_eq!(selected.model_id, "default-model");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compose_scheduled_task_system_prompt_appends_task_specific_prompt() {
|
||||
let prompt = compose_scheduled_task_system_prompt(Some(" 只汇报异常 "));
|
||||
|
||||
assert!(prompt.contains("当前输入来自一次已经触发的定时任务执行"));
|
||||
assert!(prompt.contains("任务专属要求:只汇报异常"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compose_scheduled_task_system_prompt_ignores_blank_override() {
|
||||
let prompt = compose_scheduled_task_system_prompt(Some(" "));
|
||||
|
||||
assert!(prompt.contains("当前输入来自一次已经触发的定时任务执行"));
|
||||
assert!(!prompt.contains("任务专属要求"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_should_display_message_to_user_keeps_pending_tool_action_visible() {
|
||||
let message = ChatMessage::tool_with_state(
|
||||
"call-1",
|
||||
"approval",
|
||||
"需要用户确认",
|
||||
ToolMessageState::PendingUserAction,
|
||||
);
|
||||
|
||||
assert!(should_display_message_to_user(false, &message));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_should_display_message_to_user_hides_completed_tool_when_disabled() {
|
||||
let message = ChatMessage::tool("call-1", "calculator", "2");
|
||||
|
||||
assert!(!should_display_message_to_user(false, &message));
|
||||
assert!(should_display_message_to_user(true, &message));
|
||||
}
|
||||
}
|
||||
@ -1,5 +1,7 @@
|
||||
pub mod execution;
|
||||
pub mod http;
|
||||
pub mod processor;
|
||||
pub mod prompt;
|
||||
pub mod session;
|
||||
pub mod ws;
|
||||
|
||||
|
||||
149
src/gateway/prompt.rs
Normal file
149
src/gateway/prompt.rs
Normal file
@ -0,0 +1,149 @@
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use crate::agent::AgentError;
|
||||
|
||||
pub(crate) const DEFAULT_AGENT_PROMPT: &str = include_str!("default_agent_prompt.md");
|
||||
pub(crate) const MANAGED_AGENT_MEMORY_BLOCK_START: &str = "<!-- PICOBOT_MANAGED_MEMORY:START -->";
|
||||
pub(crate) const MANAGED_AGENT_MEMORY_BLOCK_END: &str = "<!-- PICOBOT_MANAGED_MEMORY:END -->";
|
||||
pub(crate) const MANAGED_AGENT_MEMORY_TITLE: &str = "## 用户记忆摘要";
|
||||
|
||||
pub(crate) fn load_agent_prompt() -> Result<Option<String>, AgentError> {
|
||||
let path = agent_prompt_path()?;
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.map_err(|err| AgentError::Other(format!("create agent prompt dir error: {}", err)))?;
|
||||
}
|
||||
|
||||
if !path.exists() {
|
||||
write_agent_prompt(&path, DEFAULT_AGENT_PROMPT)?;
|
||||
}
|
||||
|
||||
let content = fs::read_to_string(&path)
|
||||
.map_err(|err| AgentError::Other(format!("read agent prompt file error: {}", err)))?;
|
||||
let trimmed = content.trim();
|
||||
if trimmed.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Ok(Some(trimmed.to_string()))
|
||||
}
|
||||
|
||||
pub(crate) fn upsert_managed_agent_memory_summary(markdown_body: &str) -> Result<(), AgentError> {
|
||||
let path = agent_prompt_path()?;
|
||||
let existing = if path.exists() {
|
||||
fs::read_to_string(&path)
|
||||
.map_err(|err| AgentError::Other(format!("read agent prompt file error: {}", err)))?
|
||||
} else {
|
||||
DEFAULT_AGENT_PROMPT.to_string()
|
||||
};
|
||||
let updated = upsert_managed_agent_memory_block(&existing, markdown_body);
|
||||
write_agent_prompt(&path, &updated)
|
||||
}
|
||||
|
||||
pub(crate) fn upsert_managed_agent_memory_block(existing: &str, markdown_body: &str) -> String {
|
||||
let managed_block = render_managed_agent_memory_block(markdown_body);
|
||||
|
||||
if let (Some(start), Some(end)) = (
|
||||
existing.find(MANAGED_AGENT_MEMORY_BLOCK_START),
|
||||
existing.find(MANAGED_AGENT_MEMORY_BLOCK_END),
|
||||
) {
|
||||
let end = end + MANAGED_AGENT_MEMORY_BLOCK_END.len();
|
||||
let mut updated = String::new();
|
||||
updated.push_str(existing[..start].trim_end());
|
||||
updated.push_str("\n\n");
|
||||
updated.push_str(&managed_block);
|
||||
updated.push_str("\n\n");
|
||||
updated.push_str(existing[end..].trim_start());
|
||||
return updated.trim().to_string() + "\n";
|
||||
}
|
||||
|
||||
if let Some(reply_rules_index) = existing.find("## 回复规则") {
|
||||
let mut updated = String::new();
|
||||
updated.push_str(existing[..reply_rules_index].trim_end());
|
||||
updated.push_str("\n\n");
|
||||
updated.push_str(&managed_block);
|
||||
updated.push_str("\n\n");
|
||||
updated.push_str(existing[reply_rules_index..].trim_start());
|
||||
return updated.trim().to_string() + "\n";
|
||||
}
|
||||
|
||||
let mut updated = existing.trim_end().to_string();
|
||||
if !updated.is_empty() {
|
||||
updated.push_str("\n\n");
|
||||
}
|
||||
updated.push_str(&managed_block);
|
||||
updated.push('\n');
|
||||
updated
|
||||
}
|
||||
|
||||
fn render_managed_agent_memory_block(markdown_body: &str) -> String {
|
||||
format!(
|
||||
"{MANAGED_AGENT_MEMORY_BLOCK_START}\n{MANAGED_AGENT_MEMORY_TITLE}\n\n{}\n{MANAGED_AGENT_MEMORY_BLOCK_END}",
|
||||
markdown_body.trim()
|
||||
)
|
||||
}
|
||||
|
||||
fn write_agent_prompt(path: &Path, content: &str) -> Result<(), AgentError> {
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.map_err(|err| AgentError::Other(format!("create agent prompt dir error: {}", err)))?;
|
||||
}
|
||||
|
||||
let temp_path = path.with_extension("md.tmp");
|
||||
fs::write(&temp_path, content)
|
||||
.map_err(|err| AgentError::Other(format!("write agent prompt temp file error: {}", err)))?;
|
||||
fs::rename(&temp_path, path)
|
||||
.map_err(|err| AgentError::Other(format!("replace agent prompt file error: {}", err)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn agent_prompt_path() -> Result<PathBuf, AgentError> {
|
||||
let home = dirs::home_dir()
|
||||
.ok_or_else(|| AgentError::Other("home directory not found".to_string()))?;
|
||||
Ok(home.join(".picobot").join("agent").join("AGENT.md"))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_upsert_managed_agent_memory_block_inserts_before_reply_rules() {
|
||||
let original =
|
||||
"# PicoBot 代理配置\n\n## 身份\n- 你是 PicoBot。\n\n## 回复规则\n- 使用中文回复。\n";
|
||||
let updated = upsert_managed_agent_memory_block(
|
||||
original,
|
||||
"### 用户事实\n- 用户在做AI产品\n\n### 用户偏好\n- 偏好简洁表达",
|
||||
);
|
||||
|
||||
let managed_pos = updated.find(MANAGED_AGENT_MEMORY_BLOCK_START).unwrap();
|
||||
let reply_rules_pos = updated.find("## 回复规则").unwrap();
|
||||
assert!(managed_pos < reply_rules_pos);
|
||||
assert!(updated.contains(MANAGED_AGENT_MEMORY_TITLE));
|
||||
assert!(updated.contains("用户在做AI产品"));
|
||||
assert!(updated.contains("偏好简洁表达"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_upsert_managed_agent_memory_block_replaces_existing_block() {
|
||||
let original = format!(
|
||||
"# PicoBot\n\n{MANAGED_AGENT_MEMORY_BLOCK_START}\n{MANAGED_AGENT_MEMORY_TITLE}\n\nold\n{MANAGED_AGENT_MEMORY_BLOCK_END}\n\n## 回复规则\n- 简洁。\n"
|
||||
);
|
||||
|
||||
let updated = upsert_managed_agent_memory_block(&original, "new");
|
||||
|
||||
assert!(updated.contains("new"));
|
||||
assert!(!updated.contains("old"));
|
||||
assert_eq!(updated.matches(MANAGED_AGENT_MEMORY_BLOCK_START).count(), 1);
|
||||
assert_eq!(updated.matches(MANAGED_AGENT_MEMORY_BLOCK_END).count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_upsert_managed_agent_memory_block_trims_summary_body() {
|
||||
let updated = upsert_managed_agent_memory_block("# PicoBot\n", "\n\nsummary\n\n");
|
||||
|
||||
assert!(updated.contains("\n\nsummary\n"));
|
||||
assert!(!updated.contains("\n\nsummary\n\n\n"));
|
||||
}
|
||||
}
|
||||
@ -16,20 +16,19 @@ use crate::tools::{
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
use uuid::Uuid;
|
||||
|
||||
const DEFAULT_AGENT_PROMPT: &str = include_str!("default_agent_prompt.md");
|
||||
const MANAGED_AGENT_MEMORY_BLOCK_START: &str = "<!-- PICOBOT_MANAGED_MEMORY:START -->";
|
||||
const MANAGED_AGENT_MEMORY_BLOCK_END: &str = "<!-- PICOBOT_MANAGED_MEMORY:END -->";
|
||||
const MANAGED_AGENT_MEMORY_TITLE: &str = "## 用户记忆摘要";
|
||||
use super::execution::{
|
||||
AgentExecutionService, FinalizeAgentResultRequest, compose_scheduled_task_system_prompt,
|
||||
select_provider_config, should_display_message_to_user,
|
||||
};
|
||||
use super::prompt::{load_agent_prompt, upsert_managed_agent_memory_summary};
|
||||
|
||||
const MEMORY_MAINTENANCE_SYSTEM_PROMPT: &str = include_str!("memory_maintenance_system_prompt.md");
|
||||
const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000];
|
||||
const SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT: &str = "系统说明:当前输入来自一次已经触发的定时任务执行。你现在需要执行任务内容本身,而不是创建、修改、恢复、暂停或查询新的定时任务。除非当前任务内容明确要求管理调度器,否则不要调用任何定时任务管理工具;像“每小时”、“每天”、“cron”、“定时”等词,只应视为任务背景,不应再解释为新的建任务请求。";
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum MemoryMaintenanceCategory {
|
||||
@ -147,63 +146,6 @@ fn is_recoverable_maintenance_llm_error(error: &str) -> bool {
|
||||
|| normalized.contains("timeout")
|
||||
}
|
||||
|
||||
fn render_managed_agent_memory_block(markdown_body: &str) -> String {
|
||||
format!(
|
||||
"{MANAGED_AGENT_MEMORY_BLOCK_START}\n{MANAGED_AGENT_MEMORY_TITLE}\n\n{}\n{MANAGED_AGENT_MEMORY_BLOCK_END}",
|
||||
markdown_body.trim()
|
||||
)
|
||||
}
|
||||
|
||||
fn upsert_managed_agent_memory_block(existing: &str, markdown_body: &str) -> String {
|
||||
let managed_block = render_managed_agent_memory_block(markdown_body);
|
||||
|
||||
if let (Some(start), Some(end)) = (
|
||||
existing.find(MANAGED_AGENT_MEMORY_BLOCK_START),
|
||||
existing.find(MANAGED_AGENT_MEMORY_BLOCK_END),
|
||||
) {
|
||||
let end = end + MANAGED_AGENT_MEMORY_BLOCK_END.len();
|
||||
let mut updated = String::new();
|
||||
updated.push_str(existing[..start].trim_end());
|
||||
updated.push_str("\n\n");
|
||||
updated.push_str(&managed_block);
|
||||
updated.push_str("\n\n");
|
||||
updated.push_str(existing[end..].trim_start());
|
||||
return updated.trim().to_string() + "\n";
|
||||
}
|
||||
|
||||
if let Some(reply_rules_index) = existing.find("## 回复规则") {
|
||||
let mut updated = String::new();
|
||||
updated.push_str(existing[..reply_rules_index].trim_end());
|
||||
updated.push_str("\n\n");
|
||||
updated.push_str(&managed_block);
|
||||
updated.push_str("\n\n");
|
||||
updated.push_str(existing[reply_rules_index..].trim_start());
|
||||
return updated.trim().to_string() + "\n";
|
||||
}
|
||||
|
||||
let mut updated = existing.trim_end().to_string();
|
||||
if !updated.is_empty() {
|
||||
updated.push_str("\n\n");
|
||||
}
|
||||
updated.push_str(&managed_block);
|
||||
updated.push('\n');
|
||||
updated
|
||||
}
|
||||
|
||||
fn write_agent_prompt(path: &Path, content: &str) -> Result<(), AgentError> {
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.map_err(|err| AgentError::Other(format!("create agent prompt dir error: {}", err)))?;
|
||||
}
|
||||
|
||||
let temp_path = path.with_extension("md.tmp");
|
||||
fs::write(&temp_path, content)
|
||||
.map_err(|err| AgentError::Other(format!("write agent prompt temp file error: {}", err)))?;
|
||||
fs::rename(&temp_path, path)
|
||||
.map_err(|err| AgentError::Other(format!("replace agent prompt file error: {}", err)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn strip_json_code_fence(content: &str) -> &str {
|
||||
let trimmed = content.trim();
|
||||
if let Some(rest) = trimmed.strip_prefix("```json") {
|
||||
@ -683,7 +625,7 @@ impl Session {
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn matches_current_user_turn(&self, chat_id: &str, message: &ChatMessage) -> bool {
|
||||
pub(crate) fn matches_current_user_turn(&self, chat_id: &str, message: &ChatMessage) -> bool {
|
||||
self.latest_user_message(chat_id)
|
||||
.map(|current| {
|
||||
current.id == message.id
|
||||
@ -694,7 +636,7 @@ impl Session {
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn stale_result_diagnostics(
|
||||
pub(crate) fn stale_result_diagnostics(
|
||||
&self,
|
||||
chat_id: &str,
|
||||
) -> (Option<&str>, Option<String>, bool, usize) {
|
||||
@ -843,33 +785,6 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
fn load_agent_prompt() -> Result<Option<String>, AgentError> {
|
||||
let path = agent_prompt_path()?;
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.map_err(|err| AgentError::Other(format!("create agent prompt dir error: {}", err)))?;
|
||||
}
|
||||
|
||||
if !path.exists() {
|
||||
write_agent_prompt(&path, DEFAULT_AGENT_PROMPT)?;
|
||||
}
|
||||
|
||||
let content = fs::read_to_string(&path)
|
||||
.map_err(|err| AgentError::Other(format!("read agent prompt file error: {}", err)))?;
|
||||
let trimmed = content.trim();
|
||||
if trimmed.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Ok(Some(trimmed.to_string()))
|
||||
}
|
||||
|
||||
fn agent_prompt_path() -> Result<std::path::PathBuf, AgentError> {
|
||||
let home = dirs::home_dir()
|
||||
.ok_or_else(|| AgentError::Other("home directory not found".to_string()))?;
|
||||
Ok(home.join(".picobot").join("agent").join("AGENT.md"))
|
||||
}
|
||||
|
||||
/// SessionManager 管理所有 Session,按 channel_name 路由
|
||||
#[derive(Clone)]
|
||||
pub struct SessionManager {
|
||||
@ -1119,16 +1034,7 @@ impl SessionManager {
|
||||
&self,
|
||||
markdown_body: &str,
|
||||
) -> Result<(), AgentError> {
|
||||
let path = agent_prompt_path()?;
|
||||
let existing = if path.exists() {
|
||||
fs::read_to_string(&path).map_err(|err| {
|
||||
AgentError::Other(format!("read agent prompt file error: {}", err))
|
||||
})?
|
||||
} else {
|
||||
DEFAULT_AGENT_PROMPT.to_string()
|
||||
};
|
||||
let updated = upsert_managed_agent_memory_block(&existing, markdown_body);
|
||||
write_agent_prompt(&path, &updated)
|
||||
upsert_managed_agent_memory_summary(markdown_body)
|
||||
}
|
||||
|
||||
#[cfg_attr(not(test), allow(dead_code))]
|
||||
@ -1524,51 +1430,24 @@ impl SessionManager {
|
||||
|
||||
let result = agent.process(history).await?;
|
||||
|
||||
let mut should_schedule_compaction = false;
|
||||
let response = {
|
||||
let finalized_result = {
|
||||
let mut session_guard = session.lock().await;
|
||||
|
||||
if !session_guard.matches_current_user_turn(chat_id, &user_message) {
|
||||
let (latest_user_id, latest_user_preview, compression_in_flight, history_len) =
|
||||
session_guard.stale_result_diagnostics(chat_id);
|
||||
tracing::warn!(
|
||||
channel = %channel_name,
|
||||
chat_id = %chat_id,
|
||||
user_message_id = %user_message.id,
|
||||
latest_user_id,
|
||||
latest_user_preview,
|
||||
compression_in_flight,
|
||||
history_len,
|
||||
"Skipping stale agent result because a newer user message is already present"
|
||||
);
|
||||
Vec::new()
|
||||
} else {
|
||||
// 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复
|
||||
session_guard
|
||||
.append_persisted_messages(chat_id, result.emitted_messages.clone())?;
|
||||
should_schedule_compaction = true;
|
||||
|
||||
result
|
||||
.emitted_messages
|
||||
.iter()
|
||||
.filter(|message| {
|
||||
(!message.is_assistant_tool_call_message() || live_emitter.is_none())
|
||||
&& should_display_message_to_user(self.show_tool_results, message)
|
||||
})
|
||||
.flat_map(|message| {
|
||||
OutboundMessage::from_chat_message(
|
||||
let metadata = HashMap::new();
|
||||
AgentExecutionService::new(self.show_tool_results).finalize_result(
|
||||
&mut session_guard,
|
||||
FinalizeAgentResultRequest {
|
||||
channel_name,
|
||||
chat_id,
|
||||
None,
|
||||
&HashMap::new(),
|
||||
message,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
user_message: &user_message,
|
||||
result,
|
||||
metadata: &metadata,
|
||||
suppress_live_tool_calls: live_emitter.is_some(),
|
||||
execution_kind: "message",
|
||||
},
|
||||
)?
|
||||
};
|
||||
|
||||
if should_schedule_compaction {
|
||||
if finalized_result.should_schedule_compaction {
|
||||
if let Err(error) =
|
||||
schedule_background_history_compaction(session.clone(), chat_id.to_string()).await
|
||||
{
|
||||
@ -1580,11 +1459,11 @@ impl SessionManager {
|
||||
tracing::debug!(
|
||||
channel = %channel_name,
|
||||
chat_id = %chat_id,
|
||||
outbound_count = response.len(),
|
||||
outbound_count = finalized_result.outbound_messages.len(),
|
||||
"Agent response sequence received"
|
||||
);
|
||||
|
||||
Ok(response)
|
||||
Ok(finalized_result.outbound_messages)
|
||||
}
|
||||
|
||||
pub async fn run_scheduled_agent_task(
|
||||
@ -1649,49 +1528,23 @@ impl SessionManager {
|
||||
|
||||
let result = agent.process(history).await?;
|
||||
|
||||
let mut should_schedule_compaction = false;
|
||||
let response = {
|
||||
let finalized_result = {
|
||||
let mut session_guard = session.lock().await;
|
||||
|
||||
if !session_guard.matches_current_user_turn(chat_id, &user_message) {
|
||||
let (latest_user_id, latest_user_preview, compression_in_flight, history_len) =
|
||||
session_guard.stale_result_diagnostics(chat_id);
|
||||
tracing::warn!(
|
||||
channel = %channel_name,
|
||||
chat_id = %chat_id,
|
||||
user_message_id = %user_message.id,
|
||||
latest_user_id,
|
||||
latest_user_preview,
|
||||
compression_in_flight,
|
||||
history_len,
|
||||
"Skipping stale scheduled agent result because a newer user message is already present"
|
||||
);
|
||||
Vec::new()
|
||||
} else {
|
||||
session_guard
|
||||
.append_persisted_messages(chat_id, result.emitted_messages.clone())?;
|
||||
should_schedule_compaction = true;
|
||||
|
||||
result
|
||||
.emitted_messages
|
||||
.iter()
|
||||
.filter(|message| {
|
||||
should_display_message_to_user(self.show_tool_results, message)
|
||||
})
|
||||
.flat_map(|message| {
|
||||
OutboundMessage::from_chat_message(
|
||||
AgentExecutionService::new(self.show_tool_results).finalize_result(
|
||||
&mut session_guard,
|
||||
FinalizeAgentResultRequest {
|
||||
channel_name,
|
||||
chat_id,
|
||||
None,
|
||||
&options.metadata,
|
||||
message,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
user_message: &user_message,
|
||||
result,
|
||||
metadata: &options.metadata,
|
||||
suppress_live_tool_calls: false,
|
||||
execution_kind: "scheduled_task",
|
||||
},
|
||||
)?
|
||||
};
|
||||
|
||||
if should_schedule_compaction {
|
||||
if finalized_result.should_schedule_compaction {
|
||||
if let Err(error) =
|
||||
schedule_background_history_compaction(session.clone(), chat_id.to_string()).await
|
||||
{
|
||||
@ -1699,7 +1552,7 @@ impl SessionManager {
|
||||
}
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
Ok(finalized_result.outbound_messages)
|
||||
}
|
||||
|
||||
/// 清除指定 session 的所有历史
|
||||
@ -1712,47 +1565,6 @@ impl SessionManager {
|
||||
}
|
||||
}
|
||||
|
||||
fn should_display_message_to_user(show_tool_results: bool, message: &ChatMessage) -> bool {
|
||||
if message.role != "tool" {
|
||||
return true;
|
||||
}
|
||||
|
||||
show_tool_results
|
||||
|| matches!(
|
||||
message
|
||||
.tool_state
|
||||
.as_ref()
|
||||
.unwrap_or(&crate::bus::message::ToolMessageState::Completed),
|
||||
crate::bus::message::ToolMessageState::PendingUserAction
|
||||
)
|
||||
}
|
||||
|
||||
fn compose_scheduled_task_system_prompt(system_prompt: Option<&str>) -> String {
|
||||
match system_prompt
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
{
|
||||
Some(system_prompt) => format!(
|
||||
"{}\n\n任务专属要求:{}",
|
||||
SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT, system_prompt
|
||||
),
|
||||
None => SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn select_provider_config(
|
||||
default_provider_config: &LLMProviderConfig,
|
||||
provider_configs: &HashMap<String, LLMProviderConfig>,
|
||||
agent_name: Option<&str>,
|
||||
) -> Result<LLMProviderConfig, AgentError> {
|
||||
match agent_name.map(str::trim).filter(|value| !value.is_empty()) {
|
||||
None | Some("default") => Ok(default_provider_config.clone()),
|
||||
Some(agent_name) => provider_configs.get(agent_name).cloned().ok_or_else(|| {
|
||||
AgentError::Other(format!("Scheduled agent '{}' not found", agent_name))
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@ -1787,38 +1599,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
fn test_provider_config_named(name: &str, model_id: &str) -> LLMProviderConfig {
|
||||
LLMProviderConfig {
|
||||
provider_type: "openai".to_string(),
|
||||
name: name.to_string(),
|
||||
base_url: "http://localhost".to_string(),
|
||||
api_key: "test-key".to_string(),
|
||||
extra_headers: HashMap::new(),
|
||||
llm_timeout_secs: 120,
|
||||
model_id: model_id.to_string(),
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(32),
|
||||
model_extra: HashMap::new(),
|
||||
max_tool_iterations: 1,
|
||||
tool_result_max_chars: 20_000,
|
||||
context_tool_result_trim_chars: 20_000,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_provider_config_uses_named_agent_override() {
|
||||
let default_provider = test_provider_config_named("default-provider", "default-model");
|
||||
let provider_configs = HashMap::from([(
|
||||
"planner".to_string(),
|
||||
test_provider_config_named("planner-provider", "planner-model"),
|
||||
)]);
|
||||
|
||||
let selected =
|
||||
select_provider_config(&default_provider, &provider_configs, Some("planner")).unwrap();
|
||||
assert_eq!(selected.name, "planner-provider");
|
||||
assert_eq!(selected.model_id, "planner-model");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_enrich_user_content_with_media_refs_appends_tagged_json() {
|
||||
let media_refs = vec!["/tmp/a.png".to_string(), "/tmp/b.pdf".to_string()];
|
||||
@ -1944,17 +1724,6 @@ mod tests {
|
||||
assert!(session.matches_current_user_turn("chat-1", &second));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_provider_config_falls_back_to_default() {
|
||||
let default_provider = test_provider_config_named("default-provider", "default-model");
|
||||
let provider_configs = HashMap::new();
|
||||
|
||||
let selected =
|
||||
select_provider_config(&default_provider, &provider_configs, Some("default")).unwrap();
|
||||
assert_eq!(selected.name, "default-provider");
|
||||
assert_eq!(selected.model_id, "default-model");
|
||||
}
|
||||
|
||||
async fn start_mock_openai_server() -> String {
|
||||
async fn handle(Json(body): Json<Value>) -> Json<Value> {
|
||||
let model = body
|
||||
@ -2983,21 +2752,4 @@ mod tests {
|
||||
assert_eq!(plan.preferences[0].content, "偏好简洁表达");
|
||||
assert_eq!(plan.behavior_patterns[0].content, "习惯先问方案再要代码");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_upsert_managed_agent_memory_block_inserts_before_reply_rules() {
|
||||
let original =
|
||||
"# PicoBot 代理配置\n\n## 身份\n- 你是 PicoBot。\n\n## 回复规则\n- 使用中文回复。\n";
|
||||
let updated = upsert_managed_agent_memory_block(
|
||||
original,
|
||||
"### 用户事实\n- 用户在做AI产品\n\n### 用户偏好\n- 偏好简洁表达",
|
||||
);
|
||||
|
||||
let managed_pos = updated.find(MANAGED_AGENT_MEMORY_BLOCK_START).unwrap();
|
||||
let reply_rules_pos = updated.find("## 回复规则").unwrap();
|
||||
assert!(managed_pos < reply_rules_pos);
|
||||
assert!(updated.contains(MANAGED_AGENT_MEMORY_TITLE));
|
||||
assert!(updated.contains("用户在做AI产品"));
|
||||
assert!(updated.contains("偏好简洁表达"));
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user