feat(streaming): 支持流式文本增量与结束信号功能
- 新增 StreamDelta 和 StreamEnd 类型,支持流式数据增量传输 - 扩展 LLMProvider trait,添加带回调的 chat_with_streaming 接口 - 修改 OpenAI Provider 实现,支持流式聊天回调传输增量数据 - Agent 流处理改为异步消费增量消息并传递给前端 - 保证流式增量和最终消息使用相同消息 ID 以便前端替换 - 修改消息总线和协议层,支持携带和识别流式消息的消息 ID - 客户端 CLI 通过增量输出实现交互式流式响应显示 - Web 前端接收流式增量消息,追加到对应消息,实现实时显示 - 各通道(飞书、微信)支持转发流式增量和结束消息 - 任务工具运行时添加消息 ID 支持,保持消息一致性 - 统一消息构造函数新增流式增量和结束信号的构建方法
This commit is contained in:
parent
def6df50da
commit
fc7df67474
@ -7,7 +7,7 @@ use crate::domain::messages::{ContentBlock, ToolCall};
|
||||
use crate::observability::{
|
||||
Observer, ObserverEvent, ToolExecutionOutcome, ToolExecutionState, truncate_args,
|
||||
};
|
||||
use crate::providers::{ChatCompletionRequest, LLMProvider, Message, create_provider};
|
||||
use crate::providers::{ChatCompletionRequest, LLMProvider, Message, StreamDelta, StreamCallback, create_provider};
|
||||
use crate::text::{char_count, take_prefix_chars, take_suffix_chars};
|
||||
use crate::tools::{ToolContext, ToolRegistry};
|
||||
use async_trait::async_trait;
|
||||
@ -663,6 +663,17 @@ pub trait EmittedMessageHandler: Send + Sync + 'static {
|
||||
async fn handle_tool_result(&self, message: ChatMessage, _duration_ms: Option<u64>) {
|
||||
self.handle(message).await;
|
||||
}
|
||||
|
||||
/// Handle a streaming delta. Default is no-op.
|
||||
async fn handle_stream_delta(&self, _delta: &StreamDelta) {
|
||||
// Non-streaming channels ignore this
|
||||
}
|
||||
|
||||
/// Set the message ID to use for stream deltas (so the final assistant message
|
||||
/// can share the same ID, enabling front-end replacement).
|
||||
async fn set_stream_message_id(&self, _id: &str) {
|
||||
// Default: no-op for handlers that don't stream
|
||||
}
|
||||
}
|
||||
|
||||
/// 装饰器:在内部 emitter 广播前,先将消息持久化到 DB
|
||||
@ -706,6 +717,15 @@ impl<H: EmittedMessageHandler> EmittedMessageHandler for PersistingEmittedMessag
|
||||
}
|
||||
self.inner.handle_tool_result(message, duration_ms).await;
|
||||
}
|
||||
|
||||
async fn set_stream_message_id(&self, id: &str) {
|
||||
self.inner.set_stream_message_id(id).await;
|
||||
}
|
||||
|
||||
async fn handle_stream_delta(&self, delta: &StreamDelta) {
|
||||
// Deltas are transient — do NOT persist, just forward to inner handler
|
||||
self.inner.handle_stream_delta(delta).await;
|
||||
}
|
||||
}
|
||||
|
||||
pub trait SkillProvider: Send + Sync + 'static {
|
||||
@ -958,9 +978,34 @@ impl AgentLoop {
|
||||
tools,
|
||||
};
|
||||
|
||||
let response = match (*self.provider).chat(request).await {
|
||||
// Set up streaming delta consumer
|
||||
// Pre-generate the message ID so stream deltas and the final assistant
|
||||
// message share the same ID — this lets the front-end replace the
|
||||
// streamed message with the authoritative response.
|
||||
let streaming_message_id = uuid::Uuid::new_v4().to_string();
|
||||
if let Some(handler) = &self.emitted_message_handler {
|
||||
handler.set_stream_message_id(&streaming_message_id).await;
|
||||
}
|
||||
|
||||
let (delta_tx, mut delta_rx) = tokio::sync::mpsc::channel::<StreamDelta>(256);
|
||||
let consumer_handler = self.emitted_message_handler.clone();
|
||||
let consumer_task = tokio::spawn(async move {
|
||||
while let Some(delta) = delta_rx.recv().await {
|
||||
if let Some(ref handler) = consumer_handler {
|
||||
handler.handle_stream_delta(&delta).await;
|
||||
}
|
||||
}
|
||||
});
|
||||
let stream_callback: StreamCallback = std::sync::Arc::new(move |delta: StreamDelta| {
|
||||
// try_send is non-blocking and safe to call from within a tokio runtime
|
||||
let _ = delta_tx.try_send(delta);
|
||||
});
|
||||
|
||||
let response = match (*self.provider).chat_with_streaming(request, stream_callback).await {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
// delta_tx is dropped with the callback; await consumer to finish
|
||||
let _ = consumer_task.await;
|
||||
tracing::error!(
|
||||
provider = %self.provider.name(),
|
||||
model = %self.provider.model_id(),
|
||||
@ -979,6 +1024,22 @@ impl AgentLoop {
|
||||
}
|
||||
};
|
||||
|
||||
// Close delta channel and wait for consumer to finish processing
|
||||
// (delta_tx is dropped when the callback closure is dropped)
|
||||
let _ = consumer_task.await;
|
||||
|
||||
// Signal stream end if handler exists
|
||||
let had_streaming = self.emitted_message_handler.is_some();
|
||||
if had_streaming {
|
||||
let end_delta = StreamDelta {
|
||||
content: String::new(),
|
||||
reasoning_content: None,
|
||||
};
|
||||
if let Some(handler) = &self.emitted_message_handler {
|
||||
handler.handle_stream_delta(&end_delta).await;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
tracing::debug!(
|
||||
iteration,
|
||||
@ -989,12 +1050,17 @@ impl AgentLoop {
|
||||
|
||||
// If no tool calls, this is the final response
|
||||
if response.tool_calls.is_empty() {
|
||||
let assistant_message = if let Some(reasoning_content) = response.reasoning_content
|
||||
let mut assistant_message = if let Some(reasoning_content) = response.reasoning_content
|
||||
{
|
||||
ChatMessage::assistant_with_reasoning(response.content, reasoning_content)
|
||||
} else {
|
||||
ChatMessage::assistant(response.content)
|
||||
};
|
||||
// Use the same ID as the stream deltas so the front-end can replace
|
||||
// the streamed message with this authoritative response.
|
||||
if had_streaming {
|
||||
assistant_message.id = streaming_message_id;
|
||||
}
|
||||
emitted_messages.push(assistant_message.clone());
|
||||
self.emit_live_tool_call_message(assistant_message.clone()).await;
|
||||
return Ok(AgentProcessResult {
|
||||
@ -1011,7 +1077,7 @@ impl AgentLoop {
|
||||
);
|
||||
|
||||
// Add assistant message with tool calls
|
||||
let assistant_message =
|
||||
let mut assistant_message =
|
||||
if let Some(reasoning_content) = response.reasoning_content.clone() {
|
||||
ChatMessage::assistant_with_tool_calls_and_reasoning(
|
||||
response.content.clone(),
|
||||
@ -1024,6 +1090,11 @@ impl AgentLoop {
|
||||
response.tool_calls.clone(),
|
||||
)
|
||||
};
|
||||
// Use the same ID as stream deltas so the front-end can replace
|
||||
// the streamed message with this authoritative response.
|
||||
if had_streaming {
|
||||
assistant_message.id = streaming_message_id;
|
||||
}
|
||||
messages.push(assistant_message.clone());
|
||||
emitted_messages.push(assistant_message);
|
||||
self.emit_live_tool_call_message(
|
||||
|
||||
@ -401,6 +401,10 @@ pub struct OutboundMessage {
|
||||
pub tool_name: Option<String>,
|
||||
pub tool_arguments: Option<serde_json::Value>,
|
||||
pub reasoning_content: Option<String>,
|
||||
/// Carry the originating ChatMessage.id so the WS layer can use it
|
||||
/// instead of generating a random UUID. Critical for stream delta → assistant_response
|
||||
/// ID matching on the front-end.
|
||||
pub message_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
@ -412,11 +416,15 @@ pub enum OutboundEventKind {
|
||||
SchedulerNotification,
|
||||
ErrorNotification,
|
||||
TaskStarted,
|
||||
/// 流式文本增量
|
||||
StreamDelta,
|
||||
/// 流式结束信号
|
||||
StreamEnd,
|
||||
}
|
||||
|
||||
impl OutboundMessage {
|
||||
pub fn is_stream_delta(&self) -> bool {
|
||||
self.metadata.get("_stream_delta").is_some()
|
||||
matches!(self.event_kind, OutboundEventKind::StreamDelta | OutboundEventKind::StreamEnd)
|
||||
}
|
||||
|
||||
pub fn assistant(
|
||||
@ -441,6 +449,7 @@ impl OutboundMessage {
|
||||
tool_name: None,
|
||||
tool_arguments: None,
|
||||
reasoning_content: None,
|
||||
message_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
@ -496,6 +505,7 @@ impl OutboundMessage {
|
||||
tool_name: Some(tool_name),
|
||||
tool_arguments: Some(tool_arguments),
|
||||
reasoning_content: None,
|
||||
message_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
@ -526,6 +536,7 @@ impl OutboundMessage {
|
||||
tool_name: Some(tool_name),
|
||||
tool_arguments: None,
|
||||
reasoning_content: None,
|
||||
message_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
@ -556,6 +567,61 @@ impl OutboundMessage {
|
||||
tool_name: Some(tool_name),
|
||||
tool_arguments: None,
|
||||
reasoning_content: None,
|
||||
message_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// 构造流式文本增量消息
|
||||
pub fn stream_delta(
|
||||
channel: impl Into<String>,
|
||||
chat_id: impl Into<String>,
|
||||
session_id: Option<String>,
|
||||
message_id: impl Into<String>,
|
||||
delta: impl Into<String>,
|
||||
reasoning_delta: Option<String>,
|
||||
metadata: HashMap<String, String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
channel: channel.into(),
|
||||
chat_id: chat_id.into(),
|
||||
session_id,
|
||||
content: delta.into(),
|
||||
reply_to: None,
|
||||
media: Vec::new(),
|
||||
metadata,
|
||||
event_kind: OutboundEventKind::StreamDelta,
|
||||
role: "assistant".to_string(),
|
||||
tool_call_id: Some(message_id.into()),
|
||||
tool_name: None,
|
||||
tool_arguments: None,
|
||||
reasoning_content: reasoning_delta,
|
||||
message_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// 构造流式结束信号
|
||||
pub fn stream_end(
|
||||
channel: impl Into<String>,
|
||||
chat_id: impl Into<String>,
|
||||
session_id: Option<String>,
|
||||
message_id: impl Into<String>,
|
||||
metadata: HashMap<String, String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
channel: channel.into(),
|
||||
chat_id: chat_id.into(),
|
||||
session_id,
|
||||
content: String::new(),
|
||||
reply_to: None,
|
||||
media: Vec::new(),
|
||||
metadata,
|
||||
event_kind: OutboundEventKind::StreamEnd,
|
||||
role: "assistant".to_string(),
|
||||
tool_call_id: Some(message_id.into()),
|
||||
tool_name: None,
|
||||
tool_arguments: None,
|
||||
reasoning_content: None,
|
||||
message_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
@ -582,6 +648,7 @@ impl OutboundMessage {
|
||||
metadata.clone(),
|
||||
);
|
||||
resp.reasoning_content = message.reasoning_content.clone();
|
||||
resp.message_id = Some(message.id.clone());
|
||||
outbound.push(resp);
|
||||
}
|
||||
|
||||
@ -613,6 +680,7 @@ impl OutboundMessage {
|
||||
metadata.clone(),
|
||||
);
|
||||
resp.reasoning_content = message.reasoning_content.clone();
|
||||
resp.message_id = Some(message.id.clone());
|
||||
vec![resp]
|
||||
}
|
||||
}
|
||||
|
||||
@ -2406,7 +2406,7 @@ impl Channel for FeishuChannel {
|
||||
}
|
||||
|
||||
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
|
||||
if matches!(msg.event_kind, OutboundEventKind::ToolResult | OutboundEventKind::ToolPending)
|
||||
if matches!(msg.event_kind, OutboundEventKind::ToolResult | OutboundEventKind::ToolPending | OutboundEventKind::StreamDelta | OutboundEventKind::StreamEnd)
|
||||
|| msg.metadata.get("is_subagent_event").map(|v| v == "true").unwrap_or(false)
|
||||
{
|
||||
return Ok(());
|
||||
|
||||
@ -313,6 +313,8 @@ impl Channel for WechatChannel {
|
||||
OutboundEventKind::ToolResult
|
||||
| OutboundEventKind::ToolPending
|
||||
| OutboundEventKind::ToolCall
|
||||
| OutboundEventKind::StreamDelta
|
||||
| OutboundEventKind::StreamEnd
|
||||
) || msg.metadata.get("is_subagent_event").map(|v| v == "true").unwrap_or(false)
|
||||
{
|
||||
return Ok(());
|
||||
|
||||
@ -24,6 +24,9 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
|
||||
let mut input = InputHandler::new();
|
||||
let mut current_session_id: Option<String> = None;
|
||||
// Track message IDs that were already streamed so we can skip
|
||||
// the duplicate AssistantResponse that arrives afterwards.
|
||||
let mut streamed_message_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
|
||||
input.write_output("picobot CLI - Commands: /new [title], /save [filepath], /quit\n").await?;
|
||||
|
||||
// Main loop: poll both stdin and WebSocket
|
||||
@ -36,8 +39,11 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let text = text.to_string();
|
||||
if let Ok(outbound) = parse_message(&text) {
|
||||
match outbound {
|
||||
WsOutbound::AssistantResponse { content, .. } => {
|
||||
input.write_response(&content).await?;
|
||||
WsOutbound::AssistantResponse { id, content, .. } => {
|
||||
// Skip if already fully streamed via StreamDelta
|
||||
if !streamed_message_ids.remove(&id) {
|
||||
input.write_response(&content).await?;
|
||||
}
|
||||
}
|
||||
WsOutbound::ToolCall { tool_name, arguments, .. } => {
|
||||
input.write_output(&format!("Tool call: {}\n{}\n", tool_name, format_json(&arguments))).await?;
|
||||
@ -64,6 +70,18 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
WsOutbound::SessionSaved { session_id, filepath } => {
|
||||
input.write_output(&format!("Saved session {} to: {}\n", session_id, filepath)).await?;
|
||||
}
|
||||
WsOutbound::StreamDelta { id, delta, .. } => {
|
||||
// Track that this message is being streamed
|
||||
streamed_message_ids.insert(id.clone());
|
||||
// 在终端直接输出流式增量文本
|
||||
if !delta.is_empty() {
|
||||
input.write_output(&delta).await?;
|
||||
}
|
||||
}
|
||||
WsOutbound::StreamEnd { .. } => {
|
||||
// 流式结束,输出换行
|
||||
input.write_output("\n").await?;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
@ -2,6 +2,7 @@ use crate::agent::{AgentError, AgentLoop, ContextCompressor, EmittedMessageHandl
|
||||
#[cfg(test)]
|
||||
use crate::bus::SYSTEM_CONTEXT_SCHEDULED_PROMPT;
|
||||
use crate::bus::{ChatMessage, MessageBus, OutboundMessage};
|
||||
use crate::providers::StreamDelta;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::protocol::WsOutbound;
|
||||
use crate::scheduler::ScheduledAgentTaskOptions;
|
||||
@ -56,6 +57,7 @@ pub struct BusToolCallEmitter {
|
||||
chat_id: String,
|
||||
metadata: HashMap<String, String>,
|
||||
store: Arc<SessionStore>,
|
||||
stream_message_id: std::sync::Mutex<Option<String>>,
|
||||
}
|
||||
|
||||
impl BusToolCallEmitter {
|
||||
@ -72,6 +74,7 @@ impl BusToolCallEmitter {
|
||||
chat_id: chat_id.into(),
|
||||
metadata,
|
||||
store,
|
||||
stream_message_id: std::sync::Mutex::new(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -116,6 +119,43 @@ impl EmittedMessageHandler for BusToolCallEmitter {
|
||||
self.persist_todo_write_result(&message);
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_stream_delta(&self, delta: &StreamDelta) {
|
||||
// Get or create the stream message ID
|
||||
let message_id = {
|
||||
let mut guard = self.stream_message_id.lock().unwrap();
|
||||
guard.get_or_insert_with(|| Uuid::new_v4().to_string()).clone()
|
||||
};
|
||||
|
||||
// Empty content + no reasoning = stream end signal
|
||||
let outbound = if delta.content.is_empty() && delta.reasoning_content.is_none() {
|
||||
OutboundMessage::stream_end(
|
||||
&self.channel_name,
|
||||
&self.chat_id,
|
||||
None,
|
||||
&message_id,
|
||||
self.metadata.clone(),
|
||||
)
|
||||
} else {
|
||||
OutboundMessage::stream_delta(
|
||||
&self.channel_name,
|
||||
&self.chat_id,
|
||||
None,
|
||||
&message_id,
|
||||
&delta.content,
|
||||
delta.reasoning_content.clone(),
|
||||
self.metadata.clone(),
|
||||
)
|
||||
};
|
||||
|
||||
if let Err(error) = self.bus.publish_outbound(outbound).await {
|
||||
tracing::error!(error = %error, channel = %self.channel_name, "Failed to publish stream delta");
|
||||
}
|
||||
}
|
||||
|
||||
async fn set_stream_message_id(&self, id: &str) {
|
||||
*self.stream_message_id.lock().unwrap() = Some(id.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
impl BusToolCallEmitter {
|
||||
|
||||
@ -265,6 +265,25 @@ pub enum WsOutbound {
|
||||
},
|
||||
#[serde(rename = "execution_cancelled")]
|
||||
ExecutionCancelled { message: String },
|
||||
#[serde(rename = "stream_delta")]
|
||||
StreamDelta {
|
||||
id: String,
|
||||
delta: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
reasoning_delta: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
subagent_task_id: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
topic_id: Option<String>,
|
||||
},
|
||||
#[serde(rename = "stream_end")]
|
||||
StreamEnd {
|
||||
id: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
subagent_task_id: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
topic_id: Option<String>,
|
||||
},
|
||||
#[serde(rename = "todo_list")]
|
||||
TodoList {
|
||||
todos: Vec<TodoItemSummary>,
|
||||
|
||||
@ -104,7 +104,7 @@ pub(crate) fn ws_outbound_from_outbound_message(message: &OutboundMessage) -> Ve
|
||||
})
|
||||
.collect();
|
||||
vec![WsOutbound::AssistantResponse {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
id: message.message_id.clone().unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
||||
content: message.content.clone(),
|
||||
role: message.role.clone(),
|
||||
attachments,
|
||||
@ -174,6 +174,18 @@ pub(crate) fn ws_outbound_from_outbound_message(message: &OutboundMessage) -> Ve
|
||||
subagent_type: message.metadata.get("task_subagent_type").cloned().unwrap_or_default(),
|
||||
topic_id: message.metadata.get("topic_id").cloned(),
|
||||
}],
|
||||
OutboundEventKind::StreamDelta => vec![WsOutbound::StreamDelta {
|
||||
id: message.tool_call_id.clone().unwrap_or_default(),
|
||||
delta: message.content.clone(),
|
||||
reasoning_delta: message.reasoning_content.clone(),
|
||||
subagent_task_id: message.metadata.get("subagent_task_id").cloned(),
|
||||
topic_id: message.metadata.get("topic_id").cloned(),
|
||||
}],
|
||||
OutboundEventKind::StreamEnd => vec![WsOutbound::StreamEnd {
|
||||
id: message.tool_call_id.clone().unwrap_or_default(),
|
||||
subagent_task_id: message.metadata.get("subagent_task_id").cloned(),
|
||||
topic_id: message.metadata.get("topic_id").cloned(),
|
||||
}],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ pub use crate::domain::messages::ToolCall;
|
||||
pub use crate::domain::tools::{Tool, ToolFunction};
|
||||
pub use traits::{
|
||||
ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, ProviderRuntimeConfig,
|
||||
Usage,
|
||||
StreamCallback, StreamDelta, Usage,
|
||||
};
|
||||
|
||||
pub fn create_provider(
|
||||
|
||||
@ -6,7 +6,7 @@ use serde_json::{Value, json};
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use super::traits::Usage;
|
||||
use super::traits::{StreamCallback, StreamDelta, Usage};
|
||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
||||
use crate::domain::messages::ContentBlock;
|
||||
|
||||
@ -352,10 +352,11 @@ impl OpenAIProvider {
|
||||
})
|
||||
}
|
||||
|
||||
/// 内部流式聊天实现
|
||||
async fn chat_streaming(
|
||||
/// 内部流式聊天实现,可选传入流式回调
|
||||
async fn chat_streaming_internal(
|
||||
&self,
|
||||
request: &ChatCompletionRequest,
|
||||
stream_callback: Option<&StreamCallback>,
|
||||
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
tracing::debug!(provider = %self.name, model = %self.model_id, "Starting streaming chat");
|
||||
|
||||
@ -444,11 +445,23 @@ impl OpenAIProvider {
|
||||
// 提取内容增量
|
||||
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
|
||||
accumulator.add_content(content);
|
||||
if let Some(cb) = &stream_callback {
|
||||
cb(StreamDelta {
|
||||
content: content.to_string(),
|
||||
reasoning_content: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// 提取推理内容增量
|
||||
if let Some(reasoning) = delta.get("reasoning_content").and_then(|r| r.as_str()) {
|
||||
accumulator.add_reasoning_content(reasoning);
|
||||
if let Some(cb) = &stream_callback {
|
||||
cb(StreamDelta {
|
||||
content: String::new(),
|
||||
reasoning_content: Some(reasoning.to_string()),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// 提取工具调用增量
|
||||
@ -523,9 +536,21 @@ impl OpenAIProvider {
|
||||
if let Some(delta) = choice.get("delta") {
|
||||
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
|
||||
accumulator.add_content(content);
|
||||
if let Some(cb) = &stream_callback {
|
||||
cb(StreamDelta {
|
||||
content: content.to_string(),
|
||||
reasoning_content: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
if let Some(reasoning) = delta.get("reasoning_content").and_then(|r| r.as_str()) {
|
||||
accumulator.add_reasoning_content(reasoning);
|
||||
if let Some(cb) = &stream_callback {
|
||||
cb(StreamDelta {
|
||||
content: String::new(),
|
||||
reasoning_content: Some(reasoning.to_string()),
|
||||
});
|
||||
}
|
||||
}
|
||||
if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) {
|
||||
for tool_call in tool_calls {
|
||||
@ -737,8 +762,8 @@ impl LLMProvider for OpenAIProvider {
|
||||
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// 检查是否启用流式输出
|
||||
if self.is_streaming_enabled() {
|
||||
// 优先尝试流式输出
|
||||
match self.chat_streaming(&request).await {
|
||||
// 优先尝试流式输出(无回调)
|
||||
match self.chat_streaming_internal(&request, None).await {
|
||||
Ok(response) => return Ok(response),
|
||||
Err(e) => {
|
||||
tracing::debug!(
|
||||
@ -890,6 +915,28 @@ impl LLMProvider for OpenAIProvider {
|
||||
})
|
||||
}
|
||||
|
||||
async fn chat_with_streaming(
|
||||
&self,
|
||||
request: ChatCompletionRequest,
|
||||
callback: StreamCallback,
|
||||
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
if self.is_streaming_enabled() {
|
||||
match self.chat_streaming_internal(&request, Some(&callback)).await {
|
||||
Ok(response) => return Ok(response),
|
||||
Err(e) => {
|
||||
tracing::debug!(
|
||||
provider = %self.name,
|
||||
model = %self.model_id,
|
||||
error = %e,
|
||||
"Streaming (with callback) failed, falling back to non-streaming"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
// 回退到非流式
|
||||
self.chat(request).await
|
||||
}
|
||||
|
||||
fn ptype(&self) -> &str {
|
||||
"openai"
|
||||
}
|
||||
|
||||
@ -4,6 +4,7 @@ use crate::config::LLMProviderConfig;
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProviderRuntimeConfig {
|
||||
@ -137,6 +138,18 @@ pub struct Usage {
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
/// 流式响应中的增量事件
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamDelta {
|
||||
/// 文本内容增量
|
||||
pub content: String,
|
||||
/// 推理内容增量
|
||||
pub reasoning_content: Option<String>,
|
||||
}
|
||||
|
||||
/// 流式回调类型:每收到一个 delta 就调用一次
|
||||
pub type StreamCallback = Arc<dyn Fn(StreamDelta) + Send + Sync>;
|
||||
|
||||
#[async_trait]
|
||||
pub trait LLMProvider: Send + Sync {
|
||||
async fn chat(
|
||||
@ -144,6 +157,17 @@ pub trait LLMProvider: Send + Sync {
|
||||
request: ChatCompletionRequest,
|
||||
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>>;
|
||||
|
||||
/// 带流式回调的 chat:每收到一个 SSE delta 就调用 callback。
|
||||
/// 返回值与 `chat()` 相同(完整的 ChatCompletionResponse)。
|
||||
/// 默认实现忽略 callback,直接调用 chat()。
|
||||
async fn chat_with_streaming(
|
||||
&self,
|
||||
request: ChatCompletionRequest,
|
||||
_callback: StreamCallback,
|
||||
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
self.chat(request).await
|
||||
}
|
||||
|
||||
fn ptype(&self) -> &str;
|
||||
|
||||
fn name(&self) -> &str;
|
||||
|
||||
@ -499,6 +499,7 @@ impl SubAgentRuntime for DefaultSubAgentRuntime {
|
||||
tool_name: None,
|
||||
tool_arguments: None,
|
||||
reasoning_content: None,
|
||||
message_id: None,
|
||||
};
|
||||
|
||||
if let Err(e) = bus.publish_outbound(event).await {
|
||||
|
||||
@ -27,6 +27,7 @@ import type {
|
||||
SchedulerJobSessionLookup,
|
||||
Channel,
|
||||
ChannelList,
|
||||
StreamDelta,
|
||||
} from '../types/protocol'
|
||||
|
||||
// 简化后的层级状态
|
||||
@ -446,14 +447,55 @@ export function useChat(): UseChatReturn {
|
||||
break
|
||||
}
|
||||
|
||||
case 'stream_delta': {
|
||||
const msg = message as StreamDelta
|
||||
if (msg.topic_id && msg.topic_id !== selectedTopicRef.current) return
|
||||
setMessages((prev) => {
|
||||
const existingIdx = prev.findIndex(m => m.id === msg.id && m.type === 'message')
|
||||
if (existingIdx >= 0) {
|
||||
// 追加到已有消息
|
||||
const updated = [...prev]
|
||||
const existing = updated[existingIdx]
|
||||
updated[existingIdx] = {
|
||||
...existing,
|
||||
content: existing.content + msg.delta,
|
||||
reasoningContent: msg.reasoning_delta
|
||||
? (existing.reasoningContent || '') + msg.reasoning_delta
|
||||
: existing.reasoningContent,
|
||||
}
|
||||
return updated
|
||||
}
|
||||
// 创建新消息
|
||||
return [
|
||||
...prev,
|
||||
{
|
||||
id: msg.id,
|
||||
role: 'assistant' as const,
|
||||
content: msg.delta,
|
||||
timestamp: Math.floor(Date.now() / 1000),
|
||||
type: 'message' as const,
|
||||
reasoningContent: msg.reasoning_delta,
|
||||
},
|
||||
]
|
||||
})
|
||||
setIsLoading(false)
|
||||
break
|
||||
}
|
||||
|
||||
case 'stream_end': {
|
||||
// 流式结束,无需额外操作,后续 assistant_response 会替换完整内容
|
||||
break
|
||||
}
|
||||
|
||||
case 'assistant_response': {
|
||||
const msg = message as AssistantResponse
|
||||
// 按 topic_id 隔离:如果消息属于其他话题则丢弃
|
||||
if (msg.topic_id && msg.topic_id !== selectedTopicRef.current) return
|
||||
const role = msg.role === 'user' || msg.role === 'tool' ? msg.role : 'assistant'
|
||||
setMessages((prev) => [
|
||||
...prev,
|
||||
{
|
||||
setMessages((prev) => {
|
||||
// 如果流式消息已存在(相同 id),替换它
|
||||
const existingIdx = prev.findIndex(m => m.id === msg.id && m.type === 'message')
|
||||
const newMsg: ChatMessage = {
|
||||
id: msg.id,
|
||||
role,
|
||||
content: msg.content,
|
||||
@ -461,8 +503,14 @@ export function useChat(): UseChatReturn {
|
||||
type: 'message',
|
||||
attachments: msg.attachments,
|
||||
reasoningContent: msg.reasoning_content,
|
||||
},
|
||||
])
|
||||
}
|
||||
if (existingIdx >= 0) {
|
||||
const updated = [...prev]
|
||||
updated[existingIdx] = newMsg
|
||||
return updated
|
||||
}
|
||||
return [...prev, newMsg]
|
||||
})
|
||||
setIsLoading(false)
|
||||
|
||||
// 当前话题无描述时,可能刚触发了异步生成,标记需要刷新
|
||||
|
||||
@ -256,6 +256,22 @@ export interface ExecutionCancelled {
|
||||
message: string
|
||||
}
|
||||
|
||||
export interface StreamDelta {
|
||||
type: 'stream_delta'
|
||||
id: string
|
||||
delta: string
|
||||
reasoning_delta?: string
|
||||
subagent_task_id?: string
|
||||
topic_id?: string
|
||||
}
|
||||
|
||||
export interface StreamEnd {
|
||||
type: 'stream_end'
|
||||
id: string
|
||||
subagent_task_id?: string
|
||||
topic_id?: string
|
||||
}
|
||||
|
||||
export type WsOutbound =
|
||||
| AssistantResponse
|
||||
| ToolCall
|
||||
@ -263,6 +279,8 @@ export type WsOutbound =
|
||||
| ToolPending
|
||||
| WsError
|
||||
| TaskStarted
|
||||
| StreamDelta
|
||||
| StreamEnd
|
||||
| SessionEstablished
|
||||
| SessionCreated
|
||||
| SessionList
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user