Compare commits
3 Commits
b33350c410
...
e005d06a9b
| Author | SHA1 | Date | |
|---|---|---|---|
| e005d06a9b | |||
| 5eb9a26843 | |||
| b77fc93d71 |
@ -10,14 +10,10 @@ pub enum InputEvent {
|
|||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub enum InputCommand {
|
pub enum InputCommand {
|
||||||
Exit,
|
Exit,
|
||||||
Clear,
|
|
||||||
New(Option<String>),
|
New(Option<String>),
|
||||||
Save(Option<String>),
|
Save(Option<String>),
|
||||||
Sessions,
|
Sessions,
|
||||||
Use(String),
|
Use(String),
|
||||||
Rename(String),
|
|
||||||
Archive,
|
|
||||||
Delete,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct InputHandler {
|
pub struct InputHandler {
|
||||||
@ -74,14 +70,10 @@ impl InputHandler {
|
|||||||
|
|
||||||
match command {
|
match command {
|
||||||
"/quit" | "/exit" | "/q" => Some(InputCommand::Exit),
|
"/quit" | "/exit" | "/q" => Some(InputCommand::Exit),
|
||||||
"/clear" => Some(InputCommand::Clear),
|
|
||||||
"/new" => Some(InputCommand::New(arg.map(ToOwned::to_owned))),
|
"/new" => Some(InputCommand::New(arg.map(ToOwned::to_owned))),
|
||||||
"/save" => Some(InputCommand::Save(arg.map(ToOwned::to_owned))),
|
"/save" => Some(InputCommand::Save(arg.map(ToOwned::to_owned))),
|
||||||
"/sessions" => Some(InputCommand::Sessions),
|
"/sessions" | "/list" => Some(InputCommand::Sessions),
|
||||||
"/use" => arg.map(|value| InputCommand::Use(value.to_string())),
|
"/use" => arg.map(|value| InputCommand::Use(value.to_string())),
|
||||||
"/rename" => arg.map(|value| InputCommand::Rename(value.to_string())),
|
|
||||||
"/archive" => Some(InputCommand::Archive),
|
|
||||||
"/delete" => Some(InputCommand::Delete),
|
|
||||||
_ => None,
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -120,10 +112,6 @@ mod tests {
|
|||||||
handler.handle_special_commands("/quit"),
|
handler.handle_special_commands("/quit"),
|
||||||
Some(InputCommand::Exit)
|
Some(InputCommand::Exit)
|
||||||
);
|
);
|
||||||
assert_eq!(
|
|
||||||
handler.handle_special_commands("/clear"),
|
|
||||||
Some(InputCommand::Clear)
|
|
||||||
);
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
handler.handle_special_commands("/new"),
|
handler.handle_special_commands("/new"),
|
||||||
Some(InputCommand::New(None))
|
Some(InputCommand::New(None))
|
||||||
@ -140,6 +128,10 @@ mod tests {
|
|||||||
handler.handle_special_commands("/save ./debug/session.md"),
|
handler.handle_special_commands("/save ./debug/session.md"),
|
||||||
Some(InputCommand::Save(Some("./debug/session.md".to_string())))
|
Some(InputCommand::Save(Some("./debug/session.md".to_string())))
|
||||||
);
|
);
|
||||||
|
assert_eq!(
|
||||||
|
handler.handle_special_commands("/list"),
|
||||||
|
Some(InputCommand::Sessions)
|
||||||
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
handler.handle_special_commands("/sessions"),
|
handler.handle_special_commands("/sessions"),
|
||||||
Some(InputCommand::Sessions)
|
Some(InputCommand::Sessions)
|
||||||
@ -148,18 +140,6 @@ mod tests {
|
|||||||
handler.handle_special_commands("/use abc123"),
|
handler.handle_special_commands("/use abc123"),
|
||||||
Some(InputCommand::Use("abc123".to_string()))
|
Some(InputCommand::Use("abc123".to_string()))
|
||||||
);
|
);
|
||||||
assert_eq!(
|
|
||||||
handler.handle_special_commands("/rename project alpha"),
|
|
||||||
Some(InputCommand::Rename("project alpha".to_string()))
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
handler.handle_special_commands("/archive"),
|
|
||||||
Some(InputCommand::Archive)
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
handler.handle_special_commands("/delete"),
|
|
||||||
Some(InputCommand::Delete)
|
|
||||||
);
|
|
||||||
assert_eq!(handler.handle_special_commands("/unknown"), None);
|
assert_eq!(handler.handle_special_commands("/unknown"), None);
|
||||||
assert_eq!(handler.handle_special_commands("/use"), None);
|
assert_eq!(handler.handle_special_commands("/use"), None);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -8,36 +8,6 @@ use tokio_tungstenite::{connect_async, tungstenite::Message};
|
|||||||
|
|
||||||
use crate::cli::{InputCommand, InputEvent, InputHandler};
|
use crate::cli::{InputCommand, InputEvent, InputHandler};
|
||||||
|
|
||||||
fn format_session_list(
|
|
||||||
sessions: &[crate::protocol::SessionSummary],
|
|
||||||
current_session_id: Option<&str>,
|
|
||||||
) -> String {
|
|
||||||
if sessions.is_empty() {
|
|
||||||
return "No sessions found.".to_string();
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut lines = Vec::with_capacity(sessions.len() + 1);
|
|
||||||
lines.push("Sessions:".to_string());
|
|
||||||
for session in sessions {
|
|
||||||
let marker = if current_session_id == Some(session.session_id.as_str()) {
|
|
||||||
"*"
|
|
||||||
} else {
|
|
||||||
"-"
|
|
||||||
};
|
|
||||||
let archived = if session.archived_at.is_some() {
|
|
||||||
" [archived]"
|
|
||||||
} else {
|
|
||||||
""
|
|
||||||
};
|
|
||||||
lines.push(format!(
|
|
||||||
"{} {} | {} | {} messages{}",
|
|
||||||
marker, session.session_id, session.title, session.message_count, archived,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
lines.join("\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parse_message(raw: &str) -> Result<WsOutbound, serde_json::Error> {
|
fn parse_message(raw: &str) -> Result<WsOutbound, serde_json::Error> {
|
||||||
serde_json::from_str(raw)
|
serde_json::from_str(raw)
|
||||||
}
|
}
|
||||||
@ -54,7 +24,7 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
|
|
||||||
let mut input = InputHandler::new();
|
let mut input = InputHandler::new();
|
||||||
let mut current_session_id: Option<String> = None;
|
let mut current_session_id: Option<String> = None;
|
||||||
input.write_output("picobot CLI - Commands: /new [title], /save [filepath], /reset, /sessions, /use <session>, /rename <title>, /archive, /delete, /clear, /quit\n").await?;
|
input.write_output("picobot CLI - Commands: /new [title], /save [filepath], /quit\n").await?;
|
||||||
|
|
||||||
// Main loop: poll both stdin and WebSocket
|
// Main loop: poll both stdin and WebSocket
|
||||||
loop {
|
loop {
|
||||||
@ -91,29 +61,6 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
current_session_id = Some(session_id.clone());
|
current_session_id = Some(session_id.clone());
|
||||||
input.write_output(&format!("Created session: {} ({})\n", session_id, title)).await?;
|
input.write_output(&format!("Created session: {} ({})\n", session_id, title)).await?;
|
||||||
}
|
}
|
||||||
WsOutbound::SessionList { sessions, current_session_id: listed_current } => {
|
|
||||||
let display = format_session_list(&sessions, listed_current.as_deref());
|
|
||||||
input.write_output(&format!("{}\n", display)).await?;
|
|
||||||
}
|
|
||||||
WsOutbound::SessionLoaded { session_id, title, message_count } => {
|
|
||||||
current_session_id = Some(session_id.clone());
|
|
||||||
input.write_output(&format!("Loaded session: {} ({}, {} messages)\n", session_id, title, message_count)).await?;
|
|
||||||
}
|
|
||||||
WsOutbound::SessionRenamed { session_id, title } => {
|
|
||||||
input.write_output(&format!("Renamed session: {} -> {}\n", session_id, title)).await?;
|
|
||||||
}
|
|
||||||
WsOutbound::SessionArchived { session_id } => {
|
|
||||||
input.write_output(&format!("Archived session: {}\n", session_id)).await?;
|
|
||||||
}
|
|
||||||
WsOutbound::SessionDeleted { session_id } => {
|
|
||||||
if current_session_id.as_deref() == Some(session_id.as_str()) {
|
|
||||||
current_session_id = None;
|
|
||||||
}
|
|
||||||
input.write_output(&format!("Deleted session: {}\n", session_id)).await?;
|
|
||||||
}
|
|
||||||
WsOutbound::HistoryCleared { session_id } => {
|
|
||||||
input.write_output(&format!("Cleared history for session: {}\n", session_id)).await?;
|
|
||||||
}
|
|
||||||
WsOutbound::SessionSaved { session_id, filepath } => {
|
WsOutbound::SessionSaved { session_id, filepath } => {
|
||||||
input.write_output(&format!("Saved session {} to: {}\n", session_id, filepath)).await?;
|
input.write_output(&format!("Saved session {} to: {}\n", session_id, filepath)).await?;
|
||||||
}
|
}
|
||||||
@ -138,39 +85,25 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
input.write_output("Goodbye!").await?;
|
input.write_output("Goodbye!").await?;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
InputEvent::Command(InputCommand::Clear) => {
|
|
||||||
let inbound = WsInbound::ClearHistory {
|
|
||||||
chat_id: None,
|
|
||||||
session_id: current_session_id.clone(),
|
|
||||||
};
|
|
||||||
if let Ok(text) = serialize_inbound(&inbound) {
|
|
||||||
let _ = sender.send(Message::Text(text.into())).await;
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
InputEvent::Command(InputCommand::New(title)) => {
|
InputEvent::Command(InputCommand::New(title)) => {
|
||||||
// 使用新的命令层:通过 CliInputAdapter 构建 Command
|
// 使用 CliInputAdapter 构建 Command
|
||||||
let adapter = CliInputAdapter::new();
|
let adapter = CliInputAdapter::new();
|
||||||
let ctx = AdapterContext::new("cli")
|
let ctx = AdapterContext::new("cli")
|
||||||
.with_session_id(current_session_id.as_deref().unwrap_or(""));
|
.with_session_id(current_session_id.as_deref().unwrap_or(""));
|
||||||
|
|
||||||
// 构建输入字符串
|
// 构建输入字符串
|
||||||
let input = match title {
|
let input_str = match title {
|
||||||
Some(t) => format!("/new {}", t),
|
Some(t) => format!("/new {}", t),
|
||||||
None => "/new".to_string(),
|
None => "/new".to_string(),
|
||||||
};
|
};
|
||||||
|
|
||||||
// 解析为 Command
|
// 解析为 Command
|
||||||
match adapter.try_parse(&input, ctx) {
|
match adapter.try_parse(&input_str, ctx) {
|
||||||
Ok(Some(command)) => {
|
Ok(Some(command)) => {
|
||||||
// 序列化为 JSON 通过 WebSocket 发送
|
// 序列化为 JSON
|
||||||
let json = serde_json::to_string(&command).unwrap_or_default();
|
let json = serde_json::to_string(&command).unwrap_or_default();
|
||||||
let inbound = WsInbound::UserInput {
|
// 通过 Command 消息发送
|
||||||
content: json,
|
let inbound = WsInbound::Command { payload: json };
|
||||||
channel: None,
|
|
||||||
chat_id: current_session_id.clone(),
|
|
||||||
sender_id: None,
|
|
||||||
};
|
|
||||||
if let Ok(text) = serialize_inbound(&inbound) {
|
if let Ok(text) = serialize_inbound(&inbound) {
|
||||||
let _ = sender.send(Message::Text(text.into())).await;
|
let _ = sender.send(Message::Text(text.into())).await;
|
||||||
}
|
}
|
||||||
@ -184,62 +117,97 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
InputEvent::Command(InputCommand::Sessions) => {
|
InputEvent::Command(InputCommand::Save(filepath)) => {
|
||||||
let inbound = WsInbound::ListSessions {
|
// 使用 CliInputAdapter 构建 Command
|
||||||
include_archived: true,
|
let adapter = CliInputAdapter::new();
|
||||||
|
let ctx = AdapterContext::new("cli")
|
||||||
|
.with_session_id(current_session_id.as_deref().unwrap_or(""));
|
||||||
|
|
||||||
|
// 构建输入字符串
|
||||||
|
let input_str = match filepath {
|
||||||
|
Some(p) => format!("/save {}", p),
|
||||||
|
None => "/save".to_string(),
|
||||||
};
|
};
|
||||||
if let Ok(text) = serialize_inbound(&inbound) {
|
|
||||||
let _ = sender.send(Message::Text(text.into())).await;
|
// 解析为 Command
|
||||||
|
match adapter.try_parse(&input_str, ctx) {
|
||||||
|
Ok(Some(command)) => {
|
||||||
|
// 序列化为 JSON
|
||||||
|
let json = serde_json::to_string(&command).unwrap_or_default();
|
||||||
|
// 通过 Command 消息发送
|
||||||
|
let inbound = WsInbound::Command { payload: json };
|
||||||
|
if let Ok(text) = serialize_inbound(&inbound) {
|
||||||
|
let _ = sender.send(Message::Text(text.into())).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(None) => {
|
||||||
|
tracing::warn!("Failed to parse /save command");
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!(error = %e, "Error parsing /save command");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
InputEvent::Command(InputCommand::Sessions) => {
|
||||||
|
// 使用 CliInputAdapter 构建 Command
|
||||||
|
let adapter = CliInputAdapter::new();
|
||||||
|
let ctx = AdapterContext::new("cli")
|
||||||
|
.with_session_id(current_session_id.as_deref().unwrap_or(""));
|
||||||
|
|
||||||
|
// 解析为 Command
|
||||||
|
match adapter.try_parse("/list", ctx) {
|
||||||
|
Ok(Some(command)) => {
|
||||||
|
// 序列化为 JSON
|
||||||
|
let json = serde_json::to_string(&command).unwrap_or_default();
|
||||||
|
// 通过 Command 消息发送
|
||||||
|
let inbound = WsInbound::Command { payload: json };
|
||||||
|
if let Ok(text) = serialize_inbound(&inbound) {
|
||||||
|
let _ = sender.send(Message::Text(text.into())).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(None) => {
|
||||||
|
tracing::warn!("Failed to parse /list command");
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!(error = %e, "Error parsing /list command");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
InputEvent::Command(InputCommand::Use(session_id)) => {
|
InputEvent::Command(InputCommand::Use(session_id)) => {
|
||||||
let inbound = WsInbound::LoadSession { session_id };
|
// 使用 CliInputAdapter 构建 Command
|
||||||
if let Ok(text) = serialize_inbound(&inbound) {
|
let adapter = CliInputAdapter::new();
|
||||||
let _ = sender.send(Message::Text(text.into())).await;
|
let ctx = AdapterContext::new("cli")
|
||||||
}
|
.with_session_id(current_session_id.as_deref().unwrap_or(""));
|
||||||
continue;
|
|
||||||
}
|
// 构建输入字符串
|
||||||
InputEvent::Command(InputCommand::Rename(title)) => {
|
let input_str = format!("/use {}", session_id);
|
||||||
let inbound = WsInbound::RenameSession {
|
|
||||||
session_id: current_session_id.clone(),
|
// 解析为 Command
|
||||||
title,
|
match adapter.try_parse(&input_str, ctx) {
|
||||||
};
|
Ok(Some(command)) => {
|
||||||
if let Ok(text) = serialize_inbound(&inbound) {
|
// 序列化为 JSON
|
||||||
let _ = sender.send(Message::Text(text.into())).await;
|
let json = serde_json::to_string(&command).unwrap_or_default();
|
||||||
}
|
// 通过 Command 消息发送
|
||||||
continue;
|
let inbound = WsInbound::Command { payload: json };
|
||||||
}
|
if let Ok(text) = serialize_inbound(&inbound) {
|
||||||
InputEvent::Command(InputCommand::Archive) => {
|
let _ = sender.send(Message::Text(text.into())).await;
|
||||||
let inbound = WsInbound::ArchiveSession {
|
}
|
||||||
session_id: current_session_id.clone(),
|
// 更新当前会话 ID
|
||||||
};
|
current_session_id = Some(session_id.clone());
|
||||||
if let Ok(text) = serialize_inbound(&inbound) {
|
}
|
||||||
let _ = sender.send(Message::Text(text.into())).await;
|
Ok(None) => {
|
||||||
}
|
tracing::warn!("Failed to parse /use command");
|
||||||
continue;
|
}
|
||||||
}
|
Err(e) => {
|
||||||
InputEvent::Command(InputCommand::Delete) => {
|
tracing::error!(error = %e, "Error parsing /use command");
|
||||||
let inbound = WsInbound::DeleteSession {
|
}
|
||||||
session_id: current_session_id.clone(),
|
|
||||||
};
|
|
||||||
if let Ok(text) = serialize_inbound(&inbound) {
|
|
||||||
let _ = sender.send(Message::Text(text.into())).await;
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
InputEvent::Command(InputCommand::Save(filepath)) => {
|
|
||||||
let inbound = WsInbound::SaveSession {
|
|
||||||
filepath,
|
|
||||||
session_id: current_session_id.clone(),
|
|
||||||
};
|
|
||||||
if let Ok(text) = serialize_inbound(&inbound) {
|
|
||||||
let _ = sender.send(Message::Text(text.into())).await;
|
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
InputEvent::Message(msg) => {
|
InputEvent::Message(msg) => {
|
||||||
let inbound = WsInbound::UserInput {
|
let inbound = WsInbound::Message {
|
||||||
content: msg.content,
|
content: msg.content,
|
||||||
channel: None,
|
channel: None,
|
||||||
chat_id: current_session_id.clone(),
|
chat_id: current_session_id.clone(),
|
||||||
|
|||||||
92
src/command/adapters/channel.rs
Normal file
92
src/command/adapters/channel.rs
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
use crate::command::adapter::{AdapterError, InputAdapter};
|
||||||
|
use crate::command::context::AdapterContext;
|
||||||
|
use crate::command::Command;
|
||||||
|
|
||||||
|
/// Channel 输入适配器
|
||||||
|
///
|
||||||
|
/// 将 Channel 消息中的文本命令(如 "/new", "/save")转换为 Command
|
||||||
|
pub struct ChannelInputAdapter;
|
||||||
|
|
||||||
|
impl ChannelInputAdapter {
|
||||||
|
/// 创建新的 Channel 输入适配器
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ChannelInputAdapter {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl InputAdapter for ChannelInputAdapter {
|
||||||
|
fn try_parse(
|
||||||
|
&self,
|
||||||
|
input: &str,
|
||||||
|
_ctx: AdapterContext,
|
||||||
|
) -> Result<Option<Command>, AdapterError> {
|
||||||
|
let trimmed = input.trim();
|
||||||
|
|
||||||
|
// 解析 /new 命令
|
||||||
|
if trimmed == "/new" {
|
||||||
|
return Ok(Some(Command::CreateSession { title: None }));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(title) = trimmed.strip_prefix("/new ") {
|
||||||
|
let title = title.trim();
|
||||||
|
return Ok(Some(Command::CreateSession {
|
||||||
|
title: Some(title.to_string()),
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析 /save 命令
|
||||||
|
if trimmed == "/save" {
|
||||||
|
return Ok(Some(Command::SaveSession {
|
||||||
|
filepath: None,
|
||||||
|
include_all: false,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(args) = trimmed.strip_prefix("/save ") {
|
||||||
|
let args = args.trim();
|
||||||
|
// 解析参数:可能是 "all"、路径、或 "all 路径"
|
||||||
|
let (include_all, filepath) = if args == "all" {
|
||||||
|
// /save all - 保存全部消息
|
||||||
|
(true, None)
|
||||||
|
} else if args.starts_with("all ") {
|
||||||
|
// /save all <filepath> - 保存全部消息到指定路径
|
||||||
|
let path = args[4..].trim();
|
||||||
|
(true, Some(path.to_string()))
|
||||||
|
} else {
|
||||||
|
// /save <filepath> - 保存活跃消息到指定路径
|
||||||
|
(false, Some(args.to_string()))
|
||||||
|
};
|
||||||
|
return Ok(Some(Command::SaveSession { filepath, include_all }));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析 /list 命令
|
||||||
|
if trimmed == "/list" {
|
||||||
|
return Ok(Some(Command::ListSessions {
|
||||||
|
include_archived: false,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
if trimmed == "/list all" {
|
||||||
|
return Ok(Some(Command::ListSessions {
|
||||||
|
include_archived: true,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析 /use 命令
|
||||||
|
if let Some(session_id) = trimmed.strip_prefix("/use ") {
|
||||||
|
let session_id = session_id.trim();
|
||||||
|
return Ok(Some(Command::LoadSession {
|
||||||
|
session_id: session_id.to_string(),
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 不是命令,返回 None
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -66,6 +66,27 @@ impl InputAdapter for CliInputAdapter {
|
|||||||
return Ok(Some(Command::SaveSession { filepath, include_all }));
|
return Ok(Some(Command::SaveSession { filepath, include_all }));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 解析 /list 命令
|
||||||
|
if trimmed == "/list" {
|
||||||
|
return Ok(Some(Command::ListSessions {
|
||||||
|
include_archived: false,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
if trimmed == "/list all" {
|
||||||
|
return Ok(Some(Command::ListSessions {
|
||||||
|
include_archived: true,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析 /use 命令
|
||||||
|
if let Some(session_id) = trimmed.strip_prefix("/use ") {
|
||||||
|
let session_id = session_id.trim();
|
||||||
|
return Ok(Some(Command::LoadSession {
|
||||||
|
session_id: session_id.to_string(),
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
// 不是命令,返回 None
|
// 不是命令,返回 None
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,2 +1,3 @@
|
|||||||
|
pub mod channel;
|
||||||
pub mod cli;
|
pub mod cli;
|
||||||
pub mod websocket;
|
pub mod websocket;
|
||||||
|
|||||||
@ -12,18 +12,21 @@ pub struct CommandContext {
|
|||||||
pub chat_id: Option<String>,
|
pub chat_id: Option<String>,
|
||||||
/// 发送者ID
|
/// 发送者ID
|
||||||
pub sender_id: String,
|
pub sender_id: String,
|
||||||
|
/// 通道名称(如 "cli", "feishu", "wechat")
|
||||||
|
pub channel_name: String,
|
||||||
/// 额外元数据
|
/// 额外元数据
|
||||||
pub metadata: HashMap<String, String>,
|
pub metadata: HashMap<String, String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CommandContext {
|
impl CommandContext {
|
||||||
/// 创建新的命令上下文
|
/// 创建新的命令上下文
|
||||||
pub fn new(sender_id: impl Into<String>) -> Self {
|
pub fn new(sender_id: impl Into<String>, channel_name: impl Into<String>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
request_id: Uuid::new_v4(),
|
request_id: Uuid::new_v4(),
|
||||||
session_id: None,
|
session_id: None,
|
||||||
chat_id: None,
|
chat_id: None,
|
||||||
sender_id: sender_id.into(),
|
sender_id: sender_id.into(),
|
||||||
|
channel_name: channel_name.into(),
|
||||||
metadata: HashMap::new(),
|
metadata: HashMap::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -238,7 +238,7 @@ mod tests {
|
|||||||
router.register(Box::new(TestHandler));
|
router.register(Box::new(TestHandler));
|
||||||
router.register(Box::new(NoOpHandler));
|
router.register(Box::new(NoOpHandler));
|
||||||
|
|
||||||
let ctx = CommandContext::new("test");
|
let ctx = CommandContext::new("test", "test");
|
||||||
let cmd = Command::CreateSession { title: None };
|
let cmd = Command::CreateSession { title: None };
|
||||||
|
|
||||||
let result = router.dispatch(cmd, ctx).await;
|
let result = router.dispatch(cmd, ctx).await;
|
||||||
@ -252,7 +252,7 @@ mod tests {
|
|||||||
async fn test_router_no_handler() {
|
async fn test_router_no_handler() {
|
||||||
let router = CommandRouter::new();
|
let router = CommandRouter::new();
|
||||||
|
|
||||||
let ctx = CommandContext::new("test");
|
let ctx = CommandContext::new("test", "test");
|
||||||
let cmd = Command::CreateSession { title: None };
|
let cmd = Command::CreateSession { title: None };
|
||||||
|
|
||||||
let result = router.dispatch(cmd, ctx).await;
|
let result = router.dispatch(cmd, ctx).await;
|
||||||
|
|||||||
@ -1,2 +1,3 @@
|
|||||||
pub mod save_session;
|
pub mod save_session;
|
||||||
pub mod session;
|
pub mod session;
|
||||||
|
pub mod session_query;
|
||||||
|
|||||||
@ -36,6 +36,7 @@ impl CommandHandler for SessionCommandHandler {
|
|||||||
match cmd {
|
match cmd {
|
||||||
Command::CreateSession { title } => handle_create_session(self, title, ctx).await,
|
Command::CreateSession { title } => handle_create_session(self, title, ctx).await,
|
||||||
Command::SaveSession { .. } => unreachable!("SaveSession should be handled by SaveSessionCommandHandler"),
|
Command::SaveSession { .. } => unreachable!("SaveSession should be handled by SaveSessionCommandHandler"),
|
||||||
|
_ => unreachable!("Other commands should be handled by other handlers"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -48,7 +49,7 @@ async fn handle_create_session(
|
|||||||
) -> Result<CommandResponse, CommandError> {
|
) -> Result<CommandResponse, CommandError> {
|
||||||
let record = handler
|
let record = handler
|
||||||
.cli_sessions
|
.cli_sessions
|
||||||
.create(title.as_deref())
|
.create_with_channel(&ctx.channel_name, title.as_deref())
|
||||||
.map_err(|e| CommandError::new("CREATE_SESSION_ERROR", e.to_string()))?;
|
.map_err(|e| CommandError::new("CREATE_SESSION_ERROR", e.to_string()))?;
|
||||||
|
|
||||||
Ok(CommandResponse::success(ctx.request_id)
|
Ok(CommandResponse::success(ctx.request_id)
|
||||||
@ -74,7 +75,7 @@ mod tests {
|
|||||||
async fn test_create_session_with_title() {
|
async fn test_create_session_with_title() {
|
||||||
let service = create_test_service();
|
let service = create_test_service();
|
||||||
let handler = SessionCommandHandler::new(service);
|
let handler = SessionCommandHandler::new(service);
|
||||||
let ctx = CommandContext::new("test");
|
let ctx = CommandContext::new("test", "test");
|
||||||
let cmd = Command::CreateSession {
|
let cmd = Command::CreateSession {
|
||||||
title: Some("my session".to_string()),
|
title: Some("my session".to_string()),
|
||||||
};
|
};
|
||||||
@ -93,7 +94,7 @@ mod tests {
|
|||||||
async fn test_create_session_without_title() {
|
async fn test_create_session_without_title() {
|
||||||
let service = create_test_service();
|
let service = create_test_service();
|
||||||
let handler = SessionCommandHandler::new(service);
|
let handler = SessionCommandHandler::new(service);
|
||||||
let ctx = CommandContext::new("test");
|
let ctx = CommandContext::new("test", "test");
|
||||||
let cmd = Command::CreateSession { title: None };
|
let cmd = Command::CreateSession { title: None };
|
||||||
|
|
||||||
let result = handler.handle(cmd, ctx).await;
|
let result = handler.handle(cmd, ctx).await;
|
||||||
|
|||||||
188
src/command/handlers/session_query.rs
Normal file
188
src/command/handlers/session_query.rs
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
use crate::command::context::CommandContext;
|
||||||
|
use crate::command::handler::CommandHandler;
|
||||||
|
use crate::command::response::{CommandError, CommandResponse, MessageKind};
|
||||||
|
use crate::command::Command;
|
||||||
|
use crate::gateway::cli_session::CliSessionService;
|
||||||
|
use crate::protocol::SessionSummary;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
/// 会话查询命令处理器
|
||||||
|
///
|
||||||
|
/// 处理 ListSessions 和 LoadSession 命令
|
||||||
|
pub struct SessionQueryCommandHandler {
|
||||||
|
cli_sessions: CliSessionService,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SessionQueryCommandHandler {
|
||||||
|
/// 创建新的会话查询命令处理器
|
||||||
|
pub fn new(cli_sessions: CliSessionService) -> Self {
|
||||||
|
Self { cli_sessions }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl CommandHandler for SessionQueryCommandHandler {
|
||||||
|
fn can_handle(&self, cmd: &Command) -> bool {
|
||||||
|
matches!(cmd, Command::ListSessions { .. } | Command::LoadSession { .. })
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle(
|
||||||
|
&self,
|
||||||
|
cmd: Command,
|
||||||
|
ctx: CommandContext,
|
||||||
|
) -> Result<CommandResponse, CommandError> {
|
||||||
|
match cmd {
|
||||||
|
Command::ListSessions { include_archived } => {
|
||||||
|
handle_list_sessions(self, include_archived, ctx).await
|
||||||
|
}
|
||||||
|
Command::LoadSession { session_id } => {
|
||||||
|
handle_load_session(self, session_id, ctx).await
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 处理列出会话命令
|
||||||
|
async fn handle_list_sessions(
|
||||||
|
handler: &SessionQueryCommandHandler,
|
||||||
|
include_archived: bool,
|
||||||
|
ctx: CommandContext,
|
||||||
|
) -> Result<CommandResponse, CommandError> {
|
||||||
|
let records = handler
|
||||||
|
.cli_sessions
|
||||||
|
.list(include_archived)
|
||||||
|
.map_err(|e| CommandError::new("LIST_SESSIONS_ERROR", e.to_string()))?;
|
||||||
|
|
||||||
|
let summaries: Vec<SessionSummary> = records
|
||||||
|
.into_iter()
|
||||||
|
.map(|r| SessionSummary {
|
||||||
|
session_id: r.id,
|
||||||
|
title: r.title,
|
||||||
|
channel_name: r.channel_name,
|
||||||
|
chat_id: r.chat_id,
|
||||||
|
message_count: r.message_count,
|
||||||
|
last_active_at: r.last_active_at,
|
||||||
|
archived_at: r.archived_at,
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// 将会话列表序列化为 JSON 存储在 metadata 中
|
||||||
|
let sessions_json =
|
||||||
|
serde_json::to_string(&summaries).map_err(|e| CommandError::new("SERIALIZE_ERROR", e.to_string()))?;
|
||||||
|
|
||||||
|
let message = if summaries.is_empty() {
|
||||||
|
"No sessions found.".to_string()
|
||||||
|
} else {
|
||||||
|
format!("Found {} session(s)", summaries.len())
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(CommandResponse::success(ctx.request_id)
|
||||||
|
.with_message(MessageKind::Notification, &message)
|
||||||
|
.with_metadata("sessions", &sessions_json)
|
||||||
|
.with_metadata("count", &summaries.len().to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 处理加载会话命令
|
||||||
|
async fn handle_load_session(
|
||||||
|
handler: &SessionQueryCommandHandler,
|
||||||
|
session_id: String,
|
||||||
|
ctx: CommandContext,
|
||||||
|
) -> Result<CommandResponse, CommandError> {
|
||||||
|
let record = handler
|
||||||
|
.cli_sessions
|
||||||
|
.get(&session_id)
|
||||||
|
.map_err(|e| CommandError::new("LOAD_SESSION_ERROR", e.to_string()))?
|
||||||
|
.ok_or_else(|| CommandError::new("SESSION_NOT_FOUND", format!("Session not found: {}", session_id)))?;
|
||||||
|
|
||||||
|
Ok(CommandResponse::success(ctx.request_id)
|
||||||
|
.with_message(MessageKind::Notification, &record.title)
|
||||||
|
.with_metadata("session_id", &record.id)
|
||||||
|
.with_metadata("title", &record.title)
|
||||||
|
.with_metadata("message_count", &record.message_count.to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::storage::SessionStore;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
fn create_test_service() -> CliSessionService {
|
||||||
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
|
CliSessionService::new(store)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_list_sessions_empty() {
|
||||||
|
let service = create_test_service();
|
||||||
|
let handler = SessionQueryCommandHandler::new(service);
|
||||||
|
let ctx = CommandContext::new("test", "test");
|
||||||
|
let cmd = Command::ListSessions {
|
||||||
|
include_archived: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = handler.handle(cmd, ctx).await;
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
let resp = result.unwrap();
|
||||||
|
assert!(resp.success);
|
||||||
|
assert!(resp.messages[0].content.contains("No sessions"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_list_sessions_with_items() {
|
||||||
|
let service = create_test_service();
|
||||||
|
let handler = SessionQueryCommandHandler::new(service.clone());
|
||||||
|
|
||||||
|
// 创建一些会话
|
||||||
|
service.create(Some("test session")).unwrap();
|
||||||
|
|
||||||
|
let ctx = CommandContext::new("test", "test");
|
||||||
|
let cmd = Command::ListSessions {
|
||||||
|
include_archived: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = handler.handle(cmd, ctx).await;
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
let resp = result.unwrap();
|
||||||
|
assert!(resp.success);
|
||||||
|
assert!(resp.metadata.contains_key("sessions"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_load_session_not_found() {
|
||||||
|
let service = create_test_service();
|
||||||
|
let handler = SessionQueryCommandHandler::new(service);
|
||||||
|
let ctx = CommandContext::new("test", "test");
|
||||||
|
let cmd = Command::LoadSession {
|
||||||
|
session_id: "nonexistent".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = handler.handle(cmd, ctx).await;
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_load_session_success() {
|
||||||
|
let service = create_test_service();
|
||||||
|
let handler = SessionQueryCommandHandler::new(service.clone());
|
||||||
|
|
||||||
|
// 创建会话
|
||||||
|
let record = service.create(Some("test session")).unwrap();
|
||||||
|
|
||||||
|
let ctx = CommandContext::new("test", "test");
|
||||||
|
let cmd = Command::LoadSession {
|
||||||
|
session_id: record.id.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = handler.handle(cmd, ctx).await;
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
let resp = result.unwrap();
|
||||||
|
assert!(resp.success);
|
||||||
|
assert_eq!(resp.metadata.get("session_id").unwrap(), &record.id);
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -18,6 +18,10 @@ pub enum Command {
|
|||||||
filepath: Option<String>,
|
filepath: Option<String>,
|
||||||
include_all: bool,
|
include_all: bool,
|
||||||
},
|
},
|
||||||
|
/// 列出会话
|
||||||
|
ListSessions { include_archived: bool },
|
||||||
|
/// 加载指定会话
|
||||||
|
LoadSession { session_id: String },
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Command {
|
impl Command {
|
||||||
@ -26,6 +30,8 @@ impl Command {
|
|||||||
match self {
|
match self {
|
||||||
Command::CreateSession { .. } => "create_session",
|
Command::CreateSession { .. } => "create_session",
|
||||||
Command::SaveSession { .. } => "save_session",
|
Command::SaveSession { .. } => "save_session",
|
||||||
|
Command::ListSessions { .. } => "list_sessions",
|
||||||
|
Command::LoadSession { .. } => "load_session",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1472,7 +1472,7 @@ mod tests {
|
|||||||
assert!(config.scheduler.jobs.is_empty());
|
assert!(config.scheduler.jobs.is_empty());
|
||||||
|
|
||||||
let effective_jobs = config.scheduler.effective_jobs(&config.time);
|
let effective_jobs = config.scheduler.effective_jobs(&config.time);
|
||||||
assert_eq!(effective_jobs.len(), 1);
|
assert_eq!(effective_jobs.len(), 2);
|
||||||
assert_eq!(effective_jobs[0].id, BUILTIN_MEMORY_MAINTENANCE_JOB_ID);
|
assert_eq!(effective_jobs[0].id, BUILTIN_MEMORY_MAINTENANCE_JOB_ID);
|
||||||
assert_eq!(effective_jobs[0].kind, SchedulerJobKind::InternalEvent);
|
assert_eq!(effective_jobs[0].kind, SchedulerJobKind::InternalEvent);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -1481,6 +1481,8 @@ mod tests {
|
|||||||
expression: "0 */4 * * *".to_string(),
|
expression: "0 */4 * * *".to_string(),
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
// 第二个内置作业是会话清理
|
||||||
|
assert_eq!(effective_jobs[1].id, BUILTIN_SESSION_CLEANUP_JOB_ID);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -1516,7 +1518,8 @@ mod tests {
|
|||||||
let effective_jobs = scheduler.effective_jobs(&TimeConfig {
|
let effective_jobs = scheduler.effective_jobs(&TimeConfig {
|
||||||
timezone: "Asia/Shanghai".to_string(),
|
timezone: "Asia/Shanghai".to_string(),
|
||||||
});
|
});
|
||||||
assert_eq!(effective_jobs.len(), 2);
|
assert_eq!(effective_jobs.len(), 3); // 2个内置 + 1个自定义
|
||||||
|
// 第一个作业:内存维护(被覆盖为禁用)
|
||||||
assert_eq!(effective_jobs[0].id, BUILTIN_MEMORY_MAINTENANCE_JOB_ID);
|
assert_eq!(effective_jobs[0].id, BUILTIN_MEMORY_MAINTENANCE_JOB_ID);
|
||||||
assert!(!effective_jobs[0].enabled);
|
assert!(!effective_jobs[0].enabled);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -1525,7 +1528,11 @@ mod tests {
|
|||||||
expression: "15 2 * * *".to_string(),
|
expression: "15 2 * * *".to_string(),
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
assert_eq!(effective_jobs[1].id, "custom.reminder");
|
// 第二个作业:会话清理(保持默认)
|
||||||
|
assert_eq!(effective_jobs[1].id, BUILTIN_SESSION_CLEANUP_JOB_ID);
|
||||||
|
assert!(effective_jobs[1].enabled);
|
||||||
|
// 第三个作业:自定义提醒
|
||||||
|
assert_eq!(effective_jobs[2].id, "custom.reminder");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@ -19,6 +19,17 @@ impl CliSessionService {
|
|||||||
.map_err(|err| AgentError::Other(format!("create session error: {}", err)))
|
.map_err(|err| AgentError::Other(format!("create session error: {}", err)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 创建指定通道的会话
|
||||||
|
pub(crate) fn create_with_channel(
|
||||||
|
&self,
|
||||||
|
channel_name: &str,
|
||||||
|
title: Option<&str>,
|
||||||
|
) -> Result<SessionRecord, AgentError> {
|
||||||
|
self.store
|
||||||
|
.create_session(channel_name, title)
|
||||||
|
.map_err(|err| AgentError::Other(format!("create session error: {}", err)))
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn get(&self, session_id: &str) -> Result<Option<SessionRecord>, AgentError> {
|
pub(crate) fn get(&self, session_id: &str) -> Result<Option<SessionRecord>, AgentError> {
|
||||||
self.store
|
self.store
|
||||||
.get_session(session_id)
|
.get_session(session_id)
|
||||||
|
|||||||
@ -1,165 +1,2 @@
|
|||||||
use crate::agent::AgentError;
|
// 此文件已废弃,InChatCommand 功能已合并到 Command 系统
|
||||||
|
// 保留文件以避免破坏现有 import,但内容为空
|
||||||
use super::session::Session;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
||||||
enum InChatCommand {
|
|
||||||
FreshConversation,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parse_in_chat_command(content: &str) -> Option<InChatCommand> {
|
|
||||||
match content.trim() {
|
|
||||||
"/new" | "/reset" => Some(InChatCommand::FreshConversation),
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn handle_in_chat_command(
|
|
||||||
session: &mut Session,
|
|
||||||
chat_id: &str,
|
|
||||||
content: &str,
|
|
||||||
) -> Result<Option<String>, AgentError> {
|
|
||||||
match parse_in_chat_command(content) {
|
|
||||||
Some(InChatCommand::FreshConversation) => {
|
|
||||||
session.reset_chat_context(chat_id)?;
|
|
||||||
Ok(Some("Started a fresh conversation.".to_string()))
|
|
||||||
}
|
|
||||||
None => Ok(None),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use crate::bus::ChatMessage;
|
|
||||||
use crate::config::LLMProviderConfig;
|
|
||||||
use crate::skills::SkillRuntime;
|
|
||||||
use crate::storage::SessionStore;
|
|
||||||
use crate::tools::ToolRegistry;
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use tokio::sync::mpsc;
|
|
||||||
|
|
||||||
const TEST_CHANNEL: &str = "test-channel";
|
|
||||||
|
|
||||||
fn test_provider_config() -> LLMProviderConfig {
|
|
||||||
LLMProviderConfig {
|
|
||||||
provider_type: "openai".to_string(),
|
|
||||||
name: "test".to_string(),
|
|
||||||
base_url: "http://localhost".to_string(),
|
|
||||||
api_key: "test-key".to_string(),
|
|
||||||
extra_headers: HashMap::new(),
|
|
||||||
llm_timeout_secs: 120,
|
|
||||||
memory_maintenance_timeout_secs: 600,
|
|
||||||
model_id: "test-model".to_string(),
|
|
||||||
temperature: Some(0.0),
|
|
||||||
max_tokens: Some(32),
|
|
||||||
context_window_tokens: None,
|
|
||||||
model_extra: HashMap::new(),
|
|
||||||
max_tool_iterations: 1,
|
|
||||||
tool_result_max_chars: 20_000,
|
|
||||||
context_tool_result_trim_chars: 20_000,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_parse_in_chat_command_aliases() {
|
|
||||||
assert_eq!(
|
|
||||||
parse_in_chat_command("/new"),
|
|
||||||
Some(InChatCommand::FreshConversation)
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
parse_in_chat_command(" /reset \n"),
|
|
||||||
Some(InChatCommand::FreshConversation)
|
|
||||||
);
|
|
||||||
assert_eq!(parse_in_chat_command("/new planning"), None);
|
|
||||||
assert_eq!(parse_in_chat_command("please /reset"), None);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_handle_in_chat_command_resets_active_history_only() {
|
|
||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
|
||||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
|
||||||
let skills = Arc::new(SkillRuntime::default());
|
|
||||||
let tools = Arc::new(ToolRegistry::new());
|
|
||||||
let mut session = Session::new(
|
|
||||||
TEST_CHANNEL.to_string(),
|
|
||||||
test_provider_config(),
|
|
||||||
user_tx,
|
|
||||||
tools,
|
|
||||||
skills,
|
|
||||||
store.clone(),
|
|
||||||
100,
|
|
||||||
Some(4),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
session.ensure_persistent_session("chat-1").unwrap();
|
|
||||||
session.ensure_chat_loaded("chat-1").unwrap();
|
|
||||||
session
|
|
||||||
.append_persisted_message("chat-1", ChatMessage::user("hello"))
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let response = handle_in_chat_command(&mut session, "chat-1", "/reset")
|
|
||||||
.unwrap()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(response, "Started a fresh conversation.");
|
|
||||||
assert!(session.get_history("chat-1").unwrap().is_empty());
|
|
||||||
assert!(
|
|
||||||
store
|
|
||||||
.load_messages(&session.persistent_session_id("chat-1"))
|
|
||||||
.unwrap()
|
|
||||||
.is_empty()
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
store
|
|
||||||
.load_all_messages(&session.persistent_session_id("chat-1"))
|
|
||||||
.unwrap()
|
|
||||||
.len(),
|
|
||||||
// 新设计:系统提示词不再持久化,只有 1 条用户消息
|
|
||||||
1,
|
|
||||||
);
|
|
||||||
|
|
||||||
session.ensure_chat_loaded("chat-1").unwrap();
|
|
||||||
let history = session.get_history("chat-1").unwrap();
|
|
||||||
// 新设计:系统提示词不再持久化到历史记录
|
|
||||||
assert_eq!(history.len(), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_reset_reinjects_agent_prompt_before_next_user_message() {
|
|
||||||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
|
||||||
let (user_tx, _user_rx) = mpsc::channel(4);
|
|
||||||
let skills = Arc::new(SkillRuntime::default());
|
|
||||||
let tools = Arc::new(ToolRegistry::new());
|
|
||||||
let mut session = Session::new(
|
|
||||||
TEST_CHANNEL.to_string(),
|
|
||||||
test_provider_config(),
|
|
||||||
user_tx,
|
|
||||||
tools,
|
|
||||||
skills,
|
|
||||||
store,
|
|
||||||
100,
|
|
||||||
Some(4),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
session.ensure_persistent_session("chat-1").unwrap();
|
|
||||||
session.ensure_chat_loaded("chat-1").unwrap();
|
|
||||||
session
|
|
||||||
.append_persisted_message("chat-1", ChatMessage::user("hello"))
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
handle_in_chat_command(&mut session, "chat-1", "/reset").unwrap();
|
|
||||||
session
|
|
||||||
.ensure_agent_prompt_before_user_message("chat-1")
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// 新设计:系统提示词不再持久化到历史记录
|
|
||||||
let history = session.get_history("chat-1").unwrap();
|
|
||||||
assert_eq!(history.len(), 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@ -7,7 +7,6 @@ use crate::bus::{ChatMessage, MediaItem, OutboundMessage, SYSTEM_CONTEXT_SCHEDUL
|
|||||||
use crate::config::LLMProviderConfig;
|
use crate::config::LLMProviderConfig;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
use super::command::handle_in_chat_command;
|
|
||||||
use super::compaction::schedule_background_history_compaction;
|
use super::compaction::schedule_background_history_compaction;
|
||||||
use super::message_prepare::enrich_user_content_with_media_refs;
|
use super::message_prepare::enrich_user_content_with_media_refs;
|
||||||
use super::session::Session;
|
use super::session::Session;
|
||||||
@ -138,18 +137,6 @@ impl AgentExecutionService {
|
|||||||
session_guard.ensure_persistent_session(request.chat_id)?;
|
session_guard.ensure_persistent_session(request.chat_id)?;
|
||||||
session_guard.ensure_chat_loaded(request.chat_id)?;
|
session_guard.ensure_chat_loaded(request.chat_id)?;
|
||||||
|
|
||||||
if let Some(command_response) =
|
|
||||||
handle_in_chat_command(&mut session_guard, request.chat_id, request.content)?
|
|
||||||
{
|
|
||||||
return Ok(vec![OutboundMessage::assistant(
|
|
||||||
request.channel_name.to_string(),
|
|
||||||
request.chat_id.to_string(),
|
|
||||||
command_response,
|
|
||||||
None,
|
|
||||||
HashMap::new(),
|
|
||||||
)]);
|
|
||||||
}
|
|
||||||
|
|
||||||
session_guard.ensure_agent_prompt_before_user_message(request.chat_id)?;
|
session_guard.ensure_agent_prompt_before_user_message(request.chat_id)?;
|
||||||
|
|
||||||
let media_refs: Vec<String> = request
|
let media_refs: Vec<String> = request
|
||||||
|
|||||||
@ -4,8 +4,12 @@ use tokio::sync::Semaphore;
|
|||||||
|
|
||||||
use crate::agent::{AgentError, CompositeSystemPromptProvider};
|
use crate::agent::{AgentError, CompositeSystemPromptProvider};
|
||||||
use crate::bus::{InboundMessage, MessageBus, OutboundMessage};
|
use crate::bus::{InboundMessage, MessageBus, OutboundMessage};
|
||||||
use crate::command::handler::InChatCommandRouter;
|
use crate::command::adapter::InputAdapter;
|
||||||
use crate::command::Command;
|
use crate::command::adapters::channel::ChannelInputAdapter;
|
||||||
|
use crate::command::handler::CommandRouter;
|
||||||
|
use crate::command::handlers::save_session::SaveSessionCommandHandler;
|
||||||
|
use crate::command::handlers::session::SessionCommandHandler;
|
||||||
|
use crate::command::handlers::session_query::SessionQueryCommandHandler;
|
||||||
use crate::config::LLMProviderConfig;
|
use crate::config::LLMProviderConfig;
|
||||||
use crate::gateway::agent_prompt_provider::AgentPromptProvider;
|
use crate::gateway::agent_prompt_provider::AgentPromptProvider;
|
||||||
use crate::skills::SkillPromptProvider;
|
use crate::skills::SkillPromptProvider;
|
||||||
@ -18,7 +22,7 @@ pub struct InboundProcessor {
|
|||||||
session_manager: SessionManager,
|
session_manager: SessionManager,
|
||||||
semaphore: Arc<Semaphore>,
|
semaphore: Arc<Semaphore>,
|
||||||
_provider_config: LLMProviderConfig,
|
_provider_config: LLMProviderConfig,
|
||||||
command_router: Arc<InChatCommandRouter>,
|
command_router: Arc<CommandRouter>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl InboundProcessor {
|
impl InboundProcessor {
|
||||||
@ -29,7 +33,14 @@ impl InboundProcessor {
|
|||||||
provider_config: LLMProviderConfig,
|
provider_config: LLMProviderConfig,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// 创建命令路由器并注册处理器
|
// 创建命令路由器并注册处理器
|
||||||
let mut command_router = InChatCommandRouter::new();
|
let mut command_router = CommandRouter::new();
|
||||||
|
|
||||||
|
// 注册 Session 处理器
|
||||||
|
let cli_sessions = session_manager.cli_sessions();
|
||||||
|
command_router.register(Box::new(SessionCommandHandler::new(cli_sessions.clone())));
|
||||||
|
|
||||||
|
// 注册 session_query 处理器
|
||||||
|
command_router.register(Box::new(SessionQueryCommandHandler::new(cli_sessions)));
|
||||||
|
|
||||||
// 注册 save_session 处理器
|
// 注册 save_session 处理器
|
||||||
let store = session_manager.store();
|
let store = session_manager.store();
|
||||||
@ -43,7 +54,7 @@ impl InboundProcessor {
|
|||||||
)),
|
)),
|
||||||
Box::new(SkillPromptProvider::new(skills)),
|
Box::new(SkillPromptProvider::new(skills)),
|
||||||
]));
|
]));
|
||||||
command_router.register(Box::new(crate::command::handlers::save_session::SaveSessionInChatHandler::new(
|
command_router.register(Box::new(SaveSessionCommandHandler::new(
|
||||||
store,
|
store,
|
||||||
system_prompt_provider,
|
system_prompt_provider,
|
||||||
)));
|
)));
|
||||||
@ -103,18 +114,28 @@ impl InboundProcessor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn process_one(&self, inbound: InboundMessage) -> Result<(), AgentError> {
|
async fn process_one(&self, inbound: InboundMessage) -> Result<(), AgentError> {
|
||||||
// 尝试解析为命令
|
// 使用 ChannelInputAdapter 尝试解析命令
|
||||||
if let Some(cmd) = parse_in_chat_command(&inbound.content) {
|
let adapter = ChannelInputAdapter::new();
|
||||||
|
let ctx = crate::command::context::AdapterContext::new(&inbound.channel)
|
||||||
|
.with_session_id(&inbound.chat_id);
|
||||||
|
|
||||||
|
if let Ok(Some(cmd)) = adapter.try_parse(&inbound.content, ctx) {
|
||||||
// 使用命令路由器处理
|
// 使用命令路由器处理
|
||||||
match self.command_router.dispatch(cmd, &inbound, &self.session_manager).await? {
|
let cmd_ctx = crate::command::context::CommandContext::new(&inbound.channel, &inbound.channel)
|
||||||
Some(response_msg) => {
|
.with_session_id(&inbound.chat_id);
|
||||||
// 发送命令执行结果给用户
|
|
||||||
|
let response = self.command_router.dispatch_with_response(cmd, cmd_ctx).await;
|
||||||
|
|
||||||
|
// 发送响应给用户
|
||||||
|
if response.success {
|
||||||
|
// 提取响应消息
|
||||||
|
for msg in &response.messages {
|
||||||
if let Err(error) = self
|
if let Err(error) = self
|
||||||
.bus
|
.bus
|
||||||
.publish_outbound(OutboundMessage::assistant(
|
.publish_outbound(OutboundMessage::assistant(
|
||||||
inbound.channel.clone(),
|
inbound.channel.clone(),
|
||||||
inbound.chat_id.clone(),
|
inbound.chat_id.clone(),
|
||||||
response_msg,
|
msg.content.clone(),
|
||||||
None,
|
None,
|
||||||
inbound.forwarded_metadata.clone(),
|
inbound.forwarded_metadata.clone(),
|
||||||
))
|
))
|
||||||
@ -122,13 +143,23 @@ impl InboundProcessor {
|
|||||||
{
|
{
|
||||||
tracing::error!(error = %error, "Failed to publish command response");
|
tracing::error!(error = %error, "Failed to publish command response");
|
||||||
}
|
}
|
||||||
return Ok(());
|
|
||||||
}
|
}
|
||||||
None => {
|
} else if let Some(error) = response.error {
|
||||||
// 命令已处理但没有返回消息
|
if let Err(e) = self
|
||||||
return Ok(());
|
.bus
|
||||||
|
.publish_outbound(OutboundMessage::assistant(
|
||||||
|
inbound.channel.clone(),
|
||||||
|
inbound.chat_id.clone(),
|
||||||
|
format!("Error [{}]: {}", error.code, error.message),
|
||||||
|
None,
|
||||||
|
inbound.forwarded_metadata.clone(),
|
||||||
|
))
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
tracing::error!(error = %e, "Failed to publish error response");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
// 普通消息进入 AgentLoop
|
// 普通消息进入 AgentLoop
|
||||||
@ -183,41 +214,3 @@ impl InboundProcessor {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 解析聊天中的命令
|
|
||||||
///
|
|
||||||
/// 支持格式:
|
|
||||||
/// - `/save` - 保存活跃会话消息(到 cutoff)
|
|
||||||
/// - `/save all` - 保存全部会话消息(包括 cutoff 之前)
|
|
||||||
/// - `/save <filepath>` - 保存活跃消息到指定路径
|
|
||||||
/// - `/save all <filepath>` - 保存全部消息到指定路径
|
|
||||||
///
|
|
||||||
/// 返回 Some(Command) 如果是命令
|
|
||||||
/// 返回 None 如果不是命令
|
|
||||||
fn parse_in_chat_command(content: &str) -> Option<Command> {
|
|
||||||
let trimmed = content.trim();
|
|
||||||
|
|
||||||
if trimmed.starts_with("/save") {
|
|
||||||
let args = trimmed[5..].trim();
|
|
||||||
|
|
||||||
// 解析参数
|
|
||||||
let (include_all, filepath) = if args.is_empty() {
|
|
||||||
// /save 无参数 - 只保存活跃消息
|
|
||||||
(false, None)
|
|
||||||
} else if args == "all" {
|
|
||||||
// /save all - 保存全部消息
|
|
||||||
(true, None)
|
|
||||||
} else if args.starts_with("all ") {
|
|
||||||
// /save all <filepath> - 保存全部消息到指定路径
|
|
||||||
let path = args[4..].trim();
|
|
||||||
(true, Some(path.to_string()))
|
|
||||||
} else {
|
|
||||||
// /save <filepath> - 保存活跃消息到指定路径
|
|
||||||
(false, Some(args.to_string()))
|
|
||||||
};
|
|
||||||
|
|
||||||
Some(Command::SaveSession { filepath, include_all })
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@ -1,14 +1,15 @@
|
|||||||
use super::GatewayState;
|
use super::GatewayState;
|
||||||
use crate::agent::{AgentError, CompositeSystemPromptProvider};
|
use crate::agent::{AgentError, CompositeSystemPromptProvider};
|
||||||
use crate::bus::InboundMessage;
|
use crate::bus::InboundMessage;
|
||||||
use crate::command::adapter::OutputAdapter;
|
use crate::command::adapter::{InputAdapter, OutputAdapter};
|
||||||
use crate::command::adapters::websocket::{WebSocketInputAdapter, WebSocketOutputAdapter};
|
use crate::command::adapters::websocket::{WebSocketInputAdapter, WebSocketOutputAdapter};
|
||||||
use crate::command::context::CommandContext;
|
use crate::command::context::CommandContext;
|
||||||
use crate::command::handler::CommandRouter;
|
use crate::command::handler::CommandRouter;
|
||||||
use crate::command::handlers::save_session::SaveSessionCommandHandler;
|
use crate::command::handlers::save_session::SaveSessionCommandHandler;
|
||||||
use crate::command::handlers::session::SessionCommandHandler;
|
use crate::command::handlers::session::SessionCommandHandler;
|
||||||
|
use crate::command::handlers::session_query::SessionQueryCommandHandler;
|
||||||
use crate::gateway::agent_prompt_provider::AgentPromptProvider;
|
use crate::gateway::agent_prompt_provider::AgentPromptProvider;
|
||||||
use crate::protocol::{SessionSummary, WsInbound, WsOutbound, parse_inbound, serialize_outbound};
|
use crate::protocol::{WsInbound, WsOutbound, parse_inbound, serialize_outbound};
|
||||||
use crate::skills::SkillPromptProvider;
|
use crate::skills::SkillPromptProvider;
|
||||||
use axum::extract::State;
|
use axum::extract::State;
|
||||||
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
|
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
|
||||||
@ -125,17 +126,6 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
|||||||
tracing::info!(session_id = %runtime_session_id, current_session_id = %current_session_id, "CLI session ended");
|
tracing::info!(session_id = %runtime_session_id, current_session_id = %current_session_id, "CLI session ended");
|
||||||
}
|
}
|
||||||
|
|
||||||
fn to_session_summary(record: crate::storage::SessionRecord) -> SessionSummary {
|
|
||||||
SessionSummary {
|
|
||||||
session_id: record.id,
|
|
||||||
title: record.title,
|
|
||||||
channel_name: record.channel_name,
|
|
||||||
chat_id: record.chat_id,
|
|
||||||
message_count: record.message_count,
|
|
||||||
last_active_at: record.last_active_at,
|
|
||||||
archived_at: record.archived_at,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn handle_inbound(
|
async fn handle_inbound(
|
||||||
state: &Arc<GatewayState>,
|
state: &Arc<GatewayState>,
|
||||||
@ -143,9 +133,9 @@ async fn handle_inbound(
|
|||||||
runtime_session_id: &str,
|
runtime_session_id: &str,
|
||||||
current_session_id: &mut String,
|
current_session_id: &mut String,
|
||||||
inbound: WsInbound,
|
inbound: WsInbound,
|
||||||
) -> Result<(), crate::agent::AgentError> {
|
) -> Result<(), AgentError> {
|
||||||
match inbound {
|
match inbound {
|
||||||
WsInbound::UserInput {
|
WsInbound::Message {
|
||||||
content,
|
content,
|
||||||
chat_id,
|
chat_id,
|
||||||
sender_id,
|
sender_id,
|
||||||
@ -181,53 +171,74 @@ async fn handle_inbound(
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
WsInbound::ClearHistory {
|
WsInbound::Command { payload } => {
|
||||||
session_id,
|
// 使用 Command 系统处理命令
|
||||||
chat_id,
|
let input_adapter = WebSocketInputAdapter::new();
|
||||||
} => {
|
|
||||||
let target = session_id
|
|
||||||
.or(chat_id)
|
|
||||||
.unwrap_or_else(|| current_session_id.clone());
|
|
||||||
state
|
|
||||||
.session_manager
|
|
||||||
.cli_sessions()
|
|
||||||
.clear_messages(&target)?;
|
|
||||||
|
|
||||||
if let Some(session) = state.session_manager.get(CLI_CHANNEL_NAME).await {
|
|
||||||
session.lock().await.remove_history(&target);
|
|
||||||
}
|
|
||||||
|
|
||||||
let _ = sender
|
|
||||||
.send(WsOutbound::HistoryCleared { session_id: target })
|
|
||||||
.await;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
WsInbound::CreateSession { title } => {
|
|
||||||
// 使用新的命令层处理
|
|
||||||
let _input_adapter = WebSocketInputAdapter::new();
|
|
||||||
let output_adapter = WebSocketOutputAdapter::new();
|
let output_adapter = WebSocketOutputAdapter::new();
|
||||||
let cli_sessions = state.session_manager.cli_sessions();
|
|
||||||
let handler = SessionCommandHandler::new(cli_sessions);
|
// 解析命令
|
||||||
let router = {
|
let adapter_ctx = crate::command::context::AdapterContext::new("websocket")
|
||||||
let mut r = CommandRouter::new();
|
.with_session_id(current_session_id.as_str());
|
||||||
r.register(Box::new(handler));
|
|
||||||
r
|
let cmd = match input_adapter.try_parse(&payload, adapter_ctx) {
|
||||||
|
Ok(Some(cmd)) => cmd,
|
||||||
|
Ok(None) => {
|
||||||
|
// 不是命令,返回错误
|
||||||
|
let _ = sender
|
||||||
|
.send(WsOutbound::Error {
|
||||||
|
code: "INVALID_COMMAND".to_string(),
|
||||||
|
message: "Invalid command payload".to_string(),
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let _ = sender
|
||||||
|
.send(WsOutbound::Error {
|
||||||
|
code: "PARSE_ERROR".to_string(),
|
||||||
|
message: e.to_string(),
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// 构建命令
|
// 创建命令路由器
|
||||||
let cmd = crate::command::Command::CreateSession { title };
|
let cli_sessions = state.session_manager.cli_sessions();
|
||||||
let cmd_ctx = CommandContext::new("websocket")
|
let store = state.session_manager.store();
|
||||||
|
let skills = state.session_manager.skills();
|
||||||
|
let provider_config = state.config.get_provider_config("default")
|
||||||
|
.map_err(|e| AgentError::Other(e.to_string()))?;
|
||||||
|
let prompt_repository = state.session_manager.store().clone();
|
||||||
|
|
||||||
|
let system_prompt_provider: Arc<dyn crate::agent::SystemPromptProvider> = Arc::new(CompositeSystemPromptProvider::new(vec![
|
||||||
|
Box::new(AgentPromptProvider::new(
|
||||||
|
0,
|
||||||
|
provider_config.clone(),
|
||||||
|
prompt_repository.clone(),
|
||||||
|
)),
|
||||||
|
Box::new(SkillPromptProvider::new(skills)),
|
||||||
|
]));
|
||||||
|
|
||||||
|
let mut router = CommandRouter::new();
|
||||||
|
router.register(Box::new(SessionCommandHandler::new(cli_sessions.clone())));
|
||||||
|
router.register(Box::new(SessionQueryCommandHandler::new(cli_sessions)));
|
||||||
|
router.register(Box::new(SaveSessionCommandHandler::new(
|
||||||
|
store,
|
||||||
|
system_prompt_provider,
|
||||||
|
)));
|
||||||
|
|
||||||
|
// 构建命令上下文
|
||||||
|
let cmd_ctx = CommandContext::new("websocket", "cli")
|
||||||
.with_session_id(current_session_id.as_str());
|
.with_session_id(current_session_id.as_str());
|
||||||
|
|
||||||
// 执行命令
|
// 执行命令
|
||||||
let response = router.dispatch_with_response(cmd, cmd_ctx).await;
|
let response = router.dispatch_with_response(cmd, cmd_ctx).await;
|
||||||
|
|
||||||
// 适配输出
|
|
||||||
let outbounds = output_adapter.adapt(response);
|
|
||||||
|
|
||||||
// 处理响应
|
// 处理响应
|
||||||
for msg in outbounds {
|
if response.success {
|
||||||
if let WsOutbound::SessionCreated { session_id, title: _ } = &msg {
|
// 更新当前会话 ID(如果是创建会话)
|
||||||
|
if let Some(session_id) = response.metadata.get("session_id") {
|
||||||
*current_session_id = session_id.clone();
|
*current_session_id = session_id.clone();
|
||||||
state
|
state
|
||||||
.channel_manager
|
.channel_manager
|
||||||
@ -239,178 +250,14 @@ async fn handle_inbound(
|
|||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 适配并发送响应
|
||||||
|
let outbounds = output_adapter.adapt(response);
|
||||||
|
for msg in outbounds {
|
||||||
let _ = sender.send(msg).await;
|
let _ = sender.send(msg).await;
|
||||||
}
|
}
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
WsInbound::ListSessions { include_archived } => {
|
|
||||||
let records = state
|
|
||||||
.session_manager
|
|
||||||
.cli_sessions()
|
|
||||||
.list(include_archived)?;
|
|
||||||
let summaries = records.into_iter().map(to_session_summary).collect();
|
|
||||||
|
|
||||||
let _ = sender
|
|
||||||
.send(WsOutbound::SessionList {
|
|
||||||
sessions: summaries,
|
|
||||||
current_session_id: Some(current_session_id.clone()),
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
WsInbound::LoadSession { session_id } => {
|
|
||||||
let Some(record) = state.session_manager.cli_sessions().get(&session_id)? else {
|
|
||||||
let _ = sender
|
|
||||||
.send(WsOutbound::Error {
|
|
||||||
code: "SESSION_NOT_FOUND".to_string(),
|
|
||||||
message: format!("Session not found: {}", session_id),
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
return Ok(());
|
|
||||||
};
|
|
||||||
|
|
||||||
*current_session_id = record.id.clone();
|
|
||||||
state
|
|
||||||
.channel_manager
|
|
||||||
.cli_channel()
|
|
||||||
.register_connection(
|
|
||||||
record.id.clone(),
|
|
||||||
runtime_session_id.to_string(),
|
|
||||||
sender.clone(),
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
let _ = sender
|
|
||||||
.send(WsOutbound::SessionLoaded {
|
|
||||||
session_id: record.id,
|
|
||||||
title: record.title,
|
|
||||||
message_count: record.message_count,
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
WsInbound::RenameSession { session_id, title } => {
|
|
||||||
let target = session_id.unwrap_or_else(|| current_session_id.clone());
|
|
||||||
state
|
|
||||||
.session_manager
|
|
||||||
.cli_sessions()
|
|
||||||
.rename(&target, &title)?;
|
|
||||||
let _ = sender
|
|
||||||
.send(WsOutbound::SessionRenamed {
|
|
||||||
session_id: target,
|
|
||||||
title,
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
WsInbound::ArchiveSession { session_id } => {
|
|
||||||
let target = session_id.unwrap_or_else(|| current_session_id.clone());
|
|
||||||
state.session_manager.cli_sessions().archive(&target)?;
|
|
||||||
let _ = sender
|
|
||||||
.send(WsOutbound::SessionArchived { session_id: target })
|
|
||||||
.await;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
WsInbound::DeleteSession { session_id } => {
|
|
||||||
let target = session_id.unwrap_or_else(|| current_session_id.clone());
|
|
||||||
state.session_manager.cli_sessions().delete(&target)?;
|
|
||||||
|
|
||||||
let replacement = if target == *current_session_id {
|
|
||||||
Some(state.session_manager.cli_sessions().create(None)?)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(session) = state.session_manager.get(CLI_CHANNEL_NAME).await {
|
|
||||||
session.lock().await.remove_history(&target);
|
|
||||||
}
|
|
||||||
|
|
||||||
let _ = sender
|
|
||||||
.send(WsOutbound::SessionDeleted {
|
|
||||||
session_id: target.clone(),
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
|
|
||||||
if let Some(record) = replacement {
|
|
||||||
*current_session_id = record.id.clone();
|
|
||||||
state
|
|
||||||
.channel_manager
|
|
||||||
.cli_channel()
|
|
||||||
.register_connection(
|
|
||||||
record.id.clone(),
|
|
||||||
runtime_session_id.to_string(),
|
|
||||||
sender.clone(),
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
let _ = sender
|
|
||||||
.send(WsOutbound::SessionCreated {
|
|
||||||
session_id: record.id,
|
|
||||||
title: record.title,
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
WsInbound::SaveSession { filepath, session_id } => {
|
|
||||||
let target_session_id = session_id.unwrap_or_else(|| current_session_id.clone());
|
|
||||||
|
|
||||||
// 获取所需依赖
|
|
||||||
let store = state.session_manager.store();
|
|
||||||
let skills = state.session_manager.skills();
|
|
||||||
let provider_config = state.config.get_provider_config("default")
|
|
||||||
.map_err(|e| AgentError::Other(e.to_string()))?;
|
|
||||||
let prompt_repository = state.session_manager.store().clone();
|
|
||||||
|
|
||||||
// 构建组合系统提示词提供者(与运行时一致)
|
|
||||||
let system_prompt_provider: Arc<dyn crate::agent::SystemPromptProvider> = Arc::new(CompositeSystemPromptProvider::new(vec![
|
|
||||||
Box::new(AgentPromptProvider::new(
|
|
||||||
0, // save_session 不需要 reinject 逻辑
|
|
||||||
provider_config.clone(),
|
|
||||||
prompt_repository,
|
|
||||||
)),
|
|
||||||
Box::new(SkillPromptProvider::new(skills)),
|
|
||||||
]));
|
|
||||||
|
|
||||||
// 构建处理器
|
|
||||||
let handler = SaveSessionCommandHandler::new(store, system_prompt_provider);
|
|
||||||
let router = {
|
|
||||||
let mut r = CommandRouter::new();
|
|
||||||
r.register(Box::new(handler));
|
|
||||||
r
|
|
||||||
};
|
|
||||||
|
|
||||||
// 构建命令
|
|
||||||
let cmd = crate::command::Command::SaveSession { filepath, include_all: true };
|
|
||||||
let cmd_ctx = CommandContext::new("websocket")
|
|
||||||
.with_session_id(&target_session_id);
|
|
||||||
|
|
||||||
// 执行命令
|
|
||||||
let response = router.dispatch_with_response(cmd, cmd_ctx).await;
|
|
||||||
|
|
||||||
// 处理响应
|
|
||||||
if response.success {
|
|
||||||
let filepath = response
|
|
||||||
.metadata
|
|
||||||
.get("filepath")
|
|
||||||
.cloned()
|
|
||||||
.unwrap_or_default();
|
|
||||||
let _ = sender
|
|
||||||
.send(WsOutbound::SessionSaved {
|
|
||||||
session_id: target_session_id,
|
|
||||||
filepath,
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
} else {
|
|
||||||
let error = response.error.unwrap_or_else(|| {
|
|
||||||
crate::command::response::CommandError::new("SAVE_ERROR", "Unknown error")
|
|
||||||
});
|
|
||||||
let _ = sender
|
|
||||||
.send(WsOutbound::Error {
|
|
||||||
code: error.code,
|
|
||||||
message: error.message,
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
WsInbound::Ping => {
|
WsInbound::Ping => {
|
||||||
|
|||||||
@ -17,8 +17,9 @@ pub struct SessionSummary {
|
|||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
pub enum WsInbound {
|
pub enum WsInbound {
|
||||||
#[serde(rename = "user_input")]
|
/// 普通用户消息
|
||||||
UserInput {
|
#[serde(rename = "message")]
|
||||||
|
Message {
|
||||||
content: String,
|
content: String,
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
channel: Option<String>,
|
channel: Option<String>,
|
||||||
@ -27,48 +28,9 @@ pub enum WsInbound {
|
|||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
sender_id: Option<String>,
|
sender_id: Option<String>,
|
||||||
},
|
},
|
||||||
#[serde(rename = "clear_history")]
|
/// 命令(JSON 格式)
|
||||||
ClearHistory {
|
#[serde(rename = "command")]
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
Command { payload: String },
|
||||||
chat_id: Option<String>,
|
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
||||||
session_id: Option<String>,
|
|
||||||
},
|
|
||||||
#[serde(rename = "create_session")]
|
|
||||||
CreateSession {
|
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
||||||
title: Option<String>,
|
|
||||||
},
|
|
||||||
#[serde(rename = "list_sessions")]
|
|
||||||
ListSessions {
|
|
||||||
#[serde(default)]
|
|
||||||
include_archived: bool,
|
|
||||||
},
|
|
||||||
#[serde(rename = "load_session")]
|
|
||||||
LoadSession { session_id: String },
|
|
||||||
#[serde(rename = "rename_session")]
|
|
||||||
RenameSession {
|
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
||||||
session_id: Option<String>,
|
|
||||||
title: String,
|
|
||||||
},
|
|
||||||
#[serde(rename = "archive_session")]
|
|
||||||
ArchiveSession {
|
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
||||||
session_id: Option<String>,
|
|
||||||
},
|
|
||||||
#[serde(rename = "delete_session")]
|
|
||||||
DeleteSession {
|
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
||||||
session_id: Option<String>,
|
|
||||||
},
|
|
||||||
#[serde(rename = "save_session")]
|
|
||||||
SaveSession {
|
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
||||||
filepath: Option<String>,
|
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
||||||
session_id: Option<String>,
|
|
||||||
},
|
|
||||||
#[serde(rename = "ping")]
|
#[serde(rename = "ping")]
|
||||||
Ping,
|
Ping,
|
||||||
}
|
}
|
||||||
@ -126,14 +88,6 @@ pub enum WsOutbound {
|
|||||||
title: String,
|
title: String,
|
||||||
message_count: i64,
|
message_count: i64,
|
||||||
},
|
},
|
||||||
#[serde(rename = "session_renamed")]
|
|
||||||
SessionRenamed { session_id: String, title: String },
|
|
||||||
#[serde(rename = "session_archived")]
|
|
||||||
SessionArchived { session_id: String },
|
|
||||||
#[serde(rename = "session_deleted")]
|
|
||||||
SessionDeleted { session_id: String },
|
|
||||||
#[serde(rename = "history_cleared")]
|
|
||||||
HistoryCleared { session_id: String },
|
|
||||||
#[serde(rename = "session_saved")]
|
#[serde(rename = "session_saved")]
|
||||||
SessionSaved { session_id: String, filepath: String },
|
SessionSaved { session_id: String, filepath: String },
|
||||||
#[serde(rename = "pong")]
|
#[serde(rename = "pong")]
|
||||||
|
|||||||
@ -204,14 +204,24 @@ impl SessionStore {
|
|||||||
Self::from_connection(Connection::open_in_memory()?)
|
Self::from_connection(Connection::open_in_memory()?)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn create_cli_session(&self, title: Option<&str>) -> Result<SessionRecord, StorageError> {
|
pub fn create_session(
|
||||||
|
&self,
|
||||||
|
channel_name: &str,
|
||||||
|
title: Option<&str>,
|
||||||
|
) -> Result<SessionRecord, StorageError> {
|
||||||
let now = current_timestamp();
|
let now = current_timestamp();
|
||||||
let id = uuid::Uuid::new_v4().to_string();
|
let id = uuid::Uuid::new_v4().to_string();
|
||||||
let title = title
|
let title = title
|
||||||
.map(str::trim)
|
.map(str::trim)
|
||||||
.filter(|value| !value.is_empty())
|
.filter(|value| !value.is_empty())
|
||||||
.map(ToOwned::to_owned)
|
.map(ToOwned::to_owned)
|
||||||
.unwrap_or_else(|| format!("CLI Session {}", &id[..8]));
|
.unwrap_or_else(|| {
|
||||||
|
if channel_name == "cli" {
|
||||||
|
format!("CLI Session {}", &id[..8])
|
||||||
|
} else {
|
||||||
|
format!("Session {}", &id[..8])
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||||
conn.execute(
|
conn.execute(
|
||||||
@ -220,9 +230,9 @@ impl SessionStore {
|
|||||||
id, title, channel_name, chat_id, summary,
|
id, title, channel_name, chat_id, summary,
|
||||||
created_at, updated_at, last_active_at, archived_at, deleted_at, message_count,
|
created_at, updated_at, last_active_at, archived_at, deleted_at, message_count,
|
||||||
reset_cutoff_seq, user_turn_count, agent_prompt_reinjection_count
|
reset_cutoff_seq, user_turn_count, agent_prompt_reinjection_count
|
||||||
) VALUES (?1, ?2, 'cli', ?3, NULL, ?4, ?4, ?4, NULL, NULL, 0, 0, 0, 0)
|
) VALUES (?1, ?2, ?3, ?4, NULL, ?5, ?5, ?5, NULL, NULL, 0, 0, 0, 0)
|
||||||
",
|
",
|
||||||
params![id, title, id, now],
|
params![id, title, channel_name, id, now],
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
drop(conn);
|
drop(conn);
|
||||||
@ -230,6 +240,10 @@ impl SessionStore {
|
|||||||
.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
|
.ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn create_cli_session(&self, title: Option<&str>) -> Result<SessionRecord, StorageError> {
|
||||||
|
self.create_session("cli", title)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn ensure_channel_session(
|
pub fn ensure_channel_session(
|
||||||
&self,
|
&self,
|
||||||
channel_name: &str,
|
channel_name: &str,
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
use picobot::protocol::{SessionSummary, WsInbound, WsOutbound};
|
use picobot::protocol::{WsInbound, WsOutbound};
|
||||||
use picobot::providers::{ChatCompletionRequest, Message};
|
use picobot::providers::{ChatCompletionRequest, Message};
|
||||||
|
|
||||||
/// Test that message with special characters is properly escaped
|
/// Test that message with special characters is properly escaped
|
||||||
@ -53,70 +53,65 @@ fn test_chat_request_serialization() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_session_inbound_serialization() {
|
fn test_command_inbound_serialization() {
|
||||||
let msg = WsInbound::CreateSession {
|
// Command is now sent as payload in WsInbound::Command
|
||||||
title: Some("demo".to_string()),
|
let command_json = r#"{"type":"create_session","title":"demo"}"#;
|
||||||
|
let msg = WsInbound::Command {
|
||||||
|
payload: command_json.to_string(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let json = serde_json::to_string(&msg).unwrap();
|
let json = serde_json::to_string(&msg).unwrap();
|
||||||
assert!(json.contains(r#""type":"create_session""#));
|
assert!(json.contains(r#""type":"command""#));
|
||||||
assert!(json.contains(r#""title":"demo""#));
|
assert!(json.contains(r#""payload":""#));
|
||||||
|
assert!(json.contains(r#"create_session"#));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_message_inbound_serialization() {
|
||||||
|
let msg = WsInbound::Message {
|
||||||
|
content: "Hello world".to_string(),
|
||||||
|
channel: None,
|
||||||
|
chat_id: Some("session-1".to_string()),
|
||||||
|
sender_id: Some("user-1".to_string()),
|
||||||
|
};
|
||||||
|
|
||||||
|
let json = serde_json::to_string(&msg).unwrap();
|
||||||
|
assert!(json.contains(r#""type":"message""#));
|
||||||
|
assert!(json.contains(r#""content":"Hello world""#));
|
||||||
|
assert!(json.contains(r#""chat_id":"session-1""#));
|
||||||
|
|
||||||
let decoded: WsInbound = serde_json::from_str(&json).unwrap();
|
let decoded: WsInbound = serde_json::from_str(&json).unwrap();
|
||||||
match decoded {
|
match decoded {
|
||||||
WsInbound::CreateSession { title } => {
|
WsInbound::Message { content, chat_id, .. } => {
|
||||||
assert_eq!(title.as_deref(), Some("demo"));
|
assert_eq!(content, "Hello world");
|
||||||
|
assert_eq!(chat_id.as_deref(), Some("session-1"));
|
||||||
}
|
}
|
||||||
other => panic!("unexpected decoded variant: {:?}", other),
|
other => panic!("unexpected decoded variant: {:?}", other),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_session_list_outbound_serialization() {
|
fn test_session_created_outbound_serialization() {
|
||||||
let msg = WsOutbound::SessionList {
|
let msg = WsOutbound::SessionCreated {
|
||||||
sessions: vec![SessionSummary {
|
session_id: "session-1".to_string(),
|
||||||
session_id: "session-1".to_string(),
|
title: "demo".to_string(),
|
||||||
title: "demo".to_string(),
|
|
||||||
channel_name: "cli".to_string(),
|
|
||||||
chat_id: "session-1".to_string(),
|
|
||||||
message_count: 2,
|
|
||||||
last_active_at: 123,
|
|
||||||
archived_at: None,
|
|
||||||
}],
|
|
||||||
current_session_id: Some("session-1".to_string()),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let json = serde_json::to_string(&msg).unwrap();
|
let json = serde_json::to_string(&msg).unwrap();
|
||||||
assert!(json.contains(r#""type":"session_list""#));
|
assert!(json.contains(r#""type":"session_created""#));
|
||||||
assert!(json.contains(r#""session_id":"session-1""#));
|
assert!(json.contains(r#""session_id":"session-1""#));
|
||||||
assert!(json.contains(r#""message_count":2"#));
|
assert!(json.contains(r#""title":"demo""#));
|
||||||
|
|
||||||
let decoded: WsOutbound = serde_json::from_str(&json).unwrap();
|
let decoded: WsOutbound = serde_json::from_str(&json).unwrap();
|
||||||
match decoded {
|
match decoded {
|
||||||
WsOutbound::SessionList {
|
WsOutbound::SessionCreated { session_id, title } => {
|
||||||
sessions,
|
assert_eq!(session_id, "session-1");
|
||||||
current_session_id,
|
assert_eq!(title, "demo");
|
||||||
} => {
|
|
||||||
assert_eq!(sessions.len(), 1);
|
|
||||||
assert_eq!(sessions[0].title, "demo");
|
|
||||||
assert_eq!(current_session_id.as_deref(), Some("session-1"));
|
|
||||||
}
|
}
|
||||||
other => panic!("unexpected decoded variant: {:?}", other),
|
other => panic!("unexpected decoded variant: {:?}", other),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_clear_history_with_session_id_serialization() {
|
|
||||||
let msg = WsInbound::ClearHistory {
|
|
||||||
chat_id: None,
|
|
||||||
session_id: Some("session-1".to_string()),
|
|
||||||
};
|
|
||||||
|
|
||||||
let json = serde_json::to_string(&msg).unwrap();
|
|
||||||
assert!(json.contains(r#""type":"clear_history""#));
|
|
||||||
assert!(json.contains(r#""session_id":"session-1""#));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_tool_call_outbound_serialization() {
|
fn test_tool_call_outbound_serialization() {
|
||||||
let msg = WsOutbound::ToolCall {
|
let msg = WsOutbound::ToolCall {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user