feat: 添加实时工具调用消息处理,优化消息格式化和传递逻辑
This commit is contained in:
parent
bc24a28275
commit
4725b5406e
@ -1,3 +1,4 @@
|
|||||||
|
use async_trait::async_trait;
|
||||||
use crate::bus::message::ContentBlock;
|
use crate::bus::message::ContentBlock;
|
||||||
use crate::bus::ChatMessage;
|
use crate::bus::ChatMessage;
|
||||||
use crate::config::LLMProviderConfig;
|
use crate::config::LLMProviderConfig;
|
||||||
@ -227,6 +228,7 @@ pub struct AgentLoop {
|
|||||||
skill_event_session_id: Option<String>,
|
skill_event_session_id: Option<String>,
|
||||||
tool_context: ToolContext,
|
tool_context: ToolContext,
|
||||||
observer: Option<Arc<dyn Observer>>,
|
observer: Option<Arc<dyn Observer>>,
|
||||||
|
emitted_message_handler: Option<Arc<dyn EmittedMessageHandler>>,
|
||||||
max_iterations: usize,
|
max_iterations: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -236,6 +238,11 @@ pub struct AgentProcessResult {
|
|||||||
pub emitted_messages: Vec<ChatMessage>,
|
pub emitted_messages: Vec<ChatMessage>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait EmittedMessageHandler: Send + Sync + 'static {
|
||||||
|
async fn handle(&self, message: ChatMessage);
|
||||||
|
}
|
||||||
|
|
||||||
impl AgentLoop {
|
impl AgentLoop {
|
||||||
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
|
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
|
||||||
let max_iterations = provider_config.max_tool_iterations;
|
let max_iterations = provider_config.max_tool_iterations;
|
||||||
@ -250,6 +257,7 @@ impl AgentLoop {
|
|||||||
skill_event_session_id: None,
|
skill_event_session_id: None,
|
||||||
tool_context: ToolContext::default(),
|
tool_context: ToolContext::default(),
|
||||||
observer: None,
|
observer: None,
|
||||||
|
emitted_message_handler: None,
|
||||||
max_iterations,
|
max_iterations,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -267,6 +275,7 @@ impl AgentLoop {
|
|||||||
skill_event_session_id: None,
|
skill_event_session_id: None,
|
||||||
tool_context: ToolContext::default(),
|
tool_context: ToolContext::default(),
|
||||||
observer: None,
|
observer: None,
|
||||||
|
emitted_message_handler: None,
|
||||||
max_iterations,
|
max_iterations,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -288,6 +297,7 @@ impl AgentLoop {
|
|||||||
skill_event_session_id: None,
|
skill_event_session_id: None,
|
||||||
tool_context: ToolContext::default(),
|
tool_context: ToolContext::default(),
|
||||||
observer: None,
|
observer: None,
|
||||||
|
emitted_message_handler: None,
|
||||||
max_iterations,
|
max_iterations,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -309,6 +319,11 @@ impl AgentLoop {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn with_emitted_message_handler(mut self, handler: Arc<dyn EmittedMessageHandler>) -> Self {
|
||||||
|
self.emitted_message_handler = Some(handler);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
pub fn tools(&self) -> &Arc<ToolRegistry> {
|
pub fn tools(&self) -> &Arc<ToolRegistry> {
|
||||||
&self.tools
|
&self.tools
|
||||||
}
|
}
|
||||||
@ -388,6 +403,7 @@ impl AgentLoop {
|
|||||||
);
|
);
|
||||||
messages.push(assistant_message.clone());
|
messages.push(assistant_message.clone());
|
||||||
emitted_messages.push(assistant_message);
|
emitted_messages.push(assistant_message);
|
||||||
|
self.emit_live_tool_call_message(emitted_messages.last().expect("assistant message just pushed").clone()).await;
|
||||||
|
|
||||||
// Execute tools and add results to messages
|
// Execute tools and add results to messages
|
||||||
let tool_results = self.execute_tools(&response.tool_calls).await;
|
let tool_results = self.execute_tools(&response.tool_calls).await;
|
||||||
@ -487,6 +503,16 @@ impl AgentLoop {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn emit_live_tool_call_message(&self, message: ChatMessage) {
|
||||||
|
if !message.is_assistant_tool_call_message() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(handler) = &self.emitted_message_handler {
|
||||||
|
handler.handle(message).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Determine whether to execute tools in parallel or sequentially.
|
/// Determine whether to execute tools in parallel or sequentially.
|
||||||
///
|
///
|
||||||
/// Returns true if:
|
/// Returns true if:
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
pub mod agent_loop;
|
pub mod agent_loop;
|
||||||
pub mod context_compressor;
|
pub mod context_compressor;
|
||||||
|
|
||||||
pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult};
|
pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult, EmittedMessageHandler};
|
||||||
pub use context_compressor::ContextCompressor;
|
pub use context_compressor::ContextCompressor;
|
||||||
|
|||||||
@ -153,6 +153,15 @@ impl ChatMessage {
|
|||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn is_assistant_tool_call_message(&self) -> bool {
|
||||||
|
self.role == "assistant"
|
||||||
|
&& self
|
||||||
|
.tool_calls
|
||||||
|
.as_ref()
|
||||||
|
.map(|calls| !calls.is_empty())
|
||||||
|
.unwrap_or(false)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
@ -325,22 +334,33 @@ impl OutboundMessage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn format_tool_call_content(tool_name: &str, tool_arguments: &serde_json::Value) -> String {
|
pub(crate) fn format_tool_call_content(tool_name: &str, tool_arguments: &serde_json::Value) -> String {
|
||||||
format!(
|
let mut lines = vec![format!("### {}", tool_name)];
|
||||||
"调用工具: {}\n\n输入参数:\n{}",
|
|
||||||
tool_name,
|
match tool_arguments {
|
||||||
format_json_value(tool_arguments),
|
serde_json::Value::Object(map) if !map.is_empty() => {
|
||||||
)
|
let mut entries: Vec<_> = map.iter().collect();
|
||||||
|
entries.sort_by(|(left, _), (right, _)| left.cmp(right));
|
||||||
|
for (key, value) in entries {
|
||||||
|
lines.push(format!("- {}: {}", key, format_tool_argument_value(value)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
serde_json::Value::Object(_) => {}
|
||||||
|
other => lines.push(format!("- args: {}", format_tool_argument_value(other))),
|
||||||
|
}
|
||||||
|
|
||||||
|
lines.join("\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn format_tool_result_content(tool_name: &str, content: &str) -> String {
|
fn format_tool_result_content(tool_name: &str, content: &str) -> String {
|
||||||
format!("工具结果: {}\n\n{}", tool_name, content)
|
format!("工具结果: {}\n\n{}", tool_name, content)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn format_json_value(value: &serde_json::Value) -> String {
|
fn format_tool_argument_value(value: &serde_json::Value) -> String {
|
||||||
match value {
|
match value {
|
||||||
serde_json::Value::Object(map) if map.is_empty() => "{}".to_string(),
|
serde_json::Value::String(text) => text.clone(),
|
||||||
other => serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()),
|
serde_json::Value::Null => "null".to_string(),
|
||||||
|
other => serde_json::to_string(other).unwrap_or_else(|_| other.to_string()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -392,7 +412,9 @@ mod tests {
|
|||||||
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolCall);
|
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolCall);
|
||||||
assert_eq!(outbound[0].tool_name.as_deref(), Some("calculator"));
|
assert_eq!(outbound[0].tool_name.as_deref(), Some("calculator"));
|
||||||
assert_eq!(outbound[0].tool_arguments.as_ref().unwrap()["expression"], "1 + 1");
|
assert_eq!(outbound[0].tool_arguments.as_ref().unwrap()["expression"], "1 + 1");
|
||||||
|
assert_eq!(outbound[0].content, "### calculator\n- expression: 1 + 1");
|
||||||
assert_eq!(outbound[1].tool_name.as_deref(), Some("file_read"));
|
assert_eq!(outbound[1].tool_name.as_deref(), Some("file_read"));
|
||||||
|
assert_eq!(outbound[1].content, "### file_read\n- path: README.md");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@ -11,7 +11,7 @@ use crate::channels::ChannelManager;
|
|||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use crate::logging;
|
use crate::logging;
|
||||||
use crate::skills::SkillRuntime;
|
use crate::skills::SkillRuntime;
|
||||||
use session::SessionManager;
|
use session::{BusToolCallEmitter, SessionManager};
|
||||||
|
|
||||||
pub struct GatewayState {
|
pub struct GatewayState {
|
||||||
pub config: Config,
|
pub config: Config,
|
||||||
@ -74,12 +74,19 @@ impl GatewayState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process via session manager
|
// Process via session manager
|
||||||
|
let live_emitter = Arc::new(BusToolCallEmitter::new(
|
||||||
|
bus_for_inbound.clone(),
|
||||||
|
inbound.channel.clone(),
|
||||||
|
inbound.chat_id.clone(),
|
||||||
|
inbound.forwarded_metadata.clone(),
|
||||||
|
));
|
||||||
match session_manager.handle_message(
|
match session_manager.handle_message(
|
||||||
&inbound.channel,
|
&inbound.channel,
|
||||||
&inbound.sender_id,
|
&inbound.sender_id,
|
||||||
&inbound.chat_id,
|
&inbound.chat_id,
|
||||||
&inbound.content,
|
&inbound.content,
|
||||||
inbound.media,
|
inbound.media,
|
||||||
|
Some(live_emitter),
|
||||||
).await {
|
).await {
|
||||||
Ok(outbound_messages) => {
|
Ok(outbound_messages) => {
|
||||||
// Forward channel-specific metadata from inbound to outbound.
|
// Forward channel-specific metadata from inbound to outbound.
|
||||||
|
|||||||
@ -1,11 +1,12 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
use async_trait::async_trait;
|
||||||
use tokio::sync::{Mutex, mpsc};
|
use tokio::sync::{Mutex, mpsc};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
use crate::bus::{ChatMessage, OutboundMessage};
|
use crate::bus::{ChatMessage, MessageBus, OutboundMessage};
|
||||||
use crate::config::LLMProviderConfig;
|
use crate::config::LLMProviderConfig;
|
||||||
use crate::agent::{AgentLoop, AgentError, ContextCompressor};
|
use crate::agent::{AgentLoop, AgentError, ContextCompressor, EmittedMessageHandler};
|
||||||
use crate::protocol::WsOutbound;
|
use crate::protocol::WsOutbound;
|
||||||
use crate::skills::SkillRuntime;
|
use crate::skills::SkillRuntime;
|
||||||
use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
|
use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
|
||||||
@ -30,6 +31,46 @@ pub struct Session {
|
|||||||
store: Arc<SessionStore>,
|
store: Arc<SessionStore>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct BusToolCallEmitter {
|
||||||
|
bus: Arc<MessageBus>,
|
||||||
|
channel_name: String,
|
||||||
|
chat_id: String,
|
||||||
|
metadata: HashMap<String, String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BusToolCallEmitter {
|
||||||
|
pub fn new(
|
||||||
|
bus: Arc<MessageBus>,
|
||||||
|
channel_name: impl Into<String>,
|
||||||
|
chat_id: impl Into<String>,
|
||||||
|
metadata: HashMap<String, String>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
bus,
|
||||||
|
channel_name: channel_name.into(),
|
||||||
|
chat_id: chat_id.into(),
|
||||||
|
metadata,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl EmittedMessageHandler for BusToolCallEmitter {
|
||||||
|
async fn handle(&self, message: ChatMessage) {
|
||||||
|
for outbound in OutboundMessage::from_chat_message(
|
||||||
|
&self.channel_name,
|
||||||
|
&self.chat_id,
|
||||||
|
None,
|
||||||
|
&self.metadata,
|
||||||
|
&message,
|
||||||
|
) {
|
||||||
|
if let Err(error) = self.bus.publish_outbound(outbound).await {
|
||||||
|
tracing::error!(error = %error, channel = %self.channel_name, chat_id = %self.chat_id, "Failed to publish live outbound tool call");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Session {
|
impl Session {
|
||||||
pub async fn new(
|
pub async fn new(
|
||||||
channel_name: String,
|
channel_name: String,
|
||||||
@ -437,6 +478,7 @@ impl SessionManager {
|
|||||||
chat_id: &str,
|
chat_id: &str,
|
||||||
content: &str,
|
content: &str,
|
||||||
media: Vec<crate::bus::MediaItem>,
|
media: Vec<crate::bus::MediaItem>,
|
||||||
|
live_emitter: Option<Arc<dyn EmittedMessageHandler>>,
|
||||||
) -> Result<Vec<OutboundMessage>, AgentError> {
|
) -> Result<Vec<OutboundMessage>, AgentError> {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
@ -502,7 +544,10 @@ impl SessionManager {
|
|||||||
session_guard.record_skill_offer(chat_id)?;
|
session_guard.record_skill_offer(chat_id)?;
|
||||||
|
|
||||||
// 创建 agent 并处理
|
// 创建 agent 并处理
|
||||||
let agent = session_guard.create_agent(chat_id, Some(sender_id), Some(&user_message_id))?;
|
let mut agent = session_guard.create_agent(chat_id, Some(sender_id), Some(&user_message_id))?;
|
||||||
|
if let Some(handler) = live_emitter.clone() {
|
||||||
|
agent = agent.with_emitted_message_handler(handler);
|
||||||
|
}
|
||||||
let result = agent.process(history).await?;
|
let result = agent.process(history).await?;
|
||||||
|
|
||||||
// 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复
|
// 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复
|
||||||
@ -511,6 +556,7 @@ impl SessionManager {
|
|||||||
result
|
result
|
||||||
.emitted_messages
|
.emitted_messages
|
||||||
.iter()
|
.iter()
|
||||||
|
.filter(|message| !message.is_assistant_tool_call_message() || live_emitter.is_none())
|
||||||
.flat_map(|message| {
|
.flat_map(|message| {
|
||||||
OutboundMessage::from_chat_message(
|
OutboundMessage::from_chat_message(
|
||||||
channel_name,
|
channel_name,
|
||||||
|
|||||||
@ -1,13 +1,29 @@
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use async_trait::async_trait;
|
||||||
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage};
|
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage};
|
||||||
use axum::extract::State;
|
use axum::extract::State;
|
||||||
use axum::response::Response;
|
use axum::response::Response;
|
||||||
use futures_util::{SinkExt, StreamExt};
|
use futures_util::{SinkExt, StreamExt};
|
||||||
use tokio::sync::{mpsc, Mutex};
|
use tokio::sync::{mpsc, Mutex};
|
||||||
|
use crate::agent::EmittedMessageHandler;
|
||||||
|
use crate::bus::message::format_tool_call_content;
|
||||||
use crate::bus::ChatMessage;
|
use crate::bus::ChatMessage;
|
||||||
use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound};
|
use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound};
|
||||||
use super::{GatewayState, session::{Session, handle_in_chat_command}};
|
use super::{GatewayState, session::{Session, handle_in_chat_command}};
|
||||||
|
|
||||||
|
struct WsToolCallEmitter {
|
||||||
|
sender: mpsc::Sender<WsOutbound>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl EmittedMessageHandler for WsToolCallEmitter {
|
||||||
|
async fn handle(&self, message: ChatMessage) {
|
||||||
|
for outbound in ws_outbound_from_chat_message(&message) {
|
||||||
|
let _ = self.sender.send(outbound).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
|
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
|
||||||
ws.on_upgrade(|socket| async {
|
ws.on_upgrade(|socket| async {
|
||||||
handle_socket(socket, state).await;
|
handle_socket(socket, state).await;
|
||||||
@ -141,10 +157,6 @@ fn to_session_summary(record: crate::storage::SessionRecord) -> SessionSummary {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn format_tool_arguments(arguments: &serde_json::Value) -> String {
|
|
||||||
serde_json::to_string_pretty(arguments).unwrap_or_else(|_| arguments.to_string())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
|
fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
|
||||||
match message.role.as_str() {
|
match message.role.as_str() {
|
||||||
"assistant" => {
|
"assistant" => {
|
||||||
@ -156,11 +168,7 @@ fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
|
|||||||
tool_call_id: tool_call.id.clone(),
|
tool_call_id: tool_call.id.clone(),
|
||||||
tool_name: tool_call.name.clone(),
|
tool_name: tool_call.name.clone(),
|
||||||
arguments: tool_call.arguments.clone(),
|
arguments: tool_call.arguments.clone(),
|
||||||
content: format!(
|
content: format_tool_call_content(&tool_call.name, &tool_call.arguments),
|
||||||
"调用工具: {}\n\n输入参数:\n{}",
|
|
||||||
tool_call.name,
|
|
||||||
format_tool_arguments(&tool_call.arguments),
|
|
||||||
),
|
|
||||||
role: message.role.clone(),
|
role: message.role.clone(),
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
@ -221,13 +229,19 @@ async fn handle_inbound(
|
|||||||
|
|
||||||
session_guard.record_skill_offer(&chat_id)?;
|
session_guard.record_skill_offer(&chat_id)?;
|
||||||
|
|
||||||
let agent = session_guard.create_agent(&chat_id, None, Some(&user_message_id))?;
|
let live_emitter = Arc::new(WsToolCallEmitter {
|
||||||
|
sender: session_guard.user_tx.clone(),
|
||||||
|
});
|
||||||
|
let agent = session_guard
|
||||||
|
.create_agent(&chat_id, None, Some(&user_message_id))?
|
||||||
|
.with_emitted_message_handler(live_emitter);
|
||||||
match agent.process(history).await {
|
match agent.process(history).await {
|
||||||
Ok(result) => {
|
Ok(result) => {
|
||||||
session_guard.append_persisted_messages(&chat_id, result.emitted_messages.clone())?;
|
session_guard.append_persisted_messages(&chat_id, result.emitted_messages.clone())?;
|
||||||
for outbound in result
|
for outbound in result
|
||||||
.emitted_messages
|
.emitted_messages
|
||||||
.iter()
|
.iter()
|
||||||
|
.filter(|message| !message.is_assistant_tool_call_message())
|
||||||
.flat_map(ws_outbound_from_chat_message)
|
.flat_map(ws_outbound_from_chat_message)
|
||||||
{
|
{
|
||||||
let _ = session_guard.send(outbound).await;
|
let _ = session_guard.send(outbound).await;
|
||||||
@ -393,10 +407,11 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(outbound.len(), 1);
|
assert_eq!(outbound.len(), 1);
|
||||||
match &outbound[0] {
|
match &outbound[0] {
|
||||||
WsOutbound::ToolCall { tool_call_id, tool_name, arguments, .. } => {
|
WsOutbound::ToolCall { tool_call_id, tool_name, arguments, content, .. } => {
|
||||||
assert_eq!(tool_call_id, "call-1");
|
assert_eq!(tool_call_id, "call-1");
|
||||||
assert_eq!(tool_name, "calculator");
|
assert_eq!(tool_name, "calculator");
|
||||||
assert_eq!(arguments["expression"], "1 + 1");
|
assert_eq!(arguments["expression"], "1 + 1");
|
||||||
|
assert_eq!(content, "### calculator\n- expression: 1 + 1");
|
||||||
}
|
}
|
||||||
other => panic!("unexpected outbound variant: {:?}", other),
|
other => panic!("unexpected outbound variant: {:?}", other),
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user