PicoBot/src/gateway/session.rs

547 lines
19 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, mpsc};
use uuid::Uuid;
use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig;
use crate::agent::{AgentLoop, AgentError, ContextCompressor};
use crate::protocol::WsOutbound;
use crate::providers::{create_provider, LLMProvider};
use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
use crate::tools::{
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
HttpRequestTool, ToolRegistry, WebFetchTool,
};
/// Session 按 channel 隔离,每个 channel 一个 Session
/// History 按 chat_id 隔离,由 Session 统一管理
pub struct Session {
pub id: Uuid,
pub channel_name: String,
/// 按 chat_id 路由到不同会话历史,支持多用户多会话
chat_histories: HashMap<String, Vec<ChatMessage>>,
pub user_tx: mpsc::Sender<WsOutbound>,
provider_config: LLMProviderConfig,
provider: Arc<dyn LLMProvider>,
tools: Arc<ToolRegistry>,
compressor: ContextCompressor,
store: Arc<SessionStore>,
}
impl Session {
pub async fn new(
channel_name: String,
provider_config: LLMProviderConfig,
user_tx: mpsc::Sender<WsOutbound>,
tools: Arc<ToolRegistry>,
store: Arc<SessionStore>,
) -> Result<Self, AgentError> {
let provider_box = create_provider(provider_config.clone())
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
let provider: Arc<dyn LLMProvider> = Arc::from(provider_box);
Ok(Self {
id: Uuid::new_v4(),
channel_name,
chat_histories: HashMap::new(),
user_tx,
provider_config: provider_config.clone(),
provider: provider.clone(),
tools,
compressor: ContextCompressor::new(provider.clone(), provider_config.token_limit),
store,
})
}
pub fn persistent_session_id(&self, chat_id: &str) -> String {
persistent_session_id(&self.channel_name, chat_id)
}
pub fn ensure_persistent_session(&self, chat_id: &str) -> Result<SessionRecord, AgentError> {
self.store
.ensure_channel_session(&self.channel_name, chat_id)
.map_err(|err| AgentError::Other(format!("session persistence error: {}", err)))
}
pub fn ensure_chat_loaded(&mut self, chat_id: &str) -> Result<(), AgentError> {
if self.chat_histories.contains_key(chat_id) {
return Ok(());
}
let history = self
.store
.load_messages(&self.persistent_session_id(chat_id))
.map_err(|err| AgentError::Other(format!("session history load error: {}", err)))?;
self.chat_histories.insert(chat_id.to_string(), history);
Ok(())
}
/// 获取或创建指定 chat_id 的会话历史
pub fn get_or_create_history(&mut self, chat_id: &str) -> &mut Vec<ChatMessage> {
self.chat_histories
.entry(chat_id.to_string())
.or_insert_with(Vec::new)
}
/// 获取指定 chat_id 的会话历史(不创建)
pub fn get_history(&self, chat_id: &str) -> Option<&Vec<ChatMessage>> {
self.chat_histories.get(chat_id)
}
/// 使用完整消息追加到历史
pub fn add_message(&mut self, chat_id: &str, message: ChatMessage) {
let history = self.get_or_create_history(chat_id);
history.push(message);
}
pub fn remove_history(&mut self, chat_id: &str) {
self.chat_histories.remove(chat_id);
}
pub fn clear_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
if let Some(history) = self.chat_histories.get_mut(chat_id) {
let len = history.len();
history.clear();
#[cfg(debug_assertions)]
tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history cleared");
}
self.store
.clear_messages(&self.persistent_session_id(chat_id))
.map_err(|err| AgentError::Other(format!("clear history persistence error: {}", err)))
}
pub fn reset_chat_context(&mut self, chat_id: &str) -> Result<(), AgentError> {
if let Some(history) = self.chat_histories.get_mut(chat_id) {
let len = history.len();
history.clear();
#[cfg(debug_assertions)]
tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history reset in memory");
}
self.store
.reset_session(&self.persistent_session_id(chat_id))
.map_err(|err| AgentError::Other(format!("reset history persistence error: {}", err)))
}
/// 将消息写入内存与持久化层
pub fn append_persisted_message(&mut self, chat_id: &str, message: ChatMessage) -> Result<(), AgentError> {
let session_id = self.persistent_session_id(chat_id);
self.store
.append_message(&session_id, &message)
.map_err(|err| AgentError::Other(format!("append message persistence error: {}", err)))?;
self.add_message(chat_id, message);
Ok(())
}
pub fn append_persisted_messages<I>(&mut self, chat_id: &str, messages: I) -> Result<(), AgentError>
where
I: IntoIterator<Item = ChatMessage>,
{
for message in messages {
self.append_persisted_message(chat_id, message)?;
}
Ok(())
}
pub fn create_user_message(&self, content: &str, media_refs: Vec<String>) -> ChatMessage {
if media_refs.is_empty() {
ChatMessage::user(content)
} else {
ChatMessage::user_with_media(content, media_refs)
}
}
/// 清除所有历史
pub fn clear_all_history(&mut self) -> Result<(), AgentError> {
let chat_ids: Vec<String> = self.chat_histories.keys().cloned().collect();
let total: usize = self.chat_histories.values().map(|h| h.len()).sum();
self.chat_histories.clear();
#[cfg(debug_assertions)]
tracing::debug!(previous_total = total, "All chat histories cleared");
for chat_id in chat_ids {
self.store
.clear_messages(&self.persistent_session_id(&chat_id))
.map_err(|err| AgentError::Other(format!("clear history persistence error: {}", err)))?;
}
Ok(())
}
pub async fn send(&self, msg: WsOutbound) {
let _ = self.user_tx.send(msg).await;
}
/// 获取 provider_config 引用
pub fn provider_config(&self) -> &LLMProviderConfig {
&self.provider_config
}
/// 获取 compressor 引用
pub fn compressor(&self) -> &ContextCompressor {
&self.compressor
}
/// 创建一个临时的 AgentLoop 实例来处理消息
pub fn create_agent(&self) -> Result<AgentLoop, AgentError> {
Ok(AgentLoop::with_provider_and_tools(
self.provider.clone(),
self.tools.clone(),
self.provider_config.max_tool_iterations,
))
}
}
/// SessionManager 管理所有 Session按 channel_name 路由
#[derive(Clone)]
pub struct SessionManager {
inner: Arc<Mutex<SessionManagerInner>>,
provider_config: LLMProviderConfig,
tools: Arc<ToolRegistry>,
store: Arc<SessionStore>,
}
struct SessionManagerInner {
sessions: HashMap<String, Arc<Mutex<Session>>>,
session_timestamps: HashMap<String, Instant>,
session_ttl: Duration,
}
fn default_tools() -> ToolRegistry {
let mut registry = ToolRegistry::new();
registry.register(CalculatorTool::new());
registry.register(FileReadTool::new());
registry.register(FileWriteTool::new());
registry.register(FileEditTool::new());
registry.register(BashTool::new());
registry.register(HttpRequestTool::new(
vec!["*".to_string()], // 允许所有域名,实际使用时建议限制
1_000_000, // max_response_size
30, // timeout_secs
false, // allow_private_hosts
));
registry.register(WebFetchTool::new(50_000, 30)); // max_chars, timeout_secs
registry
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum InChatCommand {
FreshConversation,
}
fn parse_in_chat_command(content: &str) -> Option<InChatCommand> {
match content.trim() {
"/new" | "/reset" => Some(InChatCommand::FreshConversation),
_ => None,
}
}
pub(crate) fn handle_in_chat_command(
session: &mut Session,
chat_id: &str,
content: &str,
) -> Result<Option<String>, AgentError> {
match parse_in_chat_command(content) {
Some(InChatCommand::FreshConversation) => {
session.reset_chat_context(chat_id)?;
Ok(Some("Started a fresh conversation.".to_string()))
}
None => Ok(None),
}
}
impl SessionManager {
pub fn new(session_ttl_hours: u64, provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
let store = Arc::new(
SessionStore::new()
.map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?,
);
Ok(Self {
inner: Arc::new(Mutex::new(SessionManagerInner {
sessions: HashMap::new(),
session_timestamps: HashMap::new(),
session_ttl: Duration::from_secs(session_ttl_hours * 3600),
})),
provider_config,
tools: Arc::new(default_tools()),
store,
})
}
pub fn tools(&self) -> Arc<ToolRegistry> {
self.tools.clone()
}
pub fn store(&self) -> Arc<SessionStore> {
self.store.clone()
}
pub fn create_cli_session(&self, title: Option<&str>) -> Result<SessionRecord, AgentError> {
self.store
.create_cli_session(title)
.map_err(|err| AgentError::Other(format!("create session error: {}", err)))
}
pub fn get_session_record(&self, session_id: &str) -> Result<Option<SessionRecord>, AgentError> {
self.store
.get_session(session_id)
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))
}
pub fn list_cli_sessions(&self, include_archived: bool) -> Result<Vec<SessionRecord>, AgentError> {
self.store
.list_sessions("cli", include_archived)
.map_err(|err| AgentError::Other(format!("list sessions error: {}", err)))
}
pub fn rename_session(&self, session_id: &str, title: &str) -> Result<(), AgentError> {
self.store
.rename_session(session_id, title)
.map_err(|err| AgentError::Other(format!("rename session error: {}", err)))
}
pub fn archive_session(&self, session_id: &str) -> Result<(), AgentError> {
self.store
.archive_session(session_id)
.map_err(|err| AgentError::Other(format!("archive session error: {}", err)))
}
pub fn delete_session(&self, session_id: &str) -> Result<(), AgentError> {
self.store
.delete_session(session_id)
.map_err(|err| AgentError::Other(format!("delete session error: {}", err)))
}
pub fn clear_session_messages(&self, session_id: &str) -> Result<(), AgentError> {
self.store
.clear_messages(session_id)
.map_err(|err| AgentError::Other(format!("clear session error: {}", err)))
}
pub fn load_session_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, AgentError> {
self.store
.load_messages(session_id)
.map_err(|err| AgentError::Other(format!("load messages error: {}", err)))
}
/// 确保 session 存在且未超时,超时则重建
pub async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> {
let mut inner = self.inner.lock().await;
let should_recreate = if let Some(last_active) = inner.session_timestamps.get(channel_name) {
let elapsed = last_active.elapsed();
if elapsed > inner.session_ttl {
tracing::info!(channel = %channel_name, elapsed_hours = elapsed.as_secs() / 3600, "Session expired, recreating");
true
} else {
false
}
} else {
#[cfg(debug_assertions)]
tracing::debug!(channel = %channel_name, "Creating new session");
true
};
if should_recreate {
// 移除旧 session
inner.sessions.remove(channel_name);
// 创建新 session使用临时 user_tx因为 Feishu 不通过 WS
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
let session = Session::new(
channel_name.to_string(),
self.provider_config.clone(),
user_tx,
self.tools.clone(),
self.store.clone(),
)
.await?;
let arc = Arc::new(Mutex::new(session));
inner.sessions.insert(channel_name.to_string(), arc.clone());
inner.session_timestamps.insert(channel_name.to_string(), Instant::now());
}
Ok(())
}
/// 获取 session不检查超时
pub async fn get(&self, channel_name: &str) -> Option<Arc<Mutex<Session>>> {
let inner = self.inner.lock().await;
inner.sessions.get(channel_name).cloned()
}
/// 更新最后活跃时间
pub async fn touch(&self, channel_name: &str) {
let mut inner = self.inner.lock().await;
inner.session_timestamps.insert(channel_name.to_string(), Instant::now());
}
/// 处理消息:路由到对应 session 的 agent
pub async fn handle_message(
&self,
channel_name: &str,
_sender_id: &str,
chat_id: &str,
content: &str,
media: Vec<crate::bus::MediaItem>,
) -> Result<String, AgentError> {
#[cfg(debug_assertions)]
{
tracing::debug!(
channel = %channel_name,
chat_id = %chat_id,
content_len = content.len(),
media_count = %media.len(),
"Routing message to agent"
);
for (i, m) in media.iter().enumerate() {
tracing::debug!(media_index = i, media_type = %m.media_type, path = %m.path, "Media in handle_message");
}
}
// 确保 session 存在(可能需要重建)
self.ensure_session(channel_name).await?;
// 更新活跃时间
self.touch(channel_name).await;
// 获取 session
let session = self
.get(channel_name)
.await
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
// 处理消息
let response = {
let mut session_guard = session.lock().await;
session_guard.ensure_persistent_session(chat_id)?;
session_guard.ensure_chat_loaded(chat_id)?;
if let Some(command_response) = handle_in_chat_command(&mut session_guard, chat_id, content)? {
return Ok(command_response);
}
// 添加用户消息到历史
let media_refs: Vec<String> = media.iter().map(|m| m.path.clone()).collect();
#[cfg(debug_assertions)]
if !media_refs.is_empty() {
tracing::debug!(media_count = %media.len(), media_refs = ?media_refs, "Adding user message with media");
}
let user_message = session_guard.create_user_message(content, media_refs);
session_guard.append_persisted_message(chat_id, user_message)?;
// 获取完整历史
let history = session_guard.get_or_create_history(chat_id).clone();
// 压缩历史(如果需要)
let history = session_guard.compressor
.compress_if_needed(history)
.await?;
// 创建 agent 并处理
let agent = session_guard.create_agent()?;
let result = agent.process(history).await?;
// 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复
session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?;
result.final_response
};
#[cfg(debug_assertions)]
tracing::debug!(
channel = %channel_name,
chat_id = %chat_id,
response_len = response.content.len(),
"Agent response received"
);
Ok(response.content)
}
/// 清除指定 session 的所有历史
pub async fn clear_session_history(&self, channel_name: &str) -> Result<(), AgentError> {
if let Some(session) = self.get(channel_name).await {
let mut session_guard = session.lock().await;
session_guard.clear_all_history()?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use tokio::sync::mpsc;
fn test_provider_config() -> LLMProviderConfig {
LLMProviderConfig {
provider_type: "openai".to_string(),
name: "test".to_string(),
base_url: "http://localhost".to_string(),
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
model_id: "test-model".to_string(),
temperature: Some(0.0),
max_tokens: Some(32),
model_extra: HashMap::new(),
max_tool_iterations: 1,
token_limit: 4096,
}
}
#[test]
fn test_parse_in_chat_command_aliases() {
assert_eq!(parse_in_chat_command("/new"), Some(InChatCommand::FreshConversation));
assert_eq!(parse_in_chat_command(" /reset \n"), Some(InChatCommand::FreshConversation));
assert_eq!(parse_in_chat_command("/new planning"), None);
assert_eq!(parse_in_chat_command("please /reset"), None);
}
#[tokio::test]
async fn test_handle_in_chat_command_resets_active_history_only() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let (user_tx, _user_rx) = mpsc::channel(4);
let tools = Arc::new(default_tools());
let mut session = Session::new(
"feishu".to_string(),
test_provider_config(),
user_tx,
tools,
store.clone(),
)
.await
.unwrap();
session.ensure_persistent_session("chat-1").unwrap();
session.ensure_chat_loaded("chat-1").unwrap();
session
.append_persisted_message("chat-1", ChatMessage::user("hello"))
.unwrap();
let response = handle_in_chat_command(&mut session, "chat-1", "/reset")
.unwrap()
.unwrap();
assert_eq!(response, "Started a fresh conversation.");
assert!(session.get_history("chat-1").unwrap().is_empty());
assert!(store
.load_messages(&session.persistent_session_id("chat-1"))
.unwrap()
.is_empty());
assert_eq!(
store
.load_all_messages(&session.persistent_session_id("chat-1"))
.unwrap()
.len(),
1,
);
}
}