feat: 更新 reqwest 依赖以支持流式响应,添加流式处理相关结构和实现
This commit is contained in:
parent
7540828397
commit
018c104592
@ -4,7 +4,7 @@ version = "0.1.0"
|
|||||||
edition = "2024"
|
edition = "2024"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
reqwest = { version = "0.13.2", default-features = false, features = ["json", "rustls", "multipart"] }
|
reqwest = { version = "0.13.2", default-features = false, features = ["json", "rustls", "multipart", "stream"] }
|
||||||
dotenv = "0.15"
|
dotenv = "0.15"
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
regex = "1.0"
|
regex = "1.0"
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use futures_util::StreamExt;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde_json::{Value, json};
|
use serde_json::{Value, json};
|
||||||
@ -11,6 +12,98 @@ use crate::domain::messages::ContentBlock;
|
|||||||
|
|
||||||
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content"];
|
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content"];
|
||||||
|
|
||||||
|
/// 流式响应中的工具调用增量
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
struct StreamingToolCall {
|
||||||
|
id: String,
|
||||||
|
name: String,
|
||||||
|
arguments: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 流式响应累积器
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
struct StreamingAccumulator {
|
||||||
|
content: String,
|
||||||
|
reasoning_content: Option<String>,
|
||||||
|
tool_calls: HashMap<usize, StreamingToolCall>,
|
||||||
|
response_id: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StreamingAccumulator {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 添加内容增量
|
||||||
|
fn add_content(&mut self, delta: &str) {
|
||||||
|
self.content.push_str(delta);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 添加推理内容增量
|
||||||
|
fn add_reasoning_content(&mut self, delta: &str) {
|
||||||
|
if self.reasoning_content.is_none() {
|
||||||
|
self.reasoning_content = Some(String::new());
|
||||||
|
}
|
||||||
|
self.reasoning_content.as_mut().unwrap().push_str(delta);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 添加工具调用增量
|
||||||
|
fn add_tool_call(&mut self, index: usize, id: Option<&str>, name: Option<&str>, arguments: Option<&str>) {
|
||||||
|
let entry = self.tool_calls.entry(index).or_insert_with(StreamingToolCall::default);
|
||||||
|
|
||||||
|
if let Some(id) = id {
|
||||||
|
entry.id = id.to_string();
|
||||||
|
}
|
||||||
|
if let Some(name) = name {
|
||||||
|
entry.name = name.to_string();
|
||||||
|
}
|
||||||
|
if let Some(args) = arguments {
|
||||||
|
entry.arguments.push_str(args);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 设置响应 ID
|
||||||
|
fn set_response_id(&mut self, id: String) {
|
||||||
|
if self.response_id.is_empty() {
|
||||||
|
self.response_id = id;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 构建最终的 ChatCompletionResponse
|
||||||
|
fn build_response(self, model: String) -> ChatCompletionResponse {
|
||||||
|
let tool_calls: Vec<ToolCall> = self.tool_calls
|
||||||
|
.into_iter()
|
||||||
|
.filter(|(_, call)| !call.id.is_empty() && !call.name.is_empty())
|
||||||
|
.map(|(_, call)| {
|
||||||
|
let arguments = serde_json::from_str(&call.arguments)
|
||||||
|
.unwrap_or_else(|_| serde_json::Value::Null);
|
||||||
|
ToolCall {
|
||||||
|
id: call.id,
|
||||||
|
name: call.name,
|
||||||
|
arguments,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
ChatCompletionResponse {
|
||||||
|
id: if self.response_id.is_empty() {
|
||||||
|
format!("stream-{}", uuid::Uuid::new_v4())
|
||||||
|
} else {
|
||||||
|
self.response_id
|
||||||
|
},
|
||||||
|
model,
|
||||||
|
content: self.content,
|
||||||
|
reasoning_content: self.reasoning_content,
|
||||||
|
tool_calls,
|
||||||
|
usage: Usage {
|
||||||
|
prompt_tokens: 0,
|
||||||
|
completion_tokens: 0,
|
||||||
|
total_tokens: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
|
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
|
||||||
let mut details = vec![error.to_string()];
|
let mut details = vec![error.to_string()];
|
||||||
let mut current = error.source();
|
let mut current = error.source();
|
||||||
@ -117,6 +210,14 @@ impl OpenAIProvider {
|
|||||||
.unwrap_or(false)
|
.unwrap_or(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 检查是否启用流式输出,默认启用
|
||||||
|
fn is_streaming_enabled(&self) -> bool {
|
||||||
|
self.model_extra
|
||||||
|
.get("enable_streaming")
|
||||||
|
.and_then(|value| value.as_bool())
|
||||||
|
.unwrap_or(true)
|
||||||
|
}
|
||||||
|
|
||||||
fn normalize_tool_arguments(&self, arguments: &Value) -> Value {
|
fn normalize_tool_arguments(&self, arguments: &Value) -> Value {
|
||||||
match arguments {
|
match arguments {
|
||||||
Value::String(raw) => serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone()),
|
Value::String(raw) => serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone()),
|
||||||
@ -159,6 +260,199 @@ impl OpenAIProvider {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 内部流式聊天实现
|
||||||
|
async fn chat_streaming(
|
||||||
|
&self,
|
||||||
|
request: &ChatCompletionRequest,
|
||||||
|
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
tracing::debug!(provider = %self.name, model = %self.model_id, "Starting streaming chat");
|
||||||
|
|
||||||
|
let url = format!("{}/chat/completions", self.base_url);
|
||||||
|
|
||||||
|
let mut body = self.build_request_body(request);
|
||||||
|
// 启用流式输出
|
||||||
|
body["stream"] = json!(true);
|
||||||
|
|
||||||
|
let mut req_builder = self
|
||||||
|
.client
|
||||||
|
.post(&url)
|
||||||
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header("Accept", "text/event-stream");
|
||||||
|
|
||||||
|
for (key, value) in &self.extra_headers {
|
||||||
|
req_builder = req_builder.header(key.as_str(), value.as_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
let resp = req_builder.json(&body).send().await.map_err(|err| {
|
||||||
|
format_transport_error_context(
|
||||||
|
&self.name,
|
||||||
|
&self.model_id,
|
||||||
|
&url,
|
||||||
|
self.llm_timeout_secs,
|
||||||
|
&err,
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let status = resp.status();
|
||||||
|
if !status.is_success() {
|
||||||
|
let text = resp.text().await.unwrap_or_default();
|
||||||
|
return Err(format!("API error {}: {}", status, text).into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut accumulator = StreamingAccumulator::new();
|
||||||
|
|
||||||
|
// 读取 SSE 流
|
||||||
|
let mut stream = resp.bytes_stream();
|
||||||
|
let mut buffer = String::new();
|
||||||
|
let mut done_received = false;
|
||||||
|
|
||||||
|
while let Some(chunk_result) = stream.next().await {
|
||||||
|
let chunk = chunk_result?;
|
||||||
|
let text = String::from_utf8_lossy(&chunk);
|
||||||
|
buffer.push_str(&text);
|
||||||
|
|
||||||
|
// 处理缓冲区中的完整行
|
||||||
|
while let Some(newline_pos) = buffer.find('\n') {
|
||||||
|
let line = buffer[..newline_pos].to_string();
|
||||||
|
buffer = buffer[newline_pos + 1..].to_string();
|
||||||
|
|
||||||
|
let line = line.trim();
|
||||||
|
if line.is_empty() || line.starts_with(':') {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSE 格式: data: {...}
|
||||||
|
if let Some(data) = line.strip_prefix("data: ") {
|
||||||
|
if data == "[DONE]" {
|
||||||
|
// 流结束
|
||||||
|
done_received = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析 JSON
|
||||||
|
match serde_json::from_str::<Value>(data) {
|
||||||
|
Ok(json) => {
|
||||||
|
// 提取响应 ID
|
||||||
|
if let Some(id) = json.get("id").and_then(|v| v.as_str()) {
|
||||||
|
accumulator.set_response_id(id.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提取 choices
|
||||||
|
if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) {
|
||||||
|
for choice in choices {
|
||||||
|
if let Some(delta) = choice.get("delta") {
|
||||||
|
// 提取内容增量
|
||||||
|
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
|
||||||
|
accumulator.add_content(content);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提取推理内容增量
|
||||||
|
if let Some(reasoning) = delta.get("reasoning_content").and_then(|r| r.as_str()) {
|
||||||
|
accumulator.add_reasoning_content(reasoning);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提取工具调用增量
|
||||||
|
if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) {
|
||||||
|
tracing::debug!(tool_calls_count = tool_calls.len(), "Received tool_calls in delta");
|
||||||
|
for tool_call in tool_calls {
|
||||||
|
let index = tool_call.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
|
||||||
|
|
||||||
|
let id = tool_call.get("id").and_then(|v| v.as_str());
|
||||||
|
let name = tool_call.get("function")
|
||||||
|
.and_then(|f| f.get("name"))
|
||||||
|
.and_then(|n| n.as_str());
|
||||||
|
let arguments = tool_call.get("function")
|
||||||
|
.and_then(|f| f.get("arguments"))
|
||||||
|
.and_then(|a| a.as_str());
|
||||||
|
|
||||||
|
tracing::debug!(
|
||||||
|
index = index,
|
||||||
|
id = ?id,
|
||||||
|
name = ?name,
|
||||||
|
arguments = ?arguments,
|
||||||
|
"Tool call delta received"
|
||||||
|
);
|
||||||
|
|
||||||
|
accumulator.add_tool_call(index, id, name, arguments);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::debug!(
|
||||||
|
error = %e,
|
||||||
|
data = %data,
|
||||||
|
"Failed to parse SSE data"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if done_received {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理缓冲区中剩余的内容
|
||||||
|
for line in buffer.lines() {
|
||||||
|
let line = line.trim();
|
||||||
|
if line.is_empty() || line.starts_with(':') {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(data) = line.strip_prefix("data: ") {
|
||||||
|
if data == "[DONE]" {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(json) = serde_json::from_str::<Value>(data) {
|
||||||
|
if let Some(id) = json.get("id").and_then(|v| v.as_str()) {
|
||||||
|
accumulator.set_response_id(id.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) {
|
||||||
|
for choice in choices {
|
||||||
|
if let Some(delta) = choice.get("delta") {
|
||||||
|
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
|
||||||
|
accumulator.add_content(content);
|
||||||
|
}
|
||||||
|
if let Some(reasoning) = delta.get("reasoning_content").and_then(|r| r.as_str()) {
|
||||||
|
accumulator.add_reasoning_content(reasoning);
|
||||||
|
}
|
||||||
|
if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) {
|
||||||
|
for tool_call in tool_calls {
|
||||||
|
let index = tool_call.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
|
||||||
|
let id = tool_call.get("id").and_then(|v| v.as_str());
|
||||||
|
let name = tool_call.get("function")
|
||||||
|
.and_then(|f| f.get("name"))
|
||||||
|
.and_then(|n| n.as_str());
|
||||||
|
let arguments = tool_call.get("function")
|
||||||
|
.and_then(|f| f.get("arguments"))
|
||||||
|
.and_then(|a| a.as_str());
|
||||||
|
accumulator.add_tool_call(index, id, name, arguments);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = accumulator.build_response(self.model_id.clone());
|
||||||
|
tracing::debug!(
|
||||||
|
content_len = response.content.len(),
|
||||||
|
tool_calls_count = response.tool_calls.len(),
|
||||||
|
has_reasoning = response.reasoning_content.is_some(),
|
||||||
|
"Streaming response built"
|
||||||
|
);
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
fn build_request_body(&self, request: &ChatCompletionRequest) -> Value {
|
fn build_request_body(&self, request: &ChatCompletionRequest) -> Value {
|
||||||
let mut body = json!({
|
let mut body = json!({
|
||||||
"model": self.model_id,
|
"model": self.model_id,
|
||||||
@ -281,6 +575,26 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
&self,
|
&self,
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
// 检查是否启用流式输出
|
||||||
|
if self.is_streaming_enabled() {
|
||||||
|
// 优先尝试流式输出
|
||||||
|
match self.chat_streaming(&request).await {
|
||||||
|
Ok(response) => return Ok(response),
|
||||||
|
Err(e) => {
|
||||||
|
tracing::debug!(
|
||||||
|
provider = %self.name,
|
||||||
|
model = %self.model_id,
|
||||||
|
error = %e,
|
||||||
|
"Streaming failed, falling back to non-streaming"
|
||||||
|
);
|
||||||
|
// 流式失败,回退到非流式
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tracing::debug!(provider = %self.name, model = %self.model_id, "Streaming disabled, using non-streaming");
|
||||||
|
}
|
||||||
|
|
||||||
|
// 非流式回退实现
|
||||||
let url = format!("{}/chat/completions", self.base_url);
|
let url = format!("{}/chat/completions", self.base_url);
|
||||||
|
|
||||||
let body = self.build_request_body(&request);
|
let body = self.build_request_body(&request);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user