feat: 添加 /save 命令以保存会话内容到 Markdown 文件;实现 SaveSessionCommandHandler 处理逻辑
This commit is contained in:
parent
73faaa95e4
commit
35b9c42d07
@ -12,6 +12,7 @@ pub enum InputCommand {
|
|||||||
Exit,
|
Exit,
|
||||||
Clear,
|
Clear,
|
||||||
New(Option<String>),
|
New(Option<String>),
|
||||||
|
Save(Option<String>),
|
||||||
Sessions,
|
Sessions,
|
||||||
Use(String),
|
Use(String),
|
||||||
Rename(String),
|
Rename(String),
|
||||||
@ -75,6 +76,7 @@ impl InputHandler {
|
|||||||
"/quit" | "/exit" | "/q" => Some(InputCommand::Exit),
|
"/quit" | "/exit" | "/q" => Some(InputCommand::Exit),
|
||||||
"/clear" => Some(InputCommand::Clear),
|
"/clear" => Some(InputCommand::Clear),
|
||||||
"/new" => Some(InputCommand::New(arg.map(ToOwned::to_owned))),
|
"/new" => Some(InputCommand::New(arg.map(ToOwned::to_owned))),
|
||||||
|
"/save" => Some(InputCommand::Save(arg.map(ToOwned::to_owned))),
|
||||||
"/sessions" => Some(InputCommand::Sessions),
|
"/sessions" => Some(InputCommand::Sessions),
|
||||||
"/use" => arg.map(|value| InputCommand::Use(value.to_string())),
|
"/use" => arg.map(|value| InputCommand::Use(value.to_string())),
|
||||||
"/rename" => arg.map(|value| InputCommand::Rename(value.to_string())),
|
"/rename" => arg.map(|value| InputCommand::Rename(value.to_string())),
|
||||||
@ -130,6 +132,14 @@ mod tests {
|
|||||||
handler.handle_special_commands("/new planning"),
|
handler.handle_special_commands("/new planning"),
|
||||||
Some(InputCommand::New(Some("planning".to_string())))
|
Some(InputCommand::New(Some("planning".to_string())))
|
||||||
);
|
);
|
||||||
|
assert_eq!(
|
||||||
|
handler.handle_special_commands("/save"),
|
||||||
|
Some(InputCommand::Save(None))
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
handler.handle_special_commands("/save ./debug/session.md"),
|
||||||
|
Some(InputCommand::Save(Some("./debug/session.md".to_string())))
|
||||||
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
handler.handle_special_commands("/sessions"),
|
handler.handle_special_commands("/sessions"),
|
||||||
Some(InputCommand::Sessions)
|
Some(InputCommand::Sessions)
|
||||||
|
|||||||
@ -54,7 +54,7 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
|
|
||||||
let mut input = InputHandler::new();
|
let mut input = InputHandler::new();
|
||||||
let mut current_session_id: Option<String> = None;
|
let mut current_session_id: Option<String> = None;
|
||||||
input.write_output("picobot CLI - Commands: /new [title], /reset, /sessions, /use <session>, /rename <title>, /archive, /delete, /clear, /quit\n").await?;
|
input.write_output("picobot CLI - Commands: /new [title], /save [filepath], /reset, /sessions, /use <session>, /rename <title>, /archive, /delete, /clear, /quit\n").await?;
|
||||||
|
|
||||||
// Main loop: poll both stdin and WebSocket
|
// Main loop: poll both stdin and WebSocket
|
||||||
loop {
|
loop {
|
||||||
@ -114,6 +114,9 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
WsOutbound::HistoryCleared { session_id } => {
|
WsOutbound::HistoryCleared { session_id } => {
|
||||||
input.write_output(&format!("Cleared history for session: {}\n", session_id)).await?;
|
input.write_output(&format!("Cleared history for session: {}\n", session_id)).await?;
|
||||||
}
|
}
|
||||||
|
WsOutbound::SessionSaved { session_id, filepath } => {
|
||||||
|
input.write_output(&format!("Saved session {} to: {}\n", session_id, filepath)).await?;
|
||||||
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -225,6 +228,16 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
InputEvent::Command(InputCommand::Save(filepath)) => {
|
||||||
|
let inbound = WsInbound::SaveSession {
|
||||||
|
filepath,
|
||||||
|
session_id: current_session_id.clone(),
|
||||||
|
};
|
||||||
|
if let Ok(text) = serialize_inbound(&inbound) {
|
||||||
|
let _ = sender.send(Message::Text(text.into())).await;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
InputEvent::Message(msg) => {
|
InputEvent::Message(msg) => {
|
||||||
let inbound = WsInbound::UserInput {
|
let inbound = WsInbound::UserInput {
|
||||||
content: msg.content,
|
content: msg.content,
|
||||||
|
|||||||
@ -41,6 +41,18 @@ impl InputAdapter for CliInputAdapter {
|
|||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 解析 /save 命令
|
||||||
|
if trimmed == "/save" {
|
||||||
|
return Ok(Some(Command::SaveSession { filepath: None }));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(path) = trimmed.strip_prefix("/save ") {
|
||||||
|
let path = path.trim();
|
||||||
|
return Ok(Some(Command::SaveSession {
|
||||||
|
filepath: Some(path.to_string()),
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
// 不是命令,返回 None
|
// 不是命令,返回 None
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
@ -170,4 +182,46 @@ mod tests {
|
|||||||
assert!(output.contains("Error [TEST_ERROR]"));
|
assert!(output.contains("Error [TEST_ERROR]"));
|
||||||
assert!(output.contains("something failed"));
|
assert!(output.contains("something failed"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cli_input_adapter_save_without_path() {
|
||||||
|
let adapter = CliInputAdapter::new();
|
||||||
|
let ctx = AdapterContext::new("test");
|
||||||
|
|
||||||
|
let result = adapter.try_parse("/save", ctx).unwrap();
|
||||||
|
|
||||||
|
assert!(result.is_some());
|
||||||
|
let cmd = result.unwrap();
|
||||||
|
assert!(matches!(cmd, Command::SaveSession { filepath: None }));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cli_input_adapter_save_with_path() {
|
||||||
|
let adapter = CliInputAdapter::new();
|
||||||
|
let ctx = AdapterContext::new("test");
|
||||||
|
|
||||||
|
let result = adapter.try_parse("/save ./debug/session.md", ctx).unwrap();
|
||||||
|
|
||||||
|
assert!(result.is_some());
|
||||||
|
let cmd = result.unwrap();
|
||||||
|
assert!(matches!(
|
||||||
|
cmd,
|
||||||
|
Command::SaveSession {
|
||||||
|
filepath: Some(ref p)
|
||||||
|
} if p == "./debug/session.md"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cli_output_adapter_save_success() {
|
||||||
|
let adapter = CliOutputAdapter::new();
|
||||||
|
let request_id = uuid::Uuid::new_v4();
|
||||||
|
let response = CommandResponse::success(request_id)
|
||||||
|
.with_message(MessageKind::Notification, "Session saved to: session.md")
|
||||||
|
.with_metadata("filepath", "session.md");
|
||||||
|
|
||||||
|
let output = adapter.adapt(response);
|
||||||
|
|
||||||
|
assert!(output.contains("Session saved to: session.md"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1 +1,2 @@
|
|||||||
|
pub mod save_session;
|
||||||
pub mod session;
|
pub mod session;
|
||||||
|
|||||||
433
src/command/handlers/save_session.rs
Normal file
433
src/command/handlers/save_session.rs
Normal file
@ -0,0 +1,433 @@
|
|||||||
|
use crate::agent::{SystemPrompt, SystemPromptContext, SystemPromptProvider};
|
||||||
|
use crate::command::context::CommandContext;
|
||||||
|
use crate::command::handler::CommandHandler;
|
||||||
|
use crate::command::response::{CommandError, CommandResponse, MessageKind};
|
||||||
|
use crate::command::Command;
|
||||||
|
use crate::config::LLMProviderConfig;
|
||||||
|
use crate::gateway::agent_prompt_provider::SimpleAgentPromptProvider;
|
||||||
|
use crate::storage::{SessionRecord, SessionStore};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use chrono::{Local, TimeZone};
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// 保存会话命令处理器
|
||||||
|
///
|
||||||
|
/// 将当前会话内容(系统提示词和消息历史)保存到 Markdown 文件
|
||||||
|
pub struct SaveSessionCommandHandler {
|
||||||
|
store: Arc<SessionStore>,
|
||||||
|
provider_config: LLMProviderConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SaveSessionCommandHandler {
|
||||||
|
/// 创建新的保存会话命令处理器
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `store` - 会话存储
|
||||||
|
/// * `provider_config` - LLM 提供者配置(用于构建系统提示词)
|
||||||
|
pub fn new(store: Arc<SessionStore>, provider_config: LLMProviderConfig) -> Self {
|
||||||
|
Self {
|
||||||
|
store,
|
||||||
|
provider_config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 从会话记录获取存储(用于测试)
|
||||||
|
#[cfg(test)]
|
||||||
|
fn store(&self) -> &Arc<SessionStore> {
|
||||||
|
&self.store
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl CommandHandler for SaveSessionCommandHandler {
|
||||||
|
fn can_handle(&self, cmd: &Command) -> bool {
|
||||||
|
matches!(cmd, Command::SaveSession { .. })
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle(
|
||||||
|
&self,
|
||||||
|
cmd: Command,
|
||||||
|
ctx: CommandContext,
|
||||||
|
) -> Result<CommandResponse, CommandError> {
|
||||||
|
match cmd {
|
||||||
|
Command::SaveSession { filepath } => {
|
||||||
|
handle_save_session(self, filepath, ctx).await
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 处理保存会话命令
|
||||||
|
async fn handle_save_session(
|
||||||
|
handler: &SaveSessionCommandHandler,
|
||||||
|
filepath: Option<String>,
|
||||||
|
ctx: CommandContext,
|
||||||
|
) -> Result<CommandResponse, CommandError> {
|
||||||
|
let session_id = ctx
|
||||||
|
.session_id
|
||||||
|
.as_deref()
|
||||||
|
.ok_or_else(|| CommandError::new("NO_SESSION", "No active session".to_string()))?;
|
||||||
|
|
||||||
|
// 获取会话记录
|
||||||
|
let record = handler
|
||||||
|
.store
|
||||||
|
.get_session(session_id)
|
||||||
|
.map_err(|e| CommandError::new("SESSION_ERROR", e.to_string()))?
|
||||||
|
.ok_or_else(|| CommandError::new("SESSION_NOT_FOUND", "Session not found".to_string()))?;
|
||||||
|
|
||||||
|
// 获取所有消息(包括历史)
|
||||||
|
let messages = handler
|
||||||
|
.store
|
||||||
|
.load_all_messages(session_id)
|
||||||
|
.map_err(|e| CommandError::new("LOAD_MESSAGES_ERROR", e.to_string()))?;
|
||||||
|
|
||||||
|
// 计算用户消息数(用于系统提示词构建)
|
||||||
|
let user_message_count = messages.iter().filter(|m| m.role == "user").count();
|
||||||
|
|
||||||
|
// 构建系统提示词
|
||||||
|
let system_prompt = build_system_prompt(&handler.provider_config, &record, user_message_count);
|
||||||
|
|
||||||
|
// 生成 Markdown 内容
|
||||||
|
let markdown = generate_markdown(&record, &system_prompt, &messages);
|
||||||
|
|
||||||
|
// 确定输出路径
|
||||||
|
let output_path = resolve_filepath(filepath, &record);
|
||||||
|
|
||||||
|
// 创建父目录
|
||||||
|
if let Some(parent) = output_path.parent() {
|
||||||
|
if !parent.as_os_str().is_empty() && !parent.exists() {
|
||||||
|
std::fs::create_dir_all(parent).map_err(|e| {
|
||||||
|
CommandError::new(
|
||||||
|
"CREATE_DIR_ERROR",
|
||||||
|
format!("Failed to create directory: {}", e),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 写入文件
|
||||||
|
std::fs::write(&output_path, markdown).map_err(|e| {
|
||||||
|
CommandError::new(
|
||||||
|
"WRITE_FILE_ERROR",
|
||||||
|
format!("Failed to write file: {}", e),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(CommandResponse::success(ctx.request_id)
|
||||||
|
.with_message(
|
||||||
|
MessageKind::Notification,
|
||||||
|
&format!("Session saved to: {}", output_path.display()),
|
||||||
|
)
|
||||||
|
.with_metadata("filepath", output_path.to_string_lossy().as_ref())
|
||||||
|
.with_metadata("message_count", &messages.len().to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 构建系统提示词
|
||||||
|
fn build_system_prompt(
|
||||||
|
provider_config: &LLMProviderConfig,
|
||||||
|
record: &SessionRecord,
|
||||||
|
user_message_count: usize,
|
||||||
|
) -> Option<SystemPrompt> {
|
||||||
|
let provider = SimpleAgentPromptProvider::new(provider_config.clone());
|
||||||
|
let context = SystemPromptContext {
|
||||||
|
session_id: Some(record.id.clone()),
|
||||||
|
chat_id: record.chat_id.clone(),
|
||||||
|
user_message_count,
|
||||||
|
};
|
||||||
|
provider.build(&context)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 生成 Markdown 内容
|
||||||
|
fn generate_markdown(
|
||||||
|
record: &SessionRecord,
|
||||||
|
system_prompt: &Option<SystemPrompt>,
|
||||||
|
messages: &[crate::bus::ChatMessage],
|
||||||
|
) -> String {
|
||||||
|
let mut output = String::new();
|
||||||
|
|
||||||
|
// YAML frontmatter
|
||||||
|
output.push_str("---\n");
|
||||||
|
output.push_str(&format!("title: {}\n", escape_yaml_string(&record.title)));
|
||||||
|
output.push_str(&format!("session_id: {}\n", record.id));
|
||||||
|
output.push_str(&format!("channel: {}\n", record.channel_name));
|
||||||
|
output.push_str(&format!("chat_id: {}\n", record.chat_id));
|
||||||
|
output.push_str(&format!(
|
||||||
|
"created_at: {}\n",
|
||||||
|
format_timestamp(record.created_at)
|
||||||
|
));
|
||||||
|
output.push_str(&format!(
|
||||||
|
"updated_at: {}\n",
|
||||||
|
format_timestamp(record.updated_at)
|
||||||
|
));
|
||||||
|
output.push_str(&format!(
|
||||||
|
"last_active_at: {}\n",
|
||||||
|
format_timestamp(record.last_active_at)
|
||||||
|
));
|
||||||
|
output.push_str(&format!("message_count: {}\n", messages.len()));
|
||||||
|
output.push_str("---\n\n");
|
||||||
|
|
||||||
|
// 系统提示词
|
||||||
|
output.push_str("# System Prompt\n\n");
|
||||||
|
if let Some(prompt) = system_prompt {
|
||||||
|
output.push_str("```\n");
|
||||||
|
output.push_str(&prompt.content);
|
||||||
|
output.push_str("\n```\n\n");
|
||||||
|
} else {
|
||||||
|
output.push_str("*No system prompt available*\n\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
// 消息历史
|
||||||
|
output.push_str("# Message History\n\n");
|
||||||
|
|
||||||
|
for (idx, msg) in messages.iter().enumerate() {
|
||||||
|
output.push_str(&format!("## Message {}\n\n", idx + 1));
|
||||||
|
output.push_str(&format!("**Role:** {}\n\n", msg.role));
|
||||||
|
output.push_str(&format!("**ID:** {}\n\n", msg.id));
|
||||||
|
output.push_str(&format!(
|
||||||
|
"**Time:** {}\n\n",
|
||||||
|
format_timestamp(msg.timestamp)
|
||||||
|
));
|
||||||
|
|
||||||
|
if let Some(ref ctx) = msg.system_context {
|
||||||
|
output.push_str(&format!("**System Context:** `{}`\n\n", ctx));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref tool_name) = msg.tool_name {
|
||||||
|
output.push_str(&format!("**Tool Name:** `{}`\n\n", tool_name));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref tool_call_id) = msg.tool_call_id {
|
||||||
|
output.push_str(&format!("**Tool Call ID:** `{}`\n\n", tool_call_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref reasoning) = msg.reasoning_content {
|
||||||
|
output.push_str("### Reasoning\n\n");
|
||||||
|
output.push_str("```\n");
|
||||||
|
output.push_str(reasoning);
|
||||||
|
output.push_str("\n```\n\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Content
|
||||||
|
output.push_str("### Content\n\n");
|
||||||
|
if msg.content.is_empty() {
|
||||||
|
output.push_str("*empty*\n\n");
|
||||||
|
} else {
|
||||||
|
output.push_str(&format!("{}\n\n", format_message_content(&msg.content)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tool calls
|
||||||
|
if let Some(ref calls) = msg.tool_calls {
|
||||||
|
if !calls.is_empty() {
|
||||||
|
output.push_str("### Tool Calls\n\n");
|
||||||
|
for call in calls {
|
||||||
|
output.push_str(&format!("- **{}** (`{}`)\n", call.name, call.id));
|
||||||
|
output.push_str(" ```json\n");
|
||||||
|
let args_json = serde_json::to_string_pretty(&call.arguments)
|
||||||
|
.unwrap_or_else(|_| call.arguments.to_string());
|
||||||
|
for line in args_json.lines() {
|
||||||
|
output.push_str(&format!(" {}\n", line));
|
||||||
|
}
|
||||||
|
output.push_str(" ```\n");
|
||||||
|
}
|
||||||
|
output.push('\n');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Media refs
|
||||||
|
if !msg.media_refs.is_empty() {
|
||||||
|
output.push_str("### Media References\n\n");
|
||||||
|
for media_ref in &msg.media_refs {
|
||||||
|
output.push_str(&format!("- `{}`\n", media_ref));
|
||||||
|
}
|
||||||
|
output.push('\n');
|
||||||
|
}
|
||||||
|
|
||||||
|
output.push_str("---\n\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 格式化消息内容
|
||||||
|
///
|
||||||
|
/// 如果内容包含特殊字符,使用代码块包装
|
||||||
|
fn format_message_content(content: &str) -> String {
|
||||||
|
// 如果内容包含代码块标记或表格标记,使用原始格式
|
||||||
|
if content.contains("```") || content.contains("| ") {
|
||||||
|
format!("```\n{}\n```", content)
|
||||||
|
} else {
|
||||||
|
content.to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 转义 YAML 字符串
|
||||||
|
fn escape_yaml_string(s: &str) -> String {
|
||||||
|
if s.contains('\n') || s.contains('"') || s.contains(':') || s.starts_with(' ') {
|
||||||
|
// 使用双引号包裹并转义内部的双引号
|
||||||
|
format!("\"{}\"", s.replace('"', "\\\""))
|
||||||
|
} else {
|
||||||
|
s.to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 格式化时间戳
|
||||||
|
fn format_timestamp(ts: i64) -> String {
|
||||||
|
Local
|
||||||
|
.timestamp_millis_opt(ts)
|
||||||
|
.single()
|
||||||
|
.map(|dt| dt.format("%Y-%m-%d %H:%M:%S").to_string())
|
||||||
|
.unwrap_or_else(|| format!("{}", ts))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 解析文件路径
|
||||||
|
///
|
||||||
|
/// 如果未提供路径,自动生成基于会话标题和时间戳的文件名
|
||||||
|
fn resolve_filepath(filepath: Option<String>, record: &SessionRecord) -> PathBuf {
|
||||||
|
match filepath {
|
||||||
|
Some(path) => PathBuf::from(path),
|
||||||
|
None => {
|
||||||
|
// 生成安全标题(替换特殊字符)
|
||||||
|
let safe_title = record
|
||||||
|
.title
|
||||||
|
.replace(' ', "_")
|
||||||
|
.replace('/', "_")
|
||||||
|
.replace('\\', "_")
|
||||||
|
.replace(':', "_")
|
||||||
|
.replace('<', "_")
|
||||||
|
.replace('>', "_")
|
||||||
|
.replace('|', "_")
|
||||||
|
.replace('?', "_")
|
||||||
|
.replace('*', "_")
|
||||||
|
.replace('"', "_");
|
||||||
|
|
||||||
|
// 使用标题或 session_id 作为文件名
|
||||||
|
let base_name = if safe_title.is_empty() {
|
||||||
|
format!("session_{}", &record.id[..8.min(record.id.len())])
|
||||||
|
} else {
|
||||||
|
safe_title
|
||||||
|
};
|
||||||
|
|
||||||
|
// 添加时间戳
|
||||||
|
let timestamp = Local::now().format("%Y%m%d_%H%M%S");
|
||||||
|
let filename = format!("{}_{}.md", base_name, timestamp);
|
||||||
|
|
||||||
|
PathBuf::from(filename)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::storage::{SessionRecord, SessionStore};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
fn test_config() -> LLMProviderConfig {
|
||||||
|
LLMProviderConfig {
|
||||||
|
provider_type: "openai".to_string(),
|
||||||
|
name: "test".to_string(),
|
||||||
|
base_url: "http://localhost".to_string(),
|
||||||
|
api_key: "test-key".to_string(),
|
||||||
|
extra_headers: HashMap::new(),
|
||||||
|
llm_timeout_secs: 120,
|
||||||
|
memory_maintenance_timeout_secs: 600,
|
||||||
|
model_id: "test-model".to_string(),
|
||||||
|
temperature: Some(0.0),
|
||||||
|
max_tokens: Some(32),
|
||||||
|
context_window_tokens: None,
|
||||||
|
tool_result_max_chars: 20_000,
|
||||||
|
context_tool_result_trim_chars: 20_000,
|
||||||
|
model_extra: HashMap::new(),
|
||||||
|
max_tool_iterations: 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_test_record(id: &str, title: &str) -> SessionRecord {
|
||||||
|
SessionRecord {
|
||||||
|
id: id.to_string(),
|
||||||
|
title: title.to_string(),
|
||||||
|
channel_name: "cli".to_string(),
|
||||||
|
chat_id: id.to_string(),
|
||||||
|
summary: None,
|
||||||
|
created_at: 1705312800000, // 2024-01-15 10:00:00
|
||||||
|
updated_at: 1705316400000, // 2024-01-15 11:00:00
|
||||||
|
last_active_at: 1705316400000,
|
||||||
|
archived_at: None,
|
||||||
|
deleted_at: None,
|
||||||
|
message_count: 0,
|
||||||
|
reset_cutoff_seq: 0,
|
||||||
|
user_turn_count: 0,
|
||||||
|
agent_prompt_reinjection_count: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_resolve_filepath_with_custom_path() {
|
||||||
|
let record = create_test_record("test-123", "My Session");
|
||||||
|
let path = resolve_filepath(Some("/custom/path/file.md".to_string()), &record);
|
||||||
|
assert_eq!(path, PathBuf::from("/custom/path/file.md"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_resolve_filepath_generates_filename_with_title() {
|
||||||
|
let record = create_test_record("test-123", "My Session");
|
||||||
|
let path = resolve_filepath(None, &record);
|
||||||
|
let filename = path.file_name().unwrap().to_str().unwrap();
|
||||||
|
|
||||||
|
assert!(filename.starts_with("My_Session_"));
|
||||||
|
assert!(filename.ends_with(".md"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_resolve_filepath_generates_filename_with_id_when_title_empty() {
|
||||||
|
let record = create_test_record("abc12345-xyz", "");
|
||||||
|
let path = resolve_filepath(None, &record);
|
||||||
|
let filename = path.file_name().unwrap().to_str().unwrap();
|
||||||
|
|
||||||
|
assert!(filename.starts_with("session_abc123"));
|
||||||
|
assert!(filename.ends_with(".md"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_escape_yaml_string() {
|
||||||
|
assert_eq!(escape_yaml_string("simple"), "simple");
|
||||||
|
assert_eq!(escape_yaml_string("with: colon"), "\"with: colon\"");
|
||||||
|
assert_eq!(escape_yaml_string("with \"quote\""), "\"with \\\"quote\\\"\"");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_format_message_content() {
|
||||||
|
assert_eq!(format_message_content("hello"), "hello");
|
||||||
|
assert_eq!(
|
||||||
|
format_message_content("```code```"),
|
||||||
|
"```\n```code```\n```"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_generate_markdown_structure() {
|
||||||
|
let record = create_test_record("test-123", "Test Session");
|
||||||
|
let messages = vec![crate::bus::ChatMessage::system("System prompt here")];
|
||||||
|
|
||||||
|
let markdown = generate_markdown(&record, &None, &messages);
|
||||||
|
|
||||||
|
assert!(markdown.contains("---"));
|
||||||
|
assert!(markdown.contains("title:"));
|
||||||
|
assert!(markdown.contains("session_id: test-123"));
|
||||||
|
assert!(markdown.contains("# System Prompt"));
|
||||||
|
assert!(markdown.contains("# Message History"));
|
||||||
|
assert!(markdown.contains("## Message 1"));
|
||||||
|
assert!(markdown.contains("**Role:** system"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_can_handle() {
|
||||||
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
|
let handler = SaveSessionCommandHandler::new(store, test_config());
|
||||||
|
|
||||||
|
assert!(handler.can_handle(&Command::SaveSession { filepath: None }));
|
||||||
|
assert!(!handler.can_handle(&Command::CreateSession { title: None }));
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -35,6 +35,7 @@ impl CommandHandler for SessionCommandHandler {
|
|||||||
) -> Result<CommandResponse, CommandError> {
|
) -> Result<CommandResponse, CommandError> {
|
||||||
match cmd {
|
match cmd {
|
||||||
Command::CreateSession { title } => handle_create_session(self, title, ctx).await,
|
Command::CreateSession { title } => handle_create_session(self, title, ctx).await,
|
||||||
|
Command::SaveSession { .. } => unreachable!("SaveSession should be handled by SaveSessionCommandHandler"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -11,8 +11,10 @@ use serde::{Deserialize, Serialize};
|
|||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
#[serde(tag = "type", rename_all = "snake_case")]
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
pub enum Command {
|
pub enum Command {
|
||||||
/// 目前仅实现 /new 命令
|
/// 创建新会话
|
||||||
CreateSession { title: Option<String> },
|
CreateSession { title: Option<String> },
|
||||||
|
/// 保存会话内容到 Markdown 文件
|
||||||
|
SaveSession { filepath: Option<String> },
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Command {
|
impl Command {
|
||||||
@ -20,6 +22,7 @@ impl Command {
|
|||||||
pub fn name(&self) -> &'static str {
|
pub fn name(&self) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
Command::CreateSession { .. } => "create_session",
|
Command::CreateSession { .. } => "create_session",
|
||||||
|
Command::SaveSession { .. } => "save_session",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,6 +5,7 @@ use crate::command::adapter::OutputAdapter;
|
|||||||
use crate::command::adapters::websocket::{WebSocketInputAdapter, WebSocketOutputAdapter};
|
use crate::command::adapters::websocket::{WebSocketInputAdapter, WebSocketOutputAdapter};
|
||||||
use crate::command::context::CommandContext;
|
use crate::command::context::CommandContext;
|
||||||
use crate::command::handler::CommandRouter;
|
use crate::command::handler::CommandRouter;
|
||||||
|
use crate::command::handlers::save_session::SaveSessionCommandHandler;
|
||||||
use crate::command::handlers::session::SessionCommandHandler;
|
use crate::command::handlers::session::SessionCommandHandler;
|
||||||
use crate::protocol::{SessionSummary, WsInbound, WsOutbound, parse_inbound, serialize_outbound};
|
use crate::protocol::{SessionSummary, WsInbound, WsOutbound, parse_inbound, serialize_outbound};
|
||||||
use axum::extract::State;
|
use axum::extract::State;
|
||||||
@ -348,6 +349,56 @@ async fn handle_inbound(
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
WsInbound::SaveSession { filepath, session_id } => {
|
||||||
|
let target_session_id = session_id.unwrap_or_else(|| current_session_id.clone());
|
||||||
|
|
||||||
|
// 获取所需依赖
|
||||||
|
let store = state.session_manager.store();
|
||||||
|
let provider_config = state.config.get_provider_config("default")
|
||||||
|
.map_err(|e| AgentError::Other(e.to_string()))?;
|
||||||
|
|
||||||
|
// 构建处理器
|
||||||
|
let handler = SaveSessionCommandHandler::new(store, provider_config);
|
||||||
|
let router = {
|
||||||
|
let mut r = CommandRouter::new();
|
||||||
|
r.register(Box::new(handler));
|
||||||
|
r
|
||||||
|
};
|
||||||
|
|
||||||
|
// 构建命令
|
||||||
|
let cmd = crate::command::Command::SaveSession { filepath };
|
||||||
|
let cmd_ctx = CommandContext::new("websocket")
|
||||||
|
.with_session_id(&target_session_id);
|
||||||
|
|
||||||
|
// 执行命令
|
||||||
|
let response = router.dispatch_with_response(cmd, cmd_ctx).await;
|
||||||
|
|
||||||
|
// 处理响应
|
||||||
|
if response.success {
|
||||||
|
let filepath = response
|
||||||
|
.metadata
|
||||||
|
.get("filepath")
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or_default();
|
||||||
|
let _ = sender
|
||||||
|
.send(WsOutbound::SessionSaved {
|
||||||
|
session_id: target_session_id,
|
||||||
|
filepath,
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
} else {
|
||||||
|
let error = response.error.unwrap_or_else(|| {
|
||||||
|
crate::command::response::CommandError::new("SAVE_ERROR", "Unknown error")
|
||||||
|
});
|
||||||
|
let _ = sender
|
||||||
|
.send(WsOutbound::Error {
|
||||||
|
code: error.code,
|
||||||
|
message: error.message,
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
WsInbound::Ping => {
|
WsInbound::Ping => {
|
||||||
let _ = sender.send(WsOutbound::Pong).await;
|
let _ = sender.send(WsOutbound::Pong).await;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|||||||
@ -62,6 +62,13 @@ pub enum WsInbound {
|
|||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
session_id: Option<String>,
|
session_id: Option<String>,
|
||||||
},
|
},
|
||||||
|
#[serde(rename = "save_session")]
|
||||||
|
SaveSession {
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
filepath: Option<String>,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
session_id: Option<String>,
|
||||||
|
},
|
||||||
#[serde(rename = "ping")]
|
#[serde(rename = "ping")]
|
||||||
Ping,
|
Ping,
|
||||||
}
|
}
|
||||||
@ -127,6 +134,8 @@ pub enum WsOutbound {
|
|||||||
SessionDeleted { session_id: String },
|
SessionDeleted { session_id: String },
|
||||||
#[serde(rename = "history_cleared")]
|
#[serde(rename = "history_cleared")]
|
||||||
HistoryCleared { session_id: String },
|
HistoryCleared { session_id: String },
|
||||||
|
#[serde(rename = "session_saved")]
|
||||||
|
SessionSaved { session_id: String, filepath: String },
|
||||||
#[serde(rename = "pong")]
|
#[serde(rename = "pong")]
|
||||||
Pong,
|
Pong,
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user