feat: 更新 reqwest 依赖以支持流式响应,添加流式处理相关结构和实现

This commit is contained in:
oudecheng 2026-05-20 11:32:46 +08:00
parent 7540828397
commit 018c104592
2 changed files with 315 additions and 1 deletions

View File

@ -4,7 +4,7 @@ version = "0.1.0"
edition = "2024" edition = "2024"
[dependencies] [dependencies]
reqwest = { version = "0.13.2", default-features = false, features = ["json", "rustls", "multipart"] } reqwest = { version = "0.13.2", default-features = false, features = ["json", "rustls", "multipart", "stream"] }
dotenv = "0.15" dotenv = "0.15"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
regex = "1.0" regex = "1.0"

View File

@ -1,4 +1,5 @@
use async_trait::async_trait; use async_trait::async_trait;
use futures_util::StreamExt;
use reqwest::Client; use reqwest::Client;
use serde::Deserialize; use serde::Deserialize;
use serde_json::{Value, json}; use serde_json::{Value, json};
@ -11,6 +12,98 @@ use crate::domain::messages::ContentBlock;
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content"]; const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content"];
/// 流式响应中的工具调用增量
#[derive(Debug, Default)]
struct StreamingToolCall {
id: String,
name: String,
arguments: String,
}
/// 流式响应累积器
#[derive(Debug, Default)]
struct StreamingAccumulator {
content: String,
reasoning_content: Option<String>,
tool_calls: HashMap<usize, StreamingToolCall>,
response_id: String,
}
impl StreamingAccumulator {
fn new() -> Self {
Self::default()
}
/// 添加内容增量
fn add_content(&mut self, delta: &str) {
self.content.push_str(delta);
}
/// 添加推理内容增量
fn add_reasoning_content(&mut self, delta: &str) {
if self.reasoning_content.is_none() {
self.reasoning_content = Some(String::new());
}
self.reasoning_content.as_mut().unwrap().push_str(delta);
}
/// 添加工具调用增量
fn add_tool_call(&mut self, index: usize, id: Option<&str>, name: Option<&str>, arguments: Option<&str>) {
let entry = self.tool_calls.entry(index).or_insert_with(StreamingToolCall::default);
if let Some(id) = id {
entry.id = id.to_string();
}
if let Some(name) = name {
entry.name = name.to_string();
}
if let Some(args) = arguments {
entry.arguments.push_str(args);
}
}
/// 设置响应 ID
fn set_response_id(&mut self, id: String) {
if self.response_id.is_empty() {
self.response_id = id;
}
}
/// 构建最终的 ChatCompletionResponse
fn build_response(self, model: String) -> ChatCompletionResponse {
let tool_calls: Vec<ToolCall> = self.tool_calls
.into_iter()
.filter(|(_, call)| !call.id.is_empty() && !call.name.is_empty())
.map(|(_, call)| {
let arguments = serde_json::from_str(&call.arguments)
.unwrap_or_else(|_| serde_json::Value::Null);
ToolCall {
id: call.id,
name: call.name,
arguments,
}
})
.collect();
ChatCompletionResponse {
id: if self.response_id.is_empty() {
format!("stream-{}", uuid::Uuid::new_v4())
} else {
self.response_id
},
model,
content: self.content,
reasoning_content: self.reasoning_content,
tool_calls,
usage: Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
},
}
}
}
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String { fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
let mut details = vec![error.to_string()]; let mut details = vec![error.to_string()];
let mut current = error.source(); let mut current = error.source();
@ -117,6 +210,14 @@ impl OpenAIProvider {
.unwrap_or(false) .unwrap_or(false)
} }
/// 检查是否启用流式输出,默认启用
fn is_streaming_enabled(&self) -> bool {
self.model_extra
.get("enable_streaming")
.and_then(|value| value.as_bool())
.unwrap_or(true)
}
fn normalize_tool_arguments(&self, arguments: &Value) -> Value { fn normalize_tool_arguments(&self, arguments: &Value) -> Value {
match arguments { match arguments {
Value::String(raw) => serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone()), Value::String(raw) => serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone()),
@ -159,6 +260,199 @@ impl OpenAIProvider {
}) })
} }
/// 内部流式聊天实现
async fn chat_streaming(
&self,
request: &ChatCompletionRequest,
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
tracing::debug!(provider = %self.name, model = %self.model_id, "Starting streaming chat");
let url = format!("{}/chat/completions", self.base_url);
let mut body = self.build_request_body(request);
// 启用流式输出
body["stream"] = json!(true);
let mut req_builder = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.header("Accept", "text/event-stream");
for (key, value) in &self.extra_headers {
req_builder = req_builder.header(key.as_str(), value.as_str());
}
let resp = req_builder.json(&body).send().await.map_err(|err| {
format_transport_error_context(
&self.name,
&self.model_id,
&url,
self.llm_timeout_secs,
&err,
)
})?;
let status = resp.status();
if !status.is_success() {
let text = resp.text().await.unwrap_or_default();
return Err(format!("API error {}: {}", status, text).into());
}
let mut accumulator = StreamingAccumulator::new();
// 读取 SSE 流
let mut stream = resp.bytes_stream();
let mut buffer = String::new();
let mut done_received = false;
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result?;
let text = String::from_utf8_lossy(&chunk);
buffer.push_str(&text);
// 处理缓冲区中的完整行
while let Some(newline_pos) = buffer.find('\n') {
let line = buffer[..newline_pos].to_string();
buffer = buffer[newline_pos + 1..].to_string();
let line = line.trim();
if line.is_empty() || line.starts_with(':') {
continue;
}
// SSE 格式: data: {...}
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
// 流结束
done_received = true;
break;
}
// 解析 JSON
match serde_json::from_str::<Value>(data) {
Ok(json) => {
// 提取响应 ID
if let Some(id) = json.get("id").and_then(|v| v.as_str()) {
accumulator.set_response_id(id.to_string());
}
// 提取 choices
if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) {
for choice in choices {
if let Some(delta) = choice.get("delta") {
// 提取内容增量
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
accumulator.add_content(content);
}
// 提取推理内容增量
if let Some(reasoning) = delta.get("reasoning_content").and_then(|r| r.as_str()) {
accumulator.add_reasoning_content(reasoning);
}
// 提取工具调用增量
if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) {
tracing::debug!(tool_calls_count = tool_calls.len(), "Received tool_calls in delta");
for tool_call in tool_calls {
let index = tool_call.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
let id = tool_call.get("id").and_then(|v| v.as_str());
let name = tool_call.get("function")
.and_then(|f| f.get("name"))
.and_then(|n| n.as_str());
let arguments = tool_call.get("function")
.and_then(|f| f.get("arguments"))
.and_then(|a| a.as_str());
tracing::debug!(
index = index,
id = ?id,
name = ?name,
arguments = ?arguments,
"Tool call delta received"
);
accumulator.add_tool_call(index, id, name, arguments);
}
}
}
}
}
}
Err(e) => {
tracing::debug!(
error = %e,
data = %data,
"Failed to parse SSE data"
);
}
}
}
}
if done_received {
break;
}
}
// 处理缓冲区中剩余的内容
for line in buffer.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with(':') {
continue;
}
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
break;
}
if let Ok(json) = serde_json::from_str::<Value>(data) {
if let Some(id) = json.get("id").and_then(|v| v.as_str()) {
accumulator.set_response_id(id.to_string());
}
if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) {
for choice in choices {
if let Some(delta) = choice.get("delta") {
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
accumulator.add_content(content);
}
if let Some(reasoning) = delta.get("reasoning_content").and_then(|r| r.as_str()) {
accumulator.add_reasoning_content(reasoning);
}
if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) {
for tool_call in tool_calls {
let index = tool_call.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
let id = tool_call.get("id").and_then(|v| v.as_str());
let name = tool_call.get("function")
.and_then(|f| f.get("name"))
.and_then(|n| n.as_str());
let arguments = tool_call.get("function")
.and_then(|f| f.get("arguments"))
.and_then(|a| a.as_str());
accumulator.add_tool_call(index, id, name, arguments);
}
}
}
}
}
}
}
}
let response = accumulator.build_response(self.model_id.clone());
tracing::debug!(
content_len = response.content.len(),
tool_calls_count = response.tool_calls.len(),
has_reasoning = response.reasoning_content.is_some(),
"Streaming response built"
);
Ok(response)
}
fn build_request_body(&self, request: &ChatCompletionRequest) -> Value { fn build_request_body(&self, request: &ChatCompletionRequest) -> Value {
let mut body = json!({ let mut body = json!({
"model": self.model_id, "model": self.model_id,
@ -281,6 +575,26 @@ impl LLMProvider for OpenAIProvider {
&self, &self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
// 检查是否启用流式输出
if self.is_streaming_enabled() {
// 优先尝试流式输出
match self.chat_streaming(&request).await {
Ok(response) => return Ok(response),
Err(e) => {
tracing::debug!(
provider = %self.name,
model = %self.model_id,
error = %e,
"Streaming failed, falling back to non-streaming"
);
// 流式失败,回退到非流式
}
}
} else {
tracing::debug!(provider = %self.name, model = %self.model_id, "Streaming disabled, using non-streaming");
}
// 非流式回退实现
let url = format!("{}/chat/completions", self.base_url); let url = format!("{}/chat/completions", self.base_url);
let body = self.build_request_body(&request); let body = self.build_request_body(&request);