feat: 更新 /save 命令,支持保存全部消息到指定路径;修改命令处理逻辑以包含新参数

This commit is contained in:
oudecheng 2026-05-14 10:07:58 +08:00
parent b17ddd7556
commit 102a4a63c5
7 changed files with 439 additions and 55 deletions

View File

@ -43,14 +43,27 @@ impl InputAdapter for CliInputAdapter {
// 解析 /save 命令
if trimmed == "/save" {
return Ok(Some(Command::SaveSession { filepath: None }));
return Ok(Some(Command::SaveSession {
filepath: None,
include_all: false,
}));
}
if let Some(path) = trimmed.strip_prefix("/save ") {
let path = path.trim();
return Ok(Some(Command::SaveSession {
filepath: Some(path.to_string()),
}));
if let Some(args) = trimmed.strip_prefix("/save ") {
let args = args.trim();
// 解析参数:可能是 "all"、路径、或 "all 路径"
let (include_all, filepath) = if args == "all" {
// /save all - 保存全部消息
(true, None)
} else if args.starts_with("all ") {
// /save all <filepath> - 保存全部消息到指定路径
let path = args[4..].trim();
(true, Some(path.to_string()))
} else {
// /save <filepath> - 保存活跃消息到指定路径
(false, Some(args.to_string()))
};
return Ok(Some(Command::SaveSession { filepath, include_all }));
}
// 不是命令,返回 None
@ -192,7 +205,7 @@ mod tests {
assert!(result.is_some());
let cmd = result.unwrap();
assert!(matches!(cmd, Command::SaveSession { filepath: None }));
assert!(matches!(cmd, Command::SaveSession { filepath: None, include_all: false }));
}
#[test]
@ -207,7 +220,38 @@ mod tests {
assert!(matches!(
cmd,
Command::SaveSession {
filepath: Some(ref p)
filepath: Some(ref p),
include_all: false,
} if p == "./debug/session.md"
));
}
#[test]
fn test_cli_input_adapter_save_all() {
let adapter = CliInputAdapter::new();
let ctx = AdapterContext::new("test");
let result = adapter.try_parse("/save all", ctx).unwrap();
assert!(result.is_some());
let cmd = result.unwrap();
assert!(matches!(cmd, Command::SaveSession { filepath: None, include_all: true }));
}
#[test]
fn test_cli_input_adapter_save_all_with_path() {
let adapter = CliInputAdapter::new();
let ctx = AdapterContext::new("test");
let result = adapter.try_parse("/save all ./debug/session.md", ctx).unwrap();
assert!(result.is_some());
let cmd = result.unwrap();
assert!(matches!(
cmd,
Command::SaveSession {
filepath: Some(ref p),
include_all: true,
} if p == "./debug/session.md"
));
}

View File

@ -48,14 +48,15 @@ pub trait InChatCommandHandler: Send + Sync {
/// * `session_manager` - 会话管理器(用于获取 session
///
/// # Returns
/// * `Ok(())` - 命令执行成功
/// * `Ok(Some(msg))` - 命令执行成功,返回要发送给用户的消息
/// * `Ok(None)` - 命令执行成功,无需发送消息
/// * `Err(AgentError)` - 命令执行失败
async fn handle(
&self,
cmd: Command,
inbound: &InboundMessage,
session_manager: &SessionManager,
) -> Result<(), AgentError>;
) -> Result<Option<String>, AgentError>;
}
/// 命令路由器
@ -163,25 +164,25 @@ impl InChatCommandRouter {
/// * `session_manager` - 会话管理器
///
/// # Returns
/// * `Ok(true)` - 命令被处理
/// * `Ok(false)` - 没有合适的处理器
/// * `Ok(Some(msg))` - 命令被处理,返回成功消息
/// * `Ok(None)` - 没有合适的处理器
/// * `Err(AgentError)` - 执行失败
pub async fn dispatch(
&self,
cmd: Command,
inbound: &InboundMessage,
session_manager: &SessionManager,
) -> Result<bool, AgentError> {
) -> Result<Option<String>, AgentError> {
// 查找能处理此命令的处理器
for handler in &self.handlers {
if handler.can_handle(&cmd) {
handler.handle(cmd, inbound, session_manager).await?;
return Ok(true);
let result = handler.handle(cmd, inbound, session_manager).await?;
return Ok(result);
}
}
// 没有找到合适的处理器
Ok(false)
Ok(None)
}
}

View File

@ -6,10 +6,8 @@ use crate::command::response::{CommandError, CommandResponse, MessageKind};
use crate::command::Command;
use crate::config::LLMProviderConfig;
use crate::gateway::agent_prompt_provider::SimpleAgentPromptProvider;
use crate::gateway::session::SessionManager;
use crate::storage::{SessionRecord, SessionStore};
use crate::agent::AgentError;
use crate::bus::OutboundMessage;
use async_trait::async_trait;
use chrono::{Local, TimeZone};
use std::path::PathBuf;
@ -20,6 +18,7 @@ use std::sync::Arc;
/// # Arguments
/// * `session_id` - 会话ID
/// * `filepath` - 可选的文件路径
/// * `include_all` - 是否包含 cutoff 之前的所有消息
/// * `store` - 会话存储
/// * `provider_config` - LLM提供者配置
///
@ -28,6 +27,7 @@ use std::sync::Arc;
pub async fn save_session_to_file(
session_id: &str,
filepath: Option<String>,
include_all: bool,
store: &SessionStore,
provider_config: &LLMProviderConfig,
) -> Result<PathBuf, String> {
@ -37,10 +37,16 @@ pub async fn save_session_to_file(
.map_err(|e| format!("Failed to get session: {}", e))?
.ok_or_else(|| "Session not found".to_string())?;
// 获取所有消息(包括历史)
let messages = store
// 根据 include_all 决定加载消息范围
let messages = if include_all {
store
.load_all_messages(session_id)
.map_err(|e| format!("Failed to load messages: {}", e))?;
.map_err(|e| format!("Failed to load messages: {}", e))?
} else {
store
.load_messages(session_id)
.map_err(|e| format!("Failed to load messages: {}", e))?
};
// 计算用户消息数(用于系统提示词构建)
let user_message_count = messages.iter().filter(|m| m.role == "user").count();
@ -109,8 +115,8 @@ impl CommandHandler for SaveSessionCommandHandler {
ctx: CommandContext,
) -> Result<CommandResponse, CommandError> {
match cmd {
Command::SaveSession { filepath } => {
handle_save_session(self, filepath, ctx).await
Command::SaveSession { filepath, include_all } => {
handle_save_session(self, filepath, include_all, ctx).await
}
_ => unreachable!(),
}
@ -121,6 +127,7 @@ impl CommandHandler for SaveSessionCommandHandler {
async fn handle_save_session(
handler: &SaveSessionCommandHandler,
filepath: Option<String>,
include_all: bool,
ctx: CommandContext,
) -> Result<CommandResponse, CommandError> {
let session_id = ctx
@ -132,16 +139,23 @@ async fn handle_save_session(
let output_path = save_session_to_file(
session_id,
filepath,
include_all,
&*handler.store,
&handler.provider_config,
)
.await
.map_err(|e| CommandError::new("SAVE_ERROR", e))?;
// 获取消息数量用于返回
let message_count = handler
// 根据 include_all 获取消息数量
let message_count = if include_all {
handler
.store
.load_all_messages(session_id)
} else {
handler
.store
.load_messages(session_id)
}
.map_err(|e| CommandError::new("LOAD_MESSAGES_ERROR", e.to_string()))?
.len();
@ -377,9 +391,9 @@ impl InChatCommandHandler for SaveSessionInChatHandler {
cmd: Command,
inbound: &InboundMessage,
session_manager: &crate::gateway::session::SessionManager,
) -> Result<(), AgentError> {
let Command::SaveSession { filepath } = cmd else {
return Ok(());
) -> Result<Option<String>, AgentError> {
let Command::SaveSession { filepath, include_all } = cmd else {
return Ok(None);
};
// 通过 session_manager 获取 session
@ -387,7 +401,7 @@ impl InChatCommandHandler for SaveSessionInChatHandler {
Some(s) => s,
None => {
tracing::error!("Session not found for channel: {}", inbound.channel);
return Ok(());
return Ok(Some("Session not found".to_string()));
}
};
@ -398,21 +412,23 @@ impl InChatCommandHandler for SaveSessionInChatHandler {
let result = save_session_to_file(
&session_id,
filepath,
include_all,
&*self.store,
&self.provider_config,
)
.await;
// 结果通过返回 Ok(()) 表示成功
// 实际输出由调用者通过消息总线发送
// 返回成功或失败消息
match result {
Ok(output_path) => {
tracing::info!("Session saved to: {}", output_path.display());
Ok(())
let msg = format!("Session saved to: {}", output_path.display());
tracing::info!("{}", msg);
Ok(Some(msg))
}
Err(error) => {
tracing::error!("Failed to save session: {}", error);
Ok(())
let msg = format!("Failed to save session: {}", error);
tracing::error!("{}", msg);
Ok(Some(msg))
}
}
}
@ -527,7 +543,8 @@ mod tests {
let store = Arc::new(SessionStore::in_memory().unwrap());
let handler = SaveSessionCommandHandler::new(store, test_config());
assert!(handler.can_handle(&Command::SaveSession { filepath: None }));
assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: false }));
assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: true }));
assert!(!handler.can_handle(&Command::CreateSession { title: None }));
}
}

View File

@ -14,7 +14,10 @@ pub enum Command {
/// 创建新会话
CreateSession { title: Option<String> },
/// 保存会话内容到 Markdown 文件
SaveSession { filepath: Option<String> },
SaveSession {
filepath: Option<String>,
include_all: bool,
},
}
impl Command {

View File

@ -4,7 +4,7 @@ use tokio::sync::Semaphore;
use crate::agent::AgentError;
use crate::bus::{InboundMessage, MessageBus, OutboundMessage};
use crate::command::handler::{InChatCommandHandler, InChatCommandRouter};
use crate::command::handler::InChatCommandRouter;
use crate::command::Command;
use crate::config::LLMProviderConfig;
@ -94,10 +94,29 @@ impl InboundProcessor {
// 尝试解析为命令
if let Some(cmd) = parse_in_chat_command(&inbound.content) {
// 使用命令路由器处理
let handled = self.command_router.dispatch(cmd, &inbound, &self.session_manager).await?;
if handled {
match self.command_router.dispatch(cmd, &inbound, &self.session_manager).await? {
Some(response_msg) => {
// 发送命令执行结果给用户
if let Err(error) = self
.bus
.publish_outbound(OutboundMessage::assistant(
inbound.channel.clone(),
inbound.chat_id.clone(),
response_msg,
None,
inbound.forwarded_metadata.clone(),
))
.await
{
tracing::error!(error = %error, "Failed to publish command response");
}
return Ok(());
}
None => {
// 命令已处理但没有返回消息
return Ok(());
}
}
}
// 普通消息进入 AgentLoop
@ -156,7 +175,10 @@ impl InboundProcessor {
/// 解析聊天中的命令
///
/// 支持格式:
/// - `/save [filepath]` - 保存会话
/// - `/save` - 保存活跃会话消息(到 cutoff
/// - `/save all` - 保存全部会话消息(包括 cutoff 之前)
/// - `/save <filepath>` - 保存活跃消息到指定路径
/// - `/save all <filepath>` - 保存全部消息到指定路径
///
/// 返回 Some(Command) 如果是命令
/// 返回 None 如果不是命令
@ -164,13 +186,25 @@ fn parse_in_chat_command(content: &str) -> Option<Command> {
let trimmed = content.trim();
if trimmed.starts_with("/save") {
let path = trimmed[5..].trim();
let filepath = if path.is_empty() {
None
let args = trimmed[5..].trim();
// 解析参数
let (include_all, filepath) = if args.is_empty() {
// /save 无参数 - 只保存活跃消息
(false, None)
} else if args == "all" {
// /save all - 保存全部消息
(true, None)
} else if args.starts_with("all ") {
// /save all <filepath> - 保存全部消息到指定路径
let path = args[4..].trim();
(true, Some(path.to_string()))
} else {
Some(path.to_string())
// /save <filepath> - 保存活跃消息到指定路径
(false, Some(args.to_string()))
};
Some(Command::SaveSession { filepath })
Some(Command::SaveSession { filepath, include_all })
} else {
None
}

View File

@ -366,7 +366,7 @@ async fn handle_inbound(
};
// 构建命令
let cmd = crate::command::Command::SaveSession { filepath };
let cmd = crate::command::Command::SaveSession { filepath, include_all: true };
let cmd_ctx = CommandContext::new("websocket")
.with_session_id(&target_session_id);

View File

@ -31,7 +31,7 @@ impl Tool for SchedulerManageTool {
}
fn description(&self) -> &str {
"Manage repository-backed scheduled jobs. Supports actions: list, get, put, delete, pause, resume. Jobs are persisted by the configured scheduler job repository and executed by the scheduler runtime. \
"Manage repository-backed scheduled jobs. Supports actions: list, get, put, update, delete, pause, resume. Jobs are persisted by the configured scheduler job repository and executed by the scheduler runtime. \
\
When creating agent_task or silent_agent_task jobs, keep prompt/system_prompt focused on the work to perform; do not restate execution times unless the task logic truly depends on them, because the trigger already controls timing. For cron schedules, standard cron syntax is supported: use 1-5 for Monday-Friday, 0 or 7 for Sunday. \
\
@ -63,7 +63,7 @@ impl Tool for SchedulerManageTool {
"properties": {
"action": {
"type": "string",
"enum": ["list", "get", "put", "delete", "pause", "resume"]
"enum": ["list", "get", "put", "update", "delete", "pause", "resume"]
},
"id": {
"type": "string",
@ -179,6 +179,16 @@ impl Tool for SchedulerManageTool {
let saved = self.jobs.upsert_scheduler_job(&input)?;
record_to_json(&saved)
}
"update" => {
let id = require_str(&args, "id")?;
let record = self
.jobs
.get_scheduler_job(id)?
.ok_or_else(|| anyhow::anyhow!("scheduler job '{}' not found", id))?;
let input = build_update_upsert(context, &args, &self.known_agents, &record)?;
let saved = self.jobs.upsert_scheduler_job(&input)?;
record_to_json(&saved)
}
_ => return Ok(error_result("Unsupported action")),
};
@ -257,6 +267,72 @@ fn build_upsert(
})
}
fn build_update_upsert(
context: &crate::tools::ToolContext,
args: &serde_json::Value,
known_agents: &HashSet<String>,
existing: &crate::storage::SchedulerJobRecord,
) -> anyhow::Result<SchedulerJobUpsert> {
let mut upsert = record_to_upsert(existing);
if let Some(schedule_value) = args.get("schedule") {
let schedule: SchedulerSchedule = serde_json::from_value(schedule_value.clone())?;
schedule.validate(&upsert.id)?;
upsert.schedule = serde_json::to_value(&schedule)?;
let (interval_secs, startup_delay_secs) = match &schedule {
SchedulerSchedule::Interval {
seconds,
startup_delay_secs,
} => (*seconds as i64, *startup_delay_secs as i64),
_ => (0, 0),
};
upsert.interval_secs = interval_secs;
upsert.startup_delay_secs = startup_delay_secs;
upsert.next_fire_at = None;
}
if args.get("target").is_some() {
upsert.target = enrich_target_from_context(
args.get("target").cloned().unwrap_or_else(|| json!({})),
context,
);
}
if let Some(payload) = args.get("payload") {
upsert.payload = payload.clone();
}
if let Some(enabled) = args.get("enabled").and_then(|value| value.as_bool()) {
upsert.enabled = enabled;
upsert.state = if enabled {
SchedulerJobState::Scheduled
} else {
SchedulerJobState::Paused
};
if !enabled {
upsert.paused_at = Some(current_timestamp());
} else {
upsert.paused_at = None;
upsert.completed_at = None;
}
upsert.next_fire_at = None;
}
if args.get("max_runs").is_some() {
upsert.max_runs = args.get("max_runs").and_then(|value| value.as_i64());
}
if upsert.kind == "agent_task" || upsert.kind == "silent_agent_task" {
validate_agent_task_payload(&upsert.payload, known_agents)?;
validate_target_fields(&upsert.target, &["channel", "chat_id"], &upsert.kind)?;
} else if upsert.kind == "outbound_message" {
validate_outbound_message_payload(&upsert.payload)?;
validate_target_fields(&upsert.target, &["channel", "chat_id"], "outbound_message")?;
}
Ok(upsert)
}
fn enrich_target_from_context(
target: serde_json::Value,
context: &crate::tools::ToolContext,
@ -713,4 +789,213 @@ mod tests {
assert!(payload_description.contains("每天9点"));
assert!(payload_description.contains("每小时"));
}
#[tokio::test]
async fn test_scheduler_manage_update_partial_fields() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let tool = SchedulerManageTool::new(store.clone(), HashSet::new());
// First, create a job
let put_result = tool
.execute(json!({
"action": "put",
"id": "test_update_job",
"kind": "outbound_message",
"schedule": {
"type": "interval",
"seconds": 60
},
"target": {
"channel": "test-channel",
"chat_id": "oc_demo"
},
"payload": {
"content": "original message"
},
"max_runs": 10
}))
.await
.unwrap();
assert!(put_result.success);
// Update only payload
let update_result = tool
.execute(json!({
"action": "update",
"id": "test_update_job",
"payload": {
"content": "updated message"
}
}))
.await
.unwrap();
assert!(update_result.success);
assert!(update_result.output.contains("updated message"));
assert!(update_result.output.contains("test_update_job"));
// Verify other fields preserved
let get_result = tool
.execute(json!({
"action": "get",
"id": "test_update_job"
}))
.await
.unwrap();
assert!(get_result.success);
assert!(get_result.output.contains("interval"));
assert!(get_result.output.contains("test-channel"));
assert!(get_result.output.contains("max_runs\": 10"));
}
#[tokio::test]
async fn test_scheduler_manage_update_schedule() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let tool = SchedulerManageTool::new(store.clone(), HashSet::new());
// Create job
let _ = tool
.execute(json!({
"action": "put",
"id": "test_update_schedule",
"kind": "outbound_message",
"schedule": {
"type": "interval",
"seconds": 60
},
"target": {
"channel": "test",
"chat_id": "oc_demo"
},
"payload": { "content": "ping" }
}))
.await
.unwrap();
// Update schedule
let update_result = tool
.execute(json!({
"action": "update",
"id": "test_update_schedule",
"schedule": {
"type": "cron",
"expression": "0 9 * * *"
}
}))
.await
.unwrap();
assert!(update_result.success);
assert!(update_result.output.contains("cron"));
assert!(update_result.output.contains("0 9 * * *"));
}
#[tokio::test]
async fn test_scheduler_manage_update_enabled() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let tool = SchedulerManageTool::new(store.clone(), HashSet::new());
// Create enabled job
let _ = tool
.execute(json!({
"action": "put",
"id": "test_update_enabled",
"kind": "outbound_message",
"schedule": { "type": "interval", "seconds": 60 },
"target": { "channel": "test", "chat_id": "oc_demo" },
"payload": { "content": "ping" },
"enabled": true
}))
.await
.unwrap();
// Disable it
let update_result = tool
.execute(json!({
"action": "update",
"id": "test_update_enabled",
"enabled": false
}))
.await
.unwrap();
assert!(update_result.success);
assert!(update_result.output.contains("\"enabled\": false"));
assert!(update_result.output.contains("paused"));
// Re-enable it
let update_result = tool
.execute(json!({
"action": "update",
"id": "test_update_enabled",
"enabled": true
}))
.await
.unwrap();
assert!(update_result.success);
assert!(update_result.output.contains("\"enabled\": true"));
assert!(update_result.output.contains("scheduled"));
}
#[tokio::test]
async fn test_scheduler_manage_update_job_not_found() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let tool = SchedulerManageTool::new(store, HashSet::new());
let result = tool
.execute(json!({
"action": "update",
"id": "nonexistent_job",
"payload": { "content": "new" }
}))
.await;
assert!(result.is_err());
let error = result.err().unwrap().to_string();
assert!(error.contains("scheduler job 'nonexistent_job' not found"));
}
#[tokio::test]
async fn test_scheduler_manage_update_preserves_agent_task_agent() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let tool = SchedulerManageTool::new(store.clone(), HashSet::from(["planner".to_string()]));
// Create agent_task
let _ = tool
.execute(json!({
"action": "put",
"id": "test_update_agent",
"kind": "agent_task",
"schedule": { "type": "cron", "expression": "0 9 * * *" },
"target": { "channel": "test", "chat_id": "oc_demo" },
"payload": {
"prompt": "original task",
"agent": "planner"
}
}))
.await
.unwrap();
// Update only prompt
let update_result = tool
.execute(json!({
"action": "update",
"id": "test_update_agent",
"payload": {
"prompt": "updated task",
"agent": "planner"
}
}))
.await
.unwrap();
assert!(update_result.success);
// Verify agent preserved (when explicitly provided)
let get_result = tool
.execute(json!({
"action": "get",
"id": "test_update_agent"
}))
.await
.unwrap();
assert!(get_result.output.contains("planner"));
assert!(get_result.output.contains("updated task"));
}
}