feat: 添加执行服务和提示管理功能,重构相关模块以优化代码结构
This commit is contained in:
parent
73dab09bfe
commit
33f5a4cbd2
218
src/gateway/execution.rs
Normal file
218
src/gateway/execution.rs
Normal file
@ -0,0 +1,218 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use crate::agent::{AgentError, AgentProcessResult};
|
||||||
|
use crate::bus::message::ToolMessageState;
|
||||||
|
use crate::bus::{ChatMessage, OutboundMessage};
|
||||||
|
use crate::config::LLMProviderConfig;
|
||||||
|
|
||||||
|
use super::session::Session;
|
||||||
|
|
||||||
|
const SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT: &str = "系统说明:当前输入来自一次已经触发的定时任务执行。你现在需要执行任务内容本身,而不是创建、修改、恢复、暂停或查询新的定时任务。除非当前任务内容明确要求管理调度器,否则不要调用任何定时任务管理工具;像“每小时”、“每天”、“cron”、“定时”等词,只应视为任务背景,不应再解释为新的建任务请求。";
|
||||||
|
|
||||||
|
pub(crate) fn compose_scheduled_task_system_prompt(system_prompt: Option<&str>) -> String {
|
||||||
|
match system_prompt
|
||||||
|
.map(str::trim)
|
||||||
|
.filter(|value| !value.is_empty())
|
||||||
|
{
|
||||||
|
Some(system_prompt) => format!(
|
||||||
|
"{}\n\n任务专属要求:{}",
|
||||||
|
SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT, system_prompt
|
||||||
|
),
|
||||||
|
None => SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn select_provider_config(
|
||||||
|
default_provider_config: &LLMProviderConfig,
|
||||||
|
provider_configs: &HashMap<String, LLMProviderConfig>,
|
||||||
|
agent_name: Option<&str>,
|
||||||
|
) -> Result<LLMProviderConfig, AgentError> {
|
||||||
|
match agent_name.map(str::trim).filter(|value| !value.is_empty()) {
|
||||||
|
None | Some("default") => Ok(default_provider_config.clone()),
|
||||||
|
Some(agent_name) => provider_configs.get(agent_name).cloned().ok_or_else(|| {
|
||||||
|
AgentError::Other(format!("Scheduled agent '{}' not found", agent_name))
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) struct AgentExecutionService {
|
||||||
|
show_tool_results: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) struct FinalizeAgentResultRequest<'a> {
|
||||||
|
pub(crate) channel_name: &'a str,
|
||||||
|
pub(crate) chat_id: &'a str,
|
||||||
|
pub(crate) user_message: &'a ChatMessage,
|
||||||
|
pub(crate) result: AgentProcessResult,
|
||||||
|
pub(crate) metadata: &'a HashMap<String, String>,
|
||||||
|
pub(crate) suppress_live_tool_calls: bool,
|
||||||
|
pub(crate) execution_kind: &'a str,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) struct FinalizedAgentResult {
|
||||||
|
pub(crate) outbound_messages: Vec<OutboundMessage>,
|
||||||
|
pub(crate) should_schedule_compaction: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AgentExecutionService {
|
||||||
|
pub(crate) fn new(show_tool_results: bool) -> Self {
|
||||||
|
Self { show_tool_results }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn finalize_result(
|
||||||
|
&self,
|
||||||
|
session: &mut Session,
|
||||||
|
request: FinalizeAgentResultRequest<'_>,
|
||||||
|
) -> Result<FinalizedAgentResult, AgentError> {
|
||||||
|
if !session.matches_current_user_turn(request.chat_id, request.user_message) {
|
||||||
|
let (latest_user_id, latest_user_preview, compression_in_flight, history_len) =
|
||||||
|
session.stale_result_diagnostics(request.chat_id);
|
||||||
|
tracing::warn!(
|
||||||
|
channel = %request.channel_name,
|
||||||
|
chat_id = %request.chat_id,
|
||||||
|
user_message_id = %request.user_message.id,
|
||||||
|
latest_user_id,
|
||||||
|
latest_user_preview,
|
||||||
|
compression_in_flight,
|
||||||
|
history_len,
|
||||||
|
execution_kind = %request.execution_kind,
|
||||||
|
"Skipping stale agent result because a newer user message is already present"
|
||||||
|
);
|
||||||
|
|
||||||
|
return Ok(FinalizedAgentResult {
|
||||||
|
outbound_messages: Vec::new(),
|
||||||
|
should_schedule_compaction: false,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
session
|
||||||
|
.append_persisted_messages(request.chat_id, request.result.emitted_messages.clone())?;
|
||||||
|
|
||||||
|
let outbound_messages = request
|
||||||
|
.result
|
||||||
|
.emitted_messages
|
||||||
|
.iter()
|
||||||
|
.filter(|message| {
|
||||||
|
(!message.is_assistant_tool_call_message() || !request.suppress_live_tool_calls)
|
||||||
|
&& should_display_message_to_user(self.show_tool_results, message)
|
||||||
|
})
|
||||||
|
.flat_map(|message| {
|
||||||
|
OutboundMessage::from_chat_message(
|
||||||
|
request.channel_name,
|
||||||
|
request.chat_id,
|
||||||
|
None,
|
||||||
|
request.metadata,
|
||||||
|
message,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(FinalizedAgentResult {
|
||||||
|
outbound_messages,
|
||||||
|
should_schedule_compaction: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn should_display_message_to_user(
|
||||||
|
show_tool_results: bool,
|
||||||
|
message: &ChatMessage,
|
||||||
|
) -> bool {
|
||||||
|
if message.role != "tool" {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
show_tool_results
|
||||||
|
|| matches!(
|
||||||
|
message
|
||||||
|
.tool_state
|
||||||
|
.as_ref()
|
||||||
|
.unwrap_or(&ToolMessageState::Completed),
|
||||||
|
ToolMessageState::PendingUserAction
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::bus::ChatMessage;
|
||||||
|
|
||||||
|
fn test_provider_config_named(name: &str, model_id: &str) -> LLMProviderConfig {
|
||||||
|
LLMProviderConfig {
|
||||||
|
provider_type: "openai".to_string(),
|
||||||
|
name: name.to_string(),
|
||||||
|
base_url: "http://localhost".to_string(),
|
||||||
|
api_key: "test-key".to_string(),
|
||||||
|
extra_headers: HashMap::new(),
|
||||||
|
llm_timeout_secs: 120,
|
||||||
|
model_id: model_id.to_string(),
|
||||||
|
temperature: Some(0.0),
|
||||||
|
max_tokens: Some(32),
|
||||||
|
model_extra: HashMap::new(),
|
||||||
|
max_tool_iterations: 1,
|
||||||
|
tool_result_max_chars: 20_000,
|
||||||
|
context_tool_result_trim_chars: 20_000,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_select_provider_config_uses_named_agent_override() {
|
||||||
|
let default_provider = test_provider_config_named("default-provider", "default-model");
|
||||||
|
let provider_configs = HashMap::from([(
|
||||||
|
"planner".to_string(),
|
||||||
|
test_provider_config_named("planner-provider", "planner-model"),
|
||||||
|
)]);
|
||||||
|
|
||||||
|
let selected =
|
||||||
|
select_provider_config(&default_provider, &provider_configs, Some("planner")).unwrap();
|
||||||
|
assert_eq!(selected.name, "planner-provider");
|
||||||
|
assert_eq!(selected.model_id, "planner-model");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_select_provider_config_falls_back_to_default() {
|
||||||
|
let default_provider = test_provider_config_named("default-provider", "default-model");
|
||||||
|
let provider_configs = HashMap::new();
|
||||||
|
|
||||||
|
let selected =
|
||||||
|
select_provider_config(&default_provider, &provider_configs, Some("default")).unwrap();
|
||||||
|
assert_eq!(selected.name, "default-provider");
|
||||||
|
assert_eq!(selected.model_id, "default-model");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_compose_scheduled_task_system_prompt_appends_task_specific_prompt() {
|
||||||
|
let prompt = compose_scheduled_task_system_prompt(Some(" 只汇报异常 "));
|
||||||
|
|
||||||
|
assert!(prompt.contains("当前输入来自一次已经触发的定时任务执行"));
|
||||||
|
assert!(prompt.contains("任务专属要求:只汇报异常"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_compose_scheduled_task_system_prompt_ignores_blank_override() {
|
||||||
|
let prompt = compose_scheduled_task_system_prompt(Some(" "));
|
||||||
|
|
||||||
|
assert!(prompt.contains("当前输入来自一次已经触发的定时任务执行"));
|
||||||
|
assert!(!prompt.contains("任务专属要求"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_should_display_message_to_user_keeps_pending_tool_action_visible() {
|
||||||
|
let message = ChatMessage::tool_with_state(
|
||||||
|
"call-1",
|
||||||
|
"approval",
|
||||||
|
"需要用户确认",
|
||||||
|
ToolMessageState::PendingUserAction,
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(should_display_message_to_user(false, &message));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_should_display_message_to_user_hides_completed_tool_when_disabled() {
|
||||||
|
let message = ChatMessage::tool("call-1", "calculator", "2");
|
||||||
|
|
||||||
|
assert!(!should_display_message_to_user(false, &message));
|
||||||
|
assert!(should_display_message_to_user(true, &message));
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,5 +1,7 @@
|
|||||||
|
pub mod execution;
|
||||||
pub mod http;
|
pub mod http;
|
||||||
pub mod processor;
|
pub mod processor;
|
||||||
|
pub mod prompt;
|
||||||
pub mod session;
|
pub mod session;
|
||||||
pub mod ws;
|
pub mod ws;
|
||||||
|
|
||||||
|
|||||||
149
src/gateway/prompt.rs
Normal file
149
src/gateway/prompt.rs
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
use std::fs;
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
|
use crate::agent::AgentError;
|
||||||
|
|
||||||
|
pub(crate) const DEFAULT_AGENT_PROMPT: &str = include_str!("default_agent_prompt.md");
|
||||||
|
pub(crate) const MANAGED_AGENT_MEMORY_BLOCK_START: &str = "<!-- PICOBOT_MANAGED_MEMORY:START -->";
|
||||||
|
pub(crate) const MANAGED_AGENT_MEMORY_BLOCK_END: &str = "<!-- PICOBOT_MANAGED_MEMORY:END -->";
|
||||||
|
pub(crate) const MANAGED_AGENT_MEMORY_TITLE: &str = "## 用户记忆摘要";
|
||||||
|
|
||||||
|
pub(crate) fn load_agent_prompt() -> Result<Option<String>, AgentError> {
|
||||||
|
let path = agent_prompt_path()?;
|
||||||
|
if let Some(parent) = path.parent() {
|
||||||
|
fs::create_dir_all(parent)
|
||||||
|
.map_err(|err| AgentError::Other(format!("create agent prompt dir error: {}", err)))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !path.exists() {
|
||||||
|
write_agent_prompt(&path, DEFAULT_AGENT_PROMPT)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let content = fs::read_to_string(&path)
|
||||||
|
.map_err(|err| AgentError::Other(format!("read agent prompt file error: {}", err)))?;
|
||||||
|
let trimmed = content.trim();
|
||||||
|
if trimmed.is_empty() {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Some(trimmed.to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn upsert_managed_agent_memory_summary(markdown_body: &str) -> Result<(), AgentError> {
|
||||||
|
let path = agent_prompt_path()?;
|
||||||
|
let existing = if path.exists() {
|
||||||
|
fs::read_to_string(&path)
|
||||||
|
.map_err(|err| AgentError::Other(format!("read agent prompt file error: {}", err)))?
|
||||||
|
} else {
|
||||||
|
DEFAULT_AGENT_PROMPT.to_string()
|
||||||
|
};
|
||||||
|
let updated = upsert_managed_agent_memory_block(&existing, markdown_body);
|
||||||
|
write_agent_prompt(&path, &updated)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn upsert_managed_agent_memory_block(existing: &str, markdown_body: &str) -> String {
|
||||||
|
let managed_block = render_managed_agent_memory_block(markdown_body);
|
||||||
|
|
||||||
|
if let (Some(start), Some(end)) = (
|
||||||
|
existing.find(MANAGED_AGENT_MEMORY_BLOCK_START),
|
||||||
|
existing.find(MANAGED_AGENT_MEMORY_BLOCK_END),
|
||||||
|
) {
|
||||||
|
let end = end + MANAGED_AGENT_MEMORY_BLOCK_END.len();
|
||||||
|
let mut updated = String::new();
|
||||||
|
updated.push_str(existing[..start].trim_end());
|
||||||
|
updated.push_str("\n\n");
|
||||||
|
updated.push_str(&managed_block);
|
||||||
|
updated.push_str("\n\n");
|
||||||
|
updated.push_str(existing[end..].trim_start());
|
||||||
|
return updated.trim().to_string() + "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(reply_rules_index) = existing.find("## 回复规则") {
|
||||||
|
let mut updated = String::new();
|
||||||
|
updated.push_str(existing[..reply_rules_index].trim_end());
|
||||||
|
updated.push_str("\n\n");
|
||||||
|
updated.push_str(&managed_block);
|
||||||
|
updated.push_str("\n\n");
|
||||||
|
updated.push_str(existing[reply_rules_index..].trim_start());
|
||||||
|
return updated.trim().to_string() + "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut updated = existing.trim_end().to_string();
|
||||||
|
if !updated.is_empty() {
|
||||||
|
updated.push_str("\n\n");
|
||||||
|
}
|
||||||
|
updated.push_str(&managed_block);
|
||||||
|
updated.push('\n');
|
||||||
|
updated
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_managed_agent_memory_block(markdown_body: &str) -> String {
|
||||||
|
format!(
|
||||||
|
"{MANAGED_AGENT_MEMORY_BLOCK_START}\n{MANAGED_AGENT_MEMORY_TITLE}\n\n{}\n{MANAGED_AGENT_MEMORY_BLOCK_END}",
|
||||||
|
markdown_body.trim()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn write_agent_prompt(path: &Path, content: &str) -> Result<(), AgentError> {
|
||||||
|
if let Some(parent) = path.parent() {
|
||||||
|
fs::create_dir_all(parent)
|
||||||
|
.map_err(|err| AgentError::Other(format!("create agent prompt dir error: {}", err)))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let temp_path = path.with_extension("md.tmp");
|
||||||
|
fs::write(&temp_path, content)
|
||||||
|
.map_err(|err| AgentError::Other(format!("write agent prompt temp file error: {}", err)))?;
|
||||||
|
fs::rename(&temp_path, path)
|
||||||
|
.map_err(|err| AgentError::Other(format!("replace agent prompt file error: {}", err)))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn agent_prompt_path() -> Result<PathBuf, AgentError> {
|
||||||
|
let home = dirs::home_dir()
|
||||||
|
.ok_or_else(|| AgentError::Other("home directory not found".to_string()))?;
|
||||||
|
Ok(home.join(".picobot").join("agent").join("AGENT.md"))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_upsert_managed_agent_memory_block_inserts_before_reply_rules() {
|
||||||
|
let original =
|
||||||
|
"# PicoBot 代理配置\n\n## 身份\n- 你是 PicoBot。\n\n## 回复规则\n- 使用中文回复。\n";
|
||||||
|
let updated = upsert_managed_agent_memory_block(
|
||||||
|
original,
|
||||||
|
"### 用户事实\n- 用户在做AI产品\n\n### 用户偏好\n- 偏好简洁表达",
|
||||||
|
);
|
||||||
|
|
||||||
|
let managed_pos = updated.find(MANAGED_AGENT_MEMORY_BLOCK_START).unwrap();
|
||||||
|
let reply_rules_pos = updated.find("## 回复规则").unwrap();
|
||||||
|
assert!(managed_pos < reply_rules_pos);
|
||||||
|
assert!(updated.contains(MANAGED_AGENT_MEMORY_TITLE));
|
||||||
|
assert!(updated.contains("用户在做AI产品"));
|
||||||
|
assert!(updated.contains("偏好简洁表达"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_upsert_managed_agent_memory_block_replaces_existing_block() {
|
||||||
|
let original = format!(
|
||||||
|
"# PicoBot\n\n{MANAGED_AGENT_MEMORY_BLOCK_START}\n{MANAGED_AGENT_MEMORY_TITLE}\n\nold\n{MANAGED_AGENT_MEMORY_BLOCK_END}\n\n## 回复规则\n- 简洁。\n"
|
||||||
|
);
|
||||||
|
|
||||||
|
let updated = upsert_managed_agent_memory_block(&original, "new");
|
||||||
|
|
||||||
|
assert!(updated.contains("new"));
|
||||||
|
assert!(!updated.contains("old"));
|
||||||
|
assert_eq!(updated.matches(MANAGED_AGENT_MEMORY_BLOCK_START).count(), 1);
|
||||||
|
assert_eq!(updated.matches(MANAGED_AGENT_MEMORY_BLOCK_END).count(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_upsert_managed_agent_memory_block_trims_summary_body() {
|
||||||
|
let updated = upsert_managed_agent_memory_block("# PicoBot\n", "\n\nsummary\n\n");
|
||||||
|
|
||||||
|
assert!(updated.contains("\n\nsummary\n"));
|
||||||
|
assert!(!updated.contains("\n\nsummary\n\n\n"));
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -16,20 +16,19 @@ use crate::tools::{
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::fs;
|
|
||||||
use std::path::Path;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use tokio::sync::{Mutex, mpsc};
|
use tokio::sync::{Mutex, mpsc};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
const DEFAULT_AGENT_PROMPT: &str = include_str!("default_agent_prompt.md");
|
use super::execution::{
|
||||||
const MANAGED_AGENT_MEMORY_BLOCK_START: &str = "<!-- PICOBOT_MANAGED_MEMORY:START -->";
|
AgentExecutionService, FinalizeAgentResultRequest, compose_scheduled_task_system_prompt,
|
||||||
const MANAGED_AGENT_MEMORY_BLOCK_END: &str = "<!-- PICOBOT_MANAGED_MEMORY:END -->";
|
select_provider_config, should_display_message_to_user,
|
||||||
const MANAGED_AGENT_MEMORY_TITLE: &str = "## 用户记忆摘要";
|
};
|
||||||
|
use super::prompt::{load_agent_prompt, upsert_managed_agent_memory_summary};
|
||||||
|
|
||||||
const MEMORY_MAINTENANCE_SYSTEM_PROMPT: &str = include_str!("memory_maintenance_system_prompt.md");
|
const MEMORY_MAINTENANCE_SYSTEM_PROMPT: &str = include_str!("memory_maintenance_system_prompt.md");
|
||||||
const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000];
|
const MEMORY_MAINTENANCE_RETRY_DELAYS_MS: &[u64] = &[1_000, 3_000];
|
||||||
const SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT: &str = "系统说明:当前输入来自一次已经触发的定时任务执行。你现在需要执行任务内容本身,而不是创建、修改、恢复、暂停或查询新的定时任务。除非当前任务内容明确要求管理调度器,否则不要调用任何定时任务管理工具;像“每小时”、“每天”、“cron”、“定时”等词,只应视为任务背景,不应再解释为新的建任务请求。";
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
enum MemoryMaintenanceCategory {
|
enum MemoryMaintenanceCategory {
|
||||||
@ -147,63 +146,6 @@ fn is_recoverable_maintenance_llm_error(error: &str) -> bool {
|
|||||||
|| normalized.contains("timeout")
|
|| normalized.contains("timeout")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn render_managed_agent_memory_block(markdown_body: &str) -> String {
|
|
||||||
format!(
|
|
||||||
"{MANAGED_AGENT_MEMORY_BLOCK_START}\n{MANAGED_AGENT_MEMORY_TITLE}\n\n{}\n{MANAGED_AGENT_MEMORY_BLOCK_END}",
|
|
||||||
markdown_body.trim()
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn upsert_managed_agent_memory_block(existing: &str, markdown_body: &str) -> String {
|
|
||||||
let managed_block = render_managed_agent_memory_block(markdown_body);
|
|
||||||
|
|
||||||
if let (Some(start), Some(end)) = (
|
|
||||||
existing.find(MANAGED_AGENT_MEMORY_BLOCK_START),
|
|
||||||
existing.find(MANAGED_AGENT_MEMORY_BLOCK_END),
|
|
||||||
) {
|
|
||||||
let end = end + MANAGED_AGENT_MEMORY_BLOCK_END.len();
|
|
||||||
let mut updated = String::new();
|
|
||||||
updated.push_str(existing[..start].trim_end());
|
|
||||||
updated.push_str("\n\n");
|
|
||||||
updated.push_str(&managed_block);
|
|
||||||
updated.push_str("\n\n");
|
|
||||||
updated.push_str(existing[end..].trim_start());
|
|
||||||
return updated.trim().to_string() + "\n";
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(reply_rules_index) = existing.find("## 回复规则") {
|
|
||||||
let mut updated = String::new();
|
|
||||||
updated.push_str(existing[..reply_rules_index].trim_end());
|
|
||||||
updated.push_str("\n\n");
|
|
||||||
updated.push_str(&managed_block);
|
|
||||||
updated.push_str("\n\n");
|
|
||||||
updated.push_str(existing[reply_rules_index..].trim_start());
|
|
||||||
return updated.trim().to_string() + "\n";
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut updated = existing.trim_end().to_string();
|
|
||||||
if !updated.is_empty() {
|
|
||||||
updated.push_str("\n\n");
|
|
||||||
}
|
|
||||||
updated.push_str(&managed_block);
|
|
||||||
updated.push('\n');
|
|
||||||
updated
|
|
||||||
}
|
|
||||||
|
|
||||||
fn write_agent_prompt(path: &Path, content: &str) -> Result<(), AgentError> {
|
|
||||||
if let Some(parent) = path.parent() {
|
|
||||||
fs::create_dir_all(parent)
|
|
||||||
.map_err(|err| AgentError::Other(format!("create agent prompt dir error: {}", err)))?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let temp_path = path.with_extension("md.tmp");
|
|
||||||
fs::write(&temp_path, content)
|
|
||||||
.map_err(|err| AgentError::Other(format!("write agent prompt temp file error: {}", err)))?;
|
|
||||||
fs::rename(&temp_path, path)
|
|
||||||
.map_err(|err| AgentError::Other(format!("replace agent prompt file error: {}", err)))?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn strip_json_code_fence(content: &str) -> &str {
|
fn strip_json_code_fence(content: &str) -> &str {
|
||||||
let trimmed = content.trim();
|
let trimmed = content.trim();
|
||||||
if let Some(rest) = trimmed.strip_prefix("```json") {
|
if let Some(rest) = trimmed.strip_prefix("```json") {
|
||||||
@ -683,7 +625,7 @@ impl Session {
|
|||||||
.unwrap_or(false)
|
.unwrap_or(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn matches_current_user_turn(&self, chat_id: &str, message: &ChatMessage) -> bool {
|
pub(crate) fn matches_current_user_turn(&self, chat_id: &str, message: &ChatMessage) -> bool {
|
||||||
self.latest_user_message(chat_id)
|
self.latest_user_message(chat_id)
|
||||||
.map(|current| {
|
.map(|current| {
|
||||||
current.id == message.id
|
current.id == message.id
|
||||||
@ -694,7 +636,7 @@ impl Session {
|
|||||||
.unwrap_or(false)
|
.unwrap_or(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn stale_result_diagnostics(
|
pub(crate) fn stale_result_diagnostics(
|
||||||
&self,
|
&self,
|
||||||
chat_id: &str,
|
chat_id: &str,
|
||||||
) -> (Option<&str>, Option<String>, bool, usize) {
|
) -> (Option<&str>, Option<String>, bool, usize) {
|
||||||
@ -843,33 +785,6 @@ impl Session {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load_agent_prompt() -> Result<Option<String>, AgentError> {
|
|
||||||
let path = agent_prompt_path()?;
|
|
||||||
if let Some(parent) = path.parent() {
|
|
||||||
fs::create_dir_all(parent)
|
|
||||||
.map_err(|err| AgentError::Other(format!("create agent prompt dir error: {}", err)))?;
|
|
||||||
}
|
|
||||||
|
|
||||||
if !path.exists() {
|
|
||||||
write_agent_prompt(&path, DEFAULT_AGENT_PROMPT)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let content = fs::read_to_string(&path)
|
|
||||||
.map_err(|err| AgentError::Other(format!("read agent prompt file error: {}", err)))?;
|
|
||||||
let trimmed = content.trim();
|
|
||||||
if trimmed.is_empty() {
|
|
||||||
return Ok(None);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Some(trimmed.to_string()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn agent_prompt_path() -> Result<std::path::PathBuf, AgentError> {
|
|
||||||
let home = dirs::home_dir()
|
|
||||||
.ok_or_else(|| AgentError::Other("home directory not found".to_string()))?;
|
|
||||||
Ok(home.join(".picobot").join("agent").join("AGENT.md"))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// SessionManager 管理所有 Session,按 channel_name 路由
|
/// SessionManager 管理所有 Session,按 channel_name 路由
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct SessionManager {
|
pub struct SessionManager {
|
||||||
@ -1119,16 +1034,7 @@ impl SessionManager {
|
|||||||
&self,
|
&self,
|
||||||
markdown_body: &str,
|
markdown_body: &str,
|
||||||
) -> Result<(), AgentError> {
|
) -> Result<(), AgentError> {
|
||||||
let path = agent_prompt_path()?;
|
upsert_managed_agent_memory_summary(markdown_body)
|
||||||
let existing = if path.exists() {
|
|
||||||
fs::read_to_string(&path).map_err(|err| {
|
|
||||||
AgentError::Other(format!("read agent prompt file error: {}", err))
|
|
||||||
})?
|
|
||||||
} else {
|
|
||||||
DEFAULT_AGENT_PROMPT.to_string()
|
|
||||||
};
|
|
||||||
let updated = upsert_managed_agent_memory_block(&existing, markdown_body);
|
|
||||||
write_agent_prompt(&path, &updated)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg_attr(not(test), allow(dead_code))]
|
#[cfg_attr(not(test), allow(dead_code))]
|
||||||
@ -1524,51 +1430,24 @@ impl SessionManager {
|
|||||||
|
|
||||||
let result = agent.process(history).await?;
|
let result = agent.process(history).await?;
|
||||||
|
|
||||||
let mut should_schedule_compaction = false;
|
let finalized_result = {
|
||||||
let response = {
|
|
||||||
let mut session_guard = session.lock().await;
|
let mut session_guard = session.lock().await;
|
||||||
|
let metadata = HashMap::new();
|
||||||
if !session_guard.matches_current_user_turn(chat_id, &user_message) {
|
AgentExecutionService::new(self.show_tool_results).finalize_result(
|
||||||
let (latest_user_id, latest_user_preview, compression_in_flight, history_len) =
|
&mut session_guard,
|
||||||
session_guard.stale_result_diagnostics(chat_id);
|
FinalizeAgentResultRequest {
|
||||||
tracing::warn!(
|
|
||||||
channel = %channel_name,
|
|
||||||
chat_id = %chat_id,
|
|
||||||
user_message_id = %user_message.id,
|
|
||||||
latest_user_id,
|
|
||||||
latest_user_preview,
|
|
||||||
compression_in_flight,
|
|
||||||
history_len,
|
|
||||||
"Skipping stale agent result because a newer user message is already present"
|
|
||||||
);
|
|
||||||
Vec::new()
|
|
||||||
} else {
|
|
||||||
// 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复
|
|
||||||
session_guard
|
|
||||||
.append_persisted_messages(chat_id, result.emitted_messages.clone())?;
|
|
||||||
should_schedule_compaction = true;
|
|
||||||
|
|
||||||
result
|
|
||||||
.emitted_messages
|
|
||||||
.iter()
|
|
||||||
.filter(|message| {
|
|
||||||
(!message.is_assistant_tool_call_message() || live_emitter.is_none())
|
|
||||||
&& should_display_message_to_user(self.show_tool_results, message)
|
|
||||||
})
|
|
||||||
.flat_map(|message| {
|
|
||||||
OutboundMessage::from_chat_message(
|
|
||||||
channel_name,
|
channel_name,
|
||||||
chat_id,
|
chat_id,
|
||||||
None,
|
user_message: &user_message,
|
||||||
&HashMap::new(),
|
result,
|
||||||
message,
|
metadata: &metadata,
|
||||||
)
|
suppress_live_tool_calls: live_emitter.is_some(),
|
||||||
})
|
execution_kind: "message",
|
||||||
.collect::<Vec<_>>()
|
},
|
||||||
}
|
)?
|
||||||
};
|
};
|
||||||
|
|
||||||
if should_schedule_compaction {
|
if finalized_result.should_schedule_compaction {
|
||||||
if let Err(error) =
|
if let Err(error) =
|
||||||
schedule_background_history_compaction(session.clone(), chat_id.to_string()).await
|
schedule_background_history_compaction(session.clone(), chat_id.to_string()).await
|
||||||
{
|
{
|
||||||
@ -1580,11 +1459,11 @@ impl SessionManager {
|
|||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
channel = %channel_name,
|
channel = %channel_name,
|
||||||
chat_id = %chat_id,
|
chat_id = %chat_id,
|
||||||
outbound_count = response.len(),
|
outbound_count = finalized_result.outbound_messages.len(),
|
||||||
"Agent response sequence received"
|
"Agent response sequence received"
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok(response)
|
Ok(finalized_result.outbound_messages)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn run_scheduled_agent_task(
|
pub async fn run_scheduled_agent_task(
|
||||||
@ -1649,49 +1528,23 @@ impl SessionManager {
|
|||||||
|
|
||||||
let result = agent.process(history).await?;
|
let result = agent.process(history).await?;
|
||||||
|
|
||||||
let mut should_schedule_compaction = false;
|
let finalized_result = {
|
||||||
let response = {
|
|
||||||
let mut session_guard = session.lock().await;
|
let mut session_guard = session.lock().await;
|
||||||
|
AgentExecutionService::new(self.show_tool_results).finalize_result(
|
||||||
if !session_guard.matches_current_user_turn(chat_id, &user_message) {
|
&mut session_guard,
|
||||||
let (latest_user_id, latest_user_preview, compression_in_flight, history_len) =
|
FinalizeAgentResultRequest {
|
||||||
session_guard.stale_result_diagnostics(chat_id);
|
|
||||||
tracing::warn!(
|
|
||||||
channel = %channel_name,
|
|
||||||
chat_id = %chat_id,
|
|
||||||
user_message_id = %user_message.id,
|
|
||||||
latest_user_id,
|
|
||||||
latest_user_preview,
|
|
||||||
compression_in_flight,
|
|
||||||
history_len,
|
|
||||||
"Skipping stale scheduled agent result because a newer user message is already present"
|
|
||||||
);
|
|
||||||
Vec::new()
|
|
||||||
} else {
|
|
||||||
session_guard
|
|
||||||
.append_persisted_messages(chat_id, result.emitted_messages.clone())?;
|
|
||||||
should_schedule_compaction = true;
|
|
||||||
|
|
||||||
result
|
|
||||||
.emitted_messages
|
|
||||||
.iter()
|
|
||||||
.filter(|message| {
|
|
||||||
should_display_message_to_user(self.show_tool_results, message)
|
|
||||||
})
|
|
||||||
.flat_map(|message| {
|
|
||||||
OutboundMessage::from_chat_message(
|
|
||||||
channel_name,
|
channel_name,
|
||||||
chat_id,
|
chat_id,
|
||||||
None,
|
user_message: &user_message,
|
||||||
&options.metadata,
|
result,
|
||||||
message,
|
metadata: &options.metadata,
|
||||||
)
|
suppress_live_tool_calls: false,
|
||||||
})
|
execution_kind: "scheduled_task",
|
||||||
.collect::<Vec<_>>()
|
},
|
||||||
}
|
)?
|
||||||
};
|
};
|
||||||
|
|
||||||
if should_schedule_compaction {
|
if finalized_result.should_schedule_compaction {
|
||||||
if let Err(error) =
|
if let Err(error) =
|
||||||
schedule_background_history_compaction(session.clone(), chat_id.to_string()).await
|
schedule_background_history_compaction(session.clone(), chat_id.to_string()).await
|
||||||
{
|
{
|
||||||
@ -1699,7 +1552,7 @@ impl SessionManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(response)
|
Ok(finalized_result.outbound_messages)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 清除指定 session 的所有历史
|
/// 清除指定 session 的所有历史
|
||||||
@ -1712,47 +1565,6 @@ impl SessionManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn should_display_message_to_user(show_tool_results: bool, message: &ChatMessage) -> bool {
|
|
||||||
if message.role != "tool" {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
show_tool_results
|
|
||||||
|| matches!(
|
|
||||||
message
|
|
||||||
.tool_state
|
|
||||||
.as_ref()
|
|
||||||
.unwrap_or(&crate::bus::message::ToolMessageState::Completed),
|
|
||||||
crate::bus::message::ToolMessageState::PendingUserAction
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn compose_scheduled_task_system_prompt(system_prompt: Option<&str>) -> String {
|
|
||||||
match system_prompt
|
|
||||||
.map(str::trim)
|
|
||||||
.filter(|value| !value.is_empty())
|
|
||||||
{
|
|
||||||
Some(system_prompt) => format!(
|
|
||||||
"{}\n\n任务专属要求:{}",
|
|
||||||
SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT, system_prompt
|
|
||||||
),
|
|
||||||
None => SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT.to_string(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn select_provider_config(
|
|
||||||
default_provider_config: &LLMProviderConfig,
|
|
||||||
provider_configs: &HashMap<String, LLMProviderConfig>,
|
|
||||||
agent_name: Option<&str>,
|
|
||||||
) -> Result<LLMProviderConfig, AgentError> {
|
|
||||||
match agent_name.map(str::trim).filter(|value| !value.is_empty()) {
|
|
||||||
None | Some("default") => Ok(default_provider_config.clone()),
|
|
||||||
Some(agent_name) => provider_configs.get(agent_name).cloned().ok_or_else(|| {
|
|
||||||
AgentError::Other(format!("Scheduled agent '{}' not found", agent_name))
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@ -1787,38 +1599,6 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn test_provider_config_named(name: &str, model_id: &str) -> LLMProviderConfig {
|
|
||||||
LLMProviderConfig {
|
|
||||||
provider_type: "openai".to_string(),
|
|
||||||
name: name.to_string(),
|
|
||||||
base_url: "http://localhost".to_string(),
|
|
||||||
api_key: "test-key".to_string(),
|
|
||||||
extra_headers: HashMap::new(),
|
|
||||||
llm_timeout_secs: 120,
|
|
||||||
model_id: model_id.to_string(),
|
|
||||||
temperature: Some(0.0),
|
|
||||||
max_tokens: Some(32),
|
|
||||||
model_extra: HashMap::new(),
|
|
||||||
max_tool_iterations: 1,
|
|
||||||
tool_result_max_chars: 20_000,
|
|
||||||
context_tool_result_trim_chars: 20_000,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_select_provider_config_uses_named_agent_override() {
|
|
||||||
let default_provider = test_provider_config_named("default-provider", "default-model");
|
|
||||||
let provider_configs = HashMap::from([(
|
|
||||||
"planner".to_string(),
|
|
||||||
test_provider_config_named("planner-provider", "planner-model"),
|
|
||||||
)]);
|
|
||||||
|
|
||||||
let selected =
|
|
||||||
select_provider_config(&default_provider, &provider_configs, Some("planner")).unwrap();
|
|
||||||
assert_eq!(selected.name, "planner-provider");
|
|
||||||
assert_eq!(selected.model_id, "planner-model");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_enrich_user_content_with_media_refs_appends_tagged_json() {
|
fn test_enrich_user_content_with_media_refs_appends_tagged_json() {
|
||||||
let media_refs = vec!["/tmp/a.png".to_string(), "/tmp/b.pdf".to_string()];
|
let media_refs = vec!["/tmp/a.png".to_string(), "/tmp/b.pdf".to_string()];
|
||||||
@ -1944,17 +1724,6 @@ mod tests {
|
|||||||
assert!(session.matches_current_user_turn("chat-1", &second));
|
assert!(session.matches_current_user_turn("chat-1", &second));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_select_provider_config_falls_back_to_default() {
|
|
||||||
let default_provider = test_provider_config_named("default-provider", "default-model");
|
|
||||||
let provider_configs = HashMap::new();
|
|
||||||
|
|
||||||
let selected =
|
|
||||||
select_provider_config(&default_provider, &provider_configs, Some("default")).unwrap();
|
|
||||||
assert_eq!(selected.name, "default-provider");
|
|
||||||
assert_eq!(selected.model_id, "default-model");
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn start_mock_openai_server() -> String {
|
async fn start_mock_openai_server() -> String {
|
||||||
async fn handle(Json(body): Json<Value>) -> Json<Value> {
|
async fn handle(Json(body): Json<Value>) -> Json<Value> {
|
||||||
let model = body
|
let model = body
|
||||||
@ -2983,21 +2752,4 @@ mod tests {
|
|||||||
assert_eq!(plan.preferences[0].content, "偏好简洁表达");
|
assert_eq!(plan.preferences[0].content, "偏好简洁表达");
|
||||||
assert_eq!(plan.behavior_patterns[0].content, "习惯先问方案再要代码");
|
assert_eq!(plan.behavior_patterns[0].content, "习惯先问方案再要代码");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_upsert_managed_agent_memory_block_inserts_before_reply_rules() {
|
|
||||||
let original =
|
|
||||||
"# PicoBot 代理配置\n\n## 身份\n- 你是 PicoBot。\n\n## 回复规则\n- 使用中文回复。\n";
|
|
||||||
let updated = upsert_managed_agent_memory_block(
|
|
||||||
original,
|
|
||||||
"### 用户事实\n- 用户在做AI产品\n\n### 用户偏好\n- 偏好简洁表达",
|
|
||||||
);
|
|
||||||
|
|
||||||
let managed_pos = updated.find(MANAGED_AGENT_MEMORY_BLOCK_START).unwrap();
|
|
||||||
let reply_rules_pos = updated.find("## 回复规则").unwrap();
|
|
||||||
assert!(managed_pos < reply_rules_pos);
|
|
||||||
assert!(updated.contains(MANAGED_AGENT_MEMORY_TITLE));
|
|
||||||
assert!(updated.contains("用户在做AI产品"));
|
|
||||||
assert!(updated.contains("偏好简洁表达"));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user