feat(streaming): 支持流式文本增量与结束信号功能
- 新增 StreamDelta 和 StreamEnd 类型,支持流式数据增量传输 - 扩展 LLMProvider trait,添加带回调的 chat_with_streaming 接口 - 修改 OpenAI Provider 实现,支持流式聊天回调传输增量数据 - Agent 流处理改为异步消费增量消息并传递给前端 - 保证流式增量和最终消息使用相同消息 ID 以便前端替换 - 修改消息总线和协议层,支持携带和识别流式消息的消息 ID - 客户端 CLI 通过增量输出实现交互式流式响应显示 - Web 前端接收流式增量消息,追加到对应消息,实现实时显示 - 各通道(飞书、微信)支持转发流式增量和结束消息 - 任务工具运行时添加消息 ID 支持,保持消息一致性 - 统一消息构造函数新增流式增量和结束信号的构建方法
This commit is contained in:
parent
def6df50da
commit
fc7df67474
@ -7,7 +7,7 @@ use crate::domain::messages::{ContentBlock, ToolCall};
|
|||||||
use crate::observability::{
|
use crate::observability::{
|
||||||
Observer, ObserverEvent, ToolExecutionOutcome, ToolExecutionState, truncate_args,
|
Observer, ObserverEvent, ToolExecutionOutcome, ToolExecutionState, truncate_args,
|
||||||
};
|
};
|
||||||
use crate::providers::{ChatCompletionRequest, LLMProvider, Message, create_provider};
|
use crate::providers::{ChatCompletionRequest, LLMProvider, Message, StreamDelta, StreamCallback, create_provider};
|
||||||
use crate::text::{char_count, take_prefix_chars, take_suffix_chars};
|
use crate::text::{char_count, take_prefix_chars, take_suffix_chars};
|
||||||
use crate::tools::{ToolContext, ToolRegistry};
|
use crate::tools::{ToolContext, ToolRegistry};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
@ -663,6 +663,17 @@ pub trait EmittedMessageHandler: Send + Sync + 'static {
|
|||||||
async fn handle_tool_result(&self, message: ChatMessage, _duration_ms: Option<u64>) {
|
async fn handle_tool_result(&self, message: ChatMessage, _duration_ms: Option<u64>) {
|
||||||
self.handle(message).await;
|
self.handle(message).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Handle a streaming delta. Default is no-op.
|
||||||
|
async fn handle_stream_delta(&self, _delta: &StreamDelta) {
|
||||||
|
// Non-streaming channels ignore this
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the message ID to use for stream deltas (so the final assistant message
|
||||||
|
/// can share the same ID, enabling front-end replacement).
|
||||||
|
async fn set_stream_message_id(&self, _id: &str) {
|
||||||
|
// Default: no-op for handlers that don't stream
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 装饰器:在内部 emitter 广播前,先将消息持久化到 DB
|
/// 装饰器:在内部 emitter 广播前,先将消息持久化到 DB
|
||||||
@ -706,6 +717,15 @@ impl<H: EmittedMessageHandler> EmittedMessageHandler for PersistingEmittedMessag
|
|||||||
}
|
}
|
||||||
self.inner.handle_tool_result(message, duration_ms).await;
|
self.inner.handle_tool_result(message, duration_ms).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn set_stream_message_id(&self, id: &str) {
|
||||||
|
self.inner.set_stream_message_id(id).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_stream_delta(&self, delta: &StreamDelta) {
|
||||||
|
// Deltas are transient — do NOT persist, just forward to inner handler
|
||||||
|
self.inner.handle_stream_delta(delta).await;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait SkillProvider: Send + Sync + 'static {
|
pub trait SkillProvider: Send + Sync + 'static {
|
||||||
@ -958,9 +978,34 @@ impl AgentLoop {
|
|||||||
tools,
|
tools,
|
||||||
};
|
};
|
||||||
|
|
||||||
let response = match (*self.provider).chat(request).await {
|
// Set up streaming delta consumer
|
||||||
|
// Pre-generate the message ID so stream deltas and the final assistant
|
||||||
|
// message share the same ID — this lets the front-end replace the
|
||||||
|
// streamed message with the authoritative response.
|
||||||
|
let streaming_message_id = uuid::Uuid::new_v4().to_string();
|
||||||
|
if let Some(handler) = &self.emitted_message_handler {
|
||||||
|
handler.set_stream_message_id(&streaming_message_id).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let (delta_tx, mut delta_rx) = tokio::sync::mpsc::channel::<StreamDelta>(256);
|
||||||
|
let consumer_handler = self.emitted_message_handler.clone();
|
||||||
|
let consumer_task = tokio::spawn(async move {
|
||||||
|
while let Some(delta) = delta_rx.recv().await {
|
||||||
|
if let Some(ref handler) = consumer_handler {
|
||||||
|
handler.handle_stream_delta(&delta).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
let stream_callback: StreamCallback = std::sync::Arc::new(move |delta: StreamDelta| {
|
||||||
|
// try_send is non-blocking and safe to call from within a tokio runtime
|
||||||
|
let _ = delta_tx.try_send(delta);
|
||||||
|
});
|
||||||
|
|
||||||
|
let response = match (*self.provider).chat_with_streaming(request, stream_callback).await {
|
||||||
Ok(response) => response,
|
Ok(response) => response,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
// delta_tx is dropped with the callback; await consumer to finish
|
||||||
|
let _ = consumer_task.await;
|
||||||
tracing::error!(
|
tracing::error!(
|
||||||
provider = %self.provider.name(),
|
provider = %self.provider.name(),
|
||||||
model = %self.provider.model_id(),
|
model = %self.provider.model_id(),
|
||||||
@ -979,6 +1024,22 @@ impl AgentLoop {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Close delta channel and wait for consumer to finish processing
|
||||||
|
// (delta_tx is dropped when the callback closure is dropped)
|
||||||
|
let _ = consumer_task.await;
|
||||||
|
|
||||||
|
// Signal stream end if handler exists
|
||||||
|
let had_streaming = self.emitted_message_handler.is_some();
|
||||||
|
if had_streaming {
|
||||||
|
let end_delta = StreamDelta {
|
||||||
|
content: String::new(),
|
||||||
|
reasoning_content: None,
|
||||||
|
};
|
||||||
|
if let Some(handler) = &self.emitted_message_handler {
|
||||||
|
handler.handle_stream_delta(&end_delta).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
iteration,
|
iteration,
|
||||||
@ -989,12 +1050,17 @@ impl AgentLoop {
|
|||||||
|
|
||||||
// If no tool calls, this is the final response
|
// If no tool calls, this is the final response
|
||||||
if response.tool_calls.is_empty() {
|
if response.tool_calls.is_empty() {
|
||||||
let assistant_message = if let Some(reasoning_content) = response.reasoning_content
|
let mut assistant_message = if let Some(reasoning_content) = response.reasoning_content
|
||||||
{
|
{
|
||||||
ChatMessage::assistant_with_reasoning(response.content, reasoning_content)
|
ChatMessage::assistant_with_reasoning(response.content, reasoning_content)
|
||||||
} else {
|
} else {
|
||||||
ChatMessage::assistant(response.content)
|
ChatMessage::assistant(response.content)
|
||||||
};
|
};
|
||||||
|
// Use the same ID as the stream deltas so the front-end can replace
|
||||||
|
// the streamed message with this authoritative response.
|
||||||
|
if had_streaming {
|
||||||
|
assistant_message.id = streaming_message_id;
|
||||||
|
}
|
||||||
emitted_messages.push(assistant_message.clone());
|
emitted_messages.push(assistant_message.clone());
|
||||||
self.emit_live_tool_call_message(assistant_message.clone()).await;
|
self.emit_live_tool_call_message(assistant_message.clone()).await;
|
||||||
return Ok(AgentProcessResult {
|
return Ok(AgentProcessResult {
|
||||||
@ -1011,7 +1077,7 @@ impl AgentLoop {
|
|||||||
);
|
);
|
||||||
|
|
||||||
// Add assistant message with tool calls
|
// Add assistant message with tool calls
|
||||||
let assistant_message =
|
let mut assistant_message =
|
||||||
if let Some(reasoning_content) = response.reasoning_content.clone() {
|
if let Some(reasoning_content) = response.reasoning_content.clone() {
|
||||||
ChatMessage::assistant_with_tool_calls_and_reasoning(
|
ChatMessage::assistant_with_tool_calls_and_reasoning(
|
||||||
response.content.clone(),
|
response.content.clone(),
|
||||||
@ -1024,6 +1090,11 @@ impl AgentLoop {
|
|||||||
response.tool_calls.clone(),
|
response.tool_calls.clone(),
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
|
// Use the same ID as stream deltas so the front-end can replace
|
||||||
|
// the streamed message with this authoritative response.
|
||||||
|
if had_streaming {
|
||||||
|
assistant_message.id = streaming_message_id;
|
||||||
|
}
|
||||||
messages.push(assistant_message.clone());
|
messages.push(assistant_message.clone());
|
||||||
emitted_messages.push(assistant_message);
|
emitted_messages.push(assistant_message);
|
||||||
self.emit_live_tool_call_message(
|
self.emit_live_tool_call_message(
|
||||||
|
|||||||
@ -401,6 +401,10 @@ pub struct OutboundMessage {
|
|||||||
pub tool_name: Option<String>,
|
pub tool_name: Option<String>,
|
||||||
pub tool_arguments: Option<serde_json::Value>,
|
pub tool_arguments: Option<serde_json::Value>,
|
||||||
pub reasoning_content: Option<String>,
|
pub reasoning_content: Option<String>,
|
||||||
|
/// Carry the originating ChatMessage.id so the WS layer can use it
|
||||||
|
/// instead of generating a random UUID. Critical for stream delta → assistant_response
|
||||||
|
/// ID matching on the front-end.
|
||||||
|
pub message_id: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
@ -412,11 +416,15 @@ pub enum OutboundEventKind {
|
|||||||
SchedulerNotification,
|
SchedulerNotification,
|
||||||
ErrorNotification,
|
ErrorNotification,
|
||||||
TaskStarted,
|
TaskStarted,
|
||||||
|
/// 流式文本增量
|
||||||
|
StreamDelta,
|
||||||
|
/// 流式结束信号
|
||||||
|
StreamEnd,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OutboundMessage {
|
impl OutboundMessage {
|
||||||
pub fn is_stream_delta(&self) -> bool {
|
pub fn is_stream_delta(&self) -> bool {
|
||||||
self.metadata.get("_stream_delta").is_some()
|
matches!(self.event_kind, OutboundEventKind::StreamDelta | OutboundEventKind::StreamEnd)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn assistant(
|
pub fn assistant(
|
||||||
@ -441,6 +449,7 @@ impl OutboundMessage {
|
|||||||
tool_name: None,
|
tool_name: None,
|
||||||
tool_arguments: None,
|
tool_arguments: None,
|
||||||
reasoning_content: None,
|
reasoning_content: None,
|
||||||
|
message_id: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -496,6 +505,7 @@ impl OutboundMessage {
|
|||||||
tool_name: Some(tool_name),
|
tool_name: Some(tool_name),
|
||||||
tool_arguments: Some(tool_arguments),
|
tool_arguments: Some(tool_arguments),
|
||||||
reasoning_content: None,
|
reasoning_content: None,
|
||||||
|
message_id: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -526,6 +536,7 @@ impl OutboundMessage {
|
|||||||
tool_name: Some(tool_name),
|
tool_name: Some(tool_name),
|
||||||
tool_arguments: None,
|
tool_arguments: None,
|
||||||
reasoning_content: None,
|
reasoning_content: None,
|
||||||
|
message_id: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -556,6 +567,61 @@ impl OutboundMessage {
|
|||||||
tool_name: Some(tool_name),
|
tool_name: Some(tool_name),
|
||||||
tool_arguments: None,
|
tool_arguments: None,
|
||||||
reasoning_content: None,
|
reasoning_content: None,
|
||||||
|
message_id: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 构造流式文本增量消息
|
||||||
|
pub fn stream_delta(
|
||||||
|
channel: impl Into<String>,
|
||||||
|
chat_id: impl Into<String>,
|
||||||
|
session_id: Option<String>,
|
||||||
|
message_id: impl Into<String>,
|
||||||
|
delta: impl Into<String>,
|
||||||
|
reasoning_delta: Option<String>,
|
||||||
|
metadata: HashMap<String, String>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
channel: channel.into(),
|
||||||
|
chat_id: chat_id.into(),
|
||||||
|
session_id,
|
||||||
|
content: delta.into(),
|
||||||
|
reply_to: None,
|
||||||
|
media: Vec::new(),
|
||||||
|
metadata,
|
||||||
|
event_kind: OutboundEventKind::StreamDelta,
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
tool_call_id: Some(message_id.into()),
|
||||||
|
tool_name: None,
|
||||||
|
tool_arguments: None,
|
||||||
|
reasoning_content: reasoning_delta,
|
||||||
|
message_id: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 构造流式结束信号
|
||||||
|
pub fn stream_end(
|
||||||
|
channel: impl Into<String>,
|
||||||
|
chat_id: impl Into<String>,
|
||||||
|
session_id: Option<String>,
|
||||||
|
message_id: impl Into<String>,
|
||||||
|
metadata: HashMap<String, String>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
channel: channel.into(),
|
||||||
|
chat_id: chat_id.into(),
|
||||||
|
session_id,
|
||||||
|
content: String::new(),
|
||||||
|
reply_to: None,
|
||||||
|
media: Vec::new(),
|
||||||
|
metadata,
|
||||||
|
event_kind: OutboundEventKind::StreamEnd,
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
tool_call_id: Some(message_id.into()),
|
||||||
|
tool_name: None,
|
||||||
|
tool_arguments: None,
|
||||||
|
reasoning_content: None,
|
||||||
|
message_id: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -582,6 +648,7 @@ impl OutboundMessage {
|
|||||||
metadata.clone(),
|
metadata.clone(),
|
||||||
);
|
);
|
||||||
resp.reasoning_content = message.reasoning_content.clone();
|
resp.reasoning_content = message.reasoning_content.clone();
|
||||||
|
resp.message_id = Some(message.id.clone());
|
||||||
outbound.push(resp);
|
outbound.push(resp);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -613,6 +680,7 @@ impl OutboundMessage {
|
|||||||
metadata.clone(),
|
metadata.clone(),
|
||||||
);
|
);
|
||||||
resp.reasoning_content = message.reasoning_content.clone();
|
resp.reasoning_content = message.reasoning_content.clone();
|
||||||
|
resp.message_id = Some(message.id.clone());
|
||||||
vec![resp]
|
vec![resp]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2406,7 +2406,7 @@ impl Channel for FeishuChannel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
|
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
|
||||||
if matches!(msg.event_kind, OutboundEventKind::ToolResult | OutboundEventKind::ToolPending)
|
if matches!(msg.event_kind, OutboundEventKind::ToolResult | OutboundEventKind::ToolPending | OutboundEventKind::StreamDelta | OutboundEventKind::StreamEnd)
|
||||||
|| msg.metadata.get("is_subagent_event").map(|v| v == "true").unwrap_or(false)
|
|| msg.metadata.get("is_subagent_event").map(|v| v == "true").unwrap_or(false)
|
||||||
{
|
{
|
||||||
return Ok(());
|
return Ok(());
|
||||||
|
|||||||
@ -313,6 +313,8 @@ impl Channel for WechatChannel {
|
|||||||
OutboundEventKind::ToolResult
|
OutboundEventKind::ToolResult
|
||||||
| OutboundEventKind::ToolPending
|
| OutboundEventKind::ToolPending
|
||||||
| OutboundEventKind::ToolCall
|
| OutboundEventKind::ToolCall
|
||||||
|
| OutboundEventKind::StreamDelta
|
||||||
|
| OutboundEventKind::StreamEnd
|
||||||
) || msg.metadata.get("is_subagent_event").map(|v| v == "true").unwrap_or(false)
|
) || msg.metadata.get("is_subagent_event").map(|v| v == "true").unwrap_or(false)
|
||||||
{
|
{
|
||||||
return Ok(());
|
return Ok(());
|
||||||
|
|||||||
@ -24,6 +24,9 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
|
|
||||||
let mut input = InputHandler::new();
|
let mut input = InputHandler::new();
|
||||||
let mut current_session_id: Option<String> = None;
|
let mut current_session_id: Option<String> = None;
|
||||||
|
// Track message IDs that were already streamed so we can skip
|
||||||
|
// the duplicate AssistantResponse that arrives afterwards.
|
||||||
|
let mut streamed_message_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
|
||||||
input.write_output("picobot CLI - Commands: /new [title], /save [filepath], /quit\n").await?;
|
input.write_output("picobot CLI - Commands: /new [title], /save [filepath], /quit\n").await?;
|
||||||
|
|
||||||
// Main loop: poll both stdin and WebSocket
|
// Main loop: poll both stdin and WebSocket
|
||||||
@ -36,9 +39,12 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
let text = text.to_string();
|
let text = text.to_string();
|
||||||
if let Ok(outbound) = parse_message(&text) {
|
if let Ok(outbound) = parse_message(&text) {
|
||||||
match outbound {
|
match outbound {
|
||||||
WsOutbound::AssistantResponse { content, .. } => {
|
WsOutbound::AssistantResponse { id, content, .. } => {
|
||||||
|
// Skip if already fully streamed via StreamDelta
|
||||||
|
if !streamed_message_ids.remove(&id) {
|
||||||
input.write_response(&content).await?;
|
input.write_response(&content).await?;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
WsOutbound::ToolCall { tool_name, arguments, .. } => {
|
WsOutbound::ToolCall { tool_name, arguments, .. } => {
|
||||||
input.write_output(&format!("Tool call: {}\n{}\n", tool_name, format_json(&arguments))).await?;
|
input.write_output(&format!("Tool call: {}\n{}\n", tool_name, format_json(&arguments))).await?;
|
||||||
}
|
}
|
||||||
@ -64,6 +70,18 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
WsOutbound::SessionSaved { session_id, filepath } => {
|
WsOutbound::SessionSaved { session_id, filepath } => {
|
||||||
input.write_output(&format!("Saved session {} to: {}\n", session_id, filepath)).await?;
|
input.write_output(&format!("Saved session {} to: {}\n", session_id, filepath)).await?;
|
||||||
}
|
}
|
||||||
|
WsOutbound::StreamDelta { id, delta, .. } => {
|
||||||
|
// Track that this message is being streamed
|
||||||
|
streamed_message_ids.insert(id.clone());
|
||||||
|
// 在终端直接输出流式增量文本
|
||||||
|
if !delta.is_empty() {
|
||||||
|
input.write_output(&delta).await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
WsOutbound::StreamEnd { .. } => {
|
||||||
|
// 流式结束,输出换行
|
||||||
|
input.write_output("\n").await?;
|
||||||
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,6 +2,7 @@ use crate::agent::{AgentError, AgentLoop, ContextCompressor, EmittedMessageHandl
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
use crate::bus::SYSTEM_CONTEXT_SCHEDULED_PROMPT;
|
use crate::bus::SYSTEM_CONTEXT_SCHEDULED_PROMPT;
|
||||||
use crate::bus::{ChatMessage, MessageBus, OutboundMessage};
|
use crate::bus::{ChatMessage, MessageBus, OutboundMessage};
|
||||||
|
use crate::providers::StreamDelta;
|
||||||
use crate::config::LLMProviderConfig;
|
use crate::config::LLMProviderConfig;
|
||||||
use crate::protocol::WsOutbound;
|
use crate::protocol::WsOutbound;
|
||||||
use crate::scheduler::ScheduledAgentTaskOptions;
|
use crate::scheduler::ScheduledAgentTaskOptions;
|
||||||
@ -56,6 +57,7 @@ pub struct BusToolCallEmitter {
|
|||||||
chat_id: String,
|
chat_id: String,
|
||||||
metadata: HashMap<String, String>,
|
metadata: HashMap<String, String>,
|
||||||
store: Arc<SessionStore>,
|
store: Arc<SessionStore>,
|
||||||
|
stream_message_id: std::sync::Mutex<Option<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BusToolCallEmitter {
|
impl BusToolCallEmitter {
|
||||||
@ -72,6 +74,7 @@ impl BusToolCallEmitter {
|
|||||||
chat_id: chat_id.into(),
|
chat_id: chat_id.into(),
|
||||||
metadata,
|
metadata,
|
||||||
store,
|
store,
|
||||||
|
stream_message_id: std::sync::Mutex::new(None),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -116,6 +119,43 @@ impl EmittedMessageHandler for BusToolCallEmitter {
|
|||||||
self.persist_todo_write_result(&message);
|
self.persist_todo_write_result(&message);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn handle_stream_delta(&self, delta: &StreamDelta) {
|
||||||
|
// Get or create the stream message ID
|
||||||
|
let message_id = {
|
||||||
|
let mut guard = self.stream_message_id.lock().unwrap();
|
||||||
|
guard.get_or_insert_with(|| Uuid::new_v4().to_string()).clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Empty content + no reasoning = stream end signal
|
||||||
|
let outbound = if delta.content.is_empty() && delta.reasoning_content.is_none() {
|
||||||
|
OutboundMessage::stream_end(
|
||||||
|
&self.channel_name,
|
||||||
|
&self.chat_id,
|
||||||
|
None,
|
||||||
|
&message_id,
|
||||||
|
self.metadata.clone(),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
OutboundMessage::stream_delta(
|
||||||
|
&self.channel_name,
|
||||||
|
&self.chat_id,
|
||||||
|
None,
|
||||||
|
&message_id,
|
||||||
|
&delta.content,
|
||||||
|
delta.reasoning_content.clone(),
|
||||||
|
self.metadata.clone(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Err(error) = self.bus.publish_outbound(outbound).await {
|
||||||
|
tracing::error!(error = %error, channel = %self.channel_name, "Failed to publish stream delta");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn set_stream_message_id(&self, id: &str) {
|
||||||
|
*self.stream_message_id.lock().unwrap() = Some(id.to_string());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BusToolCallEmitter {
|
impl BusToolCallEmitter {
|
||||||
|
|||||||
@ -265,6 +265,25 @@ pub enum WsOutbound {
|
|||||||
},
|
},
|
||||||
#[serde(rename = "execution_cancelled")]
|
#[serde(rename = "execution_cancelled")]
|
||||||
ExecutionCancelled { message: String },
|
ExecutionCancelled { message: String },
|
||||||
|
#[serde(rename = "stream_delta")]
|
||||||
|
StreamDelta {
|
||||||
|
id: String,
|
||||||
|
delta: String,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
reasoning_delta: Option<String>,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
subagent_task_id: Option<String>,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
topic_id: Option<String>,
|
||||||
|
},
|
||||||
|
#[serde(rename = "stream_end")]
|
||||||
|
StreamEnd {
|
||||||
|
id: String,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
subagent_task_id: Option<String>,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
topic_id: Option<String>,
|
||||||
|
},
|
||||||
#[serde(rename = "todo_list")]
|
#[serde(rename = "todo_list")]
|
||||||
TodoList {
|
TodoList {
|
||||||
todos: Vec<TodoItemSummary>,
|
todos: Vec<TodoItemSummary>,
|
||||||
|
|||||||
@ -104,7 +104,7 @@ pub(crate) fn ws_outbound_from_outbound_message(message: &OutboundMessage) -> Ve
|
|||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
vec![WsOutbound::AssistantResponse {
|
vec![WsOutbound::AssistantResponse {
|
||||||
id: uuid::Uuid::new_v4().to_string(),
|
id: message.message_id.clone().unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
||||||
content: message.content.clone(),
|
content: message.content.clone(),
|
||||||
role: message.role.clone(),
|
role: message.role.clone(),
|
||||||
attachments,
|
attachments,
|
||||||
@ -174,6 +174,18 @@ pub(crate) fn ws_outbound_from_outbound_message(message: &OutboundMessage) -> Ve
|
|||||||
subagent_type: message.metadata.get("task_subagent_type").cloned().unwrap_or_default(),
|
subagent_type: message.metadata.get("task_subagent_type").cloned().unwrap_or_default(),
|
||||||
topic_id: message.metadata.get("topic_id").cloned(),
|
topic_id: message.metadata.get("topic_id").cloned(),
|
||||||
}],
|
}],
|
||||||
|
OutboundEventKind::StreamDelta => vec![WsOutbound::StreamDelta {
|
||||||
|
id: message.tool_call_id.clone().unwrap_or_default(),
|
||||||
|
delta: message.content.clone(),
|
||||||
|
reasoning_delta: message.reasoning_content.clone(),
|
||||||
|
subagent_task_id: message.metadata.get("subagent_task_id").cloned(),
|
||||||
|
topic_id: message.metadata.get("topic_id").cloned(),
|
||||||
|
}],
|
||||||
|
OutboundEventKind::StreamEnd => vec![WsOutbound::StreamEnd {
|
||||||
|
id: message.tool_call_id.clone().unwrap_or_default(),
|
||||||
|
subagent_task_id: message.metadata.get("subagent_task_id").cloned(),
|
||||||
|
topic_id: message.metadata.get("topic_id").cloned(),
|
||||||
|
}],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -9,7 +9,7 @@ pub use crate::domain::messages::ToolCall;
|
|||||||
pub use crate::domain::tools::{Tool, ToolFunction};
|
pub use crate::domain::tools::{Tool, ToolFunction};
|
||||||
pub use traits::{
|
pub use traits::{
|
||||||
ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, ProviderRuntimeConfig,
|
ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, ProviderRuntimeConfig,
|
||||||
Usage,
|
StreamCallback, StreamDelta, Usage,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn create_provider(
|
pub fn create_provider(
|
||||||
|
|||||||
@ -6,7 +6,7 @@ use serde_json::{Value, json};
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use super::traits::Usage;
|
use super::traits::{StreamCallback, StreamDelta, Usage};
|
||||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
||||||
use crate::domain::messages::ContentBlock;
|
use crate::domain::messages::ContentBlock;
|
||||||
|
|
||||||
@ -352,10 +352,11 @@ impl OpenAIProvider {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 内部流式聊天实现
|
/// 内部流式聊天实现,可选传入流式回调
|
||||||
async fn chat_streaming(
|
async fn chat_streaming_internal(
|
||||||
&self,
|
&self,
|
||||||
request: &ChatCompletionRequest,
|
request: &ChatCompletionRequest,
|
||||||
|
stream_callback: Option<&StreamCallback>,
|
||||||
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
tracing::debug!(provider = %self.name, model = %self.model_id, "Starting streaming chat");
|
tracing::debug!(provider = %self.name, model = %self.model_id, "Starting streaming chat");
|
||||||
|
|
||||||
@ -444,11 +445,23 @@ impl OpenAIProvider {
|
|||||||
// 提取内容增量
|
// 提取内容增量
|
||||||
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
|
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
|
||||||
accumulator.add_content(content);
|
accumulator.add_content(content);
|
||||||
|
if let Some(cb) = &stream_callback {
|
||||||
|
cb(StreamDelta {
|
||||||
|
content: content.to_string(),
|
||||||
|
reasoning_content: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 提取推理内容增量
|
// 提取推理内容增量
|
||||||
if let Some(reasoning) = delta.get("reasoning_content").and_then(|r| r.as_str()) {
|
if let Some(reasoning) = delta.get("reasoning_content").and_then(|r| r.as_str()) {
|
||||||
accumulator.add_reasoning_content(reasoning);
|
accumulator.add_reasoning_content(reasoning);
|
||||||
|
if let Some(cb) = &stream_callback {
|
||||||
|
cb(StreamDelta {
|
||||||
|
content: String::new(),
|
||||||
|
reasoning_content: Some(reasoning.to_string()),
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 提取工具调用增量
|
// 提取工具调用增量
|
||||||
@ -523,9 +536,21 @@ impl OpenAIProvider {
|
|||||||
if let Some(delta) = choice.get("delta") {
|
if let Some(delta) = choice.get("delta") {
|
||||||
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
|
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
|
||||||
accumulator.add_content(content);
|
accumulator.add_content(content);
|
||||||
|
if let Some(cb) = &stream_callback {
|
||||||
|
cb(StreamDelta {
|
||||||
|
content: content.to_string(),
|
||||||
|
reasoning_content: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if let Some(reasoning) = delta.get("reasoning_content").and_then(|r| r.as_str()) {
|
if let Some(reasoning) = delta.get("reasoning_content").and_then(|r| r.as_str()) {
|
||||||
accumulator.add_reasoning_content(reasoning);
|
accumulator.add_reasoning_content(reasoning);
|
||||||
|
if let Some(cb) = &stream_callback {
|
||||||
|
cb(StreamDelta {
|
||||||
|
content: String::new(),
|
||||||
|
reasoning_content: Some(reasoning.to_string()),
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) {
|
if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) {
|
||||||
for tool_call in tool_calls {
|
for tool_call in tool_calls {
|
||||||
@ -737,8 +762,8 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
// 检查是否启用流式输出
|
// 检查是否启用流式输出
|
||||||
if self.is_streaming_enabled() {
|
if self.is_streaming_enabled() {
|
||||||
// 优先尝试流式输出
|
// 优先尝试流式输出(无回调)
|
||||||
match self.chat_streaming(&request).await {
|
match self.chat_streaming_internal(&request, None).await {
|
||||||
Ok(response) => return Ok(response),
|
Ok(response) => return Ok(response),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
@ -890,6 +915,28 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn chat_with_streaming(
|
||||||
|
&self,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
callback: StreamCallback,
|
||||||
|
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
if self.is_streaming_enabled() {
|
||||||
|
match self.chat_streaming_internal(&request, Some(&callback)).await {
|
||||||
|
Ok(response) => return Ok(response),
|
||||||
|
Err(e) => {
|
||||||
|
tracing::debug!(
|
||||||
|
provider = %self.name,
|
||||||
|
model = %self.model_id,
|
||||||
|
error = %e,
|
||||||
|
"Streaming (with callback) failed, falling back to non-streaming"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 回退到非流式
|
||||||
|
self.chat(request).await
|
||||||
|
}
|
||||||
|
|
||||||
fn ptype(&self) -> &str {
|
fn ptype(&self) -> &str {
|
||||||
"openai"
|
"openai"
|
||||||
}
|
}
|
||||||
|
|||||||
@ -4,6 +4,7 @@ use crate::config::LLMProviderConfig;
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct ProviderRuntimeConfig {
|
pub struct ProviderRuntimeConfig {
|
||||||
@ -137,6 +138,18 @@ pub struct Usage {
|
|||||||
pub total_tokens: u32,
|
pub total_tokens: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 流式响应中的增量事件
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct StreamDelta {
|
||||||
|
/// 文本内容增量
|
||||||
|
pub content: String,
|
||||||
|
/// 推理内容增量
|
||||||
|
pub reasoning_content: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 流式回调类型:每收到一个 delta 就调用一次
|
||||||
|
pub type StreamCallback = Arc<dyn Fn(StreamDelta) + Send + Sync>;
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait LLMProvider: Send + Sync {
|
pub trait LLMProvider: Send + Sync {
|
||||||
async fn chat(
|
async fn chat(
|
||||||
@ -144,6 +157,17 @@ pub trait LLMProvider: Send + Sync {
|
|||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>>;
|
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>>;
|
||||||
|
|
||||||
|
/// 带流式回调的 chat:每收到一个 SSE delta 就调用 callback。
|
||||||
|
/// 返回值与 `chat()` 相同(完整的 ChatCompletionResponse)。
|
||||||
|
/// 默认实现忽略 callback,直接调用 chat()。
|
||||||
|
async fn chat_with_streaming(
|
||||||
|
&self,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
_callback: StreamCallback,
|
||||||
|
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
self.chat(request).await
|
||||||
|
}
|
||||||
|
|
||||||
fn ptype(&self) -> &str;
|
fn ptype(&self) -> &str;
|
||||||
|
|
||||||
fn name(&self) -> &str;
|
fn name(&self) -> &str;
|
||||||
|
|||||||
@ -499,6 +499,7 @@ impl SubAgentRuntime for DefaultSubAgentRuntime {
|
|||||||
tool_name: None,
|
tool_name: None,
|
||||||
tool_arguments: None,
|
tool_arguments: None,
|
||||||
reasoning_content: None,
|
reasoning_content: None,
|
||||||
|
message_id: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Err(e) = bus.publish_outbound(event).await {
|
if let Err(e) = bus.publish_outbound(event).await {
|
||||||
|
|||||||
@ -27,6 +27,7 @@ import type {
|
|||||||
SchedulerJobSessionLookup,
|
SchedulerJobSessionLookup,
|
||||||
Channel,
|
Channel,
|
||||||
ChannelList,
|
ChannelList,
|
||||||
|
StreamDelta,
|
||||||
} from '../types/protocol'
|
} from '../types/protocol'
|
||||||
|
|
||||||
// 简化后的层级状态
|
// 简化后的层级状态
|
||||||
@ -446,14 +447,55 @@ export function useChat(): UseChatReturn {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case 'stream_delta': {
|
||||||
|
const msg = message as StreamDelta
|
||||||
|
if (msg.topic_id && msg.topic_id !== selectedTopicRef.current) return
|
||||||
|
setMessages((prev) => {
|
||||||
|
const existingIdx = prev.findIndex(m => m.id === msg.id && m.type === 'message')
|
||||||
|
if (existingIdx >= 0) {
|
||||||
|
// 追加到已有消息
|
||||||
|
const updated = [...prev]
|
||||||
|
const existing = updated[existingIdx]
|
||||||
|
updated[existingIdx] = {
|
||||||
|
...existing,
|
||||||
|
content: existing.content + msg.delta,
|
||||||
|
reasoningContent: msg.reasoning_delta
|
||||||
|
? (existing.reasoningContent || '') + msg.reasoning_delta
|
||||||
|
: existing.reasoningContent,
|
||||||
|
}
|
||||||
|
return updated
|
||||||
|
}
|
||||||
|
// 创建新消息
|
||||||
|
return [
|
||||||
|
...prev,
|
||||||
|
{
|
||||||
|
id: msg.id,
|
||||||
|
role: 'assistant' as const,
|
||||||
|
content: msg.delta,
|
||||||
|
timestamp: Math.floor(Date.now() / 1000),
|
||||||
|
type: 'message' as const,
|
||||||
|
reasoningContent: msg.reasoning_delta,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
})
|
||||||
|
setIsLoading(false)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'stream_end': {
|
||||||
|
// 流式结束,无需额外操作,后续 assistant_response 会替换完整内容
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
case 'assistant_response': {
|
case 'assistant_response': {
|
||||||
const msg = message as AssistantResponse
|
const msg = message as AssistantResponse
|
||||||
// 按 topic_id 隔离:如果消息属于其他话题则丢弃
|
// 按 topic_id 隔离:如果消息属于其他话题则丢弃
|
||||||
if (msg.topic_id && msg.topic_id !== selectedTopicRef.current) return
|
if (msg.topic_id && msg.topic_id !== selectedTopicRef.current) return
|
||||||
const role = msg.role === 'user' || msg.role === 'tool' ? msg.role : 'assistant'
|
const role = msg.role === 'user' || msg.role === 'tool' ? msg.role : 'assistant'
|
||||||
setMessages((prev) => [
|
setMessages((prev) => {
|
||||||
...prev,
|
// 如果流式消息已存在(相同 id),替换它
|
||||||
{
|
const existingIdx = prev.findIndex(m => m.id === msg.id && m.type === 'message')
|
||||||
|
const newMsg: ChatMessage = {
|
||||||
id: msg.id,
|
id: msg.id,
|
||||||
role,
|
role,
|
||||||
content: msg.content,
|
content: msg.content,
|
||||||
@ -461,8 +503,14 @@ export function useChat(): UseChatReturn {
|
|||||||
type: 'message',
|
type: 'message',
|
||||||
attachments: msg.attachments,
|
attachments: msg.attachments,
|
||||||
reasoningContent: msg.reasoning_content,
|
reasoningContent: msg.reasoning_content,
|
||||||
},
|
}
|
||||||
])
|
if (existingIdx >= 0) {
|
||||||
|
const updated = [...prev]
|
||||||
|
updated[existingIdx] = newMsg
|
||||||
|
return updated
|
||||||
|
}
|
||||||
|
return [...prev, newMsg]
|
||||||
|
})
|
||||||
setIsLoading(false)
|
setIsLoading(false)
|
||||||
|
|
||||||
// 当前话题无描述时,可能刚触发了异步生成,标记需要刷新
|
// 当前话题无描述时,可能刚触发了异步生成,标记需要刷新
|
||||||
|
|||||||
@ -256,6 +256,22 @@ export interface ExecutionCancelled {
|
|||||||
message: string
|
message: string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface StreamDelta {
|
||||||
|
type: 'stream_delta'
|
||||||
|
id: string
|
||||||
|
delta: string
|
||||||
|
reasoning_delta?: string
|
||||||
|
subagent_task_id?: string
|
||||||
|
topic_id?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface StreamEnd {
|
||||||
|
type: 'stream_end'
|
||||||
|
id: string
|
||||||
|
subagent_task_id?: string
|
||||||
|
topic_id?: string
|
||||||
|
}
|
||||||
|
|
||||||
export type WsOutbound =
|
export type WsOutbound =
|
||||||
| AssistantResponse
|
| AssistantResponse
|
||||||
| ToolCall
|
| ToolCall
|
||||||
@ -263,6 +279,8 @@ export type WsOutbound =
|
|||||||
| ToolPending
|
| ToolPending
|
||||||
| WsError
|
| WsError
|
||||||
| TaskStarted
|
| TaskStarted
|
||||||
|
| StreamDelta
|
||||||
|
| StreamEnd
|
||||||
| SessionEstablished
|
| SessionEstablished
|
||||||
| SessionCreated
|
| SessionCreated
|
||||||
| SessionList
|
| SessionList
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user