547 lines
19 KiB
Rust
547 lines
19 KiB
Rust
use std::collections::HashMap;
|
||
use std::sync::Arc;
|
||
use std::time::{Duration, Instant};
|
||
use tokio::sync::{Mutex, mpsc};
|
||
use uuid::Uuid;
|
||
use crate::bus::ChatMessage;
|
||
use crate::config::LLMProviderConfig;
|
||
use crate::agent::{AgentLoop, AgentError, ContextCompressor};
|
||
use crate::protocol::WsOutbound;
|
||
use crate::providers::{create_provider, LLMProvider};
|
||
use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
|
||
use crate::tools::{
|
||
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
|
||
HttpRequestTool, ToolRegistry, WebFetchTool,
|
||
};
|
||
|
||
/// Session 按 channel 隔离,每个 channel 一个 Session
|
||
/// History 按 chat_id 隔离,由 Session 统一管理
|
||
pub struct Session {
|
||
pub id: Uuid,
|
||
pub channel_name: String,
|
||
/// 按 chat_id 路由到不同会话历史,支持多用户多会话
|
||
chat_histories: HashMap<String, Vec<ChatMessage>>,
|
||
pub user_tx: mpsc::Sender<WsOutbound>,
|
||
provider_config: LLMProviderConfig,
|
||
provider: Arc<dyn LLMProvider>,
|
||
tools: Arc<ToolRegistry>,
|
||
compressor: ContextCompressor,
|
||
store: Arc<SessionStore>,
|
||
}
|
||
|
||
impl Session {
|
||
pub async fn new(
|
||
channel_name: String,
|
||
provider_config: LLMProviderConfig,
|
||
user_tx: mpsc::Sender<WsOutbound>,
|
||
tools: Arc<ToolRegistry>,
|
||
store: Arc<SessionStore>,
|
||
) -> Result<Self, AgentError> {
|
||
let provider_box = create_provider(provider_config.clone())
|
||
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
|
||
let provider: Arc<dyn LLMProvider> = Arc::from(provider_box);
|
||
|
||
Ok(Self {
|
||
id: Uuid::new_v4(),
|
||
channel_name,
|
||
chat_histories: HashMap::new(),
|
||
user_tx,
|
||
provider_config: provider_config.clone(),
|
||
provider: provider.clone(),
|
||
tools,
|
||
compressor: ContextCompressor::new(provider.clone(), provider_config.token_limit),
|
||
store,
|
||
})
|
||
}
|
||
|
||
pub fn persistent_session_id(&self, chat_id: &str) -> String {
|
||
persistent_session_id(&self.channel_name, chat_id)
|
||
}
|
||
|
||
pub fn ensure_persistent_session(&self, chat_id: &str) -> Result<SessionRecord, AgentError> {
|
||
self.store
|
||
.ensure_channel_session(&self.channel_name, chat_id)
|
||
.map_err(|err| AgentError::Other(format!("session persistence error: {}", err)))
|
||
}
|
||
|
||
pub fn ensure_chat_loaded(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||
if self.chat_histories.contains_key(chat_id) {
|
||
return Ok(());
|
||
}
|
||
|
||
let history = self
|
||
.store
|
||
.load_messages(&self.persistent_session_id(chat_id))
|
||
.map_err(|err| AgentError::Other(format!("session history load error: {}", err)))?;
|
||
self.chat_histories.insert(chat_id.to_string(), history);
|
||
Ok(())
|
||
}
|
||
|
||
/// 获取或创建指定 chat_id 的会话历史
|
||
pub fn get_or_create_history(&mut self, chat_id: &str) -> &mut Vec<ChatMessage> {
|
||
self.chat_histories
|
||
.entry(chat_id.to_string())
|
||
.or_insert_with(Vec::new)
|
||
}
|
||
|
||
/// 获取指定 chat_id 的会话历史(不创建)
|
||
pub fn get_history(&self, chat_id: &str) -> Option<&Vec<ChatMessage>> {
|
||
self.chat_histories.get(chat_id)
|
||
}
|
||
|
||
/// 使用完整消息追加到历史
|
||
pub fn add_message(&mut self, chat_id: &str, message: ChatMessage) {
|
||
let history = self.get_or_create_history(chat_id);
|
||
history.push(message);
|
||
}
|
||
|
||
pub fn remove_history(&mut self, chat_id: &str) {
|
||
self.chat_histories.remove(chat_id);
|
||
}
|
||
|
||
pub fn clear_chat_history(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||
if let Some(history) = self.chat_histories.get_mut(chat_id) {
|
||
let len = history.len();
|
||
history.clear();
|
||
#[cfg(debug_assertions)]
|
||
tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history cleared");
|
||
}
|
||
|
||
self.store
|
||
.clear_messages(&self.persistent_session_id(chat_id))
|
||
.map_err(|err| AgentError::Other(format!("clear history persistence error: {}", err)))
|
||
}
|
||
|
||
pub fn reset_chat_context(&mut self, chat_id: &str) -> Result<(), AgentError> {
|
||
if let Some(history) = self.chat_histories.get_mut(chat_id) {
|
||
let len = history.len();
|
||
history.clear();
|
||
#[cfg(debug_assertions)]
|
||
tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history reset in memory");
|
||
}
|
||
|
||
self.store
|
||
.reset_session(&self.persistent_session_id(chat_id))
|
||
.map_err(|err| AgentError::Other(format!("reset history persistence error: {}", err)))
|
||
}
|
||
|
||
/// 将消息写入内存与持久化层
|
||
pub fn append_persisted_message(&mut self, chat_id: &str, message: ChatMessage) -> Result<(), AgentError> {
|
||
let session_id = self.persistent_session_id(chat_id);
|
||
self.store
|
||
.append_message(&session_id, &message)
|
||
.map_err(|err| AgentError::Other(format!("append message persistence error: {}", err)))?;
|
||
self.add_message(chat_id, message);
|
||
Ok(())
|
||
}
|
||
|
||
pub fn append_persisted_messages<I>(&mut self, chat_id: &str, messages: I) -> Result<(), AgentError>
|
||
where
|
||
I: IntoIterator<Item = ChatMessage>,
|
||
{
|
||
for message in messages {
|
||
self.append_persisted_message(chat_id, message)?;
|
||
}
|
||
Ok(())
|
||
}
|
||
|
||
pub fn create_user_message(&self, content: &str, media_refs: Vec<String>) -> ChatMessage {
|
||
if media_refs.is_empty() {
|
||
ChatMessage::user(content)
|
||
} else {
|
||
ChatMessage::user_with_media(content, media_refs)
|
||
}
|
||
}
|
||
|
||
/// 清除所有历史
|
||
pub fn clear_all_history(&mut self) -> Result<(), AgentError> {
|
||
let chat_ids: Vec<String> = self.chat_histories.keys().cloned().collect();
|
||
let total: usize = self.chat_histories.values().map(|h| h.len()).sum();
|
||
self.chat_histories.clear();
|
||
#[cfg(debug_assertions)]
|
||
tracing::debug!(previous_total = total, "All chat histories cleared");
|
||
|
||
for chat_id in chat_ids {
|
||
self.store
|
||
.clear_messages(&self.persistent_session_id(&chat_id))
|
||
.map_err(|err| AgentError::Other(format!("clear history persistence error: {}", err)))?;
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
pub async fn send(&self, msg: WsOutbound) {
|
||
let _ = self.user_tx.send(msg).await;
|
||
}
|
||
|
||
/// 获取 provider_config 引用
|
||
pub fn provider_config(&self) -> &LLMProviderConfig {
|
||
&self.provider_config
|
||
}
|
||
|
||
/// 获取 compressor 引用
|
||
pub fn compressor(&self) -> &ContextCompressor {
|
||
&self.compressor
|
||
}
|
||
|
||
/// 创建一个临时的 AgentLoop 实例来处理消息
|
||
pub fn create_agent(&self) -> Result<AgentLoop, AgentError> {
|
||
Ok(AgentLoop::with_provider_and_tools(
|
||
self.provider.clone(),
|
||
self.tools.clone(),
|
||
self.provider_config.max_tool_iterations,
|
||
))
|
||
}
|
||
}
|
||
|
||
/// SessionManager 管理所有 Session,按 channel_name 路由
|
||
#[derive(Clone)]
|
||
pub struct SessionManager {
|
||
inner: Arc<Mutex<SessionManagerInner>>,
|
||
provider_config: LLMProviderConfig,
|
||
tools: Arc<ToolRegistry>,
|
||
store: Arc<SessionStore>,
|
||
}
|
||
|
||
struct SessionManagerInner {
|
||
sessions: HashMap<String, Arc<Mutex<Session>>>,
|
||
session_timestamps: HashMap<String, Instant>,
|
||
session_ttl: Duration,
|
||
}
|
||
|
||
fn default_tools() -> ToolRegistry {
|
||
let mut registry = ToolRegistry::new();
|
||
registry.register(CalculatorTool::new());
|
||
registry.register(FileReadTool::new());
|
||
registry.register(FileWriteTool::new());
|
||
registry.register(FileEditTool::new());
|
||
registry.register(BashTool::new());
|
||
registry.register(HttpRequestTool::new(
|
||
vec!["*".to_string()], // 允许所有域名,实际使用时建议限制
|
||
1_000_000, // max_response_size
|
||
30, // timeout_secs
|
||
false, // allow_private_hosts
|
||
));
|
||
registry.register(WebFetchTool::new(50_000, 30)); // max_chars, timeout_secs
|
||
registry
|
||
}
|
||
|
||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||
enum InChatCommand {
|
||
FreshConversation,
|
||
}
|
||
|
||
fn parse_in_chat_command(content: &str) -> Option<InChatCommand> {
|
||
match content.trim() {
|
||
"/new" | "/reset" => Some(InChatCommand::FreshConversation),
|
||
_ => None,
|
||
}
|
||
}
|
||
|
||
pub(crate) fn handle_in_chat_command(
|
||
session: &mut Session,
|
||
chat_id: &str,
|
||
content: &str,
|
||
) -> Result<Option<String>, AgentError> {
|
||
match parse_in_chat_command(content) {
|
||
Some(InChatCommand::FreshConversation) => {
|
||
session.reset_chat_context(chat_id)?;
|
||
Ok(Some("Started a fresh conversation.".to_string()))
|
||
}
|
||
None => Ok(None),
|
||
}
|
||
}
|
||
|
||
impl SessionManager {
|
||
pub fn new(session_ttl_hours: u64, provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
|
||
let store = Arc::new(
|
||
SessionStore::new()
|
||
.map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?,
|
||
);
|
||
|
||
Ok(Self {
|
||
inner: Arc::new(Mutex::new(SessionManagerInner {
|
||
sessions: HashMap::new(),
|
||
session_timestamps: HashMap::new(),
|
||
session_ttl: Duration::from_secs(session_ttl_hours * 3600),
|
||
})),
|
||
provider_config,
|
||
tools: Arc::new(default_tools()),
|
||
store,
|
||
})
|
||
}
|
||
|
||
pub fn tools(&self) -> Arc<ToolRegistry> {
|
||
self.tools.clone()
|
||
}
|
||
|
||
pub fn store(&self) -> Arc<SessionStore> {
|
||
self.store.clone()
|
||
}
|
||
|
||
pub fn create_cli_session(&self, title: Option<&str>) -> Result<SessionRecord, AgentError> {
|
||
self.store
|
||
.create_cli_session(title)
|
||
.map_err(|err| AgentError::Other(format!("create session error: {}", err)))
|
||
}
|
||
|
||
pub fn get_session_record(&self, session_id: &str) -> Result<Option<SessionRecord>, AgentError> {
|
||
self.store
|
||
.get_session(session_id)
|
||
.map_err(|err| AgentError::Other(format!("get session error: {}", err)))
|
||
}
|
||
|
||
pub fn list_cli_sessions(&self, include_archived: bool) -> Result<Vec<SessionRecord>, AgentError> {
|
||
self.store
|
||
.list_sessions("cli", include_archived)
|
||
.map_err(|err| AgentError::Other(format!("list sessions error: {}", err)))
|
||
}
|
||
|
||
pub fn rename_session(&self, session_id: &str, title: &str) -> Result<(), AgentError> {
|
||
self.store
|
||
.rename_session(session_id, title)
|
||
.map_err(|err| AgentError::Other(format!("rename session error: {}", err)))
|
||
}
|
||
|
||
pub fn archive_session(&self, session_id: &str) -> Result<(), AgentError> {
|
||
self.store
|
||
.archive_session(session_id)
|
||
.map_err(|err| AgentError::Other(format!("archive session error: {}", err)))
|
||
}
|
||
|
||
pub fn delete_session(&self, session_id: &str) -> Result<(), AgentError> {
|
||
self.store
|
||
.delete_session(session_id)
|
||
.map_err(|err| AgentError::Other(format!("delete session error: {}", err)))
|
||
}
|
||
|
||
pub fn clear_session_messages(&self, session_id: &str) -> Result<(), AgentError> {
|
||
self.store
|
||
.clear_messages(session_id)
|
||
.map_err(|err| AgentError::Other(format!("clear session error: {}", err)))
|
||
}
|
||
|
||
pub fn load_session_messages(&self, session_id: &str) -> Result<Vec<ChatMessage>, AgentError> {
|
||
self.store
|
||
.load_messages(session_id)
|
||
.map_err(|err| AgentError::Other(format!("load messages error: {}", err)))
|
||
}
|
||
|
||
/// 确保 session 存在且未超时,超时则重建
|
||
pub async fn ensure_session(&self, channel_name: &str) -> Result<(), AgentError> {
|
||
let mut inner = self.inner.lock().await;
|
||
|
||
let should_recreate = if let Some(last_active) = inner.session_timestamps.get(channel_name) {
|
||
let elapsed = last_active.elapsed();
|
||
if elapsed > inner.session_ttl {
|
||
tracing::info!(channel = %channel_name, elapsed_hours = elapsed.as_secs() / 3600, "Session expired, recreating");
|
||
true
|
||
} else {
|
||
false
|
||
}
|
||
} else {
|
||
#[cfg(debug_assertions)]
|
||
tracing::debug!(channel = %channel_name, "Creating new session");
|
||
true
|
||
};
|
||
|
||
if should_recreate {
|
||
// 移除旧 session
|
||
inner.sessions.remove(channel_name);
|
||
|
||
// 创建新 session(使用临时 user_tx,因为 Feishu 不通过 WS)
|
||
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
|
||
let session = Session::new(
|
||
channel_name.to_string(),
|
||
self.provider_config.clone(),
|
||
user_tx,
|
||
self.tools.clone(),
|
||
self.store.clone(),
|
||
)
|
||
.await?;
|
||
let arc = Arc::new(Mutex::new(session));
|
||
|
||
inner.sessions.insert(channel_name.to_string(), arc.clone());
|
||
inner.session_timestamps.insert(channel_name.to_string(), Instant::now());
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
/// 获取 session(不检查超时)
|
||
pub async fn get(&self, channel_name: &str) -> Option<Arc<Mutex<Session>>> {
|
||
let inner = self.inner.lock().await;
|
||
inner.sessions.get(channel_name).cloned()
|
||
}
|
||
|
||
/// 更新最后活跃时间
|
||
pub async fn touch(&self, channel_name: &str) {
|
||
let mut inner = self.inner.lock().await;
|
||
inner.session_timestamps.insert(channel_name.to_string(), Instant::now());
|
||
}
|
||
|
||
/// 处理消息:路由到对应 session 的 agent
|
||
pub async fn handle_message(
|
||
&self,
|
||
channel_name: &str,
|
||
_sender_id: &str,
|
||
chat_id: &str,
|
||
content: &str,
|
||
media: Vec<crate::bus::MediaItem>,
|
||
) -> Result<String, AgentError> {
|
||
#[cfg(debug_assertions)]
|
||
{
|
||
tracing::debug!(
|
||
channel = %channel_name,
|
||
chat_id = %chat_id,
|
||
content_len = content.len(),
|
||
media_count = %media.len(),
|
||
"Routing message to agent"
|
||
);
|
||
for (i, m) in media.iter().enumerate() {
|
||
tracing::debug!(media_index = i, media_type = %m.media_type, path = %m.path, "Media in handle_message");
|
||
}
|
||
}
|
||
|
||
// 确保 session 存在(可能需要重建)
|
||
self.ensure_session(channel_name).await?;
|
||
|
||
// 更新活跃时间
|
||
self.touch(channel_name).await;
|
||
|
||
// 获取 session
|
||
let session = self
|
||
.get(channel_name)
|
||
.await
|
||
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
|
||
|
||
// 处理消息
|
||
let response = {
|
||
let mut session_guard = session.lock().await;
|
||
|
||
session_guard.ensure_persistent_session(chat_id)?;
|
||
session_guard.ensure_chat_loaded(chat_id)?;
|
||
|
||
if let Some(command_response) = handle_in_chat_command(&mut session_guard, chat_id, content)? {
|
||
return Ok(command_response);
|
||
}
|
||
|
||
// 添加用户消息到历史
|
||
let media_refs: Vec<String> = media.iter().map(|m| m.path.clone()).collect();
|
||
#[cfg(debug_assertions)]
|
||
if !media_refs.is_empty() {
|
||
tracing::debug!(media_count = %media.len(), media_refs = ?media_refs, "Adding user message with media");
|
||
}
|
||
let user_message = session_guard.create_user_message(content, media_refs);
|
||
session_guard.append_persisted_message(chat_id, user_message)?;
|
||
|
||
// 获取完整历史
|
||
let history = session_guard.get_or_create_history(chat_id).clone();
|
||
|
||
// 压缩历史(如果需要)
|
||
let history = session_guard.compressor
|
||
.compress_if_needed(history)
|
||
.await?;
|
||
|
||
// 创建 agent 并处理
|
||
let agent = session_guard.create_agent()?;
|
||
let result = agent.process(history).await?;
|
||
|
||
// 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复
|
||
session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?;
|
||
|
||
result.final_response
|
||
};
|
||
|
||
#[cfg(debug_assertions)]
|
||
tracing::debug!(
|
||
channel = %channel_name,
|
||
chat_id = %chat_id,
|
||
response_len = response.content.len(),
|
||
"Agent response received"
|
||
);
|
||
|
||
Ok(response.content)
|
||
}
|
||
|
||
/// 清除指定 session 的所有历史
|
||
pub async fn clear_session_history(&self, channel_name: &str) -> Result<(), AgentError> {
|
||
if let Some(session) = self.get(channel_name).await {
|
||
let mut session_guard = session.lock().await;
|
||
session_guard.clear_all_history()?;
|
||
}
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use std::collections::HashMap;
|
||
use tokio::sync::mpsc;
|
||
|
||
fn test_provider_config() -> LLMProviderConfig {
|
||
LLMProviderConfig {
|
||
provider_type: "openai".to_string(),
|
||
name: "test".to_string(),
|
||
base_url: "http://localhost".to_string(),
|
||
api_key: "test-key".to_string(),
|
||
extra_headers: HashMap::new(),
|
||
model_id: "test-model".to_string(),
|
||
temperature: Some(0.0),
|
||
max_tokens: Some(32),
|
||
model_extra: HashMap::new(),
|
||
max_tool_iterations: 1,
|
||
token_limit: 4096,
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn test_parse_in_chat_command_aliases() {
|
||
assert_eq!(parse_in_chat_command("/new"), Some(InChatCommand::FreshConversation));
|
||
assert_eq!(parse_in_chat_command(" /reset \n"), Some(InChatCommand::FreshConversation));
|
||
assert_eq!(parse_in_chat_command("/new planning"), None);
|
||
assert_eq!(parse_in_chat_command("please /reset"), None);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_handle_in_chat_command_resets_active_history_only() {
|
||
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||
let (user_tx, _user_rx) = mpsc::channel(4);
|
||
let tools = Arc::new(default_tools());
|
||
let mut session = Session::new(
|
||
"feishu".to_string(),
|
||
test_provider_config(),
|
||
user_tx,
|
||
tools,
|
||
store.clone(),
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
session.ensure_persistent_session("chat-1").unwrap();
|
||
session.ensure_chat_loaded("chat-1").unwrap();
|
||
session
|
||
.append_persisted_message("chat-1", ChatMessage::user("hello"))
|
||
.unwrap();
|
||
|
||
let response = handle_in_chat_command(&mut session, "chat-1", "/reset")
|
||
.unwrap()
|
||
.unwrap();
|
||
|
||
assert_eq!(response, "Started a fresh conversation.");
|
||
assert!(session.get_history("chat-1").unwrap().is_empty());
|
||
assert!(store
|
||
.load_messages(&session.persistent_session_id("chat-1"))
|
||
.unwrap()
|
||
.is_empty());
|
||
assert_eq!(
|
||
store
|
||
.load_all_messages(&session.persistent_session_id("chat-1"))
|
||
.unwrap()
|
||
.len(),
|
||
1,
|
||
);
|
||
}
|
||
}
|