feat: 添加实时工具调用消息处理,优化消息格式化和传递逻辑
This commit is contained in:
parent
bc24a28275
commit
4725b5406e
@ -1,3 +1,4 @@
|
||||
use async_trait::async_trait;
|
||||
use crate::bus::message::ContentBlock;
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::config::LLMProviderConfig;
|
||||
@ -227,6 +228,7 @@ pub struct AgentLoop {
|
||||
skill_event_session_id: Option<String>,
|
||||
tool_context: ToolContext,
|
||||
observer: Option<Arc<dyn Observer>>,
|
||||
emitted_message_handler: Option<Arc<dyn EmittedMessageHandler>>,
|
||||
max_iterations: usize,
|
||||
}
|
||||
|
||||
@ -236,6 +238,11 @@ pub struct AgentProcessResult {
|
||||
pub emitted_messages: Vec<ChatMessage>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait EmittedMessageHandler: Send + Sync + 'static {
|
||||
async fn handle(&self, message: ChatMessage);
|
||||
}
|
||||
|
||||
impl AgentLoop {
|
||||
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
|
||||
let max_iterations = provider_config.max_tool_iterations;
|
||||
@ -250,6 +257,7 @@ impl AgentLoop {
|
||||
skill_event_session_id: None,
|
||||
tool_context: ToolContext::default(),
|
||||
observer: None,
|
||||
emitted_message_handler: None,
|
||||
max_iterations,
|
||||
})
|
||||
}
|
||||
@ -267,6 +275,7 @@ impl AgentLoop {
|
||||
skill_event_session_id: None,
|
||||
tool_context: ToolContext::default(),
|
||||
observer: None,
|
||||
emitted_message_handler: None,
|
||||
max_iterations,
|
||||
})
|
||||
}
|
||||
@ -288,6 +297,7 @@ impl AgentLoop {
|
||||
skill_event_session_id: None,
|
||||
tool_context: ToolContext::default(),
|
||||
observer: None,
|
||||
emitted_message_handler: None,
|
||||
max_iterations,
|
||||
})
|
||||
}
|
||||
@ -309,6 +319,11 @@ impl AgentLoop {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_emitted_message_handler(mut self, handler: Arc<dyn EmittedMessageHandler>) -> Self {
|
||||
self.emitted_message_handler = Some(handler);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn tools(&self) -> &Arc<ToolRegistry> {
|
||||
&self.tools
|
||||
}
|
||||
@ -388,6 +403,7 @@ impl AgentLoop {
|
||||
);
|
||||
messages.push(assistant_message.clone());
|
||||
emitted_messages.push(assistant_message);
|
||||
self.emit_live_tool_call_message(emitted_messages.last().expect("assistant message just pushed").clone()).await;
|
||||
|
||||
// Execute tools and add results to messages
|
||||
let tool_results = self.execute_tools(&response.tool_calls).await;
|
||||
@ -487,6 +503,16 @@ impl AgentLoop {
|
||||
}
|
||||
}
|
||||
|
||||
async fn emit_live_tool_call_message(&self, message: ChatMessage) {
|
||||
if !message.is_assistant_tool_call_message() {
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(handler) = &self.emitted_message_handler {
|
||||
handler.handle(message).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Determine whether to execute tools in parallel or sequentially.
|
||||
///
|
||||
/// Returns true if:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
pub mod agent_loop;
|
||||
pub mod context_compressor;
|
||||
|
||||
pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult};
|
||||
pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult, EmittedMessageHandler};
|
||||
pub use context_compressor::ContextCompressor;
|
||||
|
||||
@ -153,6 +153,15 @@ impl ChatMessage {
|
||||
tool_calls: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_assistant_tool_call_message(&self) -> bool {
|
||||
self.role == "assistant"
|
||||
&& self
|
||||
.tool_calls
|
||||
.as_ref()
|
||||
.map(|calls| !calls.is_empty())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
@ -325,22 +334,33 @@ impl OutboundMessage {
|
||||
}
|
||||
}
|
||||
|
||||
fn format_tool_call_content(tool_name: &str, tool_arguments: &serde_json::Value) -> String {
|
||||
format!(
|
||||
"调用工具: {}\n\n输入参数:\n{}",
|
||||
tool_name,
|
||||
format_json_value(tool_arguments),
|
||||
)
|
||||
pub(crate) fn format_tool_call_content(tool_name: &str, tool_arguments: &serde_json::Value) -> String {
|
||||
let mut lines = vec![format!("### {}", tool_name)];
|
||||
|
||||
match tool_arguments {
|
||||
serde_json::Value::Object(map) if !map.is_empty() => {
|
||||
let mut entries: Vec<_> = map.iter().collect();
|
||||
entries.sort_by(|(left, _), (right, _)| left.cmp(right));
|
||||
for (key, value) in entries {
|
||||
lines.push(format!("- {}: {}", key, format_tool_argument_value(value)));
|
||||
}
|
||||
}
|
||||
serde_json::Value::Object(_) => {}
|
||||
other => lines.push(format!("- args: {}", format_tool_argument_value(other))),
|
||||
}
|
||||
|
||||
lines.join("\n")
|
||||
}
|
||||
|
||||
fn format_tool_result_content(tool_name: &str, content: &str) -> String {
|
||||
format!("工具结果: {}\n\n{}", tool_name, content)
|
||||
}
|
||||
|
||||
fn format_json_value(value: &serde_json::Value) -> String {
|
||||
fn format_tool_argument_value(value: &serde_json::Value) -> String {
|
||||
match value {
|
||||
serde_json::Value::Object(map) if map.is_empty() => "{}".to_string(),
|
||||
other => serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()),
|
||||
serde_json::Value::String(text) => text.clone(),
|
||||
serde_json::Value::Null => "null".to_string(),
|
||||
other => serde_json::to_string(other).unwrap_or_else(|_| other.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
@ -392,7 +412,9 @@ mod tests {
|
||||
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolCall);
|
||||
assert_eq!(outbound[0].tool_name.as_deref(), Some("calculator"));
|
||||
assert_eq!(outbound[0].tool_arguments.as_ref().unwrap()["expression"], "1 + 1");
|
||||
assert_eq!(outbound[0].content, "### calculator\n- expression: 1 + 1");
|
||||
assert_eq!(outbound[1].tool_name.as_deref(), Some("file_read"));
|
||||
assert_eq!(outbound[1].content, "### file_read\n- path: README.md");
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@ -11,7 +11,7 @@ use crate::channels::ChannelManager;
|
||||
use crate::config::Config;
|
||||
use crate::logging;
|
||||
use crate::skills::SkillRuntime;
|
||||
use session::SessionManager;
|
||||
use session::{BusToolCallEmitter, SessionManager};
|
||||
|
||||
pub struct GatewayState {
|
||||
pub config: Config,
|
||||
@ -74,12 +74,19 @@ impl GatewayState {
|
||||
}
|
||||
|
||||
// Process via session manager
|
||||
let live_emitter = Arc::new(BusToolCallEmitter::new(
|
||||
bus_for_inbound.clone(),
|
||||
inbound.channel.clone(),
|
||||
inbound.chat_id.clone(),
|
||||
inbound.forwarded_metadata.clone(),
|
||||
));
|
||||
match session_manager.handle_message(
|
||||
&inbound.channel,
|
||||
&inbound.sender_id,
|
||||
&inbound.chat_id,
|
||||
&inbound.content,
|
||||
inbound.media,
|
||||
Some(live_emitter),
|
||||
).await {
|
||||
Ok(outbound_messages) => {
|
||||
// Forward channel-specific metadata from inbound to outbound.
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use async_trait::async_trait;
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
use uuid::Uuid;
|
||||
use crate::bus::{ChatMessage, OutboundMessage};
|
||||
use crate::bus::{ChatMessage, MessageBus, OutboundMessage};
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::agent::{AgentLoop, AgentError, ContextCompressor};
|
||||
use crate::agent::{AgentLoop, AgentError, ContextCompressor, EmittedMessageHandler};
|
||||
use crate::protocol::WsOutbound;
|
||||
use crate::skills::SkillRuntime;
|
||||
use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
|
||||
@ -30,6 +31,46 @@ pub struct Session {
|
||||
store: Arc<SessionStore>,
|
||||
}
|
||||
|
||||
pub struct BusToolCallEmitter {
|
||||
bus: Arc<MessageBus>,
|
||||
channel_name: String,
|
||||
chat_id: String,
|
||||
metadata: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl BusToolCallEmitter {
|
||||
pub fn new(
|
||||
bus: Arc<MessageBus>,
|
||||
channel_name: impl Into<String>,
|
||||
chat_id: impl Into<String>,
|
||||
metadata: HashMap<String, String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
bus,
|
||||
channel_name: channel_name.into(),
|
||||
chat_id: chat_id.into(),
|
||||
metadata,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmittedMessageHandler for BusToolCallEmitter {
|
||||
async fn handle(&self, message: ChatMessage) {
|
||||
for outbound in OutboundMessage::from_chat_message(
|
||||
&self.channel_name,
|
||||
&self.chat_id,
|
||||
None,
|
||||
&self.metadata,
|
||||
&message,
|
||||
) {
|
||||
if let Err(error) = self.bus.publish_outbound(outbound).await {
|
||||
tracing::error!(error = %error, channel = %self.channel_name, chat_id = %self.chat_id, "Failed to publish live outbound tool call");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub async fn new(
|
||||
channel_name: String,
|
||||
@ -437,6 +478,7 @@ impl SessionManager {
|
||||
chat_id: &str,
|
||||
content: &str,
|
||||
media: Vec<crate::bus::MediaItem>,
|
||||
live_emitter: Option<Arc<dyn EmittedMessageHandler>>,
|
||||
) -> Result<Vec<OutboundMessage>, AgentError> {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
@ -502,7 +544,10 @@ impl SessionManager {
|
||||
session_guard.record_skill_offer(chat_id)?;
|
||||
|
||||
// 创建 agent 并处理
|
||||
let agent = session_guard.create_agent(chat_id, Some(sender_id), Some(&user_message_id))?;
|
||||
let mut agent = session_guard.create_agent(chat_id, Some(sender_id), Some(&user_message_id))?;
|
||||
if let Some(handler) = live_emitter.clone() {
|
||||
agent = agent.with_emitted_message_handler(handler);
|
||||
}
|
||||
let result = agent.process(history).await?;
|
||||
|
||||
// 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复
|
||||
@ -511,6 +556,7 @@ impl SessionManager {
|
||||
result
|
||||
.emitted_messages
|
||||
.iter()
|
||||
.filter(|message| !message.is_assistant_tool_call_message() || live_emitter.is_none())
|
||||
.flat_map(|message| {
|
||||
OutboundMessage::from_chat_message(
|
||||
channel_name,
|
||||
|
||||
@ -1,13 +1,29 @@
|
||||
use std::sync::Arc;
|
||||
use async_trait::async_trait;
|
||||
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage};
|
||||
use axum::extract::State;
|
||||
use axum::response::Response;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use crate::agent::EmittedMessageHandler;
|
||||
use crate::bus::message::format_tool_call_content;
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound};
|
||||
use super::{GatewayState, session::{Session, handle_in_chat_command}};
|
||||
|
||||
struct WsToolCallEmitter {
|
||||
sender: mpsc::Sender<WsOutbound>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmittedMessageHandler for WsToolCallEmitter {
|
||||
async fn handle(&self, message: ChatMessage) {
|
||||
for outbound in ws_outbound_from_chat_message(&message) {
|
||||
let _ = self.sender.send(outbound).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
|
||||
ws.on_upgrade(|socket| async {
|
||||
handle_socket(socket, state).await;
|
||||
@ -141,10 +157,6 @@ fn to_session_summary(record: crate::storage::SessionRecord) -> SessionSummary {
|
||||
}
|
||||
}
|
||||
|
||||
fn format_tool_arguments(arguments: &serde_json::Value) -> String {
|
||||
serde_json::to_string_pretty(arguments).unwrap_or_else(|_| arguments.to_string())
|
||||
}
|
||||
|
||||
fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
|
||||
match message.role.as_str() {
|
||||
"assistant" => {
|
||||
@ -156,11 +168,7 @@ fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
|
||||
tool_call_id: tool_call.id.clone(),
|
||||
tool_name: tool_call.name.clone(),
|
||||
arguments: tool_call.arguments.clone(),
|
||||
content: format!(
|
||||
"调用工具: {}\n\n输入参数:\n{}",
|
||||
tool_call.name,
|
||||
format_tool_arguments(&tool_call.arguments),
|
||||
),
|
||||
content: format_tool_call_content(&tool_call.name, &tool_call.arguments),
|
||||
role: message.role.clone(),
|
||||
})
|
||||
.collect()
|
||||
@ -221,13 +229,19 @@ async fn handle_inbound(
|
||||
|
||||
session_guard.record_skill_offer(&chat_id)?;
|
||||
|
||||
let agent = session_guard.create_agent(&chat_id, None, Some(&user_message_id))?;
|
||||
let live_emitter = Arc::new(WsToolCallEmitter {
|
||||
sender: session_guard.user_tx.clone(),
|
||||
});
|
||||
let agent = session_guard
|
||||
.create_agent(&chat_id, None, Some(&user_message_id))?
|
||||
.with_emitted_message_handler(live_emitter);
|
||||
match agent.process(history).await {
|
||||
Ok(result) => {
|
||||
session_guard.append_persisted_messages(&chat_id, result.emitted_messages.clone())?;
|
||||
for outbound in result
|
||||
.emitted_messages
|
||||
.iter()
|
||||
.filter(|message| !message.is_assistant_tool_call_message())
|
||||
.flat_map(ws_outbound_from_chat_message)
|
||||
{
|
||||
let _ = session_guard.send(outbound).await;
|
||||
@ -393,10 +407,11 @@ mod tests {
|
||||
|
||||
assert_eq!(outbound.len(), 1);
|
||||
match &outbound[0] {
|
||||
WsOutbound::ToolCall { tool_call_id, tool_name, arguments, .. } => {
|
||||
WsOutbound::ToolCall { tool_call_id, tool_name, arguments, content, .. } => {
|
||||
assert_eq!(tool_call_id, "call-1");
|
||||
assert_eq!(tool_name, "calculator");
|
||||
assert_eq!(arguments["expression"], "1 + 1");
|
||||
assert_eq!(content, "### calculator\n- expression: 1 + 1");
|
||||
}
|
||||
other => panic!("unexpected outbound variant: {:?}", other),
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user