Revert "feat: 添加流式聊天支持,更新相关依赖和接口以实现流式响应"

This reverts commit cb48ef09b22ef72798a15ff3485c8e21f77e59f4.
This commit is contained in:
oudecheng 2026-05-20 09:20:26 +08:00
parent cb48ef09b2
commit 7540828397
6 changed files with 7 additions and 359 deletions

View File

@ -4,7 +4,7 @@ version = "0.1.0"
edition = "2024"
[dependencies]
reqwest = { version = "0.13.2", default-features = false, features = ["json", "rustls", "multipart", "stream"] }
reqwest = { version = "0.13.2", default-features = false, features = ["json", "rustls", "multipart"] }
dotenv = "0.15"
serde = { version = "1.0", features = ["derive"] }
regex = "1.0"

View File

@ -6,7 +6,7 @@ use crate::domain::messages::{ContentBlock, ToolCall};
use crate::observability::{
Observer, ObserverEvent, ToolExecutionOutcome, ToolExecutionState, truncate_args,
};
use crate::providers::{ChatCompletionRequest, LLMProvider, Message, StreamingResponse, create_provider};
use crate::providers::{ChatCompletionRequest, LLMProvider, Message, create_provider};
use crate::text::{char_count, take_prefix_chars, take_suffix_chars};
use crate::tools::{ToolContext, ToolRegistry};
use async_trait::async_trait;
@ -360,91 +360,6 @@ fn recoverable_llm_message(error: &str) -> String {
}
}
/// 使用流式 API 调用 LLM在超时/错误时返回已收到的内容
async fn call_llm_streaming(
provider: &dyn LLMProvider,
request: ChatCompletionRequest,
provider_name: &str,
model_id: &str,
) -> Result<crate::providers::ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
let mut streaming_response = StreamingResponse::new(model_id.to_string(), String::new());
let tool_calls: Vec<ToolCall> = Vec::new();
match provider.chat_streaming(request.clone()).await {
Ok(mut stream) => {
use futures_util::StreamExt;
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
if chunk.is_done {
break;
}
streaming_response.add_chunk(chunk);
}
Err(e) => {
// 流式传输错误,记录并返回已累积的内容
let error_msg = e.to_string();
tracing::error!(
provider = %provider_name,
model = %model_id,
error = %error_msg,
accumulated_len = %streaming_response.accumulated_content.len(),
"Stream error during LLM request, returning partial content"
);
// 如果有累积的内容,返回部分结果
if !streaming_response.accumulated_content.is_empty() {
let partial_content = streaming_response.get_partial_content();
return Ok(crate::providers::ChatCompletionResponse {
id: format!("partial-{}", uuid::Uuid::new_v4()),
model: model_id.to_string(),
content: partial_content,
reasoning_content: streaming_response.accumulated_reasoning_content.clone(),
tool_calls: Vec::new(), // 超时时不返回 tool_calls
usage: crate::providers::Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
},
});
}
// 没有累积内容,返回原始错误
return Err(e);
}
}
}
// 流完成,构建完整响应
Ok(crate::providers::ChatCompletionResponse {
id: streaming_response.id.clone().is_empty()
.then(|| format!("stream-{}", uuid::Uuid::new_v4()))
.unwrap_or_else(|| streaming_response.id.clone()),
model: streaming_response.model.clone(),
content: streaming_response.accumulated_content.clone(),
reasoning_content: streaming_response.accumulated_reasoning_content.clone(),
tool_calls, // 注意:流式输出通常不包含 tool_calls
usage: crate::providers::Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
},
})
}
Err(e) => {
// 流式请求初始化失败,可能是 provider 不支持流式
tracing::debug!(
provider = %provider_name,
model = %model_id,
error = %e,
"Streaming not supported or failed, falling back to non-streaming"
);
// 回退到非流式调用
provider.chat(request).await
}
}
}
/// Loop detection result.
#[derive(Debug, Clone, PartialEq, Eq)]
enum LoopDetectionResult {
@ -823,12 +738,7 @@ impl AgentLoop {
tools,
};
let response = match call_llm_streaming(
&*self.provider,
request,
self.provider.name(),
self.provider.model_id(),
).await {
let response = match (*self.provider).chat(request).await {
Ok(response) => response,
Err(e) => {
tracing::error!(
@ -1037,12 +947,7 @@ impl AgentLoop {
tools: None, // No tools in final summary call
};
match call_llm_streaming(
&*self.provider,
request,
self.provider.name(),
self.provider.model_id(),
).await {
match (*self.provider).chat(request).await {
Ok(response) => {
let assistant_message = if let Some(reasoning_content) = response.reasoning_content
{

View File

@ -1,12 +1,11 @@
use async_trait::async_trait;
use futures_util::Stream;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
use super::traits::Usage;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, StreamChunk, Tool, ToolCall};
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
use crate::domain::messages::ContentBlock;
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
@ -295,17 +294,6 @@ impl LLMProvider for AnthropicProvider {
})
}
async fn chat_streaming(
&self,
_request: ChatCompletionRequest,
) -> Result<
std::pin::Pin<Box<dyn Stream<Item = Result<StreamChunk, Box<dyn std::error::Error + Send + Sync>>> + Send>>,
Box<dyn std::error::Error + Send + Sync>,
> {
// Anthropic 流式暂未实现,返回错误
Err("Anthropic streaming not yet implemented".into())
}
fn ptype(&self) -> &str {
"anthropic"
}

View File

@ -9,7 +9,7 @@ pub use crate::domain::messages::ToolCall;
pub use crate::domain::tools::{Tool, ToolFunction};
pub use traits::{
ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, ProviderRuntimeConfig,
StreamChunk, StreamingResponse, Usage,
Usage,
};
pub fn create_provider(

View File

@ -1,5 +1,4 @@
use async_trait::async_trait;
use futures_util::{Stream, StreamExt};
use reqwest::Client;
use serde::Deserialize;
use serde_json::{Value, json};
@ -7,7 +6,7 @@ use std::collections::HashMap;
use std::time::Duration;
use super::traits::Usage;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, StreamChunk, ToolCall};
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
use crate::domain::messages::ContentBlock;
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content"];
@ -417,179 +416,6 @@ impl LLMProvider for OpenAIProvider {
})
}
async fn chat_streaming(
&self,
request: ChatCompletionRequest,
) -> Result<
std::pin::Pin<Box<dyn Stream<Item = Result<StreamChunk, Box<dyn std::error::Error + Send + Sync>>> + Send>>,
Box<dyn std::error::Error + Send + Sync>,
> {
let url = format!("{}/chat/completions", self.base_url);
let mut body = self.build_request_body(&request);
// 启用流式输出
body["stream"] = Value::Bool(true);
let mut req_builder = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.header("Accept", "text/event-stream");
for (key, value) in &self.extra_headers {
req_builder = req_builder.header(key.as_str(), value.as_str());
}
let resp = req_builder.json(&body).send().await.map_err(|err| {
let error_context = format_transport_error_context(
&self.name,
&self.model_id,
&url,
self.llm_timeout_secs,
&err,
);
tracing::error!(
provider = %self.name,
model = %self.model_id,
url = %url,
base_url = %self.base_url,
timeout_secs = self.llm_timeout_secs,
error = %error_context,
"OpenAI-compatible API streaming transport request failed"
);
error_context
})?;
let status = resp.status();
if !status.is_success() {
let text = resp.text().await.unwrap_or_default();
tracing::error!(
provider = %self.name,
model = %self.model_id,
url = %url,
status = %status,
response_len = text.len(),
response_body = %text,
"OpenAI-compatible API streaming request failed"
);
return Err(format!("API error {}: {}", status, text).into());
}
let model_id = self.model_id.clone();
let provider_name = self.name.clone();
// 检查是否是 SSE 响应
let is_sse = resp
.headers()
.get("content-type")
.and_then(|ct| ct.to_str().ok())
.map(|ct| ct.contains("text/event-stream"))
.unwrap_or(false);
// 创建 SSE 流
let stream = resp.bytes_stream().flat_map(move |chunk_result| {
match chunk_result {
Ok(chunk) => {
let text = String::from_utf8_lossy(&chunk);
let mut chunks = Vec::new();
// 如果不是 SSE尝试解析为普通 JSON 响应
if !is_sse && !text.contains("data: ") {
// 尝试解析为非流式 JSON 响应
if let Ok(json) = serde_json::from_str::<OpenAIResponse>(&text) {
if let Some(choice) = json.choices.first() {
if let Some(content) = choice.message.content.as_ref() {
if !content.is_empty() {
chunks.push(Ok(StreamChunk {
content_delta: content.clone(),
reasoning_content_delta: choice.message.reasoning_content.clone(),
is_done: false,
}));
}
}
}
// 标记完成
chunks.push(Ok(StreamChunk {
content_delta: String::new(),
reasoning_content_delta: None,
is_done: true,
}));
return futures_util::stream::iter(chunks);
}
}
for line in text.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with(":") {
continue;
}
// SSE 格式: data: {...}
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
chunks.push(Ok(StreamChunk {
content_delta: String::new(),
reasoning_content_delta: None,
is_done: true,
}));
continue;
}
// 解析 JSON
match serde_json::from_str::<Value>(data) {
Ok(json) => {
if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) {
if let Some(choice) = choices.first() {
// 获取 content delta
let content_delta = choice
.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str())
.unwrap_or("")
.to_string();
// 获取 reasoning_content delta某些模型支持
let reasoning_delta = choice
.get("delta")
.and_then(|d| d.get("reasoning_content"))
.and_then(|r| r.as_str())
.map(|s| s.to_string());
if !content_delta.is_empty() || reasoning_delta.is_some() {
chunks.push(Ok(StreamChunk {
content_delta,
reasoning_content_delta: reasoning_delta,
is_done: false,
}));
}
}
}
}
Err(e) => {
tracing::debug!(
provider = %provider_name,
model = %model_id,
error = %e,
data = %data,
"Failed to parse SSE data"
);
}
}
}
}
futures_util::stream::iter(chunks)
}
Err(e) => {
futures_util::stream::iter(vec![Err(Box::new(e) as Box<dyn std::error::Error + Send + Sync>)])
}
}
});
Ok(Box::pin(stream))
}
fn ptype(&self) -> &str {
"openai"
}

View File

@ -1,70 +1,9 @@
use crate::domain::messages::{ContentBlock, ToolCall};
use crate::domain::tools::Tool;
use async_trait::async_trait;
use futures_util::Stream;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// 流式输出 chunk
#[derive(Debug, Clone)]
pub struct StreamChunk {
/// 内容增量delta
pub content_delta: String,
/// 推理内容增量
pub reasoning_content_delta: Option<String>,
/// 是否为完成信号
pub is_done: bool,
}
/// 流式响应收集器
pub struct StreamingResponse {
/// 已累积的完整内容
pub accumulated_content: String,
/// 已累积的推理内容
pub accumulated_reasoning_content: Option<String>,
/// 模型ID
pub model: String,
/// 响应ID
pub id: String,
}
impl StreamingResponse {
pub fn new(model: String, id: String) -> Self {
Self {
accumulated_content: String::new(),
accumulated_reasoning_content: None,
model,
id,
}
}
/// 添加 chunk 并累积内容
pub fn add_chunk(&mut self, chunk: StreamChunk) {
self.accumulated_content.push_str(&chunk.content_delta);
if let Some(reasoning_delta) = chunk.reasoning_content_delta {
if self.accumulated_reasoning_content.is_none() {
self.accumulated_reasoning_content = Some(String::new());
}
self.accumulated_reasoning_content
.as_mut()
.unwrap()
.push_str(&reasoning_delta);
}
}
/// 获取当前累积的内容(用于错误时保存)
pub fn get_partial_content(&self) -> String {
if self.accumulated_content.is_empty() {
"[模型响应超时,无部分内容可用]".to_string()
} else {
format!(
"{}\n\n[警告:模型响应超时,以上内容可能不完整]",
self.accumulated_content
)
}
}
}
#[derive(Debug, Clone)]
pub struct ProviderRuntimeConfig {
pub provider_type: String,
@ -187,16 +126,6 @@ pub trait LLMProvider: Send + Sync {
request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>>;
/// 流式聊天完成,返回一个 pinned chunk 流
/// 当连接断开或超时时,调用者可以使用累积的部分内容
async fn chat_streaming(
&self,
request: ChatCompletionRequest,
) -> Result<
std::pin::Pin<Box<dyn Stream<Item = Result<StreamChunk, Box<dyn std::error::Error + Send + Sync>>> + Send>>,
Box<dyn std::error::Error + Send + Sync>,
>;
fn ptype(&self) -> &str;
fn name(&self) -> &str;