PicoBot/src/session/session.rs
xiaoxixi a3d8ebb534 添加 /dump 命令,导出 session 为 markdown 文档
/dump 命令会输出当前 session 的完整信息:
- Session 元信息 (ID, channel, chat_id, model 等)
- 所有对话历史 (system, user, assistant, tool)
- 每条消息包含角色、时间戳、内容、工具调用等
2026-04-28 20:54:54 +08:00

598 lines
20 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, mpsc};
use uuid::Uuid;
use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig;
use crate::agent::{AgentLoop, AgentError, ContextCompressor};
use crate::agent::system_prompt::build_system_prompt;
use crate::agent::context_compressor::ContextCompressionConfig;
use crate::protocol::WsOutbound;
use crate::providers::{create_provider, LLMProvider};
use crate::session::session_id::{UnifiedSessionId, DEFAULT_DIALOG_ID};
use crate::session::events::DialogInfo;
use crate::skills::SkillsLoader;
use crate::tools::{
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
GetSkillTool, HttpRequestTool, ToolRegistry, WebFetchTool,
};
/// Generate a short ID (8 characters) from a UUID
fn short_id() -> String {
Uuid::new_v4().to_string()[..8].to_string()
}
/// Session = 一个 dialog
/// 每个 Session 对应一个 UnifiedSessionId有独立的 messages history
pub struct Session {
pub id: UnifiedSessionId,
messages: Vec<ChatMessage>,
pub user_tx: mpsc::Sender<WsOutbound>,
provider_config: LLMProviderConfig,
provider: Arc<dyn LLMProvider>,
tools: Arc<ToolRegistry>,
compressor: ContextCompressor,
}
impl Session {
pub async fn new(
id: UnifiedSessionId,
provider_config: LLMProviderConfig,
user_tx: mpsc::Sender<WsOutbound>,
tools: Arc<ToolRegistry>,
) -> Result<Self, AgentError> {
let provider_box = create_provider(provider_config.clone())
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
let provider: Arc<dyn LLMProvider> = Arc::from(provider_box);
let compressor_config = ContextCompressionConfig {
protect_first_n: 2,
..Default::default()
};
Ok(Self {
id,
messages: Vec::new(),
user_tx,
provider_config: provider_config.clone(),
provider: provider.clone(),
tools,
compressor: ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config),
})
}
/// 获取 session ID
pub fn session_id(&self) -> String {
self.id.to_string()
}
/// 添加消息到历史
pub fn add_message(&mut self, message: ChatMessage) {
self.messages.push(message);
}
/// 获取消息历史
pub fn get_history(&self) -> &[ChatMessage] {
&self.messages
}
/// 清除历史消息
pub fn clear_history(&mut self) {
let len = self.messages.len();
self.messages.clear();
#[cfg(debug_assertions)]
tracing::debug!(session_id = %self.id, previous_len = len, "Chat history cleared");
}
/// 重置对话上下文
pub fn reset_context(&mut self) {
let len = self.messages.len();
self.messages.clear();
#[cfg(debug_assertions)]
tracing::debug!(session_id = %self.id, previous_len = len, "Chat context reset in memory");
}
pub fn create_user_message(&self, content: &str, media_refs: Vec<String>) -> ChatMessage {
if media_refs.is_empty() {
ChatMessage::user(content)
} else {
ChatMessage::user_with_media(content, media_refs)
}
}
pub async fn send(&self, msg: WsOutbound) {
let _ = self.user_tx.send(msg).await;
}
/// 获取 provider_config 引用
pub fn provider_config(&self) -> &LLMProviderConfig {
&self.provider_config
}
/// 获取 compressor 引用
pub fn compressor(&self) -> &ContextCompressor {
&self.compressor
}
/// 创建一个临时的 AgentLoop 实例来处理消息
pub fn create_agent(&self) -> Result<AgentLoop, AgentError> {
Ok(AgentLoop::with_provider_and_tools(
self.provider.clone(),
self.tools.clone(),
self.provider_config.max_tool_iterations,
self.provider_config.model_id.clone(),
self.provider_config.workspace_dir.clone(),
))
}
/// 构建系统提示词(包含 AgentLoop 的基础提示词 + skills
pub fn build_system_prompt(&self, skills_prompt: &str) -> String {
let base_prompt = build_system_prompt(
&self.provider_config.workspace_dir,
&self.provider_config.model_id,
&self.tools,
);
if skills_prompt.trim().is_empty() {
base_prompt
} else {
format!("{}\n\n## Skills\n\n{}\n\nUse the `get_skill` tool to load a skill's full content when needed.", base_prompt, skills_prompt)
}
}
/// 将当前 session 导出为 markdown 文档
pub fn dump_as_markdown(&self) -> String {
use chrono::{DateTime, Local};
let now = Local::now().format("%Y-%m-%d %H:%M:%S");
let mut md = String::new();
md.push_str(&format!("# Session Dump\n\n"));
md.push_str(&format!("- **Session ID**: `{}`\n", self.id));
md.push_str(&format!("- **Channel**: `{}`\n", self.id.channel));
md.push_str(&format!("- **Chat ID**: `{}`\n", self.id.chat_id));
md.push_str(&format!("- **Dialog ID**: `{}`\n", self.id.dialog_id));
md.push_str(&format!("- **Message Count**: {}\n", self.messages.len()));
md.push_str(&format!("- **Model**: `{}`\n", self.provider_config.model_id));
md.push_str(&format!("- **Exported At**: {}\n", now));
md.push_str("\n---\n\n");
md.push_str("## Conversation History\n\n");
for (i, msg) in self.messages.iter().enumerate() {
let role = match msg.role.as_str() {
"system" => "System",
"user" => "User",
"assistant" => "Assistant",
"tool" => "Tool",
r => r,
};
let timestamp = if msg.timestamp > 0 {
DateTime::from_timestamp_millis(msg.timestamp)
.map(|dt| dt.format("%Y-%m-%d %H:%M:%S").to_string())
.unwrap_or_default()
} else {
String::new()
};
md.push_str(&format!("### [{:03}] {} {}\n\n", i + 1, role, timestamp));
md.push_str("```\n");
if let Some(ref tool_calls) = msg.tool_calls {
md.push_str(&format!("[Tool Calls]\n"));
for tc in tool_calls {
md.push_str(&format!("- {}: {:?}\n", tc.name, tc.arguments));
}
}
if let Some(ref tool_name) = msg.tool_name {
md.push_str(&format!("[Tool: {}]\n", tool_name));
}
if let Some(ref tool_call_id) = msg.tool_call_id {
md.push_str(&format!("[Tool Call ID: {}]\n", tool_call_id));
}
md.push_str(&msg.content);
md.push_str("\n```\n\n");
if !msg.media_refs.is_empty() {
md.push_str(&format!("**Media**: {:?}\n\n", msg.media_refs));
}
}
md
}
}
/// SessionManager 管理所有 Session按 channel_name 路由
#[derive(Clone)]
pub struct SessionManager {
inner: Arc<Mutex<SessionManagerInner>>,
provider_config: LLMProviderConfig,
tools: Arc<ToolRegistry>,
skills_loader: Arc<SkillsLoader>,
}
struct SessionManagerInner {
/// Sessions keyed by UnifiedSessionId.to_string()
sessions: HashMap<String, Arc<Mutex<Session>>>,
session_timestamps: HashMap<String, Instant>,
session_ttl: Duration,
}
fn create_default_tools(skills_loader: Arc<SkillsLoader>) -> ToolRegistry {
let mut registry = ToolRegistry::new();
registry.register(CalculatorTool::new());
registry.register(FileReadTool::new());
registry.register(FileWriteTool::new());
registry.register(FileEditTool::new());
registry.register(BashTool::new());
registry.register(HttpRequestTool::new(
vec!["*".to_string()],
1_000_000,
30,
false,
));
registry.register(WebFetchTool::new(50_000, 30));
registry.register(GetSkillTool::new(skills_loader));
registry
}
/// 斜杠命令定义
#[derive(Debug, Clone)]
pub struct SlashCommand {
/// 命令名称
pub name: &'static str,
/// 命令描述
pub description: &'static str,
/// 命令别名(触发词)
pub aliases: &'static [&'static str],
}
impl SlashCommand {
/// 检查给定内容是否匹配此命令
pub fn matches(&self, content: &str) -> bool {
let trimmed = content.trim();
self.aliases.iter().any(|&alias| trimmed == alias || trimmed.starts_with(&format!("{} ", alias)))
}
}
/// Session 支持的斜杠命令列表
pub static SLASH_COMMANDS: &[SlashCommand] = &[
SlashCommand {
name: "new",
description: "Create a new conversation",
aliases: &["/new"],
},
SlashCommand {
name: "delete",
description: "Delete current conversation and start a new one",
aliases: &["/delete"],
},
SlashCommand {
name: "compact",
description: "Manually trigger context compression",
aliases: &["/compact"],
},
SlashCommand {
name: "info",
description: "Print current session information",
aliases: &["/info"],
},
SlashCommand {
name: "dump",
description: "Save current session as markdown document",
aliases: &["/dump"],
},
];
impl SessionManager {
pub fn new(session_ttl_hours: u64, provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
let skills_loader = SkillsLoader::new();
skills_loader.load_skills();
let skills_loader = Arc::new(skills_loader);
let tools = Arc::new(create_default_tools(skills_loader.clone()));
Ok(Self {
inner: Arc::new(Mutex::new(SessionManagerInner {
sessions: HashMap::new(),
session_timestamps: HashMap::new(),
session_ttl: Duration::from_secs(session_ttl_hours * 3600),
})),
provider_config,
tools,
skills_loader,
})
}
pub fn tools(&self) -> Arc<ToolRegistry> {
self.tools.clone()
}
/// 获取所有可用的斜杠命令
pub fn get_slash_commands(&self) -> &[SlashCommand] {
SLASH_COMMANDS
}
/// 执行斜杠命令
/// 返回 (新session_id, 响应消息)
pub async fn execute_slash_command(
&self,
command: &str,
args: Option<&str>,
channel: &str,
chat_id: &str,
current_session_id: Option<&UnifiedSessionId>,
) -> Result<(Option<UnifiedSessionId>, String), AgentError> {
let cmd = SLASH_COMMANDS
.iter()
.find(|c| c.name == command)
.ok_or_else(|| AgentError::Other(format!("Unknown command: {}", command)))?;
tracing::info!(cmd = %cmd.name, args = ?args, "Executing slash command");
match cmd.name {
"new" => {
let title = args.map(|s| s.to_string());
let (new_id, title) = self.create_session(channel, chat_id, title.as_deref()).await?;
Ok((Some(new_id), format!("New conversation '{}' created.", title)))
}
"delete" => {
let (new_id, _title) = self.create_session(channel, chat_id, None).await?;
Ok((Some(new_id), "Conversation deleted. New conversation created.".to_string()))
}
"compact" => {
if let Some(sid) = current_session_id {
let session = self.get_or_create_session(sid).await?;
let mut session_guard = session.lock().await;
let original_count = session_guard.get_history().len();
let history = session_guard.get_history().to_vec();
let compressed = session_guard.compressor
.compress_if_needed(history)
.await?;
let compressed_count = compressed.len();
session_guard.clear_history();
for msg in compressed {
session_guard.add_message(msg);
}
Ok((None, format!(
"Context compressed: {}{} messages.",
original_count, compressed_count
)))
} else {
Ok((None, "No active conversation to compress.".to_string()))
}
}
"info" => {
if let Some(sid) = current_session_id {
let session = self.get_or_create_session(sid).await?;
let session_guard = session.lock().await;
let message_count = session_guard.get_history().len();
let session_id_str = session_guard.session_id();
Ok((None, format!(
"Session ID: {}\nMessage count: {}",
session_id_str, message_count
)))
} else {
Ok((None, "No active session.".to_string()))
}
}
"dump" => {
if let Some(sid) = current_session_id {
let session = self.get_or_create_session(sid).await?;
let session_guard = session.lock().await;
let md = session_guard.dump_as_markdown();
Ok((None, md))
} else {
Ok((None, "No active session.".to_string()))
}
}
_ => Err(AgentError::Other(format!("Command not implemented: {}", cmd.name))),
}
}
pub async fn create_session(
&self,
channel: &str,
chat_id: &str,
title: Option<&str>,
) -> Result<(UnifiedSessionId, String), AgentError> {
let dialog_id = short_id();
let unified_id = UnifiedSessionId::new(channel, chat_id, &dialog_id);
let session_id_str = unified_id.to_string();
let title = title
.map(str::trim)
.filter(|value| !value.is_empty())
.map(ToOwned::to_owned)
.unwrap_or_else(|| format!("Dialog {}", &dialog_id));
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
let session = Session::new(
unified_id.clone(),
self.provider_config.clone(),
user_tx,
self.tools.clone(),
).await?;
let arc = Arc::new(Mutex::new(session));
let inner = &mut *self.inner.lock().await;
inner.sessions.insert(session_id_str.clone(), arc.clone());
inner.session_timestamps.insert(session_id_str, Instant::now());
Ok((unified_id, title))
}
pub async fn get_or_create_session(&self, unified_id: &UnifiedSessionId) -> Result<Arc<Mutex<Session>>, AgentError> {
let session_id_str = unified_id.to_string();
let inner = &mut *self.inner.lock().await;
if let Some(session) = inner.sessions.get(&session_id_str) {
inner.session_timestamps.insert(session_id_str, Instant::now());
return Ok(session.clone());
}
// Create new session
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
let session = Session::new(
unified_id.clone(),
self.provider_config.clone(),
user_tx,
self.tools.clone(),
).await?;
let arc = Arc::new(Mutex::new(session));
inner.sessions.insert(session_id_str.clone(), arc.clone());
inner.session_timestamps.insert(session_id_str, Instant::now());
Ok(arc)
}
pub async fn create_dialog(
&self,
channel: &str,
chat_id: &str,
title: Option<&str>,
) -> Result<(UnifiedSessionId, String), AgentError> {
self.create_session(channel, chat_id, title).await
}
pub async fn get_current_dialog(
&self,
_channel: &str,
_chat_id: &str,
) -> Result<Option<UnifiedSessionId>, AgentError> {
Ok(None)
}
pub async fn switch_dialog(
&self,
_channel: &str,
_chat_id: &str,
_dialog_id: &str,
) -> Result<UnifiedSessionId, AgentError> {
Err(AgentError::Other("switch_dialog not applicable in new architecture".to_string()))
}
pub async fn list_dialogs(
&self,
_channel: &str,
_chat_id: &str,
_include_archived: bool,
) -> Result<(Vec<DialogInfo>, Option<String>), AgentError> {
Ok((vec![], None))
}
pub fn rename_dialog(&self, _session_id: &UnifiedSessionId, _title: &str) -> Result<(), AgentError> {
Err(AgentError::Other("rename_dialog not available".to_string()))
}
pub fn archive_dialog(&self, _session_id: &UnifiedSessionId) -> Result<(), AgentError> {
Err(AgentError::Other("archive_dialog not available".to_string()))
}
pub fn delete_dialog(&self, _session_id: &UnifiedSessionId) -> Result<(), AgentError> {
Err(AgentError::Other("delete_dialog not available".to_string()))
}
pub fn clear_dialog_history(&self, _session_id: &UnifiedSessionId) -> Result<(), AgentError> {
Err(AgentError::Other("clear_dialog_history not available".to_string()))
}
pub async fn handle_message(
&self,
channel: &str,
_sender_id: &str,
chat_id: &str,
dialog_id: Option<&str>,
content: &str,
media: Vec<crate::bus::MediaItem>,
) -> Result<String, AgentError> {
let dialog_id = dialog_id.unwrap_or(DEFAULT_DIALOG_ID);
let unified_id = UnifiedSessionId::new(channel, chat_id, dialog_id);
let session = self.get_or_create_session(&unified_id).await?;
let response: String = {
let mut session_guard = session.lock().await;
let media_refs: Vec<String> = media.iter().map(|m| m.path.clone()).collect();
#[cfg(debug_assertions)]
if !media_refs.is_empty() {
tracing::debug!(media_count = %media.len(), media_refs = ?media_refs, "Adding user message with media");
}
let user_message = session_guard.create_user_message(content, media_refs);
session_guard.add_message(user_message);
let mut history = session_guard.get_history().to_vec();
// Build skills prompt
let skills_prompt = self.skills_loader.build_skills_prompt();
// Build combined system prompt and inject at position 0
// This ensures AgentLoop.process() sees a system message and doesn't inject its own
let system_prompt = session_guard.build_system_prompt(&skills_prompt);
history.insert(0, ChatMessage::system(system_prompt));
let history = session_guard.compressor
.compress_if_needed(history)
.await?;
let agent = session_guard.create_agent()?;
let result = agent.process(history).await?;
for msg in result.emitted_messages {
session_guard.add_message(msg);
}
result.final_response.content
};
#[cfg(debug_assertions)]
tracing::debug!(
channel = %channel,
chat_id = %chat_id,
response_len = %response.len(),
"Agent response received"
);
Ok(response)
}
pub async fn clear_session_history(&self, unified_id: &UnifiedSessionId) -> Result<(), AgentError> {
let session = self.get_or_create_session(unified_id).await?;
let mut session_guard = session.lock().await;
session_guard.clear_history();
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use tokio::sync::mpsc;
fn test_provider_config() -> LLMProviderConfig {
LLMProviderConfig {
provider_type: "openai".to_string(),
name: "test".to_string(),
base_url: "http://localhost".to_string(),
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
model_id: "test-model".to_string(),
temperature: Some(0.0),
max_tokens: Some(32),
model_extra: HashMap::new(),
max_tool_iterations: 1,
token_limit: 4096,
workspace_dir: std::path::PathBuf::from("/tmp/test-workspace"),
}
}
}