Refactor command handling and input adapters
- Removed the `format_session_list` function and related session handling from the client module. - Simplified command output in the client by removing session-related commands. - Introduced `ChannelInputAdapter` for parsing channel commands like `/new` and `/save`. - Updated WebSocket handling to process commands via the new command system. - Removed deprecated in-chat command handling from the gateway. - Adjusted tests to reflect changes in command serialization and session handling. - Enhanced session cleanup and job scheduling in the configuration module.
This commit is contained in:
parent
b33350c410
commit
b77fc93d71
@ -10,14 +10,8 @@ pub enum InputEvent {
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum InputCommand {
|
||||
Exit,
|
||||
Clear,
|
||||
New(Option<String>),
|
||||
Save(Option<String>),
|
||||
Sessions,
|
||||
Use(String),
|
||||
Rename(String),
|
||||
Archive,
|
||||
Delete,
|
||||
}
|
||||
|
||||
pub struct InputHandler {
|
||||
@ -74,14 +68,8 @@ impl InputHandler {
|
||||
|
||||
match command {
|
||||
"/quit" | "/exit" | "/q" => Some(InputCommand::Exit),
|
||||
"/clear" => Some(InputCommand::Clear),
|
||||
"/new" => Some(InputCommand::New(arg.map(ToOwned::to_owned))),
|
||||
"/save" => Some(InputCommand::Save(arg.map(ToOwned::to_owned))),
|
||||
"/sessions" => Some(InputCommand::Sessions),
|
||||
"/use" => arg.map(|value| InputCommand::Use(value.to_string())),
|
||||
"/rename" => arg.map(|value| InputCommand::Rename(value.to_string())),
|
||||
"/archive" => Some(InputCommand::Archive),
|
||||
"/delete" => Some(InputCommand::Delete),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
@ -120,10 +108,6 @@ mod tests {
|
||||
handler.handle_special_commands("/quit"),
|
||||
Some(InputCommand::Exit)
|
||||
);
|
||||
assert_eq!(
|
||||
handler.handle_special_commands("/clear"),
|
||||
Some(InputCommand::Clear)
|
||||
);
|
||||
assert_eq!(
|
||||
handler.handle_special_commands("/new"),
|
||||
Some(InputCommand::New(None))
|
||||
@ -140,27 +124,6 @@ mod tests {
|
||||
handler.handle_special_commands("/save ./debug/session.md"),
|
||||
Some(InputCommand::Save(Some("./debug/session.md".to_string())))
|
||||
);
|
||||
assert_eq!(
|
||||
handler.handle_special_commands("/sessions"),
|
||||
Some(InputCommand::Sessions)
|
||||
);
|
||||
assert_eq!(
|
||||
handler.handle_special_commands("/use abc123"),
|
||||
Some(InputCommand::Use("abc123".to_string()))
|
||||
);
|
||||
assert_eq!(
|
||||
handler.handle_special_commands("/rename project alpha"),
|
||||
Some(InputCommand::Rename("project alpha".to_string()))
|
||||
);
|
||||
assert_eq!(
|
||||
handler.handle_special_commands("/archive"),
|
||||
Some(InputCommand::Archive)
|
||||
);
|
||||
assert_eq!(
|
||||
handler.handle_special_commands("/delete"),
|
||||
Some(InputCommand::Delete)
|
||||
);
|
||||
assert_eq!(handler.handle_special_commands("/unknown"), None);
|
||||
assert_eq!(handler.handle_special_commands("/use"), None);
|
||||
}
|
||||
}
|
||||
|
||||
@ -8,36 +8,6 @@ use tokio_tungstenite::{connect_async, tungstenite::Message};
|
||||
|
||||
use crate::cli::{InputCommand, InputEvent, InputHandler};
|
||||
|
||||
fn format_session_list(
|
||||
sessions: &[crate::protocol::SessionSummary],
|
||||
current_session_id: Option<&str>,
|
||||
) -> String {
|
||||
if sessions.is_empty() {
|
||||
return "No sessions found.".to_string();
|
||||
}
|
||||
|
||||
let mut lines = Vec::with_capacity(sessions.len() + 1);
|
||||
lines.push("Sessions:".to_string());
|
||||
for session in sessions {
|
||||
let marker = if current_session_id == Some(session.session_id.as_str()) {
|
||||
"*"
|
||||
} else {
|
||||
"-"
|
||||
};
|
||||
let archived = if session.archived_at.is_some() {
|
||||
" [archived]"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
lines.push(format!(
|
||||
"{} {} | {} | {} messages{}",
|
||||
marker, session.session_id, session.title, session.message_count, archived,
|
||||
));
|
||||
}
|
||||
|
||||
lines.join("\n")
|
||||
}
|
||||
|
||||
fn parse_message(raw: &str) -> Result<WsOutbound, serde_json::Error> {
|
||||
serde_json::from_str(raw)
|
||||
}
|
||||
@ -54,7 +24,7 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
|
||||
let mut input = InputHandler::new();
|
||||
let mut current_session_id: Option<String> = None;
|
||||
input.write_output("picobot CLI - Commands: /new [title], /save [filepath], /reset, /sessions, /use <session>, /rename <title>, /archive, /delete, /clear, /quit\n").await?;
|
||||
input.write_output("picobot CLI - Commands: /new [title], /save [filepath], /quit\n").await?;
|
||||
|
||||
// Main loop: poll both stdin and WebSocket
|
||||
loop {
|
||||
@ -91,29 +61,6 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
current_session_id = Some(session_id.clone());
|
||||
input.write_output(&format!("Created session: {} ({})\n", session_id, title)).await?;
|
||||
}
|
||||
WsOutbound::SessionList { sessions, current_session_id: listed_current } => {
|
||||
let display = format_session_list(&sessions, listed_current.as_deref());
|
||||
input.write_output(&format!("{}\n", display)).await?;
|
||||
}
|
||||
WsOutbound::SessionLoaded { session_id, title, message_count } => {
|
||||
current_session_id = Some(session_id.clone());
|
||||
input.write_output(&format!("Loaded session: {} ({}, {} messages)\n", session_id, title, message_count)).await?;
|
||||
}
|
||||
WsOutbound::SessionRenamed { session_id, title } => {
|
||||
input.write_output(&format!("Renamed session: {} -> {}\n", session_id, title)).await?;
|
||||
}
|
||||
WsOutbound::SessionArchived { session_id } => {
|
||||
input.write_output(&format!("Archived session: {}\n", session_id)).await?;
|
||||
}
|
||||
WsOutbound::SessionDeleted { session_id } => {
|
||||
if current_session_id.as_deref() == Some(session_id.as_str()) {
|
||||
current_session_id = None;
|
||||
}
|
||||
input.write_output(&format!("Deleted session: {}\n", session_id)).await?;
|
||||
}
|
||||
WsOutbound::HistoryCleared { session_id } => {
|
||||
input.write_output(&format!("Cleared history for session: {}\n", session_id)).await?;
|
||||
}
|
||||
WsOutbound::SessionSaved { session_id, filepath } => {
|
||||
input.write_output(&format!("Saved session {} to: {}\n", session_id, filepath)).await?;
|
||||
}
|
||||
@ -138,39 +85,25 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
input.write_output("Goodbye!").await?;
|
||||
break;
|
||||
}
|
||||
InputEvent::Command(InputCommand::Clear) => {
|
||||
let inbound = WsInbound::ClearHistory {
|
||||
chat_id: None,
|
||||
session_id: current_session_id.clone(),
|
||||
};
|
||||
if let Ok(text) = serialize_inbound(&inbound) {
|
||||
let _ = sender.send(Message::Text(text.into())).await;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
InputEvent::Command(InputCommand::New(title)) => {
|
||||
// 使用新的命令层:通过 CliInputAdapter 构建 Command
|
||||
// 使用 CliInputAdapter 构建 Command
|
||||
let adapter = CliInputAdapter::new();
|
||||
let ctx = AdapterContext::new("cli")
|
||||
.with_session_id(current_session_id.as_deref().unwrap_or(""));
|
||||
|
||||
// 构建输入字符串
|
||||
let input = match title {
|
||||
let input_str = match title {
|
||||
Some(t) => format!("/new {}", t),
|
||||
None => "/new".to_string(),
|
||||
};
|
||||
|
||||
// 解析为 Command
|
||||
match adapter.try_parse(&input, ctx) {
|
||||
match adapter.try_parse(&input_str, ctx) {
|
||||
Ok(Some(command)) => {
|
||||
// 序列化为 JSON 通过 WebSocket 发送
|
||||
// 序列化为 JSON
|
||||
let json = serde_json::to_string(&command).unwrap_or_default();
|
||||
let inbound = WsInbound::UserInput {
|
||||
content: json,
|
||||
channel: None,
|
||||
chat_id: current_session_id.clone(),
|
||||
sender_id: None,
|
||||
};
|
||||
// 通过 Command 消息发送
|
||||
let inbound = WsInbound::Command { payload: json };
|
||||
if let Ok(text) = serialize_inbound(&inbound) {
|
||||
let _ = sender.send(Message::Text(text.into())).await;
|
||||
}
|
||||
@ -184,62 +117,40 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
}
|
||||
continue;
|
||||
}
|
||||
InputEvent::Command(InputCommand::Sessions) => {
|
||||
let inbound = WsInbound::ListSessions {
|
||||
include_archived: true,
|
||||
};
|
||||
if let Ok(text) = serialize_inbound(&inbound) {
|
||||
let _ = sender.send(Message::Text(text.into())).await;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
InputEvent::Command(InputCommand::Use(session_id)) => {
|
||||
let inbound = WsInbound::LoadSession { session_id };
|
||||
if let Ok(text) = serialize_inbound(&inbound) {
|
||||
let _ = sender.send(Message::Text(text.into())).await;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
InputEvent::Command(InputCommand::Rename(title)) => {
|
||||
let inbound = WsInbound::RenameSession {
|
||||
session_id: current_session_id.clone(),
|
||||
title,
|
||||
};
|
||||
if let Ok(text) = serialize_inbound(&inbound) {
|
||||
let _ = sender.send(Message::Text(text.into())).await;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
InputEvent::Command(InputCommand::Archive) => {
|
||||
let inbound = WsInbound::ArchiveSession {
|
||||
session_id: current_session_id.clone(),
|
||||
};
|
||||
if let Ok(text) = serialize_inbound(&inbound) {
|
||||
let _ = sender.send(Message::Text(text.into())).await;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
InputEvent::Command(InputCommand::Delete) => {
|
||||
let inbound = WsInbound::DeleteSession {
|
||||
session_id: current_session_id.clone(),
|
||||
};
|
||||
if let Ok(text) = serialize_inbound(&inbound) {
|
||||
let _ = sender.send(Message::Text(text.into())).await;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
InputEvent::Command(InputCommand::Save(filepath)) => {
|
||||
let inbound = WsInbound::SaveSession {
|
||||
filepath,
|
||||
session_id: current_session_id.clone(),
|
||||
// 使用 CliInputAdapter 构建 Command
|
||||
let adapter = CliInputAdapter::new();
|
||||
let ctx = AdapterContext::new("cli")
|
||||
.with_session_id(current_session_id.as_deref().unwrap_or(""));
|
||||
|
||||
// 构建输入字符串
|
||||
let input_str = match filepath {
|
||||
Some(p) => format!("/save {}", p),
|
||||
None => "/save".to_string(),
|
||||
};
|
||||
if let Ok(text) = serialize_inbound(&inbound) {
|
||||
let _ = sender.send(Message::Text(text.into())).await;
|
||||
|
||||
// 解析为 Command
|
||||
match adapter.try_parse(&input_str, ctx) {
|
||||
Ok(Some(command)) => {
|
||||
// 序列化为 JSON
|
||||
let json = serde_json::to_string(&command).unwrap_or_default();
|
||||
// 通过 Command 消息发送
|
||||
let inbound = WsInbound::Command { payload: json };
|
||||
if let Ok(text) = serialize_inbound(&inbound) {
|
||||
let _ = sender.send(Message::Text(text.into())).await;
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
tracing::warn!("Failed to parse /save command");
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Error parsing /save command");
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
InputEvent::Message(msg) => {
|
||||
let inbound = WsInbound::UserInput {
|
||||
let inbound = WsInbound::Message {
|
||||
content: msg.content,
|
||||
channel: None,
|
||||
chat_id: current_session_id.clone(),
|
||||
|
||||
187
src/command/adapters/channel.rs
Normal file
187
src/command/adapters/channel.rs
Normal file
@ -0,0 +1,187 @@
|
||||
use crate::command::adapter::{AdapterError, InputAdapter};
|
||||
use crate::command::context::AdapterContext;
|
||||
use crate::command::Command;
|
||||
|
||||
/// Channel 输入适配器
|
||||
///
|
||||
/// 将 Channel 消息中的文本命令(如 "/new", "/save")转换为 Command
|
||||
pub struct ChannelInputAdapter;
|
||||
|
||||
impl ChannelInputAdapter {
|
||||
/// 创建新的 Channel 输入适配器
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ChannelInputAdapter {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl InputAdapter for ChannelInputAdapter {
|
||||
fn try_parse(
|
||||
&self,
|
||||
input: &str,
|
||||
_ctx: AdapterContext,
|
||||
) -> Result<Option<Command>, AdapterError> {
|
||||
let trimmed = input.trim();
|
||||
|
||||
// 解析 /new 命令
|
||||
if trimmed == "/new" {
|
||||
return Ok(Some(Command::CreateSession { title: None }));
|
||||
}
|
||||
|
||||
if let Some(title) = trimmed.strip_prefix("/new ") {
|
||||
let title = title.trim();
|
||||
return Ok(Some(Command::CreateSession {
|
||||
title: Some(title.to_string()),
|
||||
}));
|
||||
}
|
||||
|
||||
// 解析 /save 命令
|
||||
if trimmed == "/save" {
|
||||
return Ok(Some(Command::SaveSession {
|
||||
filepath: None,
|
||||
include_all: false,
|
||||
}));
|
||||
}
|
||||
|
||||
if let Some(args) = trimmed.strip_prefix("/save ") {
|
||||
let args = args.trim();
|
||||
// 解析参数:可能是 "all"、路径、或 "all 路径"
|
||||
let (include_all, filepath) = if args == "all" {
|
||||
// /save all - 保存全部消息
|
||||
(true, None)
|
||||
} else if args.starts_with("all ") {
|
||||
// /save all <filepath> - 保存全部消息到指定路径
|
||||
let path = args[4..].trim();
|
||||
(true, Some(path.to_string()))
|
||||
} else {
|
||||
// /save <filepath> - 保存活跃消息到指定路径
|
||||
(false, Some(args.to_string()))
|
||||
};
|
||||
return Ok(Some(Command::SaveSession { filepath, include_all }));
|
||||
}
|
||||
|
||||
// 不是命令,返回 None
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_channel_adapter_new_without_title() {
|
||||
let adapter = ChannelInputAdapter::new();
|
||||
let ctx = AdapterContext::new("channel");
|
||||
|
||||
let result = adapter.try_parse("/new", ctx).unwrap();
|
||||
|
||||
assert!(result.is_some());
|
||||
let cmd = result.unwrap();
|
||||
assert!(matches!(cmd, Command::CreateSession { title: None }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_channel_adapter_new_with_title() {
|
||||
let adapter = ChannelInputAdapter::new();
|
||||
let ctx = AdapterContext::new("channel");
|
||||
|
||||
let result = adapter.try_parse("/new planning session", ctx).unwrap();
|
||||
|
||||
assert!(result.is_some());
|
||||
let cmd = result.unwrap();
|
||||
assert!(matches!(
|
||||
cmd,
|
||||
Command::CreateSession {
|
||||
title: Some(ref t)
|
||||
} if t == "planning session"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_channel_adapter_save_without_path() {
|
||||
let adapter = ChannelInputAdapter::new();
|
||||
let ctx = AdapterContext::new("channel");
|
||||
|
||||
let result = adapter.try_parse("/save", ctx).unwrap();
|
||||
|
||||
assert!(result.is_some());
|
||||
let cmd = result.unwrap();
|
||||
assert!(matches!(
|
||||
cmd,
|
||||
Command::SaveSession {
|
||||
filepath: None,
|
||||
include_all: false,
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_channel_adapter_save_with_path() {
|
||||
let adapter = ChannelInputAdapter::new();
|
||||
let ctx = AdapterContext::new("channel");
|
||||
|
||||
let result = adapter.try_parse("/save ./session.md", ctx).unwrap();
|
||||
|
||||
assert!(result.is_some());
|
||||
let cmd = result.unwrap();
|
||||
assert!(matches!(
|
||||
cmd,
|
||||
Command::SaveSession {
|
||||
filepath: Some(ref p),
|
||||
include_all: false,
|
||||
} if p == "./session.md"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_channel_adapter_save_all() {
|
||||
let adapter = ChannelInputAdapter::new();
|
||||
let ctx = AdapterContext::new("channel");
|
||||
|
||||
let result = adapter.try_parse("/save all", ctx).unwrap();
|
||||
|
||||
assert!(result.is_some());
|
||||
let cmd = result.unwrap();
|
||||
assert!(matches!(
|
||||
cmd,
|
||||
Command::SaveSession {
|
||||
filepath: None,
|
||||
include_all: true,
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_channel_adapter_save_all_with_path() {
|
||||
let adapter = ChannelInputAdapter::new();
|
||||
let ctx = AdapterContext::new("channel");
|
||||
|
||||
let result = adapter.try_parse("/save all ./session.md", ctx).unwrap();
|
||||
|
||||
assert!(result.is_some());
|
||||
let cmd = result.unwrap();
|
||||
assert!(matches!(
|
||||
cmd,
|
||||
Command::SaveSession {
|
||||
filepath: Some(ref p),
|
||||
include_all: true,
|
||||
} if p == "./session.md"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_channel_adapter_not_command() {
|
||||
let adapter = ChannelInputAdapter::new();
|
||||
let ctx = AdapterContext::new("channel");
|
||||
|
||||
let result = adapter.try_parse("hello world", ctx).unwrap();
|
||||
|
||||
assert!(result.is_none());
|
||||
}
|
||||
}
|
||||
@ -1,2 +1,3 @@
|
||||
pub mod channel;
|
||||
pub mod cli;
|
||||
pub mod websocket;
|
||||
|
||||
@ -1472,7 +1472,7 @@ mod tests {
|
||||
assert!(config.scheduler.jobs.is_empty());
|
||||
|
||||
let effective_jobs = config.scheduler.effective_jobs(&config.time);
|
||||
assert_eq!(effective_jobs.len(), 1);
|
||||
assert_eq!(effective_jobs.len(), 2);
|
||||
assert_eq!(effective_jobs[0].id, BUILTIN_MEMORY_MAINTENANCE_JOB_ID);
|
||||
assert_eq!(effective_jobs[0].kind, SchedulerJobKind::InternalEvent);
|
||||
assert_eq!(
|
||||
@ -1481,6 +1481,8 @@ mod tests {
|
||||
expression: "0 */4 * * *".to_string(),
|
||||
}
|
||||
);
|
||||
// 第二个内置作业是会话清理
|
||||
assert_eq!(effective_jobs[1].id, BUILTIN_SESSION_CLEANUP_JOB_ID);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -1516,7 +1518,8 @@ mod tests {
|
||||
let effective_jobs = scheduler.effective_jobs(&TimeConfig {
|
||||
timezone: "Asia/Shanghai".to_string(),
|
||||
});
|
||||
assert_eq!(effective_jobs.len(), 2);
|
||||
assert_eq!(effective_jobs.len(), 3); // 2个内置 + 1个自定义
|
||||
// 第一个作业:内存维护(被覆盖为禁用)
|
||||
assert_eq!(effective_jobs[0].id, BUILTIN_MEMORY_MAINTENANCE_JOB_ID);
|
||||
assert!(!effective_jobs[0].enabled);
|
||||
assert_eq!(
|
||||
@ -1525,7 +1528,11 @@ mod tests {
|
||||
expression: "15 2 * * *".to_string(),
|
||||
}
|
||||
);
|
||||
assert_eq!(effective_jobs[1].id, "custom.reminder");
|
||||
// 第二个作业:会话清理(保持默认)
|
||||
assert_eq!(effective_jobs[1].id, BUILTIN_SESSION_CLEANUP_JOB_ID);
|
||||
assert!(effective_jobs[1].enabled);
|
||||
// 第三个作业:自定义提醒
|
||||
assert_eq!(effective_jobs[2].id, "custom.reminder");
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@ -1,165 +1,2 @@
|
||||
use crate::agent::AgentError;
|
||||
|
||||
use super::session::Session;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum InChatCommand {
|
||||
FreshConversation,
|
||||
}
|
||||
|
||||
fn parse_in_chat_command(content: &str) -> Option<InChatCommand> {
|
||||
match content.trim() {
|
||||
"/new" | "/reset" => Some(InChatCommand::FreshConversation),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn handle_in_chat_command(
|
||||
session: &mut Session,
|
||||
chat_id: &str,
|
||||
content: &str,
|
||||
) -> Result<Option<String>, AgentError> {
|
||||
match parse_in_chat_command(content) {
|
||||
Some(InChatCommand::FreshConversation) => {
|
||||
session.reset_chat_context(chat_id)?;
|
||||
Ok(Some("Started a fresh conversation.".to_string()))
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::skills::SkillRuntime;
|
||||
use crate::storage::SessionStore;
|
||||
use crate::tools::ToolRegistry;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
const TEST_CHANNEL: &str = "test-channel";
|
||||
|
||||
fn test_provider_config() -> LLMProviderConfig {
|
||||
LLMProviderConfig {
|
||||
provider_type: "openai".to_string(),
|
||||
name: "test".to_string(),
|
||||
base_url: "http://localhost".to_string(),
|
||||
api_key: "test-key".to_string(),
|
||||
extra_headers: HashMap::new(),
|
||||
llm_timeout_secs: 120,
|
||||
memory_maintenance_timeout_secs: 600,
|
||||
model_id: "test-model".to_string(),
|
||||
temperature: Some(0.0),
|
||||
max_tokens: Some(32),
|
||||
context_window_tokens: None,
|
||||
model_extra: HashMap::new(),
|
||||
max_tool_iterations: 1,
|
||||
tool_result_max_chars: 20_000,
|
||||
context_tool_result_trim_chars: 20_000,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_in_chat_command_aliases() {
|
||||
assert_eq!(
|
||||
parse_in_chat_command("/new"),
|
||||
Some(InChatCommand::FreshConversation)
|
||||
);
|
||||
assert_eq!(
|
||||
parse_in_chat_command(" /reset \n"),
|
||||
Some(InChatCommand::FreshConversation)
|
||||
);
|
||||
assert_eq!(parse_in_chat_command("/new planning"), None);
|
||||
assert_eq!(parse_in_chat_command("please /reset"), None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_handle_in_chat_command_resets_active_history_only() {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||
let skills = Arc::new(SkillRuntime::default());
|
||||
let tools = Arc::new(ToolRegistry::new());
|
||||
let mut session = Session::new(
|
||||
TEST_CHANNEL.to_string(),
|
||||
test_provider_config(),
|
||||
user_tx,
|
||||
tools,
|
||||
skills,
|
||||
store.clone(),
|
||||
100,
|
||||
Some(4),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
session.ensure_persistent_session("chat-1").unwrap();
|
||||
session.ensure_chat_loaded("chat-1").unwrap();
|
||||
session
|
||||
.append_persisted_message("chat-1", ChatMessage::user("hello"))
|
||||
.unwrap();
|
||||
|
||||
let response = handle_in_chat_command(&mut session, "chat-1", "/reset")
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response, "Started a fresh conversation.");
|
||||
assert!(session.get_history("chat-1").unwrap().is_empty());
|
||||
assert!(
|
||||
store
|
||||
.load_messages(&session.persistent_session_id("chat-1"))
|
||||
.unwrap()
|
||||
.is_empty()
|
||||
);
|
||||
assert_eq!(
|
||||
store
|
||||
.load_all_messages(&session.persistent_session_id("chat-1"))
|
||||
.unwrap()
|
||||
.len(),
|
||||
// 新设计:系统提示词不再持久化,只有 1 条用户消息
|
||||
1,
|
||||
);
|
||||
|
||||
session.ensure_chat_loaded("chat-1").unwrap();
|
||||
let history = session.get_history("chat-1").unwrap();
|
||||
// 新设计:系统提示词不再持久化到历史记录
|
||||
assert_eq!(history.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_reset_reinjects_agent_prompt_before_next_user_message() {
|
||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||||
let skills = Arc::new(SkillRuntime::default());
|
||||
let tools = Arc::new(ToolRegistry::new());
|
||||
let mut session = Session::new(
|
||||
TEST_CHANNEL.to_string(),
|
||||
test_provider_config(),
|
||||
user_tx,
|
||||
tools,
|
||||
skills,
|
||||
store,
|
||||
100,
|
||||
Some(4),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
session.ensure_persistent_session("chat-1").unwrap();
|
||||
session.ensure_chat_loaded("chat-1").unwrap();
|
||||
session
|
||||
.append_persisted_message("chat-1", ChatMessage::user("hello"))
|
||||
.unwrap();
|
||||
|
||||
handle_in_chat_command(&mut session, "chat-1", "/reset").unwrap();
|
||||
session
|
||||
.ensure_agent_prompt_before_user_message("chat-1")
|
||||
.unwrap();
|
||||
|
||||
// 新设计:系统提示词不再持久化到历史记录
|
||||
let history = session.get_history("chat-1").unwrap();
|
||||
assert_eq!(history.len(), 0);
|
||||
}
|
||||
}
|
||||
// 此文件已废弃,InChatCommand 功能已合并到 Command 系统
|
||||
// 保留文件以避免破坏现有 import,但内容为空
|
||||
|
||||
@ -7,7 +7,6 @@ use crate::bus::{ChatMessage, MediaItem, OutboundMessage, SYSTEM_CONTEXT_SCHEDUL
|
||||
use crate::config::LLMProviderConfig;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use super::command::handle_in_chat_command;
|
||||
use super::compaction::schedule_background_history_compaction;
|
||||
use super::message_prepare::enrich_user_content_with_media_refs;
|
||||
use super::session::Session;
|
||||
@ -138,18 +137,6 @@ impl AgentExecutionService {
|
||||
session_guard.ensure_persistent_session(request.chat_id)?;
|
||||
session_guard.ensure_chat_loaded(request.chat_id)?;
|
||||
|
||||
if let Some(command_response) =
|
||||
handle_in_chat_command(&mut session_guard, request.chat_id, request.content)?
|
||||
{
|
||||
return Ok(vec![OutboundMessage::assistant(
|
||||
request.channel_name.to_string(),
|
||||
request.chat_id.to_string(),
|
||||
command_response,
|
||||
None,
|
||||
HashMap::new(),
|
||||
)]);
|
||||
}
|
||||
|
||||
session_guard.ensure_agent_prompt_before_user_message(request.chat_id)?;
|
||||
|
||||
let media_refs: Vec<String> = request
|
||||
|
||||
@ -4,8 +4,11 @@ use tokio::sync::Semaphore;
|
||||
|
||||
use crate::agent::{AgentError, CompositeSystemPromptProvider};
|
||||
use crate::bus::{InboundMessage, MessageBus, OutboundMessage};
|
||||
use crate::command::handler::InChatCommandRouter;
|
||||
use crate::command::Command;
|
||||
use crate::command::adapter::InputAdapter;
|
||||
use crate::command::adapters::channel::ChannelInputAdapter;
|
||||
use crate::command::handler::CommandRouter;
|
||||
use crate::command::handlers::save_session::SaveSessionCommandHandler;
|
||||
use crate::command::handlers::session::SessionCommandHandler;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::gateway::agent_prompt_provider::AgentPromptProvider;
|
||||
use crate::skills::SkillPromptProvider;
|
||||
@ -18,7 +21,7 @@ pub struct InboundProcessor {
|
||||
session_manager: SessionManager,
|
||||
semaphore: Arc<Semaphore>,
|
||||
_provider_config: LLMProviderConfig,
|
||||
command_router: Arc<InChatCommandRouter>,
|
||||
command_router: Arc<CommandRouter>,
|
||||
}
|
||||
|
||||
impl InboundProcessor {
|
||||
@ -29,7 +32,11 @@ impl InboundProcessor {
|
||||
provider_config: LLMProviderConfig,
|
||||
) -> Self {
|
||||
// 创建命令路由器并注册处理器
|
||||
let mut command_router = InChatCommandRouter::new();
|
||||
let mut command_router = CommandRouter::new();
|
||||
|
||||
// 注册 Session 处理器
|
||||
let cli_sessions = session_manager.cli_sessions();
|
||||
command_router.register(Box::new(SessionCommandHandler::new(cli_sessions)));
|
||||
|
||||
// 注册 save_session 处理器
|
||||
let store = session_manager.store();
|
||||
@ -43,7 +50,7 @@ impl InboundProcessor {
|
||||
)),
|
||||
Box::new(SkillPromptProvider::new(skills)),
|
||||
]));
|
||||
command_router.register(Box::new(crate::command::handlers::save_session::SaveSessionInChatHandler::new(
|
||||
command_router.register(Box::new(SaveSessionCommandHandler::new(
|
||||
store,
|
||||
system_prompt_provider,
|
||||
)));
|
||||
@ -103,18 +110,28 @@ impl InboundProcessor {
|
||||
}
|
||||
|
||||
async fn process_one(&self, inbound: InboundMessage) -> Result<(), AgentError> {
|
||||
// 尝试解析为命令
|
||||
if let Some(cmd) = parse_in_chat_command(&inbound.content) {
|
||||
// 使用 ChannelInputAdapter 尝试解析命令
|
||||
let adapter = ChannelInputAdapter::new();
|
||||
let ctx = crate::command::context::AdapterContext::new(&inbound.channel)
|
||||
.with_session_id(&inbound.chat_id);
|
||||
|
||||
if let Ok(Some(cmd)) = adapter.try_parse(&inbound.content, ctx) {
|
||||
// 使用命令路由器处理
|
||||
match self.command_router.dispatch(cmd, &inbound, &self.session_manager).await? {
|
||||
Some(response_msg) => {
|
||||
// 发送命令执行结果给用户
|
||||
let cmd_ctx = crate::command::context::CommandContext::new(&inbound.channel)
|
||||
.with_session_id(&inbound.chat_id);
|
||||
|
||||
let response = self.command_router.dispatch_with_response(cmd, cmd_ctx).await;
|
||||
|
||||
// 发送响应给用户
|
||||
if response.success {
|
||||
// 提取响应消息
|
||||
for msg in &response.messages {
|
||||
if let Err(error) = self
|
||||
.bus
|
||||
.publish_outbound(OutboundMessage::assistant(
|
||||
inbound.channel.clone(),
|
||||
inbound.chat_id.clone(),
|
||||
response_msg,
|
||||
msg.content.clone(),
|
||||
None,
|
||||
inbound.forwarded_metadata.clone(),
|
||||
))
|
||||
@ -122,13 +139,23 @@ impl InboundProcessor {
|
||||
{
|
||||
tracing::error!(error = %error, "Failed to publish command response");
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
None => {
|
||||
// 命令已处理但没有返回消息
|
||||
return Ok(());
|
||||
} else if let Some(error) = response.error {
|
||||
if let Err(e) = self
|
||||
.bus
|
||||
.publish_outbound(OutboundMessage::assistant(
|
||||
inbound.channel.clone(),
|
||||
inbound.chat_id.clone(),
|
||||
format!("Error [{}]: {}", error.code, error.message),
|
||||
None,
|
||||
inbound.forwarded_metadata.clone(),
|
||||
))
|
||||
.await
|
||||
{
|
||||
tracing::error!(error = %e, "Failed to publish error response");
|
||||
}
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// 普通消息进入 AgentLoop
|
||||
@ -183,41 +210,3 @@ impl InboundProcessor {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// 解析聊天中的命令
|
||||
///
|
||||
/// 支持格式:
|
||||
/// - `/save` - 保存活跃会话消息(到 cutoff)
|
||||
/// - `/save all` - 保存全部会话消息(包括 cutoff 之前)
|
||||
/// - `/save <filepath>` - 保存活跃消息到指定路径
|
||||
/// - `/save all <filepath>` - 保存全部消息到指定路径
|
||||
///
|
||||
/// 返回 Some(Command) 如果是命令
|
||||
/// 返回 None 如果不是命令
|
||||
fn parse_in_chat_command(content: &str) -> Option<Command> {
|
||||
let trimmed = content.trim();
|
||||
|
||||
if trimmed.starts_with("/save") {
|
||||
let args = trimmed[5..].trim();
|
||||
|
||||
// 解析参数
|
||||
let (include_all, filepath) = if args.is_empty() {
|
||||
// /save 无参数 - 只保存活跃消息
|
||||
(false, None)
|
||||
} else if args == "all" {
|
||||
// /save all - 保存全部消息
|
||||
(true, None)
|
||||
} else if args.starts_with("all ") {
|
||||
// /save all <filepath> - 保存全部消息到指定路径
|
||||
let path = args[4..].trim();
|
||||
(true, Some(path.to_string()))
|
||||
} else {
|
||||
// /save <filepath> - 保存活跃消息到指定路径
|
||||
(false, Some(args.to_string()))
|
||||
};
|
||||
|
||||
Some(Command::SaveSession { filepath, include_all })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,14 +1,14 @@
|
||||
use super::GatewayState;
|
||||
use crate::agent::{AgentError, CompositeSystemPromptProvider};
|
||||
use crate::bus::InboundMessage;
|
||||
use crate::command::adapter::OutputAdapter;
|
||||
use crate::command::adapter::{InputAdapter, OutputAdapter};
|
||||
use crate::command::adapters::websocket::{WebSocketInputAdapter, WebSocketOutputAdapter};
|
||||
use crate::command::context::CommandContext;
|
||||
use crate::command::handler::CommandRouter;
|
||||
use crate::command::handlers::save_session::SaveSessionCommandHandler;
|
||||
use crate::command::handlers::session::SessionCommandHandler;
|
||||
use crate::gateway::agent_prompt_provider::AgentPromptProvider;
|
||||
use crate::protocol::{SessionSummary, WsInbound, WsOutbound, parse_inbound, serialize_outbound};
|
||||
use crate::protocol::{WsInbound, WsOutbound, parse_inbound, serialize_outbound};
|
||||
use crate::skills::SkillPromptProvider;
|
||||
use axum::extract::State;
|
||||
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
|
||||
@ -125,17 +125,6 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
||||
tracing::info!(session_id = %runtime_session_id, current_session_id = %current_session_id, "CLI session ended");
|
||||
}
|
||||
|
||||
fn to_session_summary(record: crate::storage::SessionRecord) -> SessionSummary {
|
||||
SessionSummary {
|
||||
session_id: record.id,
|
||||
title: record.title,
|
||||
channel_name: record.channel_name,
|
||||
chat_id: record.chat_id,
|
||||
message_count: record.message_count,
|
||||
last_active_at: record.last_active_at,
|
||||
archived_at: record.archived_at,
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_inbound(
|
||||
state: &Arc<GatewayState>,
|
||||
@ -143,9 +132,9 @@ async fn handle_inbound(
|
||||
runtime_session_id: &str,
|
||||
current_session_id: &mut String,
|
||||
inbound: WsInbound,
|
||||
) -> Result<(), crate::agent::AgentError> {
|
||||
) -> Result<(), AgentError> {
|
||||
match inbound {
|
||||
WsInbound::UserInput {
|
||||
WsInbound::Message {
|
||||
content,
|
||||
chat_id,
|
||||
sender_id,
|
||||
@ -181,53 +170,73 @@ async fn handle_inbound(
|
||||
|
||||
Ok(())
|
||||
}
|
||||
WsInbound::ClearHistory {
|
||||
session_id,
|
||||
chat_id,
|
||||
} => {
|
||||
let target = session_id
|
||||
.or(chat_id)
|
||||
.unwrap_or_else(|| current_session_id.clone());
|
||||
state
|
||||
.session_manager
|
||||
.cli_sessions()
|
||||
.clear_messages(&target)?;
|
||||
|
||||
if let Some(session) = state.session_manager.get(CLI_CHANNEL_NAME).await {
|
||||
session.lock().await.remove_history(&target);
|
||||
}
|
||||
|
||||
let _ = sender
|
||||
.send(WsOutbound::HistoryCleared { session_id: target })
|
||||
.await;
|
||||
Ok(())
|
||||
}
|
||||
WsInbound::CreateSession { title } => {
|
||||
// 使用新的命令层处理
|
||||
let _input_adapter = WebSocketInputAdapter::new();
|
||||
WsInbound::Command { payload } => {
|
||||
// 使用 Command 系统处理命令
|
||||
let input_adapter = WebSocketInputAdapter::new();
|
||||
let output_adapter = WebSocketOutputAdapter::new();
|
||||
let cli_sessions = state.session_manager.cli_sessions();
|
||||
let handler = SessionCommandHandler::new(cli_sessions);
|
||||
let router = {
|
||||
let mut r = CommandRouter::new();
|
||||
r.register(Box::new(handler));
|
||||
r
|
||||
|
||||
// 解析命令
|
||||
let adapter_ctx = crate::command::context::AdapterContext::new("websocket")
|
||||
.with_session_id(current_session_id.as_str());
|
||||
|
||||
let cmd = match input_adapter.try_parse(&payload, adapter_ctx) {
|
||||
Ok(Some(cmd)) => cmd,
|
||||
Ok(None) => {
|
||||
// 不是命令,返回错误
|
||||
let _ = sender
|
||||
.send(WsOutbound::Error {
|
||||
code: "INVALID_COMMAND".to_string(),
|
||||
message: "Invalid command payload".to_string(),
|
||||
})
|
||||
.await;
|
||||
return Ok(());
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = sender
|
||||
.send(WsOutbound::Error {
|
||||
code: "PARSE_ERROR".to_string(),
|
||||
message: e.to_string(),
|
||||
})
|
||||
.await;
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
// 构建命令
|
||||
let cmd = crate::command::Command::CreateSession { title };
|
||||
// 创建命令路由器
|
||||
let cli_sessions = state.session_manager.cli_sessions();
|
||||
let store = state.session_manager.store();
|
||||
let skills = state.session_manager.skills();
|
||||
let provider_config = state.config.get_provider_config("default")
|
||||
.map_err(|e| AgentError::Other(e.to_string()))?;
|
||||
let prompt_repository = state.session_manager.store().clone();
|
||||
|
||||
let system_prompt_provider: Arc<dyn crate::agent::SystemPromptProvider> = Arc::new(CompositeSystemPromptProvider::new(vec![
|
||||
Box::new(AgentPromptProvider::new(
|
||||
0,
|
||||
provider_config.clone(),
|
||||
prompt_repository.clone(),
|
||||
)),
|
||||
Box::new(SkillPromptProvider::new(skills)),
|
||||
]));
|
||||
|
||||
let mut router = CommandRouter::new();
|
||||
router.register(Box::new(SessionCommandHandler::new(cli_sessions.clone())));
|
||||
router.register(Box::new(SaveSessionCommandHandler::new(
|
||||
store,
|
||||
system_prompt_provider,
|
||||
)));
|
||||
|
||||
// 构建命令上下文
|
||||
let cmd_ctx = CommandContext::new("websocket")
|
||||
.with_session_id(current_session_id.as_str());
|
||||
|
||||
// 执行命令
|
||||
let response = router.dispatch_with_response(cmd, cmd_ctx).await;
|
||||
|
||||
// 适配输出
|
||||
let outbounds = output_adapter.adapt(response);
|
||||
|
||||
// 处理响应
|
||||
for msg in outbounds {
|
||||
if let WsOutbound::SessionCreated { session_id, title: _ } = &msg {
|
||||
if response.success {
|
||||
// 更新当前会话 ID(如果是创建会话)
|
||||
if let Some(session_id) = response.metadata.get("session_id") {
|
||||
*current_session_id = session_id.clone();
|
||||
state
|
||||
.channel_manager
|
||||
@ -239,178 +248,14 @@ async fn handle_inbound(
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
// 适配并发送响应
|
||||
let outbounds = output_adapter.adapt(response);
|
||||
for msg in outbounds {
|
||||
let _ = sender.send(msg).await;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
WsInbound::ListSessions { include_archived } => {
|
||||
let records = state
|
||||
.session_manager
|
||||
.cli_sessions()
|
||||
.list(include_archived)?;
|
||||
let summaries = records.into_iter().map(to_session_summary).collect();
|
||||
|
||||
let _ = sender
|
||||
.send(WsOutbound::SessionList {
|
||||
sessions: summaries,
|
||||
current_session_id: Some(current_session_id.clone()),
|
||||
})
|
||||
.await;
|
||||
Ok(())
|
||||
}
|
||||
WsInbound::LoadSession { session_id } => {
|
||||
let Some(record) = state.session_manager.cli_sessions().get(&session_id)? else {
|
||||
let _ = sender
|
||||
.send(WsOutbound::Error {
|
||||
code: "SESSION_NOT_FOUND".to_string(),
|
||||
message: format!("Session not found: {}", session_id),
|
||||
})
|
||||
.await;
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
*current_session_id = record.id.clone();
|
||||
state
|
||||
.channel_manager
|
||||
.cli_channel()
|
||||
.register_connection(
|
||||
record.id.clone(),
|
||||
runtime_session_id.to_string(),
|
||||
sender.clone(),
|
||||
)
|
||||
.await;
|
||||
let _ = sender
|
||||
.send(WsOutbound::SessionLoaded {
|
||||
session_id: record.id,
|
||||
title: record.title,
|
||||
message_count: record.message_count,
|
||||
})
|
||||
.await;
|
||||
Ok(())
|
||||
}
|
||||
WsInbound::RenameSession { session_id, title } => {
|
||||
let target = session_id.unwrap_or_else(|| current_session_id.clone());
|
||||
state
|
||||
.session_manager
|
||||
.cli_sessions()
|
||||
.rename(&target, &title)?;
|
||||
let _ = sender
|
||||
.send(WsOutbound::SessionRenamed {
|
||||
session_id: target,
|
||||
title,
|
||||
})
|
||||
.await;
|
||||
Ok(())
|
||||
}
|
||||
WsInbound::ArchiveSession { session_id } => {
|
||||
let target = session_id.unwrap_or_else(|| current_session_id.clone());
|
||||
state.session_manager.cli_sessions().archive(&target)?;
|
||||
let _ = sender
|
||||
.send(WsOutbound::SessionArchived { session_id: target })
|
||||
.await;
|
||||
Ok(())
|
||||
}
|
||||
WsInbound::DeleteSession { session_id } => {
|
||||
let target = session_id.unwrap_or_else(|| current_session_id.clone());
|
||||
state.session_manager.cli_sessions().delete(&target)?;
|
||||
|
||||
let replacement = if target == *current_session_id {
|
||||
Some(state.session_manager.cli_sessions().create(None)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let Some(session) = state.session_manager.get(CLI_CHANNEL_NAME).await {
|
||||
session.lock().await.remove_history(&target);
|
||||
}
|
||||
|
||||
let _ = sender
|
||||
.send(WsOutbound::SessionDeleted {
|
||||
session_id: target.clone(),
|
||||
})
|
||||
.await;
|
||||
|
||||
if let Some(record) = replacement {
|
||||
*current_session_id = record.id.clone();
|
||||
state
|
||||
.channel_manager
|
||||
.cli_channel()
|
||||
.register_connection(
|
||||
record.id.clone(),
|
||||
runtime_session_id.to_string(),
|
||||
sender.clone(),
|
||||
)
|
||||
.await;
|
||||
let _ = sender
|
||||
.send(WsOutbound::SessionCreated {
|
||||
session_id: record.id,
|
||||
title: record.title,
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
WsInbound::SaveSession { filepath, session_id } => {
|
||||
let target_session_id = session_id.unwrap_or_else(|| current_session_id.clone());
|
||||
|
||||
// 获取所需依赖
|
||||
let store = state.session_manager.store();
|
||||
let skills = state.session_manager.skills();
|
||||
let provider_config = state.config.get_provider_config("default")
|
||||
.map_err(|e| AgentError::Other(e.to_string()))?;
|
||||
let prompt_repository = state.session_manager.store().clone();
|
||||
|
||||
// 构建组合系统提示词提供者(与运行时一致)
|
||||
let system_prompt_provider: Arc<dyn crate::agent::SystemPromptProvider> = Arc::new(CompositeSystemPromptProvider::new(vec![
|
||||
Box::new(AgentPromptProvider::new(
|
||||
0, // save_session 不需要 reinject 逻辑
|
||||
provider_config.clone(),
|
||||
prompt_repository,
|
||||
)),
|
||||
Box::new(SkillPromptProvider::new(skills)),
|
||||
]));
|
||||
|
||||
// 构建处理器
|
||||
let handler = SaveSessionCommandHandler::new(store, system_prompt_provider);
|
||||
let router = {
|
||||
let mut r = CommandRouter::new();
|
||||
r.register(Box::new(handler));
|
||||
r
|
||||
};
|
||||
|
||||
// 构建命令
|
||||
let cmd = crate::command::Command::SaveSession { filepath, include_all: true };
|
||||
let cmd_ctx = CommandContext::new("websocket")
|
||||
.with_session_id(&target_session_id);
|
||||
|
||||
// 执行命令
|
||||
let response = router.dispatch_with_response(cmd, cmd_ctx).await;
|
||||
|
||||
// 处理响应
|
||||
if response.success {
|
||||
let filepath = response
|
||||
.metadata
|
||||
.get("filepath")
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
let _ = sender
|
||||
.send(WsOutbound::SessionSaved {
|
||||
session_id: target_session_id,
|
||||
filepath,
|
||||
})
|
||||
.await;
|
||||
} else {
|
||||
let error = response.error.unwrap_or_else(|| {
|
||||
crate::command::response::CommandError::new("SAVE_ERROR", "Unknown error")
|
||||
});
|
||||
let _ = sender
|
||||
.send(WsOutbound::Error {
|
||||
code: error.code,
|
||||
message: error.message,
|
||||
})
|
||||
.await;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
WsInbound::Ping => {
|
||||
|
||||
@ -17,8 +17,9 @@ pub struct SessionSummary {
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum WsInbound {
|
||||
#[serde(rename = "user_input")]
|
||||
UserInput {
|
||||
/// 普通用户消息
|
||||
#[serde(rename = "message")]
|
||||
Message {
|
||||
content: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
channel: Option<String>,
|
||||
@ -27,48 +28,9 @@ pub enum WsInbound {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
sender_id: Option<String>,
|
||||
},
|
||||
#[serde(rename = "clear_history")]
|
||||
ClearHistory {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
chat_id: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
session_id: Option<String>,
|
||||
},
|
||||
#[serde(rename = "create_session")]
|
||||
CreateSession {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
title: Option<String>,
|
||||
},
|
||||
#[serde(rename = "list_sessions")]
|
||||
ListSessions {
|
||||
#[serde(default)]
|
||||
include_archived: bool,
|
||||
},
|
||||
#[serde(rename = "load_session")]
|
||||
LoadSession { session_id: String },
|
||||
#[serde(rename = "rename_session")]
|
||||
RenameSession {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
session_id: Option<String>,
|
||||
title: String,
|
||||
},
|
||||
#[serde(rename = "archive_session")]
|
||||
ArchiveSession {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
session_id: Option<String>,
|
||||
},
|
||||
#[serde(rename = "delete_session")]
|
||||
DeleteSession {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
session_id: Option<String>,
|
||||
},
|
||||
#[serde(rename = "save_session")]
|
||||
SaveSession {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
filepath: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
session_id: Option<String>,
|
||||
},
|
||||
/// 命令(JSON 格式)
|
||||
#[serde(rename = "command")]
|
||||
Command { payload: String },
|
||||
#[serde(rename = "ping")]
|
||||
Ping,
|
||||
}
|
||||
@ -114,26 +76,6 @@ pub enum WsOutbound {
|
||||
SessionEstablished { session_id: String },
|
||||
#[serde(rename = "session_created")]
|
||||
SessionCreated { session_id: String, title: String },
|
||||
#[serde(rename = "session_list")]
|
||||
SessionList {
|
||||
sessions: Vec<SessionSummary>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
current_session_id: Option<String>,
|
||||
},
|
||||
#[serde(rename = "session_loaded")]
|
||||
SessionLoaded {
|
||||
session_id: String,
|
||||
title: String,
|
||||
message_count: i64,
|
||||
},
|
||||
#[serde(rename = "session_renamed")]
|
||||
SessionRenamed { session_id: String, title: String },
|
||||
#[serde(rename = "session_archived")]
|
||||
SessionArchived { session_id: String },
|
||||
#[serde(rename = "session_deleted")]
|
||||
SessionDeleted { session_id: String },
|
||||
#[serde(rename = "history_cleared")]
|
||||
HistoryCleared { session_id: String },
|
||||
#[serde(rename = "session_saved")]
|
||||
SessionSaved { session_id: String, filepath: String },
|
||||
#[serde(rename = "pong")]
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
use picobot::protocol::{SessionSummary, WsInbound, WsOutbound};
|
||||
use picobot::protocol::{WsInbound, WsOutbound};
|
||||
use picobot::providers::{ChatCompletionRequest, Message};
|
||||
|
||||
/// Test that message with special characters is properly escaped
|
||||
@ -53,70 +53,65 @@ fn test_chat_request_serialization() {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_inbound_serialization() {
|
||||
let msg = WsInbound::CreateSession {
|
||||
title: Some("demo".to_string()),
|
||||
fn test_command_inbound_serialization() {
|
||||
// Command is now sent as payload in WsInbound::Command
|
||||
let command_json = r#"{"type":"create_session","title":"demo"}"#;
|
||||
let msg = WsInbound::Command {
|
||||
payload: command_json.to_string(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
assert!(json.contains(r#""type":"create_session""#));
|
||||
assert!(json.contains(r#""title":"demo""#));
|
||||
assert!(json.contains(r#""type":"command""#));
|
||||
assert!(json.contains(r#""payload":""#));
|
||||
assert!(json.contains(r#"create_session"#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_inbound_serialization() {
|
||||
let msg = WsInbound::Message {
|
||||
content: "Hello world".to_string(),
|
||||
channel: None,
|
||||
chat_id: Some("session-1".to_string()),
|
||||
sender_id: Some("user-1".to_string()),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
assert!(json.contains(r#""type":"message""#));
|
||||
assert!(json.contains(r#""content":"Hello world""#));
|
||||
assert!(json.contains(r#""chat_id":"session-1""#));
|
||||
|
||||
let decoded: WsInbound = serde_json::from_str(&json).unwrap();
|
||||
match decoded {
|
||||
WsInbound::CreateSession { title } => {
|
||||
assert_eq!(title.as_deref(), Some("demo"));
|
||||
WsInbound::Message { content, chat_id, .. } => {
|
||||
assert_eq!(content, "Hello world");
|
||||
assert_eq!(chat_id.as_deref(), Some("session-1"));
|
||||
}
|
||||
other => panic!("unexpected decoded variant: {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_list_outbound_serialization() {
|
||||
let msg = WsOutbound::SessionList {
|
||||
sessions: vec![SessionSummary {
|
||||
session_id: "session-1".to_string(),
|
||||
title: "demo".to_string(),
|
||||
channel_name: "cli".to_string(),
|
||||
chat_id: "session-1".to_string(),
|
||||
message_count: 2,
|
||||
last_active_at: 123,
|
||||
archived_at: None,
|
||||
}],
|
||||
current_session_id: Some("session-1".to_string()),
|
||||
fn test_session_created_outbound_serialization() {
|
||||
let msg = WsOutbound::SessionCreated {
|
||||
session_id: "session-1".to_string(),
|
||||
title: "demo".to_string(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
assert!(json.contains(r#""type":"session_list""#));
|
||||
assert!(json.contains(r#""type":"session_created""#));
|
||||
assert!(json.contains(r#""session_id":"session-1""#));
|
||||
assert!(json.contains(r#""message_count":2"#));
|
||||
assert!(json.contains(r#""title":"demo""#));
|
||||
|
||||
let decoded: WsOutbound = serde_json::from_str(&json).unwrap();
|
||||
match decoded {
|
||||
WsOutbound::SessionList {
|
||||
sessions,
|
||||
current_session_id,
|
||||
} => {
|
||||
assert_eq!(sessions.len(), 1);
|
||||
assert_eq!(sessions[0].title, "demo");
|
||||
assert_eq!(current_session_id.as_deref(), Some("session-1"));
|
||||
WsOutbound::SessionCreated { session_id, title } => {
|
||||
assert_eq!(session_id, "session-1");
|
||||
assert_eq!(title, "demo");
|
||||
}
|
||||
other => panic!("unexpected decoded variant: {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clear_history_with_session_id_serialization() {
|
||||
let msg = WsInbound::ClearHistory {
|
||||
chat_id: None,
|
||||
session_id: Some("session-1".to_string()),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
assert!(json.contains(r#""type":"clear_history""#));
|
||||
assert!(json.contains(r#""session_id":"session-1""#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_outbound_serialization() {
|
||||
let msg = WsOutbound::ToolCall {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user