PicoBot/src/session/session.rs
xiaoxixi e235268133 fix(session): /new 后仍停留在旧对话的问题
问题原因:/new 创建新 session 后,客户端下次发消息仍带着旧的
dialog_id,导致服务端找到旧 session。

解决方案:在 SessionManager 中新增 current_sessions 跟踪
每个 channel:chat_id 的当前活跃 session:
- create_session / get_or_create_session 时更新 current_sessions
- switch_dialog / delete_dialog 时同步更新 current_sessions
- handle_message 无 dialog_id 时优先使用 current_sessions
2026-04-28 22:53:37 +08:00

1200 lines
44 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, mpsc};
use uuid::Uuid;
use crate::bus::ChatMessage;
use crate::storage::{Storage, StorageError};
use std::sync::Arc as StdArc;
/// Result of handling a message - either an AI response or a command output
pub enum HandleResult {
/// AI response to be sent as AssistantResponse
AgentResponse(String),
/// Command output to be sent as CommandExecuted
CommandOutput(String),
}
use crate::channels::slash_command::parse_slash_command;
use crate::config::LLMProviderConfig;
use crate::agent::{AgentLoop, AgentError, ContextCompressor};
use crate::agent::system_prompt::build_system_prompt;
use crate::agent::context_compressor::ContextCompressionConfig;
use crate::protocol::WsOutbound;
use crate::providers::{create_provider, LLMProvider};
use crate::session::session_id::{UnifiedSessionId, DEFAULT_DIALOG_ID};
use crate::session::events::DialogInfo;
use crate::skills::SkillsLoader;
use crate::tools::{
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
GetSkillTool, HttpRequestTool, ToolRegistry, WebFetchTool,
};
/// Generate a short ID (8 characters) from a UUID
fn short_id() -> String {
Uuid::new_v4().to_string()[..8].to_string()
}
/// Session = 一个 dialog
/// 每个 Session 对应一个 UnifiedSessionId有独立的 messages history
pub struct Session {
pub id: UnifiedSessionId,
pub title: String,
pub created_at: i64,
pub last_active_at: i64,
pub message_count: i64,
pub total_message_count: i64,
messages: Vec<ChatMessage>,
seq_counter: i64,
pub user_tx: mpsc::Sender<WsOutbound>,
provider_config: LLMProviderConfig,
provider: Arc<dyn LLMProvider>,
tools: Arc<ToolRegistry>,
compressor: ContextCompressor,
storage: Option<StdArc<Storage>>,
routing_info: String,
}
impl Session {
pub async fn new(
id: UnifiedSessionId,
provider_config: LLMProviderConfig,
user_tx: mpsc::Sender<WsOutbound>,
tools: Arc<ToolRegistry>,
storage: Option<StdArc<Storage>>,
routing_info: String,
title: String,
) -> Result<Self, AgentError> {
let provider_box = create_provider(provider_config.clone())
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
let provider: Arc<dyn LLMProvider> = Arc::from(provider_box);
let compressor_config = ContextCompressionConfig {
protect_first_n: 2,
..Default::default()
};
let now = chrono::Utc::now().timestamp_millis();
Ok(Self {
id: id.clone(),
title,
created_at: now,
last_active_at: now,
message_count: 0,
total_message_count: 0,
messages: Vec::new(),
seq_counter: 1,
user_tx,
provider_config: provider_config.clone(),
provider: provider.clone(),
tools,
compressor: ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config),
storage,
routing_info,
})
}
/// 从 Storage 恢复 Session
pub async fn from_storage(
id: UnifiedSessionId,
provider_config: LLMProviderConfig,
user_tx: mpsc::Sender<WsOutbound>,
tools: Arc<ToolRegistry>,
storage: StdArc<Storage>,
) -> Result<Self, AgentError> {
let session_meta = storage.get_session(&id.to_string()).await
.map_err(|e| AgentError::Other(format!("failed to load session from storage: {}", e)))?;
let messages = storage.load_messages(&id.to_string(), 0).await
.map_err(|e| AgentError::Other(format!("failed to load messages from storage: {}", e)))?;
let provider_box = create_provider(provider_config.clone())
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
let provider: Arc<dyn LLMProvider> = Arc::from(provider_box);
let compressor_config = ContextCompressionConfig {
protect_first_n: 2,
..Default::default()
};
// Convert MessageMeta to ChatMessage
let chat_messages: Vec<ChatMessage> = messages.into_iter().map(|m| {
ChatMessage {
id: m.id,
role: m.role,
content: m.content,
media_refs: m.media_refs.map(|refs| serde_json::from_str(&refs).unwrap_or_default()).unwrap_or_default(),
timestamp: m.created_at,
tool_call_id: m.tool_call_id,
tool_name: m.tool_name,
tool_calls: m.tool_calls.map(|tc| serde_json::from_str(&tc).unwrap_or_default()),
}
}).collect();
let seq_counter = chat_messages.len() as i64 + 1;
let total_message_count = chat_messages.len() as i64;
Ok(Self {
id: id.clone(),
title: session_meta.title,
created_at: session_meta.created_at,
last_active_at: session_meta.last_active_at,
message_count: session_meta.message_count,
total_message_count,
messages: chat_messages,
seq_counter,
user_tx,
provider_config: provider_config.clone(),
provider: provider.clone(),
tools,
compressor: ContextCompressor::with_config(provider.clone(), provider_config.token_limit, compressor_config),
storage: Some(storage),
routing_info: session_meta.routing_info.unwrap_or_default(),
})
}
/// 获取 session ID
pub fn session_id(&self) -> String {
self.id.to_string()
}
/// 添加消息到历史仅内存Phase 3 会扩展为持久化)
pub fn add_message(&mut self, message: ChatMessage) {
let is_user = message.role == "user";
let now = chrono::Utc::now().timestamp_millis();
// Assign seq (in-memory only, persistence in Phase 3)
let _seq = self.seq_counter;
self.seq_counter += 1;
// Update in-memory state
self.messages.push(message);
self.total_message_count += 1;
if is_user {
self.message_count += 1;
}
self.last_active_at = now;
}
/// 添加消息到历史并持久化到 StoragePhase 3 使用)
/// 目前 storage 为 None此方法退化为 add_message
pub async fn add_message_and_persist(&mut self, message: ChatMessage) -> Result<(), StorageError> {
let is_user = message.role == "user";
let now = chrono::Utc::now().timestamp_millis();
// Assign seq
let seq = self.seq_counter;
self.seq_counter += 1;
// Persist to Storage (currently None, wired up in Phase 3)
if let Some(ref storage) = self.storage {
let msg_meta = crate::storage::message::MessageMeta {
id: message.id.clone(),
session_id: self.id.to_string(),
seq,
role: message.role.clone(),
content: message.content.clone(),
media_refs: if message.media_refs.is_empty() {
None
} else {
Some(serde_json::to_string(&message.media_refs).unwrap_or_default())
},
tool_call_id: message.tool_call_id.clone(),
tool_name: message.tool_name.clone(),
tool_calls: message.tool_calls.as_ref().map(|tc| serde_json::to_string(tc).unwrap_or_default()),
created_at: now,
};
storage.append_message_with_retry(&self.id.to_string(), &msg_meta).await?;
}
// Update in-memory state
self.messages.push(message);
self.total_message_count += 1;
if is_user {
self.message_count += 1;
}
self.last_active_at = now;
Ok(())
}
/// 获取消息历史
pub fn get_history(&self) -> &[ChatMessage] {
&self.messages
}
/// 清除历史消息
pub fn clear_history(&mut self) {
let len = self.messages.len();
self.messages.clear();
self.seq_counter = 1;
self.total_message_count = 0;
self.message_count = 0;
#[cfg(debug_assertions)]
tracing::debug!(session_id = %self.id, previous_len = len, "Chat history cleared");
}
/// 重置对话上下文
pub fn reset_context(&mut self) {
let len = self.messages.len();
self.messages.clear();
self.seq_counter = 1;
self.total_message_count = 0;
self.message_count = 0;
#[cfg(debug_assertions)]
tracing::debug!(session_id = %self.id, previous_len = len, "Chat context reset in memory");
}
pub fn create_user_message(&self, content: &str, media_refs: Vec<String>) -> ChatMessage {
if media_refs.is_empty() {
ChatMessage::user(content)
} else {
ChatMessage::user_with_media(content, media_refs)
}
}
pub async fn send(&self, msg: WsOutbound) {
let _ = self.user_tx.send(msg).await;
}
/// 发送系统通知(不记录进 session 历史)
pub async fn send_system_notification(&self, content: &str) {
let msg = WsOutbound::SystemNotification {
content: content.to_string(),
};
let _ = self.user_tx.send(msg).await;
}
/// 将 session 元数据写回 Storage
pub async fn persist_session_meta(&self) -> Result<(), StorageError> {
if let Some(ref storage) = self.storage {
let meta = crate::storage::session::SessionMeta {
id: self.id.to_string(),
channel: self.id.channel.clone(),
chat_id: self.id.chat_id.clone(),
dialog_id: self.id.dialog_id.clone(),
title: self.title.clone(),
created_at: self.created_at,
last_active_at: self.last_active_at,
message_count: self.message_count,
routing_info: if self.routing_info.is_empty() {
None
} else {
Some(self.routing_info.clone())
},
deleted_at: None,
};
storage.upsert_session(&meta).await?;
}
Ok(())
}
/// 检查是否需要自动生成 title10 条用户消息后)
pub fn should_generate_title(&self) -> bool {
self.title == "新对话" && self.message_count >= 10
}
/// 生成标题(调用 LLM
pub async fn generate_title(&mut self) -> Result<(), AgentError> {
if !self.should_generate_title() {
return Ok(());
}
let prompt = format!(
r#"给定以下对话历史生成一个简短的会话标题5-15 个中文字符),概括这个对话的核心内容或用户的主要需求。只返回一个标题,不要解释。
历史:
{}"#,
self.messages.iter()
.filter(|m| m.role == "user" || m.role == "assistant")
.take(20)
.map(|m| format!("[{}]: {}", m.role, m.content))
.collect::<Vec<_>>()
.join("\n")
);
let title = self.call_llm_for_title(&prompt).await?;
if !title.is_empty() {
self.title = title.clone();
if let Err(e) = self.persist_session_meta().await {
tracing::warn!("failed to persist title: {}", e);
}
}
Ok(())
}
/// 调用 LLM 生成标题
async fn call_llm_for_title(&self, prompt: &str) -> Result<String, AgentError> {
use crate::providers::{ChatCompletionRequest, ChatCompletionResponse, Message};
let request = ChatCompletionRequest {
messages: vec![
Message::user(prompt.to_string())
],
temperature: Some(0.3),
max_tokens: Some(20),
tools: None,
};
let response: ChatCompletionResponse = self.provider.chat(request).await
.map_err(|e| AgentError::Other(format!("LLM call failed: {}", e)))?;
Ok(response.content.trim().to_string())
}
/// 获取 provider_config 引用
pub fn provider_config(&self) -> &LLMProviderConfig {
&self.provider_config
}
/// 获取 compressor 引用
pub fn compressor(&self) -> &ContextCompressor {
&self.compressor
}
/// 创建一个临时的 AgentLoop 实例来处理消息
pub fn create_agent(&self) -> Result<AgentLoop, AgentError> {
Ok(AgentLoop::with_provider_and_tools(
self.provider.clone(),
self.tools.clone(),
self.provider_config.max_tool_iterations,
self.provider_config.model_id.clone(),
self.provider_config.workspace_dir.clone(),
))
}
/// 构建系统提示词(包含 AgentLoop 的基础提示词 + skills
pub fn build_system_prompt(&self, skills_prompt: &str) -> String {
let base_prompt = build_system_prompt(
&self.provider_config.workspace_dir,
&self.provider_config.model_id,
&self.tools,
);
if skills_prompt.trim().is_empty() {
base_prompt
} else {
format!("{}\n\n## Skills\n\n{}\n\nUse the `get_skill` tool to load a skill's full content when needed.", base_prompt, skills_prompt)
}
}
/// 将当前 session 导出为 markdown 文档并保存到文件
pub fn dump_to_file(&self, system_prompt: &str) -> std::io::Result<String> {
use chrono::{DateTime, Local};
use std::fs;
use std::io::Write;
let md = self.dump_as_markdown_with_system_prompt(system_prompt);
// Create dumps directory under workspace
let dumps_dir = self.provider_config.workspace_dir.join("dumps");
fs::create_dir_all(&dumps_dir)?;
// Generate filename based on session info
let timestamp = Local::now().format("%Y%m%d_%H%M%S");
let filename = format!("{}_{}_{}.md", self.id.channel, self.id.chat_id, timestamp);
let filepath = dumps_dir.join(&filename);
// Write to file
let mut file = fs::File::create(&filepath)?;
file.write_all(md.as_bytes())?;
Ok(filepath.to_string_lossy().to_string())
}
/// 将当前 session 导出为 markdown 文档(纯内存版本)
pub fn dump_as_markdown(&self) -> String {
use chrono::{DateTime, Local};
let now = Local::now().format("%Y-%m-%d %H:%M:%S");
let mut md = String::new();
md.push_str(&format!("# Session Dump\n\n"));
md.push_str(&format!("- **Session ID**: `{}`\n", self.id));
md.push_str(&format!("- **Channel**: `{}`\n", self.id.channel));
md.push_str(&format!("- **Chat ID**: `{}`\n", self.id.chat_id));
md.push_str(&format!("- **Dialog ID**: `{}`\n", self.id.dialog_id));
md.push_str(&format!("- **Message Count**: {}\n", self.messages.len()));
md.push_str(&format!("- **Model**: `{}`\n", self.provider_config.model_id));
md.push_str(&format!("- **Exported At**: {}\n", now));
md.push_str("\n---\n\n");
md.push_str("## Conversation History\n\n");
for (i, msg) in self.messages.iter().enumerate() {
let role = match msg.role.as_str() {
"system" => "System",
"user" => "User",
"assistant" => "Assistant",
"tool" => "Tool",
r => r,
};
let timestamp = if msg.timestamp > 0 {
DateTime::from_timestamp_millis(msg.timestamp)
.map(|dt| dt.format("%Y-%m-%d %H:%M:%S").to_string())
.unwrap_or_default()
} else {
String::new()
};
md.push_str(&format!("### [{:03}] {} {}\n\n", i + 1, role, timestamp));
md.push_str("```\n");
if let Some(ref tool_calls) = msg.tool_calls {
md.push_str(&format!("[Tool Calls]\n"));
for tc in tool_calls {
md.push_str(&format!("- {}: {:?}\n", tc.name, tc.arguments));
}
}
if let Some(ref tool_name) = msg.tool_name {
md.push_str(&format!("[Tool: {}]\n", tool_name));
}
if let Some(ref tool_call_id) = msg.tool_call_id {
md.push_str(&format!("[Tool Call ID: {}]\n", tool_call_id));
}
md.push_str(&msg.content);
md.push_str("\n```\n\n");
if !msg.media_refs.is_empty() {
md.push_str(&format!("**Media**: {:?}\n\n", msg.media_refs));
}
}
md
}
/// 将当前 session 导出为 markdown 文档(包含系统提示词)
pub fn dump_as_markdown_with_system_prompt(&self, system_prompt: &str) -> String {
use chrono::{DateTime, Local};
let now = Local::now().format("%Y-%m-%d %H:%M:%S");
let mut md = String::new();
md.push_str("# Session Dump\n\n");
md.push_str(&format!("- **Session ID**: `{}`\n", self.id));
md.push_str(&format!("- **Channel**: `{}`\n", self.id.channel));
md.push_str(&format!("- **Chat ID**: `{}`\n", self.id.chat_id));
md.push_str(&format!("- **Dialog ID**: `{}`\n", self.id.dialog_id));
md.push_str(&format!("- **Message Count**: {}\n", self.messages.len()));
md.push_str(&format!("- **Model**: `{}`\n", self.provider_config.model_id));
md.push_str(&format!("- **Exported At**: {}\n", now));
md.push_str("\n---\n\n");
// System Prompt Section
md.push_str("## System Prompt (Injected to Model)\n\n");
md.push_str("```\n");
md.push_str(system_prompt);
md.push_str("\n```\n\n");
md.push_str("---\n\n");
md.push_str("## Conversation History\n\n");
for (i, msg) in self.messages.iter().enumerate() {
let role = match msg.role.as_str() {
"system" => "System",
"user" => "User",
"assistant" => "Assistant",
"tool" => "Tool",
r => r,
};
let timestamp = if msg.timestamp > 0 {
DateTime::from_timestamp_millis(msg.timestamp)
.map(|dt| dt.format("%Y-%m-%d %H:%M:%S").to_string())
.unwrap_or_default()
} else {
String::new()
};
md.push_str(&format!("### [{:03}] {} {}\n\n", i + 1, role, timestamp));
md.push_str("```\n");
if let Some(ref tool_calls) = msg.tool_calls {
md.push_str("[Tool Calls]\n");
for tc in tool_calls {
md.push_str(&format!("- {}: {:?}\n", tc.name, tc.arguments));
}
}
if let Some(ref tool_name) = msg.tool_name {
md.push_str(&format!("[Tool: {}]\n", tool_name));
}
if let Some(ref tool_call_id) = msg.tool_call_id {
md.push_str(&format!("[Tool Call ID: {}]\n", tool_call_id));
}
md.push_str(&msg.content);
md.push_str("\n```\n\n");
if !msg.media_refs.is_empty() {
md.push_str(&format!("**Media**: {:?}\n\n", msg.media_refs));
}
}
md
}
}
/// SessionManager 管理所有 Session按 channel_name 路由
#[derive(Clone)]
pub struct SessionManager {
inner: Arc<Mutex<SessionManagerInner>>,
provider_config: LLMProviderConfig,
tools: Arc<ToolRegistry>,
skills_loader: Arc<SkillsLoader>,
storage: Arc<Storage>,
}
struct SessionManagerInner {
/// Sessions keyed by UnifiedSessionId.to_string()
sessions: HashMap<String, Arc<Mutex<Session>>>,
session_timestamps: HashMap<String, Instant>,
session_ttl: Duration,
/// Current active session per channel:chat_id
current_sessions: HashMap<String, String>,
}
fn create_default_tools(skills_loader: Arc<SkillsLoader>) -> ToolRegistry {
let mut registry = ToolRegistry::new();
registry.register(CalculatorTool::new());
registry.register(FileReadTool::new());
registry.register(FileWriteTool::new());
registry.register(FileEditTool::new());
registry.register(BashTool::new());
registry.register(HttpRequestTool::new(
vec!["*".to_string()],
1_000_000,
30,
false,
));
registry.register(WebFetchTool::new(50_000, 30));
registry.register(GetSkillTool::new(skills_loader));
registry
}
/// 斜杠命令定义
#[derive(Debug, Clone)]
pub struct SlashCommand {
/// 命令名称
pub name: &'static str,
/// 命令描述
pub description: &'static str,
/// 命令别名(触发词)
pub aliases: &'static [&'static str],
}
impl SlashCommand {
/// 检查给定内容是否匹配此命令
pub fn matches(&self, content: &str) -> bool {
let trimmed = content.trim();
self.aliases.iter().any(|&alias| trimmed == alias || trimmed.starts_with(&format!("{} ", alias)))
}
}
/// Session 支持的斜杠命令列表
pub static SLASH_COMMANDS: &[SlashCommand] = &[
SlashCommand {
name: "new",
description: "创建新对话",
aliases: &["/new"],
},
SlashCommand {
name: "sessions",
description: "列出最近对话",
aliases: &["/sessions"],
},
SlashCommand {
name: "switch",
description: "切换到指定对话",
aliases: &["/switch"],
},
SlashCommand {
name: "rename",
description: "重命名当前对话",
aliases: &["/rename"],
},
SlashCommand {
name: "delete",
description: "删除当前对话",
aliases: &["/delete"],
},
SlashCommand {
name: "compact",
description: "手动触发上下文压缩",
aliases: &["/compact"],
},
SlashCommand {
name: "info",
description: "显示当前对话信息",
aliases: &["/info"],
},
SlashCommand {
name: "dump",
description: "保存当前对话为 markdown 文档",
aliases: &["/dump"],
},
];
impl SessionManager {
pub fn new(
session_ttl_hours: u64,
provider_config: LLMProviderConfig,
storage: Arc<Storage>,
) -> Result<Self, AgentError> {
let skills_loader = SkillsLoader::new();
skills_loader.load_skills();
let skills_loader = Arc::new(skills_loader);
let tools = Arc::new(create_default_tools(skills_loader.clone()));
Ok(Self {
inner: Arc::new(Mutex::new(SessionManagerInner {
sessions: HashMap::new(),
session_timestamps: HashMap::new(),
session_ttl: Duration::from_secs(session_ttl_hours * 3600),
current_sessions: HashMap::new(),
})),
provider_config,
tools,
skills_loader,
storage,
})
}
pub fn tools(&self) -> Arc<ToolRegistry> {
self.tools.clone()
}
/// 启动后台 TTL 清理任务
pub fn start_cleanup_task(self: Arc<Self>, interval_mins: u64) {
let cleanup_interval = Duration::from_secs(interval_mins * 60);
tokio::spawn(async move {
loop {
tokio::time::sleep(cleanup_interval).await;
self.run_cleanup().await;
}
});
}
/// 执行一次 TTL 清理:释放内存中过期的 sessionStorage 记录保留
async fn run_cleanup(&self) {
let inner = self.inner.lock().await;
let now = Instant::now();
let ttl = inner.session_ttl;
let expired: Vec<String> = inner
.session_timestamps
.iter()
.filter(|(_, last_touch)| now.duration_since(**last_touch) > ttl)
.map(|(id, _)| id.clone())
.collect();
drop(inner);
if !expired.is_empty() {
let mut inner = self.inner.lock().await;
for id in &expired {
inner.sessions.remove(id);
inner.session_timestamps.remove(id);
}
tracing::debug!(count = expired.len(), "Cleaned up expired sessions");
}
}
/// 获取所有可用的斜杠命令
pub fn get_slash_commands(&self) -> &[SlashCommand] {
SLASH_COMMANDS
}
/// 执行斜杠命令
/// 返回 (新session_id, 响应消息)
pub async fn execute_slash_command(
&self,
command: &str,
args: Option<&str>,
channel: &str,
chat_id: &str,
current_session_id: Option<&UnifiedSessionId>,
) -> Result<(Option<UnifiedSessionId>, String), AgentError> {
let cmd = SLASH_COMMANDS
.iter()
.find(|c| c.name == command)
.ok_or_else(|| AgentError::Other(format!("Unknown command: {}", command)))?;
tracing::info!(cmd = %cmd.name, args = ?args, "Executing slash command");
match cmd.name {
"new" => {
let title = args.map(|s| s.to_string());
let (new_id, title) = self.create_session(channel, chat_id, title.as_deref(), String::new()).await?;
Ok((Some(new_id), format!("新对话 '{}' 已创建。", title)))
}
"delete" => {
let (new_id, _title) = self.create_session(channel, chat_id, None, String::new()).await?;
Ok((Some(new_id), "对话已删除。新对话已创建。".to_string()))
}
"compact" => {
if let Some(sid) = current_session_id {
let session = self.get_or_create_session(sid).await?;
let mut session_guard = session.lock().await;
let original_count = session_guard.get_history().len();
let history = session_guard.get_history().to_vec();
let compressed = session_guard.compressor
.compress_if_needed(history)
.await?;
let compressed_count = compressed.len();
session_guard.clear_history();
for msg in compressed {
session_guard.add_message(msg);
}
Ok((None, format!(
"Context compressed: {}{} messages.",
original_count, compressed_count
)))
} else {
Ok((None, "No active conversation to compress.".to_string()))
}
}
"info" => {
if let Some(sid) = current_session_id {
let session = self.get_or_create_session(sid).await?;
let session_guard = session.lock().await;
let message_count = session_guard.get_history().len();
let session_id_str = session_guard.session_id();
Ok((None, format!(
"Session ID: {}\nMessage count: {}",
session_id_str, message_count
)))
} else {
Ok((None, "No active session.".to_string()))
}
}
"dump" => {
if let Some(sid) = current_session_id {
let session = self.get_or_create_session(sid).await?;
let session_guard = session.lock().await;
// Build the same system prompt that would be injected to the model
let skills_prompt = self.skills_loader.build_skills_prompt();
let system_prompt = session_guard.build_system_prompt(&skills_prompt);
let filepath = session_guard.dump_to_file(&system_prompt)
.map_err(|e| AgentError::Other(format!("Failed to save dump: {}", e)))?;
Ok((None, format!("Session dump saved to: {}", filepath)))
} else {
Ok((None, "No active session.".to_string()))
}
}
"sessions" => {
let (dialogs, _current) = self.list_dialogs(channel, chat_id, false).await?;
if dialogs.is_empty() {
Ok((None, "暂无对话记录。".to_string()))
} else {
let lines: Vec<String> = dialogs.iter().map(|d| {
let current = if current_session_id.map(|s| s.dialog_id == d.session_id.dialog_id).unwrap_or(false) {
" [当前]"
} else {
""
};
format!("- {} ({}){}{}", d.session_id.dialog_id, d.title, current, chrono::DateTime::from_timestamp_millis(d.last_active_at).map(|dt| dt.format("%m-%d %H:%M").to_string()).unwrap_or_default())
}).collect();
Ok((None, format!("最近对话:\n{}", lines.join("\n"))))
}
}
"switch" => {
let dialog_id = args.ok_or_else(|| AgentError::Other("Usage: /switch <dialog_id>".to_string()))?;
let new_id = self.switch_dialog(channel, chat_id, dialog_id).await?;
Ok((None, format!("已切换到对话:{}", new_id.dialog_id)))
}
"rename" => {
let title = args.ok_or_else(|| AgentError::Other("Usage: /rename <新标题>".to_string()))?;
if let Some(sid) = current_session_id {
self.rename_dialog(sid, title).await?;
Ok((None, format!("对话已重命名为:{}", title)))
} else {
Ok((None, "No active session.".to_string()))
}
}
_ => Err(AgentError::Other(format!("Command not implemented: {}", cmd.name))),
}
}
pub async fn create_session(
&self,
channel: &str,
chat_id: &str,
title: Option<&str>,
routing_info: String,
) -> Result<(UnifiedSessionId, String), AgentError> {
let dialog_id = short_id();
let unified_id = UnifiedSessionId::new(channel, chat_id, &dialog_id);
let session_id_str = unified_id.to_string();
let title = title
.map(str::trim)
.filter(|value| !value.is_empty())
.map(ToOwned::to_owned)
.unwrap_or_else(|| "新对话".to_string());
// Write to Storage first
let now = chrono::Utc::now().timestamp_millis();
let meta = crate::storage::session::SessionMeta {
id: session_id_str.clone(),
channel: channel.to_string(),
chat_id: chat_id.to_string(),
dialog_id: dialog_id.clone(),
title: title.clone(),
created_at: now,
last_active_at: now,
message_count: 0,
routing_info: if routing_info.is_empty() { None } else { Some(routing_info.clone()) },
deleted_at: None,
};
self.storage.upsert_session(&meta).await
.map_err(|e| AgentError::Other(format!("failed to create session in storage: {}", e)))?;
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
let session = Session::new(
unified_id.clone(),
self.provider_config.clone(),
user_tx,
self.tools.clone(),
Some(self.storage.clone()),
routing_info,
title.clone(),
).await?;
let arc = Arc::new(Mutex::new(session));
let inner = &mut *self.inner.lock().await;
inner.sessions.insert(session_id_str.clone(), arc.clone());
inner.session_timestamps.insert(session_id_str.clone(), Instant::now());
// Set as current session for this channel:chat_id
let chat_scope = format!("{}:{}", channel, chat_id);
inner.current_sessions.insert(chat_scope, session_id_str);
Ok((unified_id, title))
}
pub async fn get_or_create_session(&self, unified_id: &UnifiedSessionId) -> Result<Arc<Mutex<Session>>, AgentError> {
let session_id_str = unified_id.to_string();
let inner = &mut *self.inner.lock().await;
if let Some(session) = inner.sessions.get(&session_id_str) {
inner.session_timestamps.insert(session_id_str, Instant::now());
return Ok(session.clone());
}
// Try to restore from Storage
match self.storage.get_session(&session_id_str).await {
Ok(meta) => {
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
let session = Session::from_storage(
unified_id.clone(),
self.provider_config.clone(),
user_tx,
self.tools.clone(),
self.storage.clone(),
).await?;
let arc = Arc::new(Mutex::new(session));
inner.sessions.insert(session_id_str.clone(), arc.clone());
inner.session_timestamps.insert(session_id_str.clone(), Instant::now());
// Set as current session
let chat_scope = format!("{}:{}", unified_id.channel, unified_id.chat_id);
inner.current_sessions.insert(chat_scope, session_id_str);
return Ok(arc);
}
Err(_) => {
// Session not in Storage, create new
}
}
// Create new session
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
let session = Session::new(
unified_id.clone(),
self.provider_config.clone(),
user_tx,
self.tools.clone(),
Some(self.storage.clone()),
String::new(),
format!("新对话"),
).await?;
let arc = Arc::new(Mutex::new(session));
inner.sessions.insert(session_id_str.clone(), arc.clone());
inner.session_timestamps.insert(session_id_str.clone(), Instant::now());
// Set as current session
let chat_scope = format!("{}:{}", unified_id.channel, unified_id.chat_id);
inner.current_sessions.insert(chat_scope, session_id_str);
Ok(arc)
}
pub async fn create_dialog(
&self,
channel: &str,
chat_id: &str,
title: Option<&str>,
) -> Result<(UnifiedSessionId, String), AgentError> {
self.create_session(channel, chat_id, title, String::new()).await
}
pub async fn get_current_dialog(
&self,
_channel: &str,
_chat_id: &str,
) -> Result<Option<UnifiedSessionId>, AgentError> {
Ok(None)
}
pub async fn switch_dialog(
&self,
channel: &str,
chat_id: &str,
dialog_id: &str,
) -> Result<UnifiedSessionId, AgentError> {
let unified_id = UnifiedSessionId::new(channel, chat_id, dialog_id);
// Ensure session is loaded into memory
self.get_or_create_session(&unified_id).await?;
// Update current session tracking
let mut inner = self.inner.lock().await;
let chat_scope = format!("{}:{}", channel, chat_id);
inner.current_sessions.insert(chat_scope, unified_id.to_string());
Ok(unified_id)
}
pub async fn list_dialogs(
&self,
channel: &str,
chat_id: &str,
_include_archived: bool,
) -> Result<(Vec<DialogInfo>, Option<String>), AgentError> {
let metas = self.storage.list_sessions(channel, chat_id, 10).await
.map_err(|e| AgentError::Other(format!("failed to list dialogs: {}", e)))?;
let dialogs: Vec<DialogInfo> = metas.into_iter().map(|meta| {
DialogInfo {
session_id: UnifiedSessionId::new(channel, chat_id, &meta.dialog_id),
title: meta.title,
created_at: meta.created_at,
last_active_at: meta.last_active_at,
message_count: meta.message_count,
archived_at: None,
}
}).collect();
Ok((dialogs, None))
}
pub async fn rename_dialog(&self, session_id: &UnifiedSessionId, title: &str) -> Result<(), AgentError> {
// Update in-memory session
let session = self.get_or_create_session(session_id).await?;
let mut session_guard = session.lock().await;
session_guard.title = title.to_string();
session_guard.persist_session_meta().await
.map_err(|e| AgentError::Other(format!("failed to rename dialog: {}", e)))?;
Ok(())
}
pub async fn delete_dialog(&self, session_id: &UnifiedSessionId) -> Result<(), AgentError> {
let session_id_str = session_id.to_string();
// Soft delete from Storage
self.storage.soft_delete_session(&session_id_str).await
.map_err(|e| AgentError::Other(format!("failed to delete dialog: {}", e)))?;
// Remove from memory and current sessions
let mut inner = self.inner.lock().await;
inner.sessions.remove(&session_id_str);
inner.session_timestamps.remove(&session_id_str);
let chat_scope = format!("{}:{}", session_id.channel, session_id.chat_id);
inner.current_sessions.remove(&chat_scope);
Ok(())
}
pub fn archive_dialog(&self, _session_id: &UnifiedSessionId) -> Result<(), AgentError> {
// Archive concept removed - just return OK
Ok(())
}
pub fn clear_dialog_history(&self, _session_id: &UnifiedSessionId) -> Result<(), AgentError> {
Err(AgentError::Other("clear_dialog_history not available".to_string()))
}
pub async fn handle_message(
&self,
channel: &str,
_sender_id: &str,
chat_id: &str,
dialog_id: Option<&str>,
content: &str,
media: Vec<crate::bus::MediaItem>,
) -> Result<HandleResult, AgentError> {
// Determine dialog_id: if not provided, use current session or find active or create new
let unified_id = if let Some(did) = dialog_id {
UnifiedSessionId::new(channel, chat_id, did)
} else {
// Check if we have a current session tracked for this channel:chat_id
let chat_scope = format!("{}:{}", channel, chat_id);
let current_session_id = {
let inner = self.inner.lock().await;
inner.current_sessions.get(&chat_scope).cloned()
};
if let Some(current_id) = current_session_id {
// Verify current session still exists in Storage
match self.storage.get_session(&current_id).await {
Ok(meta) => {
// Current session still valid
let parts: Vec<&str> = current_id.split(':').collect();
if parts.len() == 3 {
UnifiedSessionId::new(channel, chat_id, parts[2])
} else {
// Malformed, fallback to find or create
let ttl_millis = self.inner.lock().await.session_ttl.as_millis() as i64;
match self.storage.find_active_session(channel, chat_id, ttl_millis).await {
Ok(Some(m)) => UnifiedSessionId::new(channel, chat_id, &m.dialog_id),
_ => {
let (new_id, _) = self.create_session(channel, chat_id, None, String::new()).await?;
new_id
}
}
}
}
Err(_) => {
// Current session no longer exists, create new
let (new_id, _) = self.create_session(channel, chat_id, None, String::new()).await?;
new_id
}
}
} else {
// No current session tracked, find active or create new
let ttl_millis = self.inner.lock().await.session_ttl.as_millis() as i64;
match self.storage.find_active_session(channel, chat_id, ttl_millis).await {
Ok(Some(meta)) => {
UnifiedSessionId::new(channel, chat_id, &meta.dialog_id)
}
Ok(None) | Err(_) => {
// Create new session
let (new_id, _) = self.create_session(channel, chat_id, None, String::new()).await?;
new_id
}
}
}
};
let session = self.get_or_create_session(&unified_id).await?;
// Check for slash command
if let Some((cmd_name, cmd_args)) = parse_slash_command(content) {
let (new_session_id, response) = self.execute_slash_command(
cmd_name,
if cmd_args.is_empty() { None } else { Some(cmd_args) },
channel,
chat_id,
Some(&unified_id),
).await?;
// If a new session was created (e.g., /new, /delete), update the session binding
if let Some(new_id) = new_session_id {
// Update the session in the map with the new ID
let mut inner = self.inner.lock().await;
if let Some(old_session) = inner.sessions.remove(&unified_id.to_string()) {
inner.sessions.insert(new_id.to_string(), old_session);
}
}
return Ok(HandleResult::CommandOutput(response));
}
// Normal message handling through LLM
let response: String = {
let mut session_guard = session.lock().await;
let media_refs: Vec<String> = media.iter().map(|m| m.path.clone()).collect();
#[cfg(debug_assertions)]
if !media_refs.is_empty() {
tracing::debug!(media_count = %media.len(), media_refs = ?media_refs, "Adding user message with media");
}
let user_message = session_guard.create_user_message(content, media_refs);
session_guard.add_message(user_message);
let mut history = session_guard.get_history().to_vec();
// Build skills prompt
let skills_prompt = self.skills_loader.build_skills_prompt();
// Build combined system prompt and inject at position 0
// This ensures AgentLoop.process() sees a system message and doesn't inject its own
let system_prompt = session_guard.build_system_prompt(&skills_prompt);
history.insert(0, ChatMessage::system(system_prompt));
let history = session_guard.compressor
.compress_if_needed(history)
.await?;
let agent = session_guard.create_agent()?;
let result = agent.process(history).await?;
for msg in result.emitted_messages {
session_guard.add_message(msg);
}
result.final_response.content
};
#[cfg(debug_assertions)]
tracing::debug!(
channel = %channel,
chat_id = %chat_id,
response_len = %response.len(),
"Agent response received"
);
Ok(HandleResult::AgentResponse(response))
}
pub async fn clear_session_history(&self, unified_id: &UnifiedSessionId) -> Result<(), AgentError> {
let session = self.get_or_create_session(unified_id).await?;
let mut session_guard = session.lock().await;
session_guard.clear_history();
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use tokio::sync::mpsc;
fn test_provider_config() -> LLMProviderConfig {
LLMProviderConfig {
provider_type: "openai".to_string(),
name: "test".to_string(),
base_url: "http://localhost".to_string(),
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
model_id: "test-model".to_string(),
temperature: Some(0.0),
max_tokens: Some(32),
model_extra: HashMap::new(),
max_tool_iterations: 1,
token_limit: 4096,
workspace_dir: std::path::PathBuf::from("/tmp/test-workspace"),
}
}
}