feat: 添加实时工具调用消息处理,优化消息格式化和传递逻辑

This commit is contained in:
ooodc 2026-04-22 09:01:56 +08:00
parent bc24a28275
commit 4725b5406e
6 changed files with 141 additions and 25 deletions

View File

@ -1,3 +1,4 @@
use async_trait::async_trait;
use crate::bus::message::ContentBlock;
use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig;
@ -227,6 +228,7 @@ pub struct AgentLoop {
skill_event_session_id: Option<String>,
tool_context: ToolContext,
observer: Option<Arc<dyn Observer>>,
emitted_message_handler: Option<Arc<dyn EmittedMessageHandler>>,
max_iterations: usize,
}
@ -236,6 +238,11 @@ pub struct AgentProcessResult {
pub emitted_messages: Vec<ChatMessage>,
}
#[async_trait]
pub trait EmittedMessageHandler: Send + Sync + 'static {
async fn handle(&self, message: ChatMessage);
}
impl AgentLoop {
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
let max_iterations = provider_config.max_tool_iterations;
@ -250,6 +257,7 @@ impl AgentLoop {
skill_event_session_id: None,
tool_context: ToolContext::default(),
observer: None,
emitted_message_handler: None,
max_iterations,
})
}
@ -267,6 +275,7 @@ impl AgentLoop {
skill_event_session_id: None,
tool_context: ToolContext::default(),
observer: None,
emitted_message_handler: None,
max_iterations,
})
}
@ -288,6 +297,7 @@ impl AgentLoop {
skill_event_session_id: None,
tool_context: ToolContext::default(),
observer: None,
emitted_message_handler: None,
max_iterations,
})
}
@ -309,6 +319,11 @@ impl AgentLoop {
self
}
pub fn with_emitted_message_handler(mut self, handler: Arc<dyn EmittedMessageHandler>) -> Self {
self.emitted_message_handler = Some(handler);
self
}
pub fn tools(&self) -> &Arc<ToolRegistry> {
&self.tools
}
@ -388,6 +403,7 @@ impl AgentLoop {
);
messages.push(assistant_message.clone());
emitted_messages.push(assistant_message);
self.emit_live_tool_call_message(emitted_messages.last().expect("assistant message just pushed").clone()).await;
// Execute tools and add results to messages
let tool_results = self.execute_tools(&response.tool_calls).await;
@ -487,6 +503,16 @@ impl AgentLoop {
}
}
async fn emit_live_tool_call_message(&self, message: ChatMessage) {
if !message.is_assistant_tool_call_message() {
return;
}
if let Some(handler) = &self.emitted_message_handler {
handler.handle(message).await;
}
}
/// Determine whether to execute tools in parallel or sequentially.
///
/// Returns true if:

View File

@ -1,5 +1,5 @@
pub mod agent_loop;
pub mod context_compressor;
pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult};
pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult, EmittedMessageHandler};
pub use context_compressor::ContextCompressor;

View File

@ -153,6 +153,15 @@ impl ChatMessage {
tool_calls: None,
}
}
pub fn is_assistant_tool_call_message(&self) -> bool {
self.role == "assistant"
&& self
.tool_calls
.as_ref()
.map(|calls| !calls.is_empty())
.unwrap_or(false)
}
}
// ============================================================================
@ -325,22 +334,33 @@ impl OutboundMessage {
}
}
fn format_tool_call_content(tool_name: &str, tool_arguments: &serde_json::Value) -> String {
format!(
"调用工具: {}\n\n输入参数:\n{}",
tool_name,
format_json_value(tool_arguments),
)
pub(crate) fn format_tool_call_content(tool_name: &str, tool_arguments: &serde_json::Value) -> String {
let mut lines = vec![format!("### {}", tool_name)];
match tool_arguments {
serde_json::Value::Object(map) if !map.is_empty() => {
let mut entries: Vec<_> = map.iter().collect();
entries.sort_by(|(left, _), (right, _)| left.cmp(right));
for (key, value) in entries {
lines.push(format!("- {}: {}", key, format_tool_argument_value(value)));
}
}
serde_json::Value::Object(_) => {}
other => lines.push(format!("- args: {}", format_tool_argument_value(other))),
}
lines.join("\n")
}
fn format_tool_result_content(tool_name: &str, content: &str) -> String {
format!("工具结果: {}\n\n{}", tool_name, content)
}
fn format_json_value(value: &serde_json::Value) -> String {
fn format_tool_argument_value(value: &serde_json::Value) -> String {
match value {
serde_json::Value::Object(map) if map.is_empty() => "{}".to_string(),
other => serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()),
serde_json::Value::String(text) => text.clone(),
serde_json::Value::Null => "null".to_string(),
other => serde_json::to_string(other).unwrap_or_else(|_| other.to_string()),
}
}
@ -392,7 +412,9 @@ mod tests {
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolCall);
assert_eq!(outbound[0].tool_name.as_deref(), Some("calculator"));
assert_eq!(outbound[0].tool_arguments.as_ref().unwrap()["expression"], "1 + 1");
assert_eq!(outbound[0].content, "### calculator\n- expression: 1 + 1");
assert_eq!(outbound[1].tool_name.as_deref(), Some("file_read"));
assert_eq!(outbound[1].content, "### file_read\n- path: README.md");
}
#[test]

View File

@ -11,7 +11,7 @@ use crate::channels::ChannelManager;
use crate::config::Config;
use crate::logging;
use crate::skills::SkillRuntime;
use session::SessionManager;
use session::{BusToolCallEmitter, SessionManager};
pub struct GatewayState {
pub config: Config,
@ -74,12 +74,19 @@ impl GatewayState {
}
// Process via session manager
let live_emitter = Arc::new(BusToolCallEmitter::new(
bus_for_inbound.clone(),
inbound.channel.clone(),
inbound.chat_id.clone(),
inbound.forwarded_metadata.clone(),
));
match session_manager.handle_message(
&inbound.channel,
&inbound.sender_id,
&inbound.chat_id,
&inbound.content,
inbound.media,
Some(live_emitter),
).await {
Ok(outbound_messages) => {
// Forward channel-specific metadata from inbound to outbound.

View File

@ -1,11 +1,12 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use tokio::sync::{Mutex, mpsc};
use uuid::Uuid;
use crate::bus::{ChatMessage, OutboundMessage};
use crate::bus::{ChatMessage, MessageBus, OutboundMessage};
use crate::config::LLMProviderConfig;
use crate::agent::{AgentLoop, AgentError, ContextCompressor};
use crate::agent::{AgentLoop, AgentError, ContextCompressor, EmittedMessageHandler};
use crate::protocol::WsOutbound;
use crate::skills::SkillRuntime;
use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
@ -30,6 +31,46 @@ pub struct Session {
store: Arc<SessionStore>,
}
pub struct BusToolCallEmitter {
bus: Arc<MessageBus>,
channel_name: String,
chat_id: String,
metadata: HashMap<String, String>,
}
impl BusToolCallEmitter {
pub fn new(
bus: Arc<MessageBus>,
channel_name: impl Into<String>,
chat_id: impl Into<String>,
metadata: HashMap<String, String>,
) -> Self {
Self {
bus,
channel_name: channel_name.into(),
chat_id: chat_id.into(),
metadata,
}
}
}
#[async_trait]
impl EmittedMessageHandler for BusToolCallEmitter {
async fn handle(&self, message: ChatMessage) {
for outbound in OutboundMessage::from_chat_message(
&self.channel_name,
&self.chat_id,
None,
&self.metadata,
&message,
) {
if let Err(error) = self.bus.publish_outbound(outbound).await {
tracing::error!(error = %error, channel = %self.channel_name, chat_id = %self.chat_id, "Failed to publish live outbound tool call");
}
}
}
}
impl Session {
pub async fn new(
channel_name: String,
@ -437,6 +478,7 @@ impl SessionManager {
chat_id: &str,
content: &str,
media: Vec<crate::bus::MediaItem>,
live_emitter: Option<Arc<dyn EmittedMessageHandler>>,
) -> Result<Vec<OutboundMessage>, AgentError> {
#[cfg(debug_assertions)]
{
@ -502,7 +544,10 @@ impl SessionManager {
session_guard.record_skill_offer(chat_id)?;
// 创建 agent 并处理
let agent = session_guard.create_agent(chat_id, Some(sender_id), Some(&user_message_id))?;
let mut agent = session_guard.create_agent(chat_id, Some(sender_id), Some(&user_message_id))?;
if let Some(handler) = live_emitter.clone() {
agent = agent.with_emitted_message_handler(handler);
}
let result = agent.process(history).await?;
// 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复
@ -511,6 +556,7 @@ impl SessionManager {
result
.emitted_messages
.iter()
.filter(|message| !message.is_assistant_tool_call_message() || live_emitter.is_none())
.flat_map(|message| {
OutboundMessage::from_chat_message(
channel_name,

View File

@ -1,13 +1,29 @@
use std::sync::Arc;
use async_trait::async_trait;
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage};
use axum::extract::State;
use axum::response::Response;
use futures_util::{SinkExt, StreamExt};
use tokio::sync::{mpsc, Mutex};
use crate::agent::EmittedMessageHandler;
use crate::bus::message::format_tool_call_content;
use crate::bus::ChatMessage;
use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound};
use super::{GatewayState, session::{Session, handle_in_chat_command}};
struct WsToolCallEmitter {
sender: mpsc::Sender<WsOutbound>,
}
#[async_trait]
impl EmittedMessageHandler for WsToolCallEmitter {
async fn handle(&self, message: ChatMessage) {
for outbound in ws_outbound_from_chat_message(&message) {
let _ = self.sender.send(outbound).await;
}
}
}
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
ws.on_upgrade(|socket| async {
handle_socket(socket, state).await;
@ -141,10 +157,6 @@ fn to_session_summary(record: crate::storage::SessionRecord) -> SessionSummary {
}
}
fn format_tool_arguments(arguments: &serde_json::Value) -> String {
serde_json::to_string_pretty(arguments).unwrap_or_else(|_| arguments.to_string())
}
fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
match message.role.as_str() {
"assistant" => {
@ -156,11 +168,7 @@ fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
tool_call_id: tool_call.id.clone(),
tool_name: tool_call.name.clone(),
arguments: tool_call.arguments.clone(),
content: format!(
"调用工具: {}\n\n输入参数:\n{}",
tool_call.name,
format_tool_arguments(&tool_call.arguments),
),
content: format_tool_call_content(&tool_call.name, &tool_call.arguments),
role: message.role.clone(),
})
.collect()
@ -221,13 +229,19 @@ async fn handle_inbound(
session_guard.record_skill_offer(&chat_id)?;
let agent = session_guard.create_agent(&chat_id, None, Some(&user_message_id))?;
let live_emitter = Arc::new(WsToolCallEmitter {
sender: session_guard.user_tx.clone(),
});
let agent = session_guard
.create_agent(&chat_id, None, Some(&user_message_id))?
.with_emitted_message_handler(live_emitter);
match agent.process(history).await {
Ok(result) => {
session_guard.append_persisted_messages(&chat_id, result.emitted_messages.clone())?;
for outbound in result
.emitted_messages
.iter()
.filter(|message| !message.is_assistant_tool_call_message())
.flat_map(ws_outbound_from_chat_message)
{
let _ = session_guard.send(outbound).await;
@ -393,10 +407,11 @@ mod tests {
assert_eq!(outbound.len(), 1);
match &outbound[0] {
WsOutbound::ToolCall { tool_call_id, tool_name, arguments, .. } => {
WsOutbound::ToolCall { tool_call_id, tool_name, arguments, content, .. } => {
assert_eq!(tool_call_id, "call-1");
assert_eq!(tool_name, "calculator");
assert_eq!(arguments["expression"], "1 + 1");
assert_eq!(content, "### calculator\n- expression: 1 + 1");
}
other => panic!("unexpected outbound variant: {:?}", other),
}