This commit is contained in:
oudecheng 2026-05-15 08:23:56 +08:00
commit 054cb718de
20 changed files with 641 additions and 705 deletions

View File

@ -10,14 +10,10 @@ pub enum InputEvent {
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum InputCommand {
Exit,
Clear,
New(Option<String>),
Save(Option<String>),
Sessions,
Use(String),
Rename(String),
Archive,
Delete,
}
pub struct InputHandler {
@ -74,14 +70,10 @@ impl InputHandler {
match command {
"/quit" | "/exit" | "/q" => Some(InputCommand::Exit),
"/clear" => Some(InputCommand::Clear),
"/new" => Some(InputCommand::New(arg.map(ToOwned::to_owned))),
"/save" => Some(InputCommand::Save(arg.map(ToOwned::to_owned))),
"/sessions" => Some(InputCommand::Sessions),
"/sessions" | "/list" => Some(InputCommand::Sessions),
"/use" => arg.map(|value| InputCommand::Use(value.to_string())),
"/rename" => arg.map(|value| InputCommand::Rename(value.to_string())),
"/archive" => Some(InputCommand::Archive),
"/delete" => Some(InputCommand::Delete),
_ => None,
}
}
@ -120,10 +112,6 @@ mod tests {
handler.handle_special_commands("/quit"),
Some(InputCommand::Exit)
);
assert_eq!(
handler.handle_special_commands("/clear"),
Some(InputCommand::Clear)
);
assert_eq!(
handler.handle_special_commands("/new"),
Some(InputCommand::New(None))
@ -140,6 +128,10 @@ mod tests {
handler.handle_special_commands("/save ./debug/session.md"),
Some(InputCommand::Save(Some("./debug/session.md".to_string())))
);
assert_eq!(
handler.handle_special_commands("/list"),
Some(InputCommand::Sessions)
);
assert_eq!(
handler.handle_special_commands("/sessions"),
Some(InputCommand::Sessions)
@ -148,18 +140,6 @@ mod tests {
handler.handle_special_commands("/use abc123"),
Some(InputCommand::Use("abc123".to_string()))
);
assert_eq!(
handler.handle_special_commands("/rename project alpha"),
Some(InputCommand::Rename("project alpha".to_string()))
);
assert_eq!(
handler.handle_special_commands("/archive"),
Some(InputCommand::Archive)
);
assert_eq!(
handler.handle_special_commands("/delete"),
Some(InputCommand::Delete)
);
assert_eq!(handler.handle_special_commands("/unknown"), None);
assert_eq!(handler.handle_special_commands("/use"), None);
}

View File

@ -8,36 +8,6 @@ use tokio_tungstenite::{connect_async, tungstenite::Message};
use crate::cli::{InputCommand, InputEvent, InputHandler};
fn format_session_list(
sessions: &[crate::protocol::SessionSummary],
current_session_id: Option<&str>,
) -> String {
if sessions.is_empty() {
return "No sessions found.".to_string();
}
let mut lines = Vec::with_capacity(sessions.len() + 1);
lines.push("Sessions:".to_string());
for session in sessions {
let marker = if current_session_id == Some(session.session_id.as_str()) {
"*"
} else {
"-"
};
let archived = if session.archived_at.is_some() {
" [archived]"
} else {
""
};
lines.push(format!(
"{} {} | {} | {} messages{}",
marker, session.session_id, session.title, session.message_count, archived,
));
}
lines.join("\n")
}
fn parse_message(raw: &str) -> Result<WsOutbound, serde_json::Error> {
serde_json::from_str(raw)
}
@ -54,7 +24,7 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
let mut input = InputHandler::new();
let mut current_session_id: Option<String> = None;
input.write_output("picobot CLI - Commands: /new [title], /save [filepath], /reset, /sessions, /use <session>, /rename <title>, /archive, /delete, /clear, /quit\n").await?;
input.write_output("picobot CLI - Commands: /new [title], /save [filepath], /quit\n").await?;
// Main loop: poll both stdin and WebSocket
loop {
@ -91,29 +61,6 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
current_session_id = Some(session_id.clone());
input.write_output(&format!("Created session: {} ({})\n", session_id, title)).await?;
}
WsOutbound::SessionList { sessions, current_session_id: listed_current } => {
let display = format_session_list(&sessions, listed_current.as_deref());
input.write_output(&format!("{}\n", display)).await?;
}
WsOutbound::SessionLoaded { session_id, title, message_count } => {
current_session_id = Some(session_id.clone());
input.write_output(&format!("Loaded session: {} ({}, {} messages)\n", session_id, title, message_count)).await?;
}
WsOutbound::SessionRenamed { session_id, title } => {
input.write_output(&format!("Renamed session: {} -> {}\n", session_id, title)).await?;
}
WsOutbound::SessionArchived { session_id } => {
input.write_output(&format!("Archived session: {}\n", session_id)).await?;
}
WsOutbound::SessionDeleted { session_id } => {
if current_session_id.as_deref() == Some(session_id.as_str()) {
current_session_id = None;
}
input.write_output(&format!("Deleted session: {}\n", session_id)).await?;
}
WsOutbound::HistoryCleared { session_id } => {
input.write_output(&format!("Cleared history for session: {}\n", session_id)).await?;
}
WsOutbound::SessionSaved { session_id, filepath } => {
input.write_output(&format!("Saved session {} to: {}\n", session_id, filepath)).await?;
}
@ -138,39 +85,25 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
input.write_output("Goodbye!").await?;
break;
}
InputEvent::Command(InputCommand::Clear) => {
let inbound = WsInbound::ClearHistory {
chat_id: None,
session_id: current_session_id.clone(),
};
if let Ok(text) = serialize_inbound(&inbound) {
let _ = sender.send(Message::Text(text.into())).await;
}
continue;
}
InputEvent::Command(InputCommand::New(title)) => {
// 使用新的命令层:通过 CliInputAdapter 构建 Command
// 使用 CliInputAdapter 构建 Command
let adapter = CliInputAdapter::new();
let ctx = AdapterContext::new("cli")
.with_session_id(current_session_id.as_deref().unwrap_or(""));
// 构建输入字符串
let input = match title {
let input_str = match title {
Some(t) => format!("/new {}", t),
None => "/new".to_string(),
};
// 解析为 Command
match adapter.try_parse(&input, ctx) {
match adapter.try_parse(&input_str, ctx) {
Ok(Some(command)) => {
// 序列化为 JSON 通过 WebSocket 发送
// 序列化为 JSON
let json = serde_json::to_string(&command).unwrap_or_default();
let inbound = WsInbound::UserInput {
content: json,
channel: None,
chat_id: current_session_id.clone(),
sender_id: None,
};
// 通过 Command 消息发送
let inbound = WsInbound::Command { payload: json };
if let Ok(text) = serialize_inbound(&inbound) {
let _ = sender.send(Message::Text(text.into())).await;
}
@ -184,62 +117,97 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
}
continue;
}
InputEvent::Command(InputCommand::Sessions) => {
let inbound = WsInbound::ListSessions {
include_archived: true,
InputEvent::Command(InputCommand::Save(filepath)) => {
// 使用 CliInputAdapter 构建 Command
let adapter = CliInputAdapter::new();
let ctx = AdapterContext::new("cli")
.with_session_id(current_session_id.as_deref().unwrap_or(""));
// 构建输入字符串
let input_str = match filepath {
Some(p) => format!("/save {}", p),
None => "/save".to_string(),
};
if let Ok(text) = serialize_inbound(&inbound) {
let _ = sender.send(Message::Text(text.into())).await;
// 解析为 Command
match adapter.try_parse(&input_str, ctx) {
Ok(Some(command)) => {
// 序列化为 JSON
let json = serde_json::to_string(&command).unwrap_or_default();
// 通过 Command 消息发送
let inbound = WsInbound::Command { payload: json };
if let Ok(text) = serialize_inbound(&inbound) {
let _ = sender.send(Message::Text(text.into())).await;
}
}
Ok(None) => {
tracing::warn!("Failed to parse /save command");
}
Err(e) => {
tracing::error!(error = %e, "Error parsing /save command");
}
}
continue;
}
InputEvent::Command(InputCommand::Sessions) => {
// 使用 CliInputAdapter 构建 Command
let adapter = CliInputAdapter::new();
let ctx = AdapterContext::new("cli")
.with_session_id(current_session_id.as_deref().unwrap_or(""));
// 解析为 Command
match adapter.try_parse("/list", ctx) {
Ok(Some(command)) => {
// 序列化为 JSON
let json = serde_json::to_string(&command).unwrap_or_default();
// 通过 Command 消息发送
let inbound = WsInbound::Command { payload: json };
if let Ok(text) = serialize_inbound(&inbound) {
let _ = sender.send(Message::Text(text.into())).await;
}
}
Ok(None) => {
tracing::warn!("Failed to parse /list command");
}
Err(e) => {
tracing::error!(error = %e, "Error parsing /list command");
}
}
continue;
}
InputEvent::Command(InputCommand::Use(session_id)) => {
let inbound = WsInbound::LoadSession { session_id };
if let Ok(text) = serialize_inbound(&inbound) {
let _ = sender.send(Message::Text(text.into())).await;
}
continue;
}
InputEvent::Command(InputCommand::Rename(title)) => {
let inbound = WsInbound::RenameSession {
session_id: current_session_id.clone(),
title,
};
if let Ok(text) = serialize_inbound(&inbound) {
let _ = sender.send(Message::Text(text.into())).await;
}
continue;
}
InputEvent::Command(InputCommand::Archive) => {
let inbound = WsInbound::ArchiveSession {
session_id: current_session_id.clone(),
};
if let Ok(text) = serialize_inbound(&inbound) {
let _ = sender.send(Message::Text(text.into())).await;
}
continue;
}
InputEvent::Command(InputCommand::Delete) => {
let inbound = WsInbound::DeleteSession {
session_id: current_session_id.clone(),
};
if let Ok(text) = serialize_inbound(&inbound) {
let _ = sender.send(Message::Text(text.into())).await;
}
continue;
}
InputEvent::Command(InputCommand::Save(filepath)) => {
let inbound = WsInbound::SaveSession {
filepath,
session_id: current_session_id.clone(),
};
if let Ok(text) = serialize_inbound(&inbound) {
let _ = sender.send(Message::Text(text.into())).await;
// 使用 CliInputAdapter 构建 Command
let adapter = CliInputAdapter::new();
let ctx = AdapterContext::new("cli")
.with_session_id(current_session_id.as_deref().unwrap_or(""));
// 构建输入字符串
let input_str = format!("/use {}", session_id);
// 解析为 Command
match adapter.try_parse(&input_str, ctx) {
Ok(Some(command)) => {
// 序列化为 JSON
let json = serde_json::to_string(&command).unwrap_or_default();
// 通过 Command 消息发送
let inbound = WsInbound::Command { payload: json };
if let Ok(text) = serialize_inbound(&inbound) {
let _ = sender.send(Message::Text(text.into())).await;
}
// 更新当前会话 ID
current_session_id = Some(session_id.clone());
}
Ok(None) => {
tracing::warn!("Failed to parse /use command");
}
Err(e) => {
tracing::error!(error = %e, "Error parsing /use command");
}
}
continue;
}
InputEvent::Message(msg) => {
let inbound = WsInbound::UserInput {
let inbound = WsInbound::Message {
content: msg.content,
channel: None,
chat_id: current_session_id.clone(),

View File

@ -0,0 +1,92 @@
use crate::command::adapter::{AdapterError, InputAdapter};
use crate::command::context::AdapterContext;
use crate::command::Command;
/// Channel 输入适配器
///
/// 将 Channel 消息中的文本命令(如 "/new", "/save")转换为 Command
pub struct ChannelInputAdapter;
impl ChannelInputAdapter {
/// 创建新的 Channel 输入适配器
pub fn new() -> Self {
Self
}
}
impl Default for ChannelInputAdapter {
fn default() -> Self {
Self::new()
}
}
impl InputAdapter for ChannelInputAdapter {
fn try_parse(
&self,
input: &str,
_ctx: AdapterContext,
) -> Result<Option<Command>, AdapterError> {
let trimmed = input.trim();
// 解析 /new 命令
if trimmed == "/new" {
return Ok(Some(Command::CreateSession { title: None }));
}
if let Some(title) = trimmed.strip_prefix("/new ") {
let title = title.trim();
return Ok(Some(Command::CreateSession {
title: Some(title.to_string()),
}));
}
// 解析 /save 命令
if trimmed == "/save" {
return Ok(Some(Command::SaveSession {
filepath: None,
include_all: false,
}));
}
if let Some(args) = trimmed.strip_prefix("/save ") {
let args = args.trim();
// 解析参数:可能是 "all"、路径、或 "all 路径"
let (include_all, filepath) = if args == "all" {
// /save all - 保存全部消息
(true, None)
} else if args.starts_with("all ") {
// /save all <filepath> - 保存全部消息到指定路径
let path = args[4..].trim();
(true, Some(path.to_string()))
} else {
// /save <filepath> - 保存活跃消息到指定路径
(false, Some(args.to_string()))
};
return Ok(Some(Command::SaveSession { filepath, include_all }));
}
// 解析 /list 命令
if trimmed == "/list" {
return Ok(Some(Command::ListSessions {
include_archived: false,
}));
}
if trimmed == "/list all" {
return Ok(Some(Command::ListSessions {
include_archived: true,
}));
}
// 解析 /use 命令
if let Some(session_id) = trimmed.strip_prefix("/use ") {
let session_id = session_id.trim();
return Ok(Some(Command::LoadSession {
session_id: session_id.to_string(),
}));
}
// 不是命令,返回 None
Ok(None)
}
}

View File

@ -66,6 +66,27 @@ impl InputAdapter for CliInputAdapter {
return Ok(Some(Command::SaveSession { filepath, include_all }));
}
// 解析 /list 命令
if trimmed == "/list" {
return Ok(Some(Command::ListSessions {
include_archived: false,
}));
}
if trimmed == "/list all" {
return Ok(Some(Command::ListSessions {
include_archived: true,
}));
}
// 解析 /use 命令
if let Some(session_id) = trimmed.strip_prefix("/use ") {
let session_id = session_id.trim();
return Ok(Some(Command::LoadSession {
session_id: session_id.to_string(),
}));
}
// 不是命令,返回 None
Ok(None)
}

View File

@ -1,2 +1,3 @@
pub mod channel;
pub mod cli;
pub mod websocket;

View File

@ -12,18 +12,21 @@ pub struct CommandContext {
pub chat_id: Option<String>,
/// 发送者ID
pub sender_id: String,
/// 通道名称(如 "cli", "feishu", "wechat"
pub channel_name: String,
/// 额外元数据
pub metadata: HashMap<String, String>,
}
impl CommandContext {
/// 创建新的命令上下文
pub fn new(sender_id: impl Into<String>) -> Self {
pub fn new(sender_id: impl Into<String>, channel_name: impl Into<String>) -> Self {
Self {
request_id: Uuid::new_v4(),
session_id: None,
chat_id: None,
sender_id: sender_id.into(),
channel_name: channel_name.into(),
metadata: HashMap::new(),
}
}

View File

@ -238,7 +238,7 @@ mod tests {
router.register(Box::new(TestHandler));
router.register(Box::new(NoOpHandler));
let ctx = CommandContext::new("test");
let ctx = CommandContext::new("test", "test");
let cmd = Command::CreateSession { title: None };
let result = router.dispatch(cmd, ctx).await;
@ -252,7 +252,7 @@ mod tests {
async fn test_router_no_handler() {
let router = CommandRouter::new();
let ctx = CommandContext::new("test");
let ctx = CommandContext::new("test", "test");
let cmd = Command::CreateSession { title: None };
let result = router.dispatch(cmd, ctx).await;

View File

@ -1,2 +1,3 @@
pub mod save_session;
pub mod session;
pub mod session_query;

View File

@ -36,6 +36,7 @@ impl CommandHandler for SessionCommandHandler {
match cmd {
Command::CreateSession { title } => handle_create_session(self, title, ctx).await,
Command::SaveSession { .. } => unreachable!("SaveSession should be handled by SaveSessionCommandHandler"),
_ => unreachable!("Other commands should be handled by other handlers"),
}
}
}
@ -48,7 +49,7 @@ async fn handle_create_session(
) -> Result<CommandResponse, CommandError> {
let record = handler
.cli_sessions
.create(title.as_deref())
.create_with_channel(&ctx.channel_name, title.as_deref())
.map_err(|e| CommandError::new("CREATE_SESSION_ERROR", e.to_string()))?;
Ok(CommandResponse::success(ctx.request_id)
@ -74,7 +75,7 @@ mod tests {
async fn test_create_session_with_title() {
let service = create_test_service();
let handler = SessionCommandHandler::new(service);
let ctx = CommandContext::new("test");
let ctx = CommandContext::new("test", "test");
let cmd = Command::CreateSession {
title: Some("my session".to_string()),
};
@ -93,7 +94,7 @@ mod tests {
async fn test_create_session_without_title() {
let service = create_test_service();
let handler = SessionCommandHandler::new(service);
let ctx = CommandContext::new("test");
let ctx = CommandContext::new("test", "test");
let cmd = Command::CreateSession { title: None };
let result = handler.handle(cmd, ctx).await;

View File

@ -0,0 +1,202 @@
use crate::command::context::CommandContext;
use crate::command::handler::CommandHandler;
use crate::command::response::{CommandError, CommandResponse, MessageKind};
use crate::command::Command;
use crate::gateway::cli_session::CliSessionService;
use crate::protocol::SessionSummary;
use async_trait::async_trait;
/// 会话查询命令处理器
///
/// 处理 ListSessions 和 LoadSession 命令
pub struct SessionQueryCommandHandler {
cli_sessions: CliSessionService,
}
impl SessionQueryCommandHandler {
/// 创建新的会话查询命令处理器
pub fn new(cli_sessions: CliSessionService) -> Self {
Self { cli_sessions }
}
}
#[async_trait]
impl CommandHandler for SessionQueryCommandHandler {
fn can_handle(&self, cmd: &Command) -> bool {
matches!(cmd, Command::ListSessions { .. } | Command::LoadSession { .. })
}
async fn handle(
&self,
cmd: Command,
ctx: CommandContext,
) -> Result<CommandResponse, CommandError> {
match cmd {
Command::ListSessions { include_archived } => {
handle_list_sessions(self, include_archived, ctx).await
}
Command::LoadSession { session_id } => {
handle_load_session(self, session_id, ctx).await
}
_ => unreachable!(),
}
}
}
/// 处理列出会话命令
async fn handle_list_sessions(
handler: &SessionQueryCommandHandler,
include_archived: bool,
ctx: CommandContext,
) -> Result<CommandResponse, CommandError> {
let records = handler
.cli_sessions
.list(include_archived)
.map_err(|e| CommandError::new("LIST_SESSIONS_ERROR", e.to_string()))?;
let summaries: Vec<SessionSummary> = records
.into_iter()
.map(|r| SessionSummary {
session_id: r.id,
title: r.title,
channel_name: r.channel_name,
chat_id: r.chat_id,
message_count: r.message_count,
last_active_at: r.last_active_at,
archived_at: r.archived_at,
})
.collect();
// 将会话列表序列化为 JSON 存储在 metadata 中
let sessions_json =
serde_json::to_string(&summaries).map_err(|e| CommandError::new("SERIALIZE_ERROR", e.to_string()))?;
// 构建可读的会话列表消息
let message = if summaries.is_empty() {
"No sessions found.".to_string()
} else {
let mut lines = vec![format!("Found {} session(s):", summaries.len())];
for summary in &summaries {
let archived_info = summary
.archived_at
.map(|_| " [archived]")
.unwrap_or("");
lines.push(format!(
" - {}: {}{}",
summary.session_id, summary.title, archived_info
));
}
lines.push("".to_string());
lines.push("Use /use <session_id> to switch to a session".to_string());
lines.join("\n")
};
Ok(CommandResponse::success(ctx.request_id)
.with_message(MessageKind::Notification, &message)
.with_metadata("sessions", &sessions_json)
.with_metadata("count", &summaries.len().to_string()))
}
/// 处理加载会话命令
async fn handle_load_session(
handler: &SessionQueryCommandHandler,
session_id: String,
ctx: CommandContext,
) -> Result<CommandResponse, CommandError> {
let record = handler
.cli_sessions
.get(&session_id)
.map_err(|e| CommandError::new("LOAD_SESSION_ERROR", e.to_string()))?
.ok_or_else(|| CommandError::new("SESSION_NOT_FOUND", format!("Session not found: {}", session_id)))?;
Ok(CommandResponse::success(ctx.request_id)
.with_message(MessageKind::Notification, &record.title)
.with_metadata("session_id", &record.id)
.with_metadata("title", &record.title)
.with_metadata("message_count", &record.message_count.to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::SessionStore;
use std::sync::Arc;
fn create_test_service() -> CliSessionService {
let store = Arc::new(SessionStore::in_memory().unwrap());
CliSessionService::new(store)
}
#[tokio::test]
async fn test_list_sessions_empty() {
let service = create_test_service();
let handler = SessionQueryCommandHandler::new(service);
let ctx = CommandContext::new("test", "test");
let cmd = Command::ListSessions {
include_archived: false,
};
let result = handler.handle(cmd, ctx).await;
assert!(result.is_ok());
let resp = result.unwrap();
assert!(resp.success);
assert!(resp.messages[0].content.contains("No sessions"));
}
#[tokio::test]
async fn test_list_sessions_with_items() {
let service = create_test_service();
let handler = SessionQueryCommandHandler::new(service.clone());
// 创建一些会话
service.create(Some("test session")).unwrap();
let ctx = CommandContext::new("test", "test");
let cmd = Command::ListSessions {
include_archived: false,
};
let result = handler.handle(cmd, ctx).await;
assert!(result.is_ok());
let resp = result.unwrap();
assert!(resp.success);
assert!(resp.metadata.contains_key("sessions"));
}
#[tokio::test]
async fn test_load_session_not_found() {
let service = create_test_service();
let handler = SessionQueryCommandHandler::new(service);
let ctx = CommandContext::new("test", "test");
let cmd = Command::LoadSession {
session_id: "nonexistent".to_string(),
};
let result = handler.handle(cmd, ctx).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_load_session_success() {
let service = create_test_service();
let handler = SessionQueryCommandHandler::new(service.clone());
// 创建会话
let record = service.create(Some("test session")).unwrap();
let ctx = CommandContext::new("test", "test");
let cmd = Command::LoadSession {
session_id: record.id.clone(),
};
let result = handler.handle(cmd, ctx).await;
assert!(result.is_ok());
let resp = result.unwrap();
assert!(resp.success);
assert_eq!(resp.metadata.get("session_id").unwrap(), &record.id);
}
}

View File

@ -18,6 +18,10 @@ pub enum Command {
filepath: Option<String>,
include_all: bool,
},
/// 列出会话
ListSessions { include_archived: bool },
/// 加载指定会话
LoadSession { session_id: String },
}
impl Command {
@ -26,6 +30,8 @@ impl Command {
match self {
Command::CreateSession { .. } => "create_session",
Command::SaveSession { .. } => "save_session",
Command::ListSessions { .. } => "list_sessions",
Command::LoadSession { .. } => "load_session",
}
}
}

View File

@ -1472,7 +1472,7 @@ mod tests {
assert!(config.scheduler.jobs.is_empty());
let effective_jobs = config.scheduler.effective_jobs(&config.time);
assert_eq!(effective_jobs.len(), 1);
assert_eq!(effective_jobs.len(), 2);
assert_eq!(effective_jobs[0].id, BUILTIN_MEMORY_MAINTENANCE_JOB_ID);
assert_eq!(effective_jobs[0].kind, SchedulerJobKind::InternalEvent);
assert_eq!(
@ -1481,6 +1481,8 @@ mod tests {
expression: "0 */4 * * *".to_string(),
}
);
// 第二个内置作业是会话清理
assert_eq!(effective_jobs[1].id, BUILTIN_SESSION_CLEANUP_JOB_ID);
}
#[test]
@ -1516,7 +1518,8 @@ mod tests {
let effective_jobs = scheduler.effective_jobs(&TimeConfig {
timezone: "Asia/Shanghai".to_string(),
});
assert_eq!(effective_jobs.len(), 2);
assert_eq!(effective_jobs.len(), 3); // 2个内置 + 1个自定义
// 第一个作业:内存维护(被覆盖为禁用)
assert_eq!(effective_jobs[0].id, BUILTIN_MEMORY_MAINTENANCE_JOB_ID);
assert!(!effective_jobs[0].enabled);
assert_eq!(
@ -1525,7 +1528,11 @@ mod tests {
expression: "15 2 * * *".to_string(),
}
);
assert_eq!(effective_jobs[1].id, "custom.reminder");
// 第二个作业:会话清理(保持默认)
assert_eq!(effective_jobs[1].id, BUILTIN_SESSION_CLEANUP_JOB_ID);
assert!(effective_jobs[1].enabled);
// 第三个作业:自定义提醒
assert_eq!(effective_jobs[2].id, "custom.reminder");
}
#[test]

View File

@ -19,6 +19,17 @@ impl CliSessionService {
.map_err(|err| AgentError::Other(format!("create session error: {}", err)))
}
/// 创建指定通道的会话
pub(crate) fn create_with_channel(
&self,
channel_name: &str,
title: Option<&str>,
) -> Result<SessionRecord, AgentError> {
self.store
.create_session(channel_name, title)
.map_err(|err| AgentError::Other(format!("create session error: {}", err)))
}
pub(crate) fn get(&self, session_id: &str) -> Result<Option<SessionRecord>, AgentError> {
self.store
.get_session(session_id)

View File

@ -1,165 +1,2 @@
use crate::agent::AgentError;
use super::session::Session;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum InChatCommand {
FreshConversation,
}
fn parse_in_chat_command(content: &str) -> Option<InChatCommand> {
match content.trim() {
"/new" | "/reset" => Some(InChatCommand::FreshConversation),
_ => None,
}
}
pub(crate) fn handle_in_chat_command(
session: &mut Session,
chat_id: &str,
content: &str,
) -> Result<Option<String>, AgentError> {
match parse_in_chat_command(content) {
Some(InChatCommand::FreshConversation) => {
session.reset_chat_context(chat_id)?;
Ok(Some("Started a fresh conversation.".to_string()))
}
None => Ok(None),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig;
use crate::skills::SkillRuntime;
use crate::storage::SessionStore;
use crate::tools::ToolRegistry;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc;
const TEST_CHANNEL: &str = "test-channel";
fn test_provider_config() -> LLMProviderConfig {
LLMProviderConfig {
provider_type: "openai".to_string(),
name: "test".to_string(),
base_url: "http://localhost".to_string(),
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
llm_timeout_secs: 120,
memory_maintenance_timeout_secs: 600,
model_id: "test-model".to_string(),
temperature: Some(0.0),
max_tokens: Some(32),
context_window_tokens: None,
model_extra: HashMap::new(),
max_tool_iterations: 1,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 20_000,
}
}
#[test]
fn test_parse_in_chat_command_aliases() {
assert_eq!(
parse_in_chat_command("/new"),
Some(InChatCommand::FreshConversation)
);
assert_eq!(
parse_in_chat_command(" /reset \n"),
Some(InChatCommand::FreshConversation)
);
assert_eq!(parse_in_chat_command("/new planning"), None);
assert_eq!(parse_in_chat_command("please /reset"), None);
}
#[tokio::test]
async fn test_handle_in_chat_command_resets_active_history_only() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(ToolRegistry::new());
let mut session = Session::new(
TEST_CHANNEL.to_string(),
test_provider_config(),
user_tx,
tools,
skills,
store.clone(),
100,
Some(4),
)
.await
.unwrap();
session.ensure_persistent_session("chat-1").unwrap();
session.ensure_chat_loaded("chat-1").unwrap();
session
.append_persisted_message("chat-1", ChatMessage::user("hello"))
.unwrap();
let response = handle_in_chat_command(&mut session, "chat-1", "/reset")
.unwrap()
.unwrap();
assert_eq!(response, "Started a fresh conversation.");
assert!(session.get_history("chat-1").unwrap().is_empty());
assert!(
store
.load_messages(&session.persistent_session_id("chat-1"))
.unwrap()
.is_empty()
);
assert_eq!(
store
.load_all_messages(&session.persistent_session_id("chat-1"))
.unwrap()
.len(),
// 新设计:系统提示词不再持久化,只有 1 条用户消息
1,
);
session.ensure_chat_loaded("chat-1").unwrap();
let history = session.get_history("chat-1").unwrap();
// 新设计:系统提示词不再持久化到历史记录
assert_eq!(history.len(), 0);
}
#[tokio::test]
async fn test_reset_reinjects_agent_prompt_before_next_user_message() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let skills = Arc::new(SkillRuntime::default());
let tools = Arc::new(ToolRegistry::new());
let mut session = Session::new(
TEST_CHANNEL.to_string(),
test_provider_config(),
user_tx,
tools,
skills,
store,
100,
Some(4),
)
.await
.unwrap();
session.ensure_persistent_session("chat-1").unwrap();
session.ensure_chat_loaded("chat-1").unwrap();
session
.append_persisted_message("chat-1", ChatMessage::user("hello"))
.unwrap();
handle_in_chat_command(&mut session, "chat-1", "/reset").unwrap();
session
.ensure_agent_prompt_before_user_message("chat-1")
.unwrap();
// 新设计:系统提示词不再持久化到历史记录
let history = session.get_history("chat-1").unwrap();
assert_eq!(history.len(), 0);
}
}
// 此文件已废弃InChatCommand 功能已合并到 Command 系统
// 保留文件以避免破坏现有 import但内容为空

View File

@ -7,7 +7,6 @@ use crate::bus::{ChatMessage, MediaItem, OutboundMessage, SYSTEM_CONTEXT_SCHEDUL
use crate::config::LLMProviderConfig;
use tokio::sync::Mutex;
use super::command::handle_in_chat_command;
use super::compaction::schedule_background_history_compaction;
use super::message_prepare::enrich_user_content_with_media_refs;
use super::session::Session;
@ -138,18 +137,6 @@ impl AgentExecutionService {
session_guard.ensure_persistent_session(request.chat_id)?;
session_guard.ensure_chat_loaded(request.chat_id)?;
if let Some(command_response) =
handle_in_chat_command(&mut session_guard, request.chat_id, request.content)?
{
return Ok(vec![OutboundMessage::assistant(
request.channel_name.to_string(),
request.chat_id.to_string(),
command_response,
None,
HashMap::new(),
)]);
}
session_guard.ensure_agent_prompt_before_user_message(request.chat_id)?;
let media_refs: Vec<String> = request

View File

@ -4,7 +4,12 @@ use tokio::sync::Semaphore;
use crate::agent::{AgentError, CompositeSystemPromptProvider};
use crate::bus::{InboundMessage, MessageBus, OutboundMessage};
use crate::command::handler::InChatCommandRouter;
use crate::command::adapter::InputAdapter;
use crate::command::adapters::channel::ChannelInputAdapter;
use crate::command::handler::CommandRouter;
use crate::command::handlers::save_session::SaveSessionCommandHandler;
use crate::command::handlers::session::SessionCommandHandler;
use crate::command::handlers::session_query::SessionQueryCommandHandler;
use crate::command::Command;
use crate::config::LLMProviderConfig;
use crate::gateway::agent_prompt_provider::AgentPromptProvider;
@ -18,7 +23,7 @@ pub struct InboundProcessor {
session_manager: SessionManager,
semaphore: Arc<Semaphore>,
_provider_config: LLMProviderConfig,
command_router: Arc<InChatCommandRouter>,
command_router: Arc<CommandRouter>,
}
impl InboundProcessor {
@ -29,7 +34,14 @@ impl InboundProcessor {
provider_config: LLMProviderConfig,
) -> Self {
// 创建命令路由器并注册处理器
let mut command_router = InChatCommandRouter::new();
let mut command_router = CommandRouter::new();
// 注册 Session 处理器
let cli_sessions = session_manager.cli_sessions();
command_router.register(Box::new(SessionCommandHandler::new(cli_sessions.clone())));
// 注册 session_query 处理器
command_router.register(Box::new(SessionQueryCommandHandler::new(cli_sessions)));
// 注册 save_session 处理器
let store = session_manager.store();
@ -43,7 +55,7 @@ impl InboundProcessor {
)),
Box::new(SkillPromptProvider::new(skills)),
]));
command_router.register(Box::new(crate::command::handlers::save_session::SaveSessionInChatHandler::new(
command_router.register(Box::new(SaveSessionCommandHandler::new(
store,
system_prompt_provider,
)));
@ -103,18 +115,43 @@ impl InboundProcessor {
}
async fn process_one(&self, inbound: InboundMessage) -> Result<(), AgentError> {
// 尝试解析为命令
if let Some(cmd) = parse_in_chat_command(&inbound.content) {
// 使用 ChannelInputAdapter 尝试解析命令
let adapter = ChannelInputAdapter::new();
let ctx = crate::command::context::AdapterContext::new(&inbound.channel)
.with_session_id(&inbound.chat_id);
if let Ok(Some(cmd)) = adapter.try_parse(&inbound.content, ctx) {
// 使用命令路由器处理
match self.command_router.dispatch(cmd, &inbound, &self.session_manager).await? {
Some(response_msg) => {
// 发送命令执行结果给用户
let cmd_ctx = crate::command::context::CommandContext::new(&inbound.channel, &inbound.channel)
.with_session_id(&inbound.chat_id);
// 记录是否是创建会话命令(用于后续自动切换)
let is_create_session = matches!(cmd, Command::CreateSession { .. });
let response = self.command_router.dispatch_with_response(cmd, cmd_ctx).await;
// 发送响应给用户
if response.success {
// 如果是创建会话,更新 chat_id 到新会话
let target_chat_id = if let Some(session_id) = response.metadata.get("session_id") {
if is_create_session {
// 自动切换到新会话
session_id.clone()
} else {
inbound.chat_id.clone()
}
} else {
inbound.chat_id.clone()
};
// 提取响应消息
for msg in &response.messages {
if let Err(error) = self
.bus
.publish_outbound(OutboundMessage::assistant(
inbound.channel.clone(),
inbound.chat_id.clone(),
response_msg,
target_chat_id.clone(),
msg.content.clone(),
None,
inbound.forwarded_metadata.clone(),
))
@ -122,13 +159,23 @@ impl InboundProcessor {
{
tracing::error!(error = %error, "Failed to publish command response");
}
return Ok(());
}
None => {
// 命令已处理但没有返回消息
return Ok(());
} else if let Some(error) = response.error {
if let Err(e) = self
.bus
.publish_outbound(OutboundMessage::assistant(
inbound.channel.clone(),
inbound.chat_id.clone(),
format!("Error [{}]: {}", error.code, error.message),
None,
inbound.forwarded_metadata.clone(),
))
.await
{
tracing::error!(error = %e, "Failed to publish error response");
}
}
return Ok(());
}
// 普通消息进入 AgentLoop
@ -183,41 +230,3 @@ impl InboundProcessor {
Ok(())
}
}
/// 解析聊天中的命令
///
/// 支持格式:
/// - `/save` - 保存活跃会话消息(到 cutoff
/// - `/save all` - 保存全部会话消息(包括 cutoff 之前)
/// - `/save <filepath>` - 保存活跃消息到指定路径
/// - `/save all <filepath>` - 保存全部消息到指定路径
///
/// 返回 Some(Command) 如果是命令
/// 返回 None 如果不是命令
fn parse_in_chat_command(content: &str) -> Option<Command> {
let trimmed = content.trim();
if trimmed.starts_with("/save") {
let args = trimmed[5..].trim();
// 解析参数
let (include_all, filepath) = if args.is_empty() {
// /save 无参数 - 只保存活跃消息
(false, None)
} else if args == "all" {
// /save all - 保存全部消息
(true, None)
} else if args.starts_with("all ") {
// /save all <filepath> - 保存全部消息到指定路径
let path = args[4..].trim();
(true, Some(path.to_string()))
} else {
// /save <filepath> - 保存活跃消息到指定路径
(false, Some(args.to_string()))
};
Some(Command::SaveSession { filepath, include_all })
} else {
None
}
}

View File

@ -1,14 +1,15 @@
use super::GatewayState;
use crate::agent::{AgentError, CompositeSystemPromptProvider};
use crate::bus::InboundMessage;
use crate::command::adapter::OutputAdapter;
use crate::command::adapter::{InputAdapter, OutputAdapter};
use crate::command::adapters::websocket::{WebSocketInputAdapter, WebSocketOutputAdapter};
use crate::command::context::CommandContext;
use crate::command::handler::CommandRouter;
use crate::command::handlers::save_session::SaveSessionCommandHandler;
use crate::command::handlers::session::SessionCommandHandler;
use crate::command::handlers::session_query::SessionQueryCommandHandler;
use crate::gateway::agent_prompt_provider::AgentPromptProvider;
use crate::protocol::{SessionSummary, WsInbound, WsOutbound, parse_inbound, serialize_outbound};
use crate::protocol::{WsInbound, WsOutbound, parse_inbound, serialize_outbound};
use crate::skills::SkillPromptProvider;
use axum::extract::State;
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
@ -125,17 +126,6 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
tracing::info!(session_id = %runtime_session_id, current_session_id = %current_session_id, "CLI session ended");
}
fn to_session_summary(record: crate::storage::SessionRecord) -> SessionSummary {
SessionSummary {
session_id: record.id,
title: record.title,
channel_name: record.channel_name,
chat_id: record.chat_id,
message_count: record.message_count,
last_active_at: record.last_active_at,
archived_at: record.archived_at,
}
}
async fn handle_inbound(
state: &Arc<GatewayState>,
@ -143,9 +133,9 @@ async fn handle_inbound(
runtime_session_id: &str,
current_session_id: &mut String,
inbound: WsInbound,
) -> Result<(), crate::agent::AgentError> {
) -> Result<(), AgentError> {
match inbound {
WsInbound::UserInput {
WsInbound::Message {
content,
chat_id,
sender_id,
@ -181,53 +171,74 @@ async fn handle_inbound(
Ok(())
}
WsInbound::ClearHistory {
session_id,
chat_id,
} => {
let target = session_id
.or(chat_id)
.unwrap_or_else(|| current_session_id.clone());
state
.session_manager
.cli_sessions()
.clear_messages(&target)?;
if let Some(session) = state.session_manager.get(CLI_CHANNEL_NAME).await {
session.lock().await.remove_history(&target);
}
let _ = sender
.send(WsOutbound::HistoryCleared { session_id: target })
.await;
Ok(())
}
WsInbound::CreateSession { title } => {
// 使用新的命令层处理
let _input_adapter = WebSocketInputAdapter::new();
WsInbound::Command { payload } => {
// 使用 Command 系统处理命令
let input_adapter = WebSocketInputAdapter::new();
let output_adapter = WebSocketOutputAdapter::new();
let cli_sessions = state.session_manager.cli_sessions();
let handler = SessionCommandHandler::new(cli_sessions);
let router = {
let mut r = CommandRouter::new();
r.register(Box::new(handler));
r
// 解析命令
let adapter_ctx = crate::command::context::AdapterContext::new("websocket")
.with_session_id(current_session_id.as_str());
let cmd = match input_adapter.try_parse(&payload, adapter_ctx) {
Ok(Some(cmd)) => cmd,
Ok(None) => {
// 不是命令,返回错误
let _ = sender
.send(WsOutbound::Error {
code: "INVALID_COMMAND".to_string(),
message: "Invalid command payload".to_string(),
})
.await;
return Ok(());
}
Err(e) => {
let _ = sender
.send(WsOutbound::Error {
code: "PARSE_ERROR".to_string(),
message: e.to_string(),
})
.await;
return Ok(());
}
};
// 构建命令
let cmd = crate::command::Command::CreateSession { title };
let cmd_ctx = CommandContext::new("websocket")
// 创建命令路由器
let cli_sessions = state.session_manager.cli_sessions();
let store = state.session_manager.store();
let skills = state.session_manager.skills();
let provider_config = state.config.get_provider_config("default")
.map_err(|e| AgentError::Other(e.to_string()))?;
let prompt_repository = state.session_manager.store().clone();
let system_prompt_provider: Arc<dyn crate::agent::SystemPromptProvider> = Arc::new(CompositeSystemPromptProvider::new(vec![
Box::new(AgentPromptProvider::new(
0,
provider_config.clone(),
prompt_repository.clone(),
)),
Box::new(SkillPromptProvider::new(skills)),
]));
let mut router = CommandRouter::new();
router.register(Box::new(SessionCommandHandler::new(cli_sessions.clone())));
router.register(Box::new(SessionQueryCommandHandler::new(cli_sessions)));
router.register(Box::new(SaveSessionCommandHandler::new(
store,
system_prompt_provider,
)));
// 构建命令上下文
let cmd_ctx = CommandContext::new("websocket", "cli")
.with_session_id(current_session_id.as_str());
// 执行命令
let response = router.dispatch_with_response(cmd, cmd_ctx).await;
// 适配输出
let outbounds = output_adapter.adapt(response);
// 处理响应
for msg in outbounds {
if let WsOutbound::SessionCreated { session_id, title: _ } = &msg {
if response.success {
// 更新当前会话 ID如果是创建会话
if let Some(session_id) = response.metadata.get("session_id") {
*current_session_id = session_id.clone();
state
.channel_manager
@ -239,178 +250,14 @@ async fn handle_inbound(
)
.await;
}
}
// 适配并发送响应
let outbounds = output_adapter.adapt(response);
for msg in outbounds {
let _ = sender.send(msg).await;
}
Ok(())
}
WsInbound::ListSessions { include_archived } => {
let records = state
.session_manager
.cli_sessions()
.list(include_archived)?;
let summaries = records.into_iter().map(to_session_summary).collect();
let _ = sender
.send(WsOutbound::SessionList {
sessions: summaries,
current_session_id: Some(current_session_id.clone()),
})
.await;
Ok(())
}
WsInbound::LoadSession { session_id } => {
let Some(record) = state.session_manager.cli_sessions().get(&session_id)? else {
let _ = sender
.send(WsOutbound::Error {
code: "SESSION_NOT_FOUND".to_string(),
message: format!("Session not found: {}", session_id),
})
.await;
return Ok(());
};
*current_session_id = record.id.clone();
state
.channel_manager
.cli_channel()
.register_connection(
record.id.clone(),
runtime_session_id.to_string(),
sender.clone(),
)
.await;
let _ = sender
.send(WsOutbound::SessionLoaded {
session_id: record.id,
title: record.title,
message_count: record.message_count,
})
.await;
Ok(())
}
WsInbound::RenameSession { session_id, title } => {
let target = session_id.unwrap_or_else(|| current_session_id.clone());
state
.session_manager
.cli_sessions()
.rename(&target, &title)?;
let _ = sender
.send(WsOutbound::SessionRenamed {
session_id: target,
title,
})
.await;
Ok(())
}
WsInbound::ArchiveSession { session_id } => {
let target = session_id.unwrap_or_else(|| current_session_id.clone());
state.session_manager.cli_sessions().archive(&target)?;
let _ = sender
.send(WsOutbound::SessionArchived { session_id: target })
.await;
Ok(())
}
WsInbound::DeleteSession { session_id } => {
let target = session_id.unwrap_or_else(|| current_session_id.clone());
state.session_manager.cli_sessions().delete(&target)?;
let replacement = if target == *current_session_id {
Some(state.session_manager.cli_sessions().create(None)?)
} else {
None
};
if let Some(session) = state.session_manager.get(CLI_CHANNEL_NAME).await {
session.lock().await.remove_history(&target);
}
let _ = sender
.send(WsOutbound::SessionDeleted {
session_id: target.clone(),
})
.await;
if let Some(record) = replacement {
*current_session_id = record.id.clone();
state
.channel_manager
.cli_channel()
.register_connection(
record.id.clone(),
runtime_session_id.to_string(),
sender.clone(),
)
.await;
let _ = sender
.send(WsOutbound::SessionCreated {
session_id: record.id,
title: record.title,
})
.await;
}
Ok(())
}
WsInbound::SaveSession { filepath, session_id } => {
let target_session_id = session_id.unwrap_or_else(|| current_session_id.clone());
// 获取所需依赖
let store = state.session_manager.store();
let skills = state.session_manager.skills();
let provider_config = state.config.get_provider_config("default")
.map_err(|e| AgentError::Other(e.to_string()))?;
let prompt_repository = state.session_manager.store().clone();
// 构建组合系统提示词提供者(与运行时一致)
let system_prompt_provider: Arc<dyn crate::agent::SystemPromptProvider> = Arc::new(CompositeSystemPromptProvider::new(vec![
Box::new(AgentPromptProvider::new(
0, // save_session 不需要 reinject 逻辑
provider_config.clone(),
prompt_repository,
)),
Box::new(SkillPromptProvider::new(skills)),
]));
// 构建处理器
let handler = SaveSessionCommandHandler::new(store, system_prompt_provider);
let router = {
let mut r = CommandRouter::new();
r.register(Box::new(handler));
r
};
// 构建命令
let cmd = crate::command::Command::SaveSession { filepath, include_all: true };
let cmd_ctx = CommandContext::new("websocket")
.with_session_id(&target_session_id);
// 执行命令
let response = router.dispatch_with_response(cmd, cmd_ctx).await;
// 处理响应
if response.success {
let filepath = response
.metadata
.get("filepath")
.cloned()
.unwrap_or_default();
let _ = sender
.send(WsOutbound::SessionSaved {
session_id: target_session_id,
filepath,
})
.await;
} else {
let error = response.error.unwrap_or_else(|| {
crate::command::response::CommandError::new("SAVE_ERROR", "Unknown error")
});
let _ = sender
.send(WsOutbound::Error {
code: error.code,
message: error.message,
})
.await;
}
Ok(())
}
WsInbound::Ping => {

View File

@ -17,8 +17,9 @@ pub struct SessionSummary {
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum WsInbound {
#[serde(rename = "user_input")]
UserInput {
/// 普通用户消息
#[serde(rename = "message")]
Message {
content: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
channel: Option<String>,
@ -27,48 +28,9 @@ pub enum WsInbound {
#[serde(default, skip_serializing_if = "Option::is_none")]
sender_id: Option<String>,
},
#[serde(rename = "clear_history")]
ClearHistory {
#[serde(default, skip_serializing_if = "Option::is_none")]
chat_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
session_id: Option<String>,
},
#[serde(rename = "create_session")]
CreateSession {
#[serde(default, skip_serializing_if = "Option::is_none")]
title: Option<String>,
},
#[serde(rename = "list_sessions")]
ListSessions {
#[serde(default)]
include_archived: bool,
},
#[serde(rename = "load_session")]
LoadSession { session_id: String },
#[serde(rename = "rename_session")]
RenameSession {
#[serde(default, skip_serializing_if = "Option::is_none")]
session_id: Option<String>,
title: String,
},
#[serde(rename = "archive_session")]
ArchiveSession {
#[serde(default, skip_serializing_if = "Option::is_none")]
session_id: Option<String>,
},
#[serde(rename = "delete_session")]
DeleteSession {
#[serde(default, skip_serializing_if = "Option::is_none")]
session_id: Option<String>,
},
#[serde(rename = "save_session")]
SaveSession {
#[serde(default, skip_serializing_if = "Option::is_none")]
filepath: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
session_id: Option<String>,
},
/// 命令JSON 格式)
#[serde(rename = "command")]
Command { payload: String },
#[serde(rename = "ping")]
Ping,
}
@ -126,14 +88,6 @@ pub enum WsOutbound {
title: String,
message_count: i64,
},
#[serde(rename = "session_renamed")]
SessionRenamed { session_id: String, title: String },
#[serde(rename = "session_archived")]
SessionArchived { session_id: String },
#[serde(rename = "session_deleted")]
SessionDeleted { session_id: String },
#[serde(rename = "history_cleared")]
HistoryCleared { session_id: String },
#[serde(rename = "session_saved")]
SessionSaved { session_id: String, filepath: String },
#[serde(rename = "pong")]

View File

@ -204,14 +204,24 @@ impl SessionStore {
Self::from_connection(Connection::open_in_memory()?)
}
pub fn create_cli_session(&self, title: Option<&str>) -> Result<SessionRecord, StorageError> {
pub fn create_session(
&self,
channel_name: &str,
title: Option<&str>,
) -> Result<SessionRecord, StorageError> {
let now = current_timestamp();
let id = uuid::Uuid::new_v4().to_string();
let title = title
.map(str::trim)
.filter(|value| !value.is_empty())
.map(ToOwned::to_owned)
.unwrap_or_else(|| format!("CLI Session {}", &id[..8]));
.unwrap_or_else(|| {
if channel_name == "cli" {
format!("CLI Session {}", &id[..8])
} else {
format!("Session {}", &id[..8])
}
});
let conn = self.conn.lock().expect("session db mutex poisoned");
conn.execute(
@ -220,9 +230,9 @@ impl SessionStore {
id, title, channel_name, chat_id, summary,
created_at, updated_at, last_active_at, archived_at, deleted_at, message_count,
reset_cutoff_seq, user_turn_count, agent_prompt_reinjection_count
) VALUES (?1, ?2, 'cli', ?3, NULL, ?4, ?4, ?4, NULL, NULL, 0, 0, 0, 0)
) VALUES (?1, ?2, ?3, ?4, NULL, ?5, ?5, ?5, NULL, NULL, 0, 0, 0, 0)
",
params![id, title, id, now],
params![id, title, channel_name, id, now],
)?;
drop(conn);
@ -230,6 +240,10 @@ impl SessionStore {
.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
}
pub fn create_cli_session(&self, title: Option<&str>) -> Result<SessionRecord, StorageError> {
self.create_session("cli", title)
}
pub fn ensure_channel_session(
&self,
channel_name: &str,

View File

@ -1,4 +1,4 @@
use picobot::protocol::{SessionSummary, WsInbound, WsOutbound};
use picobot::protocol::{WsInbound, WsOutbound};
use picobot::providers::{ChatCompletionRequest, Message};
/// Test that message with special characters is properly escaped
@ -53,70 +53,65 @@ fn test_chat_request_serialization() {
}
#[test]
fn test_session_inbound_serialization() {
let msg = WsInbound::CreateSession {
title: Some("demo".to_string()),
fn test_command_inbound_serialization() {
// Command is now sent as payload in WsInbound::Command
let command_json = r#"{"type":"create_session","title":"demo"}"#;
let msg = WsInbound::Command {
payload: command_json.to_string(),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains(r#""type":"create_session""#));
assert!(json.contains(r#""title":"demo""#));
assert!(json.contains(r#""type":"command""#));
assert!(json.contains(r#""payload":""#));
assert!(json.contains(r#"create_session"#));
}
#[test]
fn test_message_inbound_serialization() {
let msg = WsInbound::Message {
content: "Hello world".to_string(),
channel: None,
chat_id: Some("session-1".to_string()),
sender_id: Some("user-1".to_string()),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains(r#""type":"message""#));
assert!(json.contains(r#""content":"Hello world""#));
assert!(json.contains(r#""chat_id":"session-1""#));
let decoded: WsInbound = serde_json::from_str(&json).unwrap();
match decoded {
WsInbound::CreateSession { title } => {
assert_eq!(title.as_deref(), Some("demo"));
WsInbound::Message { content, chat_id, .. } => {
assert_eq!(content, "Hello world");
assert_eq!(chat_id.as_deref(), Some("session-1"));
}
other => panic!("unexpected decoded variant: {:?}", other),
}
}
#[test]
fn test_session_list_outbound_serialization() {
let msg = WsOutbound::SessionList {
sessions: vec![SessionSummary {
session_id: "session-1".to_string(),
title: "demo".to_string(),
channel_name: "cli".to_string(),
chat_id: "session-1".to_string(),
message_count: 2,
last_active_at: 123,
archived_at: None,
}],
current_session_id: Some("session-1".to_string()),
fn test_session_created_outbound_serialization() {
let msg = WsOutbound::SessionCreated {
session_id: "session-1".to_string(),
title: "demo".to_string(),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains(r#""type":"session_list""#));
assert!(json.contains(r#""type":"session_created""#));
assert!(json.contains(r#""session_id":"session-1""#));
assert!(json.contains(r#""message_count":2"#));
assert!(json.contains(r#""title":"demo""#));
let decoded: WsOutbound = serde_json::from_str(&json).unwrap();
match decoded {
WsOutbound::SessionList {
sessions,
current_session_id,
} => {
assert_eq!(sessions.len(), 1);
assert_eq!(sessions[0].title, "demo");
assert_eq!(current_session_id.as_deref(), Some("session-1"));
WsOutbound::SessionCreated { session_id, title } => {
assert_eq!(session_id, "session-1");
assert_eq!(title, "demo");
}
other => panic!("unexpected decoded variant: {:?}", other),
}
}
#[test]
fn test_clear_history_with_session_id_serialization() {
let msg = WsInbound::ClearHistory {
chat_id: None,
session_id: Some("session-1".to_string()),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains(r#""type":"clear_history""#));
assert!(json.contains(r#""session_id":"session-1""#));
}
#[test]
fn test_tool_call_outbound_serialization() {
let msg = WsOutbound::ToolCall {