/dump 命令会输出当前 session 的完整信息: - Session 元信息 (ID, channel, chat_id, model 等) - 所有对话历史 (system, user, assistant, tool) - 每条消息包含角色、时间戳、内容、工具调用等
598 lines
20 KiB
Rust
598 lines
20 KiB
Rust
use std::collections::HashMap;
|
||
use std::sync::Arc;
|
||
use std::time::{Duration, Instant};
|
||
|
||
use tokio::sync::{Mutex, mpsc};
|
||
use uuid::Uuid;
|
||
|
||
use crate::bus::ChatMessage;
|
||
use crate::config::LLMProviderConfig;
|
||
use crate::agent::{AgentLoop, AgentError, ContextCompressor};
|
||
use crate::agent::system_prompt::build_system_prompt;
|
||
use crate::agent::context_compressor::ContextCompressionConfig;
|
||
use crate::protocol::WsOutbound;
|
||
use crate::providers::{create_provider, LLMProvider};
|
||
use crate::session::session_id::{UnifiedSessionId, DEFAULT_DIALOG_ID};
|
||
use crate::session::events::DialogInfo;
|
||
use crate::skills::SkillsLoader;
|
||
use crate::tools::{
|
||
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
|
||
GetSkillTool, HttpRequestTool, ToolRegistry, WebFetchTool,
|
||
};
|
||
|
||
/// Generate a short ID (8 characters) from a UUID
|
||
fn short_id() -> String {
|
||
Uuid::new_v4().to_string()[..8].to_string()
|
||
}
|
||
|
||
/// Session = 一个 dialog
|
||
/// 每个 Session 对应一个 UnifiedSessionId,有独立的 messages history
|
||
pub struct Session {
|
||
pub id: UnifiedSessionId,
|
||
messages: Vec<ChatMessage>,
|
||
pub user_tx: mpsc::Sender<WsOutbound>,
|
||
provider_config: LLMProviderConfig,
|
||
provider: Arc<dyn LLMProvider>,
|
||
tools: Arc<ToolRegistry>,
|
||
compressor: ContextCompressor,
|
||
}
|
||
|
||
impl Session {
|
||
pub async fn new(
|
||
id: UnifiedSessionId,
|
||
provider_config: LLMProviderConfig,
|
||
user_tx: mpsc::Sender<WsOutbound>,
|
||
tools: Arc<ToolRegistry>,
|
||
) -> Result<Self, AgentError> {
|
||
let provider_box = create_provider(provider_config.clone())
|
||
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
|
||
let provider: Arc<dyn LLMProvider> = Arc::from(provider_box);
|
||
|
||
let compressor_config = ContextCompressionConfig {
|
||
protect_first_n: 2,
|
||
..Default::default()
|
||
};
|
||
|
||
Ok(Self {
|
||
id,
|
||
messages: Vec::new(),
|
||
user_tx,
|
||
provider_config: provider_config.clone(),
|
||
provider: provider.clone(),
|
||
tools,
|
||
compressor: ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config),
|
||
})
|
||
}
|
||
|
||
/// 获取 session ID
|
||
pub fn session_id(&self) -> String {
|
||
self.id.to_string()
|
||
}
|
||
|
||
/// 添加消息到历史
|
||
pub fn add_message(&mut self, message: ChatMessage) {
|
||
self.messages.push(message);
|
||
}
|
||
|
||
/// 获取消息历史
|
||
pub fn get_history(&self) -> &[ChatMessage] {
|
||
&self.messages
|
||
}
|
||
|
||
/// 清除历史消息
|
||
pub fn clear_history(&mut self) {
|
||
let len = self.messages.len();
|
||
self.messages.clear();
|
||
#[cfg(debug_assertions)]
|
||
tracing::debug!(session_id = %self.id, previous_len = len, "Chat history cleared");
|
||
}
|
||
|
||
/// 重置对话上下文
|
||
pub fn reset_context(&mut self) {
|
||
let len = self.messages.len();
|
||
self.messages.clear();
|
||
#[cfg(debug_assertions)]
|
||
tracing::debug!(session_id = %self.id, previous_len = len, "Chat context reset in memory");
|
||
}
|
||
|
||
pub fn create_user_message(&self, content: &str, media_refs: Vec<String>) -> ChatMessage {
|
||
if media_refs.is_empty() {
|
||
ChatMessage::user(content)
|
||
} else {
|
||
ChatMessage::user_with_media(content, media_refs)
|
||
}
|
||
}
|
||
|
||
pub async fn send(&self, msg: WsOutbound) {
|
||
let _ = self.user_tx.send(msg).await;
|
||
}
|
||
|
||
/// 获取 provider_config 引用
|
||
pub fn provider_config(&self) -> &LLMProviderConfig {
|
||
&self.provider_config
|
||
}
|
||
|
||
/// 获取 compressor 引用
|
||
pub fn compressor(&self) -> &ContextCompressor {
|
||
&self.compressor
|
||
}
|
||
|
||
/// 创建一个临时的 AgentLoop 实例来处理消息
|
||
pub fn create_agent(&self) -> Result<AgentLoop, AgentError> {
|
||
Ok(AgentLoop::with_provider_and_tools(
|
||
self.provider.clone(),
|
||
self.tools.clone(),
|
||
self.provider_config.max_tool_iterations,
|
||
self.provider_config.model_id.clone(),
|
||
self.provider_config.workspace_dir.clone(),
|
||
))
|
||
}
|
||
|
||
/// 构建系统提示词(包含 AgentLoop 的基础提示词 + skills)
|
||
pub fn build_system_prompt(&self, skills_prompt: &str) -> String {
|
||
let base_prompt = build_system_prompt(
|
||
&self.provider_config.workspace_dir,
|
||
&self.provider_config.model_id,
|
||
&self.tools,
|
||
);
|
||
|
||
if skills_prompt.trim().is_empty() {
|
||
base_prompt
|
||
} else {
|
||
format!("{}\n\n## Skills\n\n{}\n\nUse the `get_skill` tool to load a skill's full content when needed.", base_prompt, skills_prompt)
|
||
}
|
||
}
|
||
|
||
/// 将当前 session 导出为 markdown 文档
|
||
pub fn dump_as_markdown(&self) -> String {
|
||
use chrono::{DateTime, Local};
|
||
|
||
let now = Local::now().format("%Y-%m-%d %H:%M:%S");
|
||
|
||
let mut md = String::new();
|
||
md.push_str(&format!("# Session Dump\n\n"));
|
||
md.push_str(&format!("- **Session ID**: `{}`\n", self.id));
|
||
md.push_str(&format!("- **Channel**: `{}`\n", self.id.channel));
|
||
md.push_str(&format!("- **Chat ID**: `{}`\n", self.id.chat_id));
|
||
md.push_str(&format!("- **Dialog ID**: `{}`\n", self.id.dialog_id));
|
||
md.push_str(&format!("- **Message Count**: {}\n", self.messages.len()));
|
||
md.push_str(&format!("- **Model**: `{}`\n", self.provider_config.model_id));
|
||
md.push_str(&format!("- **Exported At**: {}\n", now));
|
||
md.push_str("\n---\n\n");
|
||
|
||
md.push_str("## Conversation History\n\n");
|
||
|
||
for (i, msg) in self.messages.iter().enumerate() {
|
||
let role = match msg.role.as_str() {
|
||
"system" => "System",
|
||
"user" => "User",
|
||
"assistant" => "Assistant",
|
||
"tool" => "Tool",
|
||
r => r,
|
||
};
|
||
|
||
let timestamp = if msg.timestamp > 0 {
|
||
DateTime::from_timestamp_millis(msg.timestamp)
|
||
.map(|dt| dt.format("%Y-%m-%d %H:%M:%S").to_string())
|
||
.unwrap_or_default()
|
||
} else {
|
||
String::new()
|
||
};
|
||
|
||
md.push_str(&format!("### [{:03}] {} {}\n\n", i + 1, role, timestamp));
|
||
md.push_str("```\n");
|
||
|
||
if let Some(ref tool_calls) = msg.tool_calls {
|
||
md.push_str(&format!("[Tool Calls]\n"));
|
||
for tc in tool_calls {
|
||
md.push_str(&format!("- {}: {:?}\n", tc.name, tc.arguments));
|
||
}
|
||
}
|
||
|
||
if let Some(ref tool_name) = msg.tool_name {
|
||
md.push_str(&format!("[Tool: {}]\n", tool_name));
|
||
}
|
||
|
||
if let Some(ref tool_call_id) = msg.tool_call_id {
|
||
md.push_str(&format!("[Tool Call ID: {}]\n", tool_call_id));
|
||
}
|
||
|
||
md.push_str(&msg.content);
|
||
md.push_str("\n```\n\n");
|
||
|
||
if !msg.media_refs.is_empty() {
|
||
md.push_str(&format!("**Media**: {:?}\n\n", msg.media_refs));
|
||
}
|
||
}
|
||
|
||
md
|
||
}
|
||
}
|
||
|
||
/// SessionManager 管理所有 Session,按 channel_name 路由
|
||
#[derive(Clone)]
|
||
pub struct SessionManager {
|
||
inner: Arc<Mutex<SessionManagerInner>>,
|
||
provider_config: LLMProviderConfig,
|
||
tools: Arc<ToolRegistry>,
|
||
skills_loader: Arc<SkillsLoader>,
|
||
}
|
||
|
||
struct SessionManagerInner {
|
||
/// Sessions keyed by UnifiedSessionId.to_string()
|
||
sessions: HashMap<String, Arc<Mutex<Session>>>,
|
||
session_timestamps: HashMap<String, Instant>,
|
||
session_ttl: Duration,
|
||
}
|
||
|
||
fn create_default_tools(skills_loader: Arc<SkillsLoader>) -> ToolRegistry {
|
||
let mut registry = ToolRegistry::new();
|
||
registry.register(CalculatorTool::new());
|
||
registry.register(FileReadTool::new());
|
||
registry.register(FileWriteTool::new());
|
||
registry.register(FileEditTool::new());
|
||
registry.register(BashTool::new());
|
||
registry.register(HttpRequestTool::new(
|
||
vec!["*".to_string()],
|
||
1_000_000,
|
||
30,
|
||
false,
|
||
));
|
||
registry.register(WebFetchTool::new(50_000, 30));
|
||
registry.register(GetSkillTool::new(skills_loader));
|
||
registry
|
||
}
|
||
|
||
/// 斜杠命令定义
|
||
#[derive(Debug, Clone)]
|
||
pub struct SlashCommand {
|
||
/// 命令名称
|
||
pub name: &'static str,
|
||
/// 命令描述
|
||
pub description: &'static str,
|
||
/// 命令别名(触发词)
|
||
pub aliases: &'static [&'static str],
|
||
}
|
||
|
||
impl SlashCommand {
|
||
/// 检查给定内容是否匹配此命令
|
||
pub fn matches(&self, content: &str) -> bool {
|
||
let trimmed = content.trim();
|
||
self.aliases.iter().any(|&alias| trimmed == alias || trimmed.starts_with(&format!("{} ", alias)))
|
||
}
|
||
}
|
||
|
||
/// Session 支持的斜杠命令列表
|
||
pub static SLASH_COMMANDS: &[SlashCommand] = &[
|
||
SlashCommand {
|
||
name: "new",
|
||
description: "Create a new conversation",
|
||
aliases: &["/new"],
|
||
},
|
||
SlashCommand {
|
||
name: "delete",
|
||
description: "Delete current conversation and start a new one",
|
||
aliases: &["/delete"],
|
||
},
|
||
SlashCommand {
|
||
name: "compact",
|
||
description: "Manually trigger context compression",
|
||
aliases: &["/compact"],
|
||
},
|
||
SlashCommand {
|
||
name: "info",
|
||
description: "Print current session information",
|
||
aliases: &["/info"],
|
||
},
|
||
SlashCommand {
|
||
name: "dump",
|
||
description: "Save current session as markdown document",
|
||
aliases: &["/dump"],
|
||
},
|
||
];
|
||
|
||
impl SessionManager {
|
||
pub fn new(session_ttl_hours: u64, provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
|
||
let skills_loader = SkillsLoader::new();
|
||
skills_loader.load_skills();
|
||
let skills_loader = Arc::new(skills_loader);
|
||
|
||
let tools = Arc::new(create_default_tools(skills_loader.clone()));
|
||
|
||
Ok(Self {
|
||
inner: Arc::new(Mutex::new(SessionManagerInner {
|
||
sessions: HashMap::new(),
|
||
session_timestamps: HashMap::new(),
|
||
session_ttl: Duration::from_secs(session_ttl_hours * 3600),
|
||
})),
|
||
provider_config,
|
||
tools,
|
||
skills_loader,
|
||
})
|
||
}
|
||
|
||
pub fn tools(&self) -> Arc<ToolRegistry> {
|
||
self.tools.clone()
|
||
}
|
||
|
||
/// 获取所有可用的斜杠命令
|
||
pub fn get_slash_commands(&self) -> &[SlashCommand] {
|
||
SLASH_COMMANDS
|
||
}
|
||
|
||
/// 执行斜杠命令
|
||
/// 返回 (新session_id, 响应消息)
|
||
pub async fn execute_slash_command(
|
||
&self,
|
||
command: &str,
|
||
args: Option<&str>,
|
||
channel: &str,
|
||
chat_id: &str,
|
||
current_session_id: Option<&UnifiedSessionId>,
|
||
) -> Result<(Option<UnifiedSessionId>, String), AgentError> {
|
||
let cmd = SLASH_COMMANDS
|
||
.iter()
|
||
.find(|c| c.name == command)
|
||
.ok_or_else(|| AgentError::Other(format!("Unknown command: {}", command)))?;
|
||
|
||
tracing::info!(cmd = %cmd.name, args = ?args, "Executing slash command");
|
||
|
||
match cmd.name {
|
||
"new" => {
|
||
let title = args.map(|s| s.to_string());
|
||
let (new_id, title) = self.create_session(channel, chat_id, title.as_deref()).await?;
|
||
Ok((Some(new_id), format!("New conversation '{}' created.", title)))
|
||
}
|
||
"delete" => {
|
||
let (new_id, _title) = self.create_session(channel, chat_id, None).await?;
|
||
Ok((Some(new_id), "Conversation deleted. New conversation created.".to_string()))
|
||
}
|
||
"compact" => {
|
||
if let Some(sid) = current_session_id {
|
||
let session = self.get_or_create_session(sid).await?;
|
||
let mut session_guard = session.lock().await;
|
||
let original_count = session_guard.get_history().len();
|
||
let history = session_guard.get_history().to_vec();
|
||
let compressed = session_guard.compressor
|
||
.compress_if_needed(history)
|
||
.await?;
|
||
let compressed_count = compressed.len();
|
||
session_guard.clear_history();
|
||
for msg in compressed {
|
||
session_guard.add_message(msg);
|
||
}
|
||
Ok((None, format!(
|
||
"Context compressed: {} → {} messages.",
|
||
original_count, compressed_count
|
||
)))
|
||
} else {
|
||
Ok((None, "No active conversation to compress.".to_string()))
|
||
}
|
||
}
|
||
"info" => {
|
||
if let Some(sid) = current_session_id {
|
||
let session = self.get_or_create_session(sid).await?;
|
||
let session_guard = session.lock().await;
|
||
let message_count = session_guard.get_history().len();
|
||
let session_id_str = session_guard.session_id();
|
||
Ok((None, format!(
|
||
"Session ID: {}\nMessage count: {}",
|
||
session_id_str, message_count
|
||
)))
|
||
} else {
|
||
Ok((None, "No active session.".to_string()))
|
||
}
|
||
}
|
||
"dump" => {
|
||
if let Some(sid) = current_session_id {
|
||
let session = self.get_or_create_session(sid).await?;
|
||
let session_guard = session.lock().await;
|
||
let md = session_guard.dump_as_markdown();
|
||
Ok((None, md))
|
||
} else {
|
||
Ok((None, "No active session.".to_string()))
|
||
}
|
||
}
|
||
_ => Err(AgentError::Other(format!("Command not implemented: {}", cmd.name))),
|
||
}
|
||
}
|
||
|
||
pub async fn create_session(
|
||
&self,
|
||
channel: &str,
|
||
chat_id: &str,
|
||
title: Option<&str>,
|
||
) -> Result<(UnifiedSessionId, String), AgentError> {
|
||
let dialog_id = short_id();
|
||
let unified_id = UnifiedSessionId::new(channel, chat_id, &dialog_id);
|
||
let session_id_str = unified_id.to_string();
|
||
|
||
let title = title
|
||
.map(str::trim)
|
||
.filter(|value| !value.is_empty())
|
||
.map(ToOwned::to_owned)
|
||
.unwrap_or_else(|| format!("Dialog {}", &dialog_id));
|
||
|
||
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
|
||
let session = Session::new(
|
||
unified_id.clone(),
|
||
self.provider_config.clone(),
|
||
user_tx,
|
||
self.tools.clone(),
|
||
).await?;
|
||
|
||
let arc = Arc::new(Mutex::new(session));
|
||
let inner = &mut *self.inner.lock().await;
|
||
inner.sessions.insert(session_id_str.clone(), arc.clone());
|
||
inner.session_timestamps.insert(session_id_str, Instant::now());
|
||
|
||
Ok((unified_id, title))
|
||
}
|
||
|
||
pub async fn get_or_create_session(&self, unified_id: &UnifiedSessionId) -> Result<Arc<Mutex<Session>>, AgentError> {
|
||
let session_id_str = unified_id.to_string();
|
||
let inner = &mut *self.inner.lock().await;
|
||
|
||
if let Some(session) = inner.sessions.get(&session_id_str) {
|
||
inner.session_timestamps.insert(session_id_str, Instant::now());
|
||
return Ok(session.clone());
|
||
}
|
||
|
||
// Create new session
|
||
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
|
||
let session = Session::new(
|
||
unified_id.clone(),
|
||
self.provider_config.clone(),
|
||
user_tx,
|
||
self.tools.clone(),
|
||
).await?;
|
||
|
||
let arc = Arc::new(Mutex::new(session));
|
||
inner.sessions.insert(session_id_str.clone(), arc.clone());
|
||
inner.session_timestamps.insert(session_id_str, Instant::now());
|
||
Ok(arc)
|
||
}
|
||
|
||
pub async fn create_dialog(
|
||
&self,
|
||
channel: &str,
|
||
chat_id: &str,
|
||
title: Option<&str>,
|
||
) -> Result<(UnifiedSessionId, String), AgentError> {
|
||
self.create_session(channel, chat_id, title).await
|
||
}
|
||
|
||
pub async fn get_current_dialog(
|
||
&self,
|
||
_channel: &str,
|
||
_chat_id: &str,
|
||
) -> Result<Option<UnifiedSessionId>, AgentError> {
|
||
Ok(None)
|
||
}
|
||
|
||
pub async fn switch_dialog(
|
||
&self,
|
||
_channel: &str,
|
||
_chat_id: &str,
|
||
_dialog_id: &str,
|
||
) -> Result<UnifiedSessionId, AgentError> {
|
||
Err(AgentError::Other("switch_dialog not applicable in new architecture".to_string()))
|
||
}
|
||
|
||
pub async fn list_dialogs(
|
||
&self,
|
||
_channel: &str,
|
||
_chat_id: &str,
|
||
_include_archived: bool,
|
||
) -> Result<(Vec<DialogInfo>, Option<String>), AgentError> {
|
||
Ok((vec![], None))
|
||
}
|
||
|
||
pub fn rename_dialog(&self, _session_id: &UnifiedSessionId, _title: &str) -> Result<(), AgentError> {
|
||
Err(AgentError::Other("rename_dialog not available".to_string()))
|
||
}
|
||
|
||
pub fn archive_dialog(&self, _session_id: &UnifiedSessionId) -> Result<(), AgentError> {
|
||
Err(AgentError::Other("archive_dialog not available".to_string()))
|
||
}
|
||
|
||
pub fn delete_dialog(&self, _session_id: &UnifiedSessionId) -> Result<(), AgentError> {
|
||
Err(AgentError::Other("delete_dialog not available".to_string()))
|
||
}
|
||
|
||
pub fn clear_dialog_history(&self, _session_id: &UnifiedSessionId) -> Result<(), AgentError> {
|
||
Err(AgentError::Other("clear_dialog_history not available".to_string()))
|
||
}
|
||
|
||
pub async fn handle_message(
|
||
&self,
|
||
channel: &str,
|
||
_sender_id: &str,
|
||
chat_id: &str,
|
||
dialog_id: Option<&str>,
|
||
content: &str,
|
||
media: Vec<crate::bus::MediaItem>,
|
||
) -> Result<String, AgentError> {
|
||
let dialog_id = dialog_id.unwrap_or(DEFAULT_DIALOG_ID);
|
||
let unified_id = UnifiedSessionId::new(channel, chat_id, dialog_id);
|
||
let session = self.get_or_create_session(&unified_id).await?;
|
||
|
||
let response: String = {
|
||
let mut session_guard = session.lock().await;
|
||
|
||
let media_refs: Vec<String> = media.iter().map(|m| m.path.clone()).collect();
|
||
#[cfg(debug_assertions)]
|
||
if !media_refs.is_empty() {
|
||
tracing::debug!(media_count = %media.len(), media_refs = ?media_refs, "Adding user message with media");
|
||
}
|
||
|
||
let user_message = session_guard.create_user_message(content, media_refs);
|
||
session_guard.add_message(user_message);
|
||
|
||
let mut history = session_guard.get_history().to_vec();
|
||
|
||
// Build skills prompt
|
||
let skills_prompt = self.skills_loader.build_skills_prompt();
|
||
|
||
// Build combined system prompt and inject at position 0
|
||
// This ensures AgentLoop.process() sees a system message and doesn't inject its own
|
||
let system_prompt = session_guard.build_system_prompt(&skills_prompt);
|
||
history.insert(0, ChatMessage::system(system_prompt));
|
||
|
||
let history = session_guard.compressor
|
||
.compress_if_needed(history)
|
||
.await?;
|
||
|
||
let agent = session_guard.create_agent()?;
|
||
let result = agent.process(history).await?;
|
||
|
||
for msg in result.emitted_messages {
|
||
session_guard.add_message(msg);
|
||
}
|
||
|
||
result.final_response.content
|
||
};
|
||
|
||
#[cfg(debug_assertions)]
|
||
tracing::debug!(
|
||
channel = %channel,
|
||
chat_id = %chat_id,
|
||
response_len = %response.len(),
|
||
"Agent response received"
|
||
);
|
||
|
||
Ok(response)
|
||
}
|
||
|
||
pub async fn clear_session_history(&self, unified_id: &UnifiedSessionId) -> Result<(), AgentError> {
|
||
let session = self.get_or_create_session(unified_id).await?;
|
||
let mut session_guard = session.lock().await;
|
||
session_guard.clear_history();
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use std::collections::HashMap;
|
||
use tokio::sync::mpsc;
|
||
|
||
fn test_provider_config() -> LLMProviderConfig {
|
||
LLMProviderConfig {
|
||
provider_type: "openai".to_string(),
|
||
name: "test".to_string(),
|
||
base_url: "http://localhost".to_string(),
|
||
api_key: "test-key".to_string(),
|
||
extra_headers: HashMap::new(),
|
||
model_id: "test-model".to_string(),
|
||
temperature: Some(0.0),
|
||
max_tokens: Some(32),
|
||
model_extra: HashMap::new(),
|
||
max_tool_iterations: 1,
|
||
token_limit: 4096,
|
||
workspace_dir: std::path::PathBuf::from("/tmp/test-workspace"),
|
||
}
|
||
}
|
||
}
|