diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index c790c5d..eac4bb9 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -1,3 +1,4 @@ +use async_trait::async_trait; use crate::bus::message::ContentBlock; use crate::bus::ChatMessage; use crate::config::LLMProviderConfig; @@ -227,6 +228,7 @@ pub struct AgentLoop { skill_event_session_id: Option, tool_context: ToolContext, observer: Option>, + emitted_message_handler: Option>, max_iterations: usize, } @@ -236,6 +238,11 @@ pub struct AgentProcessResult { pub emitted_messages: Vec, } +#[async_trait] +pub trait EmittedMessageHandler: Send + Sync + 'static { + async fn handle(&self, message: ChatMessage); +} + impl AgentLoop { pub fn new(provider_config: LLMProviderConfig) -> Result { let max_iterations = provider_config.max_tool_iterations; @@ -250,6 +257,7 @@ impl AgentLoop { skill_event_session_id: None, tool_context: ToolContext::default(), observer: None, + emitted_message_handler: None, max_iterations, }) } @@ -267,6 +275,7 @@ impl AgentLoop { skill_event_session_id: None, tool_context: ToolContext::default(), observer: None, + emitted_message_handler: None, max_iterations, }) } @@ -288,6 +297,7 @@ impl AgentLoop { skill_event_session_id: None, tool_context: ToolContext::default(), observer: None, + emitted_message_handler: None, max_iterations, }) } @@ -309,6 +319,11 @@ impl AgentLoop { self } + pub fn with_emitted_message_handler(mut self, handler: Arc) -> Self { + self.emitted_message_handler = Some(handler); + self + } + pub fn tools(&self) -> &Arc { &self.tools } @@ -388,6 +403,7 @@ impl AgentLoop { ); messages.push(assistant_message.clone()); emitted_messages.push(assistant_message); + self.emit_live_tool_call_message(emitted_messages.last().expect("assistant message just pushed").clone()).await; // Execute tools and add results to messages let tool_results = self.execute_tools(&response.tool_calls).await; @@ -487,6 +503,16 @@ impl AgentLoop { } } + async fn emit_live_tool_call_message(&self, message: ChatMessage) { + if !message.is_assistant_tool_call_message() { + return; + } + + if let Some(handler) = &self.emitted_message_handler { + handler.handle(message).await; + } + } + /// Determine whether to execute tools in parallel or sequentially. /// /// Returns true if: diff --git a/src/agent/mod.rs b/src/agent/mod.rs index 4dd5762..cbd5daf 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -1,5 +1,5 @@ pub mod agent_loop; pub mod context_compressor; -pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult}; +pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult, EmittedMessageHandler}; pub use context_compressor::ContextCompressor; diff --git a/src/bus/message.rs b/src/bus/message.rs index 0d99009..bdb82f4 100644 --- a/src/bus/message.rs +++ b/src/bus/message.rs @@ -153,6 +153,15 @@ impl ChatMessage { tool_calls: None, } } + + pub fn is_assistant_tool_call_message(&self) -> bool { + self.role == "assistant" + && self + .tool_calls + .as_ref() + .map(|calls| !calls.is_empty()) + .unwrap_or(false) + } } // ============================================================================ @@ -325,22 +334,33 @@ impl OutboundMessage { } } -fn format_tool_call_content(tool_name: &str, tool_arguments: &serde_json::Value) -> String { - format!( - "调用工具: {}\n\n输入参数:\n{}", - tool_name, - format_json_value(tool_arguments), - ) +pub(crate) fn format_tool_call_content(tool_name: &str, tool_arguments: &serde_json::Value) -> String { + let mut lines = vec![format!("### {}", tool_name)]; + + match tool_arguments { + serde_json::Value::Object(map) if !map.is_empty() => { + let mut entries: Vec<_> = map.iter().collect(); + entries.sort_by(|(left, _), (right, _)| left.cmp(right)); + for (key, value) in entries { + lines.push(format!("- {}: {}", key, format_tool_argument_value(value))); + } + } + serde_json::Value::Object(_) => {} + other => lines.push(format!("- args: {}", format_tool_argument_value(other))), + } + + lines.join("\n") } fn format_tool_result_content(tool_name: &str, content: &str) -> String { format!("工具结果: {}\n\n{}", tool_name, content) } -fn format_json_value(value: &serde_json::Value) -> String { +fn format_tool_argument_value(value: &serde_json::Value) -> String { match value { - serde_json::Value::Object(map) if map.is_empty() => "{}".to_string(), - other => serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()), + serde_json::Value::String(text) => text.clone(), + serde_json::Value::Null => "null".to_string(), + other => serde_json::to_string(other).unwrap_or_else(|_| other.to_string()), } } @@ -392,7 +412,9 @@ mod tests { assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolCall); assert_eq!(outbound[0].tool_name.as_deref(), Some("calculator")); assert_eq!(outbound[0].tool_arguments.as_ref().unwrap()["expression"], "1 + 1"); + assert_eq!(outbound[0].content, "### calculator\n- expression: 1 + 1"); assert_eq!(outbound[1].tool_name.as_deref(), Some("file_read")); + assert_eq!(outbound[1].content, "### file_read\n- path: README.md"); } #[test] diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 5ab681d..4f0bba0 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -11,7 +11,7 @@ use crate::channels::ChannelManager; use crate::config::Config; use crate::logging; use crate::skills::SkillRuntime; -use session::SessionManager; +use session::{BusToolCallEmitter, SessionManager}; pub struct GatewayState { pub config: Config, @@ -74,12 +74,19 @@ impl GatewayState { } // Process via session manager + let live_emitter = Arc::new(BusToolCallEmitter::new( + bus_for_inbound.clone(), + inbound.channel.clone(), + inbound.chat_id.clone(), + inbound.forwarded_metadata.clone(), + )); match session_manager.handle_message( &inbound.channel, &inbound.sender_id, &inbound.chat_id, &inbound.content, inbound.media, + Some(live_emitter), ).await { Ok(outbound_messages) => { // Forward channel-specific metadata from inbound to outbound. diff --git a/src/gateway/session.rs b/src/gateway/session.rs index c4ad5b1..ee575a6 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -1,11 +1,12 @@ use std::collections::HashMap; use std::sync::Arc; use std::time::{Duration, Instant}; +use async_trait::async_trait; use tokio::sync::{Mutex, mpsc}; use uuid::Uuid; -use crate::bus::{ChatMessage, OutboundMessage}; +use crate::bus::{ChatMessage, MessageBus, OutboundMessage}; use crate::config::LLMProviderConfig; -use crate::agent::{AgentLoop, AgentError, ContextCompressor}; +use crate::agent::{AgentLoop, AgentError, ContextCompressor, EmittedMessageHandler}; use crate::protocol::WsOutbound; use crate::skills::SkillRuntime; use crate::storage::{SessionRecord, SessionStore, persistent_session_id}; @@ -30,6 +31,46 @@ pub struct Session { store: Arc, } +pub struct BusToolCallEmitter { + bus: Arc, + channel_name: String, + chat_id: String, + metadata: HashMap, +} + +impl BusToolCallEmitter { + pub fn new( + bus: Arc, + channel_name: impl Into, + chat_id: impl Into, + metadata: HashMap, + ) -> Self { + Self { + bus, + channel_name: channel_name.into(), + chat_id: chat_id.into(), + metadata, + } + } +} + +#[async_trait] +impl EmittedMessageHandler for BusToolCallEmitter { + async fn handle(&self, message: ChatMessage) { + for outbound in OutboundMessage::from_chat_message( + &self.channel_name, + &self.chat_id, + None, + &self.metadata, + &message, + ) { + if let Err(error) = self.bus.publish_outbound(outbound).await { + tracing::error!(error = %error, channel = %self.channel_name, chat_id = %self.chat_id, "Failed to publish live outbound tool call"); + } + } + } +} + impl Session { pub async fn new( channel_name: String, @@ -437,6 +478,7 @@ impl SessionManager { chat_id: &str, content: &str, media: Vec, + live_emitter: Option>, ) -> Result, AgentError> { #[cfg(debug_assertions)] { @@ -502,7 +544,10 @@ impl SessionManager { session_guard.record_skill_offer(chat_id)?; // 创建 agent 并处理 - let agent = session_guard.create_agent(chat_id, Some(sender_id), Some(&user_message_id))?; + let mut agent = session_guard.create_agent(chat_id, Some(sender_id), Some(&user_message_id))?; + if let Some(handler) = live_emitter.clone() { + agent = agent.with_emitted_message_handler(handler); + } let result = agent.process(history).await?; // 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复 @@ -511,6 +556,7 @@ impl SessionManager { result .emitted_messages .iter() + .filter(|message| !message.is_assistant_tool_call_message() || live_emitter.is_none()) .flat_map(|message| { OutboundMessage::from_chat_message( channel_name, diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index 1c17c1d..b77cd2c 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -1,13 +1,29 @@ use std::sync::Arc; +use async_trait::async_trait; use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage}; use axum::extract::State; use axum::response::Response; use futures_util::{SinkExt, StreamExt}; use tokio::sync::{mpsc, Mutex}; +use crate::agent::EmittedMessageHandler; +use crate::bus::message::format_tool_call_content; use crate::bus::ChatMessage; use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound}; use super::{GatewayState, session::{Session, handle_in_chat_command}}; +struct WsToolCallEmitter { + sender: mpsc::Sender, +} + +#[async_trait] +impl EmittedMessageHandler for WsToolCallEmitter { + async fn handle(&self, message: ChatMessage) { + for outbound in ws_outbound_from_chat_message(&message) { + let _ = self.sender.send(outbound).await; + } + } +} + pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State>) -> Response { ws.on_upgrade(|socket| async { handle_socket(socket, state).await; @@ -141,10 +157,6 @@ fn to_session_summary(record: crate::storage::SessionRecord) -> SessionSummary { } } -fn format_tool_arguments(arguments: &serde_json::Value) -> String { - serde_json::to_string_pretty(arguments).unwrap_or_else(|_| arguments.to_string()) -} - fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec { match message.role.as_str() { "assistant" => { @@ -156,11 +168,7 @@ fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec { tool_call_id: tool_call.id.clone(), tool_name: tool_call.name.clone(), arguments: tool_call.arguments.clone(), - content: format!( - "调用工具: {}\n\n输入参数:\n{}", - tool_call.name, - format_tool_arguments(&tool_call.arguments), - ), + content: format_tool_call_content(&tool_call.name, &tool_call.arguments), role: message.role.clone(), }) .collect() @@ -221,13 +229,19 @@ async fn handle_inbound( session_guard.record_skill_offer(&chat_id)?; - let agent = session_guard.create_agent(&chat_id, None, Some(&user_message_id))?; + let live_emitter = Arc::new(WsToolCallEmitter { + sender: session_guard.user_tx.clone(), + }); + let agent = session_guard + .create_agent(&chat_id, None, Some(&user_message_id))? + .with_emitted_message_handler(live_emitter); match agent.process(history).await { Ok(result) => { session_guard.append_persisted_messages(&chat_id, result.emitted_messages.clone())?; for outbound in result .emitted_messages .iter() + .filter(|message| !message.is_assistant_tool_call_message()) .flat_map(ws_outbound_from_chat_message) { let _ = session_guard.send(outbound).await; @@ -393,10 +407,11 @@ mod tests { assert_eq!(outbound.len(), 1); match &outbound[0] { - WsOutbound::ToolCall { tool_call_id, tool_name, arguments, .. } => { + WsOutbound::ToolCall { tool_call_id, tool_name, arguments, content, .. } => { assert_eq!(tool_call_id, "call-1"); assert_eq!(tool_name, "calculator"); assert_eq!(arguments["expression"], "1 + 1"); + assert_eq!(content, "### calculator\n- expression: 1 + 1"); } other => panic!("unexpected outbound variant: {:?}", other), }