Refactor AgentLoop to manage history externally via SessionManager
- Removed internal history management from AgentLoop. - Updated process method to accept conversation history as a parameter. - Adjusted continue_with_tool_results to work with external history. - Added OutboundDispatcher for handling outbound messages from MessageBus. - Introduced InboundMessage and OutboundMessage structs for message handling. - Updated Channel trait to include message handling and publishing to MessageBus. - Refactored Session to manage chat histories instead of AgentLoop instances. - Enhanced GatewayState to start message processing loops for inbound and outbound messages.
This commit is contained in:
parent
9834bd75cf
commit
a051f83050
@ -4,9 +4,9 @@ use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Mess
|
||||
use crate::tools::ToolRegistry;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Stateless AgentLoop - history is managed externally by SessionManager
|
||||
pub struct AgentLoop {
|
||||
provider: Box<dyn LLMProvider>,
|
||||
history: Vec<ChatMessage>,
|
||||
tools: Arc<ToolRegistry>,
|
||||
}
|
||||
|
||||
@ -17,7 +17,6 @@ impl AgentLoop {
|
||||
|
||||
Ok(Self {
|
||||
provider,
|
||||
history: Vec::new(),
|
||||
tools: Arc::new(ToolRegistry::new()),
|
||||
})
|
||||
}
|
||||
@ -28,7 +27,6 @@ impl AgentLoop {
|
||||
|
||||
Ok(Self {
|
||||
provider,
|
||||
history: Vec::new(),
|
||||
tools,
|
||||
})
|
||||
}
|
||||
@ -37,10 +35,10 @@ impl AgentLoop {
|
||||
&self.tools
|
||||
}
|
||||
|
||||
pub async fn process(&mut self, user_message: ChatMessage) -> Result<ChatMessage, AgentError> {
|
||||
self.history.push(user_message.clone());
|
||||
|
||||
let messages: Vec<Message> = self.history
|
||||
/// Process a message using the provided conversation history.
|
||||
/// History management is handled externally by SessionManager.
|
||||
pub async fn process(&self, messages: Vec<ChatMessage>) -> Result<ChatMessage, AgentError> {
|
||||
let messages_for_llm: Vec<Message> = messages
|
||||
.iter()
|
||||
.map(|m| Message {
|
||||
role: m.role.clone(),
|
||||
@ -50,7 +48,7 @@ impl AgentLoop {
|
||||
})
|
||||
.collect();
|
||||
|
||||
tracing::debug!(history_len = self.history.len(), "Sending request to LLM");
|
||||
tracing::debug!(history_len = messages.len(), "Sending request to LLM");
|
||||
|
||||
let tools = if self.tools.has_tools() {
|
||||
Some(self.tools.get_definitions())
|
||||
@ -59,7 +57,7 @@ impl AgentLoop {
|
||||
};
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
messages,
|
||||
messages: messages_for_llm,
|
||||
temperature: None,
|
||||
max_tokens: None,
|
||||
tools,
|
||||
@ -71,12 +69,18 @@ impl AgentLoop {
|
||||
AgentError::LlmError(e.to_string())
|
||||
})?;
|
||||
|
||||
tracing::debug!(response_len = response.content.len(), tool_calls_len = response.tool_calls.len(), "LLM response received");
|
||||
tracing::debug!(
|
||||
response_len = response.content.len(),
|
||||
tool_calls_len = response.tool_calls.len(),
|
||||
"LLM response received"
|
||||
);
|
||||
|
||||
if !response.tool_calls.is_empty() {
|
||||
tracing::info!(count = response.tool_calls.len(), "Tool calls detected, executing tools");
|
||||
|
||||
let mut updated_messages = messages.clone();
|
||||
let assistant_message = ChatMessage::assistant(response.content.clone());
|
||||
self.history.push(assistant_message.clone());
|
||||
updated_messages.push(assistant_message.clone());
|
||||
|
||||
let tool_results = self.execute_tools(&response.tool_calls).await;
|
||||
|
||||
@ -86,20 +90,18 @@ impl AgentLoop {
|
||||
tool_call.name.clone(),
|
||||
result.clone(),
|
||||
);
|
||||
self.history.push(tool_message);
|
||||
updated_messages.push(tool_message);
|
||||
}
|
||||
|
||||
return self.continue_with_tool_results(response.content).await;
|
||||
return self.continue_with_tool_results(updated_messages).await;
|
||||
}
|
||||
|
||||
let assistant_message = ChatMessage::assistant(response.content);
|
||||
self.history.push(assistant_message.clone());
|
||||
|
||||
Ok(assistant_message)
|
||||
}
|
||||
|
||||
async fn continue_with_tool_results(&mut self, _original_content: String) -> Result<ChatMessage, AgentError> {
|
||||
let messages: Vec<Message> = self.history
|
||||
async fn continue_with_tool_results(&self, messages: Vec<ChatMessage>) -> Result<ChatMessage, AgentError> {
|
||||
let messages_for_llm: Vec<Message> = messages
|
||||
.iter()
|
||||
.map(|m| Message {
|
||||
role: m.role.clone(),
|
||||
@ -116,7 +118,7 @@ impl AgentLoop {
|
||||
};
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
messages,
|
||||
messages: messages_for_llm,
|
||||
temperature: None,
|
||||
max_tokens: None,
|
||||
tools,
|
||||
@ -129,8 +131,6 @@ impl AgentLoop {
|
||||
})?;
|
||||
|
||||
let assistant_message = ChatMessage::assistant(response.content);
|
||||
self.history.push(assistant_message.clone());
|
||||
|
||||
Ok(assistant_message)
|
||||
}
|
||||
|
||||
@ -168,16 +168,6 @@ impl AgentLoop {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clear_history(&mut self) {
|
||||
let len = self.history.len();
|
||||
self.history.clear();
|
||||
tracing::debug!(previous_len = len, "Chat history cleared");
|
||||
}
|
||||
|
||||
pub fn history(&self) -> &[ChatMessage] {
|
||||
&self.history
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
||||
82
src/bus/dispatcher.rs
Normal file
82
src/bus/dispatcher.rs
Normal file
@ -0,0 +1,82 @@
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::bus::{MessageBus, OutboundMessage};
|
||||
use crate::channels::base::{Channel, ChannelError};
|
||||
|
||||
/// OutboundDispatcher consumes outbound messages from the MessageBus
|
||||
/// and dispatches them to the appropriate Channel
|
||||
pub struct OutboundDispatcher {
|
||||
bus: Arc<MessageBus>,
|
||||
channels: Arc<RwLock<HashMap<String, Arc<dyn Channel + Send + Sync>>>>,
|
||||
}
|
||||
|
||||
impl OutboundDispatcher {
|
||||
pub fn new(bus: Arc<MessageBus>) -> Self {
|
||||
Self {
|
||||
bus,
|
||||
channels: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a channel with the dispatcher
|
||||
pub async fn register_channel(&self, name: &str, channel: Arc<dyn Channel + Send + Sync>) {
|
||||
self.channels.write().await.insert(name.to_string(), channel);
|
||||
}
|
||||
|
||||
/// Run the dispatcher loop - consumes from bus and dispatches to channels
|
||||
pub async fn run(&self) {
|
||||
tracing::info!("OutboundDispatcher started");
|
||||
|
||||
loop {
|
||||
let msg = self.bus.consume_outbound().await;
|
||||
tracing::debug!(
|
||||
channel = %msg.channel,
|
||||
chat_id = %msg.chat_id,
|
||||
content_len = msg.content.len(),
|
||||
"OutboundDispatcher received message"
|
||||
);
|
||||
|
||||
let channel_name = msg.channel.clone();
|
||||
let channel = self.channels.read().await.get(&channel_name).cloned();
|
||||
|
||||
match channel {
|
||||
Some(ch) => {
|
||||
if let Err(e) = self.send_with_retry(&*ch, msg).await {
|
||||
tracing::error!(channel = %channel_name, error = %e, "Failed to send message after retries");
|
||||
}
|
||||
}
|
||||
None => {
|
||||
tracing::warn!(channel = %channel_name, "No channel found for message");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a message with exponential retry
|
||||
async fn send_with_retry(
|
||||
&self,
|
||||
channel: &dyn Channel,
|
||||
msg: OutboundMessage,
|
||||
) -> Result<(), ChannelError> {
|
||||
const DELAYS: [u64; 3] = [1, 2, 4];
|
||||
|
||||
for (i, delay) in DELAYS.iter().enumerate() {
|
||||
match channel.send(msg.clone()).await {
|
||||
Ok(()) => return Ok(()),
|
||||
Err(e) if i < DELAYS.len() - 1 => {
|
||||
tracing::warn!(
|
||||
attempt = i + 1,
|
||||
delay = delay,
|
||||
error = %e,
|
||||
"Send failed, retrying"
|
||||
);
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(*delay)).await;
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
@ -1,5 +1,10 @@
|
||||
use std::collections::HashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// ============================================================================
|
||||
// ChatMessage - Legacy type used by AgentLoop for LLM conversation history
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
pub id: String,
|
||||
@ -58,6 +63,51 @@ impl ChatMessage {
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// InboundMessage - Message from Channel to Bus (user input)
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct InboundMessage {
|
||||
pub channel: String,
|
||||
pub sender_id: String,
|
||||
pub chat_id: String,
|
||||
pub content: String,
|
||||
pub timestamp: i64,
|
||||
pub media: Vec<String>,
|
||||
pub metadata: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl InboundMessage {
|
||||
pub fn session_key(&self) -> String {
|
||||
format!("{}:{}", self.channel, self.chat_id)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// OutboundMessage - Message from Agent to Channel (bot response)
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OutboundMessage {
|
||||
pub channel: String,
|
||||
pub chat_id: String,
|
||||
pub content: String,
|
||||
pub reply_to: Option<String>,
|
||||
pub media: Vec<String>,
|
||||
pub metadata: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl OutboundMessage {
|
||||
pub fn is_stream_delta(&self) -> bool {
|
||||
self.metadata.get("_stream_delta").is_some()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Helpers
|
||||
// ============================================================================
|
||||
|
||||
fn current_timestamp() -> i64 {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
|
||||
@ -1,42 +1,86 @@
|
||||
pub mod dispatcher;
|
||||
pub mod message;
|
||||
|
||||
pub use message::ChatMessage;
|
||||
pub use dispatcher::OutboundDispatcher;
|
||||
pub use message::{ChatMessage, InboundMessage, OutboundMessage};
|
||||
|
||||
use tokio::sync::{mpsc, broadcast};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
|
||||
// ============================================================================
|
||||
// MessageBus - Async message queue for Channel <-> Agent communication
|
||||
// ============================================================================
|
||||
|
||||
pub struct MessageBus {
|
||||
user_tx: mpsc::Sender<ChatMessage>,
|
||||
llm_tx: broadcast::Sender<ChatMessage>,
|
||||
inbound_tx: mpsc::Sender<InboundMessage>,
|
||||
outbound_tx: mpsc::Sender<OutboundMessage>,
|
||||
inbound_rx: Mutex<mpsc::Receiver<InboundMessage>>,
|
||||
outbound_rx: Mutex<mpsc::Receiver<OutboundMessage>>,
|
||||
}
|
||||
|
||||
impl MessageBus {
|
||||
pub fn new(buffer_size: usize) -> Self {
|
||||
let (user_tx, _) = mpsc::channel(buffer_size);
|
||||
let (llm_tx, _) = broadcast::channel(buffer_size);
|
||||
|
||||
Self { user_tx, llm_tx }
|
||||
/// Create a new MessageBus with the given channel capacity
|
||||
pub fn new(capacity: usize) -> Arc<Self> {
|
||||
let (inbound_tx, inbound_rx) = mpsc::channel(capacity);
|
||||
let (outbound_tx, outbound_rx) = mpsc::channel(capacity);
|
||||
Arc::new(Self {
|
||||
inbound_tx,
|
||||
outbound_tx,
|
||||
inbound_rx: Mutex::new(inbound_rx),
|
||||
outbound_rx: Mutex::new(outbound_rx),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn send_user_input(&self, msg: ChatMessage) -> Result<(), BusError> {
|
||||
self.user_tx.send(msg).await.map_err(|_| BusError::ChannelClosed)
|
||||
/// Publish an inbound message (Channel -> Bus)
|
||||
pub async fn publish_inbound(&self, msg: InboundMessage) -> Result<(), BusError> {
|
||||
self.inbound_tx
|
||||
.send(msg)
|
||||
.await
|
||||
.map_err(|_| BusError::Closed)
|
||||
}
|
||||
|
||||
pub fn send_llm_output(&self, msg: ChatMessage) -> Result<usize, BusError> {
|
||||
self.llm_tx.send(msg).map_err(|_| BusError::ChannelClosed)
|
||||
/// Consume an inbound message (Agent -> Bus)
|
||||
pub async fn consume_inbound(&self) -> InboundMessage {
|
||||
self.inbound_rx
|
||||
.lock()
|
||||
.await
|
||||
.recv()
|
||||
.await
|
||||
.expect("bus inbound closed")
|
||||
}
|
||||
|
||||
/// Publish an outbound message (Agent -> Bus)
|
||||
pub async fn publish_outbound(&self, msg: OutboundMessage) -> Result<(), BusError> {
|
||||
self.outbound_tx
|
||||
.send(msg)
|
||||
.await
|
||||
.map_err(|_| BusError::Closed)
|
||||
}
|
||||
|
||||
/// Consume an outbound message (Dispatcher -> Bus)
|
||||
pub async fn consume_outbound(&self) -> OutboundMessage {
|
||||
self.outbound_rx
|
||||
.lock()
|
||||
.await
|
||||
.recv()
|
||||
.await
|
||||
.expect("bus outbound closed")
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// BusError
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum BusError {
|
||||
ChannelClosed,
|
||||
SendError(usize),
|
||||
Closed,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for BusError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
BusError::ChannelClosed => write!(f, "Channel closed"),
|
||||
BusError::SendError(n) => write!(f, "Send error, {} receivers", n),
|
||||
BusError::Closed => write!(f, "Bus channel closed"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,30 +1,14 @@
|
||||
use std::collections::HashMap;
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct InboundMessage {
|
||||
pub channel: String,
|
||||
pub sender_id: String,
|
||||
pub chat_id: String,
|
||||
pub content: String,
|
||||
pub media: Vec<String>,
|
||||
pub metadata: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OutboundMessage {
|
||||
pub channel: String,
|
||||
pub chat_id: String,
|
||||
pub content: String,
|
||||
pub media: Vec<String>,
|
||||
pub metadata: HashMap<String, String>,
|
||||
}
|
||||
use crate::bus::{BusError, InboundMessage, MessageBus, OutboundMessage};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ChannelError {
|
||||
ConfigError(String),
|
||||
ConnectionError(String),
|
||||
SendError(String),
|
||||
BusError(String),
|
||||
Other(String),
|
||||
}
|
||||
|
||||
@ -34,6 +18,7 @@ impl std::fmt::Display for ChannelError {
|
||||
ChannelError::ConfigError(s) => write!(f, "Config error: {}", s),
|
||||
ChannelError::ConnectionError(s) => write!(f, "Connection error: {}", s),
|
||||
ChannelError::SendError(s) => write!(f, "Send error: {}", s),
|
||||
ChannelError::BusError(s) => write!(f, "Bus error: {}", s),
|
||||
ChannelError::Other(s) => write!(f, "Error: {}", s),
|
||||
}
|
||||
}
|
||||
@ -41,10 +26,73 @@ impl std::fmt::Display for ChannelError {
|
||||
|
||||
impl std::error::Error for ChannelError {}
|
||||
|
||||
impl From<BusError> for ChannelError {
|
||||
fn from(e: BusError) -> Self {
|
||||
ChannelError::BusError(e.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Channel: Send + Sync + 'static {
|
||||
fn name(&self) -> &str;
|
||||
async fn start(&self) -> Result<(), ChannelError>;
|
||||
async fn stop(&self) -> Result<(), ChannelError>;
|
||||
fn is_running(&self) -> bool;
|
||||
|
||||
/// Start the channel with a reference to the MessageBus
|
||||
async fn start(&self, bus: Arc<MessageBus>) -> Result<(), ChannelError>;
|
||||
|
||||
/// Stop the channel
|
||||
async fn stop(&self) -> Result<(), ChannelError>;
|
||||
|
||||
/// Send a message to the channel (called by OutboundDispatcher)
|
||||
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError>;
|
||||
|
||||
/// Send a streaming delta (optional, for channels that support it)
|
||||
async fn send_delta(&self, chat_id: &str, delta: &str) -> Result<(), ChannelError> {
|
||||
let _ = chat_id;
|
||||
let _ = delta;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a sender is allowed to use this channel
|
||||
fn is_allowed(&self, _sender_id: &str) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// Handle an inbound message: check permissions and publish to bus
|
||||
async fn handle_and_publish(
|
||||
&self,
|
||||
bus: &Arc<MessageBus>,
|
||||
sender_id: &str,
|
||||
chat_id: &str,
|
||||
content: &str,
|
||||
) -> Result<(), ChannelError> {
|
||||
if !self.is_allowed(sender_id) {
|
||||
tracing::warn!(
|
||||
channel = %self.name(),
|
||||
sender = %sender_id,
|
||||
"Access denied"
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let msg = InboundMessage {
|
||||
channel: self.name().to_string(),
|
||||
sender_id: sender_id.to_string(),
|
||||
chat_id: chat_id.to_string(),
|
||||
content: content.to_string(),
|
||||
timestamp: current_timestamp(),
|
||||
media: vec![],
|
||||
metadata: std::collections::HashMap::new(),
|
||||
};
|
||||
|
||||
bus.publish_inbound(msg).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn current_timestamp() -> i64 {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_millis() as i64
|
||||
}
|
||||
|
||||
@ -5,8 +5,8 @@ use serde::Deserialize;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use prost::{Message as ProstMessage, bytes::Bytes};
|
||||
|
||||
use crate::bus::{MessageBus, OutboundMessage};
|
||||
use crate::channels::base::{Channel, ChannelError};
|
||||
use crate::channels::manager::MessageHandler;
|
||||
use crate::config::{FeishuChannelConfig, LLMProviderConfig};
|
||||
|
||||
const FEISHU_API_BASE: &str = "https://open.feishu.cn/open-apis";
|
||||
@ -131,8 +131,6 @@ pub struct FeishuChannel {
|
||||
running: Arc<RwLock<bool>>,
|
||||
shutdown_tx: Arc<RwLock<Option<broadcast::Sender<()>>>>,
|
||||
connected: Arc<RwLock<bool>>,
|
||||
/// Message handler for routing messages to Gateway
|
||||
message_handler: Arc<dyn MessageHandler>,
|
||||
}
|
||||
|
||||
/// Parsed message data from a Feishu frame
|
||||
@ -146,7 +144,6 @@ struct ParsedMessage {
|
||||
impl FeishuChannel {
|
||||
pub fn new(
|
||||
config: FeishuChannelConfig,
|
||||
message_handler: Arc<dyn MessageHandler>,
|
||||
_provider_config: LLMProviderConfig,
|
||||
) -> Result<Self, ChannelError> {
|
||||
Ok(Self {
|
||||
@ -155,7 +152,6 @@ impl FeishuChannel {
|
||||
running: Arc::new(RwLock::new(false)),
|
||||
shutdown_tx: Arc::new(RwLock::new(None)),
|
||||
connected: Arc::new(RwLock::new(false)),
|
||||
message_handler,
|
||||
})
|
||||
}
|
||||
|
||||
@ -224,11 +220,10 @@ impl FeishuChannel {
|
||||
.ok_or_else(|| ChannelError::Other("No token in response".to_string()))
|
||||
}
|
||||
|
||||
/// Send a text message to Feishu chat
|
||||
async fn send_message(&self, receive_id: &str, receive_id_type: &str, content: &str) -> Result<(), ChannelError> {
|
||||
/// Send a text message to Feishu chat (implements Channel trait)
|
||||
async fn send_message_to_feishu(&self, receive_id: &str, receive_id_type: &str, content: &str) -> Result<(), ChannelError> {
|
||||
let token = self.get_tenant_token().await?;
|
||||
|
||||
// For text message, content should be a JSON string: "{\"text\":\"hello\"}"
|
||||
let text_content = serde_json::json!({ "text": content }).to_string();
|
||||
|
||||
let resp = self.http_client
|
||||
@ -262,26 +257,6 @@ impl FeishuChannel {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle incoming message - delegate to message handler and send response
|
||||
async fn handle_message(&self, open_id: &str, chat_id: &str, content: &str) -> Result<(), ChannelError> {
|
||||
tracing::info!(open_id, chat_id, "Processing message from Feishu");
|
||||
|
||||
// Delegate to message handler (Gateway)
|
||||
let response = self.message_handler
|
||||
.handle_message("feishu", open_id, chat_id, content)
|
||||
.await?;
|
||||
|
||||
// Send response to the chat
|
||||
// Use open_id for p2p chats, chat_id for group chats
|
||||
let receive_id = if chat_id.starts_with("oc_") { chat_id } else { open_id };
|
||||
let receive_id_type = if chat_id.starts_with("oc_") { "chat_id" } else { "open_id" };
|
||||
|
||||
self.send_message(receive_id, receive_id_type, &response).await?;
|
||||
tracing::info!(receive_id, "Sent response to Feishu");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Extract service_id from WebSocket URL query params
|
||||
fn extract_service_id(url: &str) -> i32 {
|
||||
url.split('?')
|
||||
@ -310,7 +285,6 @@ impl FeishuChannel {
|
||||
let payload = frame.payload.as_deref()
|
||||
.ok_or_else(|| ChannelError::Other("No payload in frame".to_string()))?;
|
||||
|
||||
// Parse the event JSON to get event_type from payload header
|
||||
let event: LarkEvent = serde_json::from_slice(payload)
|
||||
.map_err(|e| ChannelError::Other(format!("Parse event error: {}", e)))?;
|
||||
|
||||
@ -359,7 +333,7 @@ impl FeishuChannel {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_ws_loop(&self, mut shutdown_rx: broadcast::Receiver<()>) -> Result<(), ChannelError> {
|
||||
async fn run_ws_loop(&self, bus: Arc<MessageBus>, mut shutdown_rx: broadcast::Receiver<()>) -> Result<(), ChannelError> {
|
||||
let (wss_url, client_config) = self.get_ws_endpoint(&self.http_client).await?;
|
||||
|
||||
let service_id = Self::extract_service_id(&wss_url);
|
||||
@ -404,7 +378,6 @@ impl FeishuChannel {
|
||||
Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(data))) => {
|
||||
let bytes: Bytes = data;
|
||||
if let Ok(frame) = PbFrame::decode(bytes.as_ref()) {
|
||||
// Parse the frame first
|
||||
match self.handle_frame(&frame).await {
|
||||
Ok(Some(parsed)) => {
|
||||
// Send ACK immediately (Feishu requires within 3 s)
|
||||
@ -412,11 +385,12 @@ impl FeishuChannel {
|
||||
tracing::error!(error = %e, "Failed to send ACK to Feishu");
|
||||
}
|
||||
|
||||
// Then process message asynchronously (don't await)
|
||||
// Publish to bus asynchronously
|
||||
let channel = self.clone();
|
||||
let bus = bus.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = channel.handle_message(&parsed.open_id, &parsed.chat_id, &parsed.content).await {
|
||||
tracing::error!(error = %e, open_id = %parsed.open_id, chat_id = %parsed.chat_id, "Failed to handle Feishu message");
|
||||
if let Err(e) = channel.handle_and_publish(&bus, &parsed.open_id, &parsed.chat_id, &parsed.content).await {
|
||||
tracing::error!(error = %e, open_id = %parsed.open_id, chat_id = %parsed.chat_id, "Failed to publish Feishu message to bus");
|
||||
}
|
||||
});
|
||||
}
|
||||
@ -528,7 +502,7 @@ impl Channel for FeishuChannel {
|
||||
"feishu"
|
||||
}
|
||||
|
||||
async fn start(&self) -> Result<(), ChannelError> {
|
||||
async fn start(&self, bus: Arc<MessageBus>) -> Result<(), ChannelError> {
|
||||
if self.config.app_id.is_empty() || self.config.app_secret.is_empty() {
|
||||
return Err(ChannelError::ConfigError(
|
||||
"Feishu app_id or app_secret is not configured".to_string()
|
||||
@ -541,6 +515,7 @@ impl Channel for FeishuChannel {
|
||||
*self.shutdown_tx.write().await = Some(shutdown_tx.clone());
|
||||
|
||||
let channel = self.clone();
|
||||
let bus = bus.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut consecutive_failures = 0;
|
||||
let max_failures = 3;
|
||||
@ -551,7 +526,7 @@ impl Channel for FeishuChannel {
|
||||
}
|
||||
|
||||
let shutdown_rx = shutdown_tx.subscribe();
|
||||
match channel.run_ws_loop(shutdown_rx).await {
|
||||
match channel.run_ws_loop(bus.clone(), shutdown_rx).await {
|
||||
Ok(_) => {
|
||||
tracing::info!("Feishu WebSocket disconnected");
|
||||
}
|
||||
@ -592,7 +567,13 @@ impl Channel for FeishuChannel {
|
||||
}
|
||||
|
||||
fn is_running(&self) -> bool {
|
||||
// Note: blocking read, acceptable for this use case
|
||||
self.running.try_read().map(|r| *r).unwrap_or(false)
|
||||
}
|
||||
|
||||
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
|
||||
let receive_id = if msg.chat_id.starts_with("oc_") { &msg.chat_id } else { &msg.reply_to.as_ref().unwrap_or(&msg.chat_id) };
|
||||
let receive_id_type = if msg.chat_id.starts_with("oc_") { "chat_id" } else { "open_id" };
|
||||
|
||||
self.send_message_to_feishu(receive_id, receive_id_type, &msg.content).await
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,51 +1,38 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use async_trait::async_trait;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::bus::{MessageBus, OutboundMessage};
|
||||
use crate::channels::base::{Channel, ChannelError};
|
||||
use crate::channels::feishu::FeishuChannel;
|
||||
use crate::config::{Config, FeishuChannelConfig};
|
||||
use crate::config::Config;
|
||||
|
||||
/// MessageHandler trait - Channel 通过这个 trait 与业务逻辑解耦
|
||||
#[async_trait]
|
||||
pub trait MessageHandler: Send + Sync {
|
||||
async fn handle_message(
|
||||
&self,
|
||||
channel_name: &str,
|
||||
sender_id: &str,
|
||||
chat_id: &str,
|
||||
content: &str,
|
||||
) -> Result<String, ChannelError>;
|
||||
}
|
||||
|
||||
/// ChannelManager 管理所有 Channel
|
||||
/// ChannelManager manages all Channel instances and the MessageBus
|
||||
#[derive(Clone)]
|
||||
pub struct ChannelManager {
|
||||
channels: Arc<RwLock<HashMap<String, Arc<dyn Channel + Send + Sync>>>>,
|
||||
message_handler: Arc<dyn MessageHandler>,
|
||||
bus: Arc<MessageBus>,
|
||||
}
|
||||
|
||||
impl ChannelManager {
|
||||
pub fn new(message_handler: Arc<dyn MessageHandler>) -> Self {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
channels: Arc::new(RwLock::new(HashMap::new())),
|
||||
message_handler,
|
||||
bus: MessageBus::new(100),
|
||||
}
|
||||
}
|
||||
|
||||
/// 获取 MessageHandler 用于让 Channel 调用
|
||||
pub fn get_handler(&self) -> Arc<dyn MessageHandler> {
|
||||
self.message_handler.clone()
|
||||
/// Get a reference to the MessageBus
|
||||
pub fn bus(&self) -> Arc<MessageBus> {
|
||||
self.bus.clone()
|
||||
}
|
||||
|
||||
/// 初始化所有 Channel
|
||||
pub async fn init(&self, config: &Config, provider_config: crate::config::LLMProviderConfig) -> Result<(), ChannelError> {
|
||||
/// Initialize all Channel instances from config
|
||||
pub async fn init(&self, config: &Config, _provider_config: crate::config::LLMProviderConfig) -> Result<(), ChannelError> {
|
||||
// Initialize Feishu channel if enabled
|
||||
if let Some(feishu_config) = config.channels.get("feishu") {
|
||||
if feishu_config.enabled {
|
||||
let handler = self.get_handler();
|
||||
let channel = FeishuChannel::new(feishu_config.clone(), handler, provider_config)
|
||||
let channel = FeishuChannel::new(feishu_config.clone(), _provider_config)
|
||||
.map_err(|e| ChannelError::Other(format!("Failed to create Feishu channel: {}", e)))?;
|
||||
|
||||
self.channels
|
||||
@ -62,9 +49,10 @@ impl ChannelManager {
|
||||
|
||||
pub async fn start_all(&self) -> Result<(), ChannelError> {
|
||||
let channels = self.channels.read().await;
|
||||
let bus = self.bus.clone();
|
||||
for (name, channel) in channels.iter() {
|
||||
tracing::info!(channel = %name, "Starting channel");
|
||||
if let Err(e) = channel.start().await {
|
||||
if let Err(e) = channel.start(bus.clone()).await {
|
||||
tracing::error!(channel = %name, error = %e, "Failed to start channel");
|
||||
}
|
||||
}
|
||||
@ -86,32 +74,14 @@ impl ChannelManager {
|
||||
pub async fn get_channel(&self, name: &str) -> Option<Arc<dyn Channel + Send + Sync>> {
|
||||
self.channels.read().await.get(name).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
/// Gateway 实现 MessageHandler trait
|
||||
#[derive(Clone)]
|
||||
pub struct GatewayMessageHandler {
|
||||
session_manager: crate::gateway::session::SessionManager,
|
||||
}
|
||||
|
||||
impl GatewayMessageHandler {
|
||||
pub fn new(session_manager: crate::gateway::session::SessionManager) -> Self {
|
||||
Self { session_manager }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl MessageHandler for GatewayMessageHandler {
|
||||
async fn handle_message(
|
||||
&self,
|
||||
channel_name: &str,
|
||||
sender_id: &str,
|
||||
chat_id: &str,
|
||||
content: &str,
|
||||
) -> Result<String, ChannelError> {
|
||||
self.session_manager
|
||||
.handle_message(channel_name, sender_id, chat_id, content)
|
||||
.await
|
||||
.map_err(|e| ChannelError::Other(e.to_string()))
|
||||
/// Dispatch an outbound message to the appropriate channel
|
||||
pub async fn dispatch(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
|
||||
let channel_name = &msg.channel;
|
||||
if let Some(channel) = self.get_channel(channel_name).await {
|
||||
channel.send(msg).await
|
||||
} else {
|
||||
Err(ChannelError::Other(format!("Channel not found: {}", channel_name)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -2,6 +2,6 @@ pub mod base;
|
||||
pub mod feishu;
|
||||
pub mod manager;
|
||||
|
||||
pub use base::{Channel, ChannelError, InboundMessage, OutboundMessage};
|
||||
pub use base::{Channel, ChannelError};
|
||||
pub use manager::ChannelManager;
|
||||
pub use feishu::FeishuChannel;
|
||||
|
||||
@ -6,7 +6,8 @@ use std::sync::Arc;
|
||||
use axum::{routing, Router};
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
use crate::channels::{ChannelManager, manager::GatewayMessageHandler};
|
||||
use crate::bus::{MessageBus, OutboundDispatcher};
|
||||
use crate::channels::ChannelManager;
|
||||
use crate::config::Config;
|
||||
use crate::logging;
|
||||
use session::SessionManager;
|
||||
@ -15,6 +16,7 @@ pub struct GatewayState {
|
||||
pub config: Config,
|
||||
pub session_manager: SessionManager,
|
||||
pub channel_manager: ChannelManager,
|
||||
pub bus: Arc<MessageBus>,
|
||||
}
|
||||
|
||||
impl GatewayState {
|
||||
@ -28,15 +30,76 @@ impl GatewayState {
|
||||
let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4);
|
||||
|
||||
let session_manager = SessionManager::new(session_ttl_hours, provider_config);
|
||||
let message_handler = Arc::new(GatewayMessageHandler::new(session_manager.clone()));
|
||||
let channel_manager = ChannelManager::new(message_handler);
|
||||
let channel_manager = ChannelManager::new();
|
||||
let bus = channel_manager.bus();
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
session_manager,
|
||||
channel_manager,
|
||||
bus,
|
||||
})
|
||||
}
|
||||
|
||||
/// Start the message processing loops
|
||||
pub async fn start_message_processing(&self) {
|
||||
let bus_for_inbound = self.bus.clone();
|
||||
let bus_for_outbound = self.bus.clone();
|
||||
let session_manager = self.session_manager.clone();
|
||||
|
||||
// Spawn inbound message processor
|
||||
// This consumes from bus.inbound, processes via SessionManager, publishes to bus.outbound
|
||||
tokio::spawn(async move {
|
||||
tracing::info!("Inbound processor started");
|
||||
loop {
|
||||
let inbound = bus_for_inbound.consume_inbound().await;
|
||||
tracing::debug!(
|
||||
channel = %inbound.channel,
|
||||
chat_id = %inbound.chat_id,
|
||||
"Processing inbound message"
|
||||
);
|
||||
|
||||
// Process via session manager
|
||||
match session_manager.handle_message(
|
||||
&inbound.channel,
|
||||
&inbound.sender_id,
|
||||
&inbound.chat_id,
|
||||
&inbound.content,
|
||||
).await {
|
||||
Ok(response_content) => {
|
||||
let outbound = crate::bus::OutboundMessage {
|
||||
channel: inbound.channel,
|
||||
chat_id: inbound.chat_id,
|
||||
content: response_content,
|
||||
reply_to: None,
|
||||
media: vec![],
|
||||
metadata: std::collections::HashMap::new(),
|
||||
};
|
||||
if let Err(e) = bus_for_inbound.publish_outbound(outbound).await {
|
||||
tracing::error!(error = %e, "Failed to publish outbound");
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to handle message");
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Spawn outbound dispatcher
|
||||
let dispatcher = OutboundDispatcher::new(bus_for_outbound);
|
||||
let channel_manager = self.channel_manager.clone();
|
||||
|
||||
// Register channels with dispatcher
|
||||
if let Some(channel) = channel_manager.get_channel("feishu").await {
|
||||
dispatcher.register_channel("feishu", channel).await;
|
||||
}
|
||||
|
||||
tokio::spawn(async move {
|
||||
tracing::info!("Outbound dispatcher started");
|
||||
dispatcher.run().await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn std::error::Error>> {
|
||||
@ -50,9 +113,12 @@ pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn
|
||||
let provider_config = state.config.get_provider_config("default")?;
|
||||
|
||||
// Initialize and start channels
|
||||
state.channel_manager.init(&state.config, provider_config).await?;
|
||||
state.channel_manager.init(&state.config, provider_config.clone()).await?;
|
||||
state.channel_manager.start_all().await?;
|
||||
|
||||
// Start message processing (inbound processor + outbound dispatcher)
|
||||
state.start_message_processing().await;
|
||||
|
||||
// CLI args override config file values
|
||||
let bind_host = host.unwrap_or_else(|| state.config.gateway.host.clone());
|
||||
let bind_port = port.unwrap_or(state.config.gateway.port);
|
||||
|
||||
@ -10,11 +10,12 @@ use crate::protocol::WsOutbound;
|
||||
use crate::tools::{CalculatorTool, ToolRegistry};
|
||||
|
||||
/// Session 按 channel 隔离,每个 channel 一个 Session
|
||||
/// History 按 chat_id 隔离,由 Session 统一管理
|
||||
pub struct Session {
|
||||
pub id: Uuid,
|
||||
pub channel_name: String,
|
||||
/// 按 chat_id 路由到不同 AgentLoop,支持多用户多会话
|
||||
chat_agents: HashMap<String, Arc<Mutex<AgentLoop>>>,
|
||||
/// 按 chat_id 路由到不同会话历史,支持多用户多会话
|
||||
chat_histories: HashMap<String, Vec<ChatMessage>>,
|
||||
pub user_tx: mpsc::Sender<WsOutbound>,
|
||||
provider_config: LLMProviderConfig,
|
||||
tools: Arc<ToolRegistry>,
|
||||
@ -30,52 +31,65 @@ impl Session {
|
||||
Ok(Self {
|
||||
id: Uuid::new_v4(),
|
||||
channel_name,
|
||||
chat_agents: HashMap::new(),
|
||||
chat_histories: HashMap::new(),
|
||||
user_tx,
|
||||
provider_config,
|
||||
tools,
|
||||
})
|
||||
}
|
||||
|
||||
/// 获取或创建指定 chat_id 的 AgentLoop
|
||||
pub async fn get_or_create_agent(&mut self, chat_id: &str) -> Result<Arc<Mutex<AgentLoop>>, AgentError> {
|
||||
if let Some(agent) = self.chat_agents.get(chat_id) {
|
||||
tracing::trace!(chat_id = %chat_id, "Reusing existing agent");
|
||||
return Ok(agent.clone());
|
||||
}
|
||||
tracing::debug!(chat_id = %chat_id, "Creating new agent for chat");
|
||||
let agent = AgentLoop::with_tools(self.provider_config.clone(), self.tools.clone())?;
|
||||
let arc = Arc::new(Mutex::new(agent));
|
||||
self.chat_agents.insert(chat_id.to_string(), arc.clone());
|
||||
Ok(arc)
|
||||
/// 获取或创建指定 chat_id 的会话历史
|
||||
pub fn get_or_create_history(&mut self, chat_id: &str) -> &mut Vec<ChatMessage> {
|
||||
self.chat_histories
|
||||
.entry(chat_id.to_string())
|
||||
.or_insert_with(Vec::new)
|
||||
}
|
||||
|
||||
/// 获取指定 chat_id 的 AgentLoop(不创建)
|
||||
pub fn get_agent(&self, chat_id: &str) -> Option<Arc<Mutex<AgentLoop>>> {
|
||||
self.chat_agents.get(chat_id).cloned()
|
||||
/// 获取指定 chat_id 的会话历史(不创建)
|
||||
pub fn get_history(&self, chat_id: &str) -> Option<&Vec<ChatMessage>> {
|
||||
self.chat_histories.get(chat_id)
|
||||
}
|
||||
|
||||
/// 添加用户消息到指定 chat_id 的历史
|
||||
pub fn add_user_message(&mut self, chat_id: &str, content: &str) {
|
||||
let history = self.get_or_create_history(chat_id);
|
||||
history.push(ChatMessage::user(content));
|
||||
}
|
||||
|
||||
/// 添加助手响应到指定 chat_id 的历史
|
||||
pub fn add_assistant_message(&mut self, chat_id: &str, message: ChatMessage) {
|
||||
if let Some(history) = self.chat_histories.get_mut(chat_id) {
|
||||
history.push(message);
|
||||
}
|
||||
}
|
||||
|
||||
/// 清除指定 chat_id 的历史
|
||||
pub async fn clear_chat_history(&mut self, chat_id: &str) {
|
||||
if let Some(agent) = self.chat_agents.get(chat_id) {
|
||||
agent.lock().await.clear_history();
|
||||
pub fn clear_chat_history(&mut self, chat_id: &str) {
|
||||
if let Some(history) = self.chat_histories.get_mut(chat_id) {
|
||||
let len = history.len();
|
||||
history.clear();
|
||||
tracing::debug!(chat_id = %chat_id, previous_len = len, "Chat history cleared");
|
||||
}
|
||||
}
|
||||
|
||||
/// 清除所有历史
|
||||
pub async fn clear_all_history(&mut self) {
|
||||
for agent in self.chat_agents.values() {
|
||||
agent.lock().await.clear_history();
|
||||
}
|
||||
pub fn clear_all_history(&mut self) {
|
||||
let total: usize = self.chat_histories.values().map(|h| h.len()).sum();
|
||||
self.chat_histories.clear();
|
||||
tracing::debug!(previous_total = total, "All chat histories cleared");
|
||||
}
|
||||
|
||||
pub async fn send(&self, msg: WsOutbound) {
|
||||
let _ = self.user_tx.send(msg).await;
|
||||
}
|
||||
|
||||
/// 创建一个临时的 AgentLoop 实例来处理消息
|
||||
pub fn create_agent(&self) -> Result<AgentLoop, AgentError> {
|
||||
AgentLoop::with_tools(self.provider_config.clone(), self.tools.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// SessionManager 管理所有 Session,按 channel_name 路由
|
||||
/// 使用 Arc<Mutex<SessionManager>> 以从 Arc 获取可变访问
|
||||
#[derive(Clone)]
|
||||
pub struct SessionManager {
|
||||
inner: Arc<Mutex<SessionManagerInner>>,
|
||||
@ -135,7 +149,13 @@ impl SessionManager {
|
||||
|
||||
// 创建新 session(使用临时 user_tx,因为 Feishu 不通过 WS)
|
||||
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
|
||||
let session = Session::new(channel_name.to_string(), self.provider_config.clone(), user_tx, self.tools.clone()).await?;
|
||||
let session = Session::new(
|
||||
channel_name.to_string(),
|
||||
self.provider_config.clone(),
|
||||
user_tx,
|
||||
self.tools.clone(),
|
||||
)
|
||||
.await?;
|
||||
let arc = Arc::new(Mutex::new(session));
|
||||
|
||||
inner.sessions.insert(channel_name.to_string(), arc.clone());
|
||||
@ -165,7 +185,12 @@ impl SessionManager {
|
||||
chat_id: &str,
|
||||
content: &str,
|
||||
) -> Result<String, AgentError> {
|
||||
tracing::debug!(channel = %channel_name, chat_id = %chat_id, content_len = content.len(), "Routing message to agent");
|
||||
tracing::debug!(
|
||||
channel = %channel_name,
|
||||
chat_id = %chat_id,
|
||||
content_len = content.len(),
|
||||
"Routing message to agent"
|
||||
);
|
||||
|
||||
// 确保 session 存在(可能需要重建)
|
||||
self.ensure_session(channel_name).await?;
|
||||
@ -174,21 +199,37 @@ impl SessionManager {
|
||||
self.touch(channel_name).await;
|
||||
|
||||
// 获取 session
|
||||
let session = self.get(channel_name).await
|
||||
let session = self
|
||||
.get(channel_name)
|
||||
.await
|
||||
.ok_or_else(|| AgentError::Other("Session not found".to_string()))?;
|
||||
|
||||
// 获取或创建 chat_id 对应的 agent
|
||||
let mut session_guard = session.lock().await;
|
||||
let agent = session_guard.get_or_create_agent(chat_id).await?;
|
||||
drop(session_guard);
|
||||
|
||||
let mut agent = agent.lock().await;
|
||||
|
||||
// 处理消息
|
||||
let user_msg = ChatMessage::user(content);
|
||||
let response = agent.process(user_msg).await?;
|
||||
let response = {
|
||||
let mut session_guard = session.lock().await;
|
||||
|
||||
tracing::debug!(channel = %channel_name, chat_id = %chat_id, response_len = response.content.len(), "Agent response received");
|
||||
// 添加用户消息到历史
|
||||
session_guard.add_user_message(chat_id, content);
|
||||
|
||||
// 获取完整历史
|
||||
let history = session_guard.get_or_create_history(chat_id).clone();
|
||||
|
||||
// 创建 agent 并处理
|
||||
let agent = session_guard.create_agent()?;
|
||||
let response = agent.process(history).await?;
|
||||
|
||||
// 添加助手响应到历史
|
||||
session_guard.add_assistant_message(chat_id, response.clone());
|
||||
|
||||
response
|
||||
};
|
||||
|
||||
tracing::debug!(
|
||||
channel = %channel_name,
|
||||
chat_id = %chat_id,
|
||||
response_len = response.content.len(),
|
||||
"Agent response received"
|
||||
);
|
||||
|
||||
Ok(response.content)
|
||||
}
|
||||
@ -197,7 +238,7 @@ impl SessionManager {
|
||||
pub async fn clear_session_history(&self, channel_name: &str) -> Result<(), AgentError> {
|
||||
if let Some(session) = self.get(channel_name).await {
|
||||
let mut session_guard = session.lock().await;
|
||||
session_guard.clear_all_history().await;
|
||||
session_guard.clear_all_history();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -4,7 +4,6 @@ use axum::extract::State;
|
||||
use axum::response::Response;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::protocol::{parse_inbound, serialize_outbound, WsInbound, WsOutbound};
|
||||
use super::{GatewayState, session::Session};
|
||||
|
||||
@ -29,7 +28,14 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
||||
let channel_name = format!("cli-{}", uuid::Uuid::new_v4());
|
||||
|
||||
// 创建 CLI session
|
||||
let session = match Session::new(channel_name.clone(), provider_config, sender, state.session_manager.tools()).await {
|
||||
let session = match Session::new(
|
||||
channel_name.clone(),
|
||||
provider_config,
|
||||
sender,
|
||||
state.session_manager.tools(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(s) => Arc::new(Mutex::new(s)),
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to create session");
|
||||
@ -40,9 +46,13 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
||||
let session_id = session.lock().await.id;
|
||||
tracing::info!(session_id = %session_id, "CLI session established");
|
||||
|
||||
let _ = session.lock().await.send(WsOutbound::SessionEstablished {
|
||||
session_id: session_id.to_string(),
|
||||
}).await;
|
||||
let _ = session
|
||||
.lock()
|
||||
.await
|
||||
.send(WsOutbound::SessionEstablished {
|
||||
session_id: session_id.to_string(),
|
||||
})
|
||||
.await;
|
||||
|
||||
let (mut ws_sender, mut ws_receiver) = ws.split();
|
||||
|
||||
@ -69,10 +79,14 @@ async fn handle_socket(ws: WebSocket, state: Arc<GatewayState>) {
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "Failed to parse inbound message");
|
||||
let _ = session.lock().await.send(WsOutbound::Error {
|
||||
code: "PARSE_ERROR".to_string(),
|
||||
message: e.to_string(),
|
||||
}).await;
|
||||
let _ = session
|
||||
.lock()
|
||||
.await
|
||||
.send(WsOutbound::Error {
|
||||
code: "PARSE_ERROR".to_string(),
|
||||
message: e.to_string(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -92,7 +106,12 @@ async fn handle_inbound(session: &Arc<Mutex<Session>>, inbound: WsInbound) {
|
||||
|
||||
// 提取 content 和 chat_id(CLI 使用 session id 作为 chat_id)
|
||||
let (content, chat_id) = match inbound_clone {
|
||||
WsInbound::UserInput { content, channel: _, chat_id, sender_id: _ } => {
|
||||
WsInbound::UserInput {
|
||||
content,
|
||||
channel: _,
|
||||
chat_id,
|
||||
sender_id: _,
|
||||
} => {
|
||||
// CLI 使用 session 中的 channel_name 作为标识
|
||||
// chat_id 使用传入的或使用默认
|
||||
let chat_id = chat_id.unwrap_or_else(|| "default".to_string());
|
||||
@ -101,38 +120,50 @@ async fn handle_inbound(session: &Arc<Mutex<Session>>, inbound: WsInbound) {
|
||||
_ => return,
|
||||
};
|
||||
|
||||
let user_msg = ChatMessage::user(content);
|
||||
|
||||
let mut session_guard = session.lock().await;
|
||||
let agent = match session_guard.get_or_create_agent(&chat_id).await {
|
||||
|
||||
// 添加用户消息到历史
|
||||
session_guard.add_user_message(&chat_id, &content);
|
||||
|
||||
// 获取完整历史
|
||||
let history = session_guard.get_or_create_history(&chat_id).clone();
|
||||
|
||||
// 创建 agent 并处理
|
||||
let agent = match session_guard.create_agent() {
|
||||
Ok(a) => a,
|
||||
Err(e) => {
|
||||
tracing::error!(chat_id = %chat_id, error = %e, "Failed to get or create agent");
|
||||
let _ = session_guard.send(WsOutbound::Error {
|
||||
code: "AGENT_ERROR".to_string(),
|
||||
message: e.to_string(),
|
||||
}).await;
|
||||
tracing::error!(chat_id = %chat_id, error = %e, "Failed to create agent");
|
||||
let _ = session_guard
|
||||
.send(WsOutbound::Error {
|
||||
code: "AGENT_ERROR".to_string(),
|
||||
message: e.to_string(),
|
||||
})
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
drop(session_guard);
|
||||
|
||||
let mut agent = agent.lock().await;
|
||||
match agent.process(user_msg).await {
|
||||
match agent.process(history).await {
|
||||
Ok(response) => {
|
||||
tracing::debug!(chat_id = %chat_id, "Agent response sent");
|
||||
let _ = session.lock().await.send(WsOutbound::AssistantResponse {
|
||||
id: response.id,
|
||||
content: response.content,
|
||||
role: response.role,
|
||||
}).await;
|
||||
// 添加助手响应到历史
|
||||
session_guard.add_assistant_message(&chat_id, response.clone());
|
||||
let _ = session_guard
|
||||
.send(WsOutbound::AssistantResponse {
|
||||
id: response.id,
|
||||
content: response.content,
|
||||
role: response.role,
|
||||
})
|
||||
.await;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(chat_id = %chat_id, error = %e, "Agent process error");
|
||||
let _ = session.lock().await.send(WsOutbound::Error {
|
||||
code: "LLM_ERROR".to_string(),
|
||||
message: e.to_string(),
|
||||
}).await;
|
||||
let _ = session_guard
|
||||
.send(WsOutbound::Error {
|
||||
code: "LLM_ERROR".to_string(),
|
||||
message: e.to_string(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user