From 018c10459207b7fbe938efb195a3c6f0f309321f Mon Sep 17 00:00:00 2001 From: oudecheng <13802883547@139.com> Date: Wed, 20 May 2026 11:32:46 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=9B=B4=E6=96=B0=20reqwest=20?= =?UTF-8?q?=E4=BE=9D=E8=B5=96=E4=BB=A5=E6=94=AF=E6=8C=81=E6=B5=81=E5=BC=8F?= =?UTF-8?q?=E5=93=8D=E5=BA=94=EF=BC=8C=E6=B7=BB=E5=8A=A0=E6=B5=81=E5=BC=8F?= =?UTF-8?q?=E5=A4=84=E7=90=86=E7=9B=B8=E5=85=B3=E7=BB=93=E6=9E=84=E5=92=8C?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.toml | 2 +- src/providers/openai.rs | 314 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 315 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index d9e97c3..c82700e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2024" [dependencies] -reqwest = { version = "0.13.2", default-features = false, features = ["json", "rustls", "multipart"] } +reqwest = { version = "0.13.2", default-features = false, features = ["json", "rustls", "multipart", "stream"] } dotenv = "0.15" serde = { version = "1.0", features = ["derive"] } regex = "1.0" diff --git a/src/providers/openai.rs b/src/providers/openai.rs index 5f256bb..04baac6 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -1,4 +1,5 @@ use async_trait::async_trait; +use futures_util::StreamExt; use reqwest::Client; use serde::Deserialize; use serde_json::{Value, json}; @@ -11,6 +12,98 @@ use crate::domain::messages::ContentBlock; const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content"]; +/// 流式响应中的工具调用增量 +#[derive(Debug, Default)] +struct StreamingToolCall { + id: String, + name: String, + arguments: String, +} + +/// 流式响应累积器 +#[derive(Debug, Default)] +struct StreamingAccumulator { + content: String, + reasoning_content: Option, + tool_calls: HashMap, + response_id: String, +} + +impl StreamingAccumulator { + fn new() -> Self { + Self::default() + } + + /// 添加内容增量 + fn add_content(&mut self, delta: &str) { + self.content.push_str(delta); + } + + /// 添加推理内容增量 + fn add_reasoning_content(&mut self, delta: &str) { + if self.reasoning_content.is_none() { + self.reasoning_content = Some(String::new()); + } + self.reasoning_content.as_mut().unwrap().push_str(delta); + } + + /// 添加工具调用增量 + fn add_tool_call(&mut self, index: usize, id: Option<&str>, name: Option<&str>, arguments: Option<&str>) { + let entry = self.tool_calls.entry(index).or_insert_with(StreamingToolCall::default); + + if let Some(id) = id { + entry.id = id.to_string(); + } + if let Some(name) = name { + entry.name = name.to_string(); + } + if let Some(args) = arguments { + entry.arguments.push_str(args); + } + } + + /// 设置响应 ID + fn set_response_id(&mut self, id: String) { + if self.response_id.is_empty() { + self.response_id = id; + } + } + + /// 构建最终的 ChatCompletionResponse + fn build_response(self, model: String) -> ChatCompletionResponse { + let tool_calls: Vec = self.tool_calls + .into_iter() + .filter(|(_, call)| !call.id.is_empty() && !call.name.is_empty()) + .map(|(_, call)| { + let arguments = serde_json::from_str(&call.arguments) + .unwrap_or_else(|_| serde_json::Value::Null); + ToolCall { + id: call.id, + name: call.name, + arguments, + } + }) + .collect(); + + ChatCompletionResponse { + id: if self.response_id.is_empty() { + format!("stream-{}", uuid::Uuid::new_v4()) + } else { + self.response_id + }, + model, + content: self.content, + reasoning_content: self.reasoning_content, + tool_calls, + usage: Usage { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0, + }, + } + } +} + fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String { let mut details = vec![error.to_string()]; let mut current = error.source(); @@ -117,6 +210,14 @@ impl OpenAIProvider { .unwrap_or(false) } + /// 检查是否启用流式输出,默认启用 + fn is_streaming_enabled(&self) -> bool { + self.model_extra + .get("enable_streaming") + .and_then(|value| value.as_bool()) + .unwrap_or(true) + } + fn normalize_tool_arguments(&self, arguments: &Value) -> Value { match arguments { Value::String(raw) => serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone()), @@ -159,6 +260,199 @@ impl OpenAIProvider { }) } + /// 内部流式聊天实现 + async fn chat_streaming( + &self, + request: &ChatCompletionRequest, + ) -> Result> { + tracing::debug!(provider = %self.name, model = %self.model_id, "Starting streaming chat"); + + let url = format!("{}/chat/completions", self.base_url); + + let mut body = self.build_request_body(request); + // 启用流式输出 + body["stream"] = json!(true); + + let mut req_builder = self + .client + .post(&url) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .header("Accept", "text/event-stream"); + + for (key, value) in &self.extra_headers { + req_builder = req_builder.header(key.as_str(), value.as_str()); + } + + let resp = req_builder.json(&body).send().await.map_err(|err| { + format_transport_error_context( + &self.name, + &self.model_id, + &url, + self.llm_timeout_secs, + &err, + ) + })?; + + let status = resp.status(); + if !status.is_success() { + let text = resp.text().await.unwrap_or_default(); + return Err(format!("API error {}: {}", status, text).into()); + } + + let mut accumulator = StreamingAccumulator::new(); + + // 读取 SSE 流 + let mut stream = resp.bytes_stream(); + let mut buffer = String::new(); + let mut done_received = false; + + while let Some(chunk_result) = stream.next().await { + let chunk = chunk_result?; + let text = String::from_utf8_lossy(&chunk); + buffer.push_str(&text); + + // 处理缓冲区中的完整行 + while let Some(newline_pos) = buffer.find('\n') { + let line = buffer[..newline_pos].to_string(); + buffer = buffer[newline_pos + 1..].to_string(); + + let line = line.trim(); + if line.is_empty() || line.starts_with(':') { + continue; + } + + // SSE 格式: data: {...} + if let Some(data) = line.strip_prefix("data: ") { + if data == "[DONE]" { + // 流结束 + done_received = true; + break; + } + + // 解析 JSON + match serde_json::from_str::(data) { + Ok(json) => { + // 提取响应 ID + if let Some(id) = json.get("id").and_then(|v| v.as_str()) { + accumulator.set_response_id(id.to_string()); + } + + // 提取 choices + if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) { + for choice in choices { + if let Some(delta) = choice.get("delta") { + // 提取内容增量 + if let Some(content) = delta.get("content").and_then(|c| c.as_str()) { + accumulator.add_content(content); + } + + // 提取推理内容增量 + if let Some(reasoning) = delta.get("reasoning_content").and_then(|r| r.as_str()) { + accumulator.add_reasoning_content(reasoning); + } + + // 提取工具调用增量 + if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) { + tracing::debug!(tool_calls_count = tool_calls.len(), "Received tool_calls in delta"); + for tool_call in tool_calls { + let index = tool_call.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize; + + let id = tool_call.get("id").and_then(|v| v.as_str()); + let name = tool_call.get("function") + .and_then(|f| f.get("name")) + .and_then(|n| n.as_str()); + let arguments = tool_call.get("function") + .and_then(|f| f.get("arguments")) + .and_then(|a| a.as_str()); + + tracing::debug!( + index = index, + id = ?id, + name = ?name, + arguments = ?arguments, + "Tool call delta received" + ); + + accumulator.add_tool_call(index, id, name, arguments); + } + } + } + } + } + } + Err(e) => { + tracing::debug!( + error = %e, + data = %data, + "Failed to parse SSE data" + ); + } + } + } + } + + if done_received { + break; + } + } + + // 处理缓冲区中剩余的内容 + for line in buffer.lines() { + let line = line.trim(); + if line.is_empty() || line.starts_with(':') { + continue; + } + + if let Some(data) = line.strip_prefix("data: ") { + if data == "[DONE]" { + break; + } + + if let Ok(json) = serde_json::from_str::(data) { + if let Some(id) = json.get("id").and_then(|v| v.as_str()) { + accumulator.set_response_id(id.to_string()); + } + + if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) { + for choice in choices { + if let Some(delta) = choice.get("delta") { + if let Some(content) = delta.get("content").and_then(|c| c.as_str()) { + accumulator.add_content(content); + } + if let Some(reasoning) = delta.get("reasoning_content").and_then(|r| r.as_str()) { + accumulator.add_reasoning_content(reasoning); + } + if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) { + for tool_call in tool_calls { + let index = tool_call.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize; + let id = tool_call.get("id").and_then(|v| v.as_str()); + let name = tool_call.get("function") + .and_then(|f| f.get("name")) + .and_then(|n| n.as_str()); + let arguments = tool_call.get("function") + .and_then(|f| f.get("arguments")) + .and_then(|a| a.as_str()); + accumulator.add_tool_call(index, id, name, arguments); + } + } + } + } + } + } + } + } + + let response = accumulator.build_response(self.model_id.clone()); + tracing::debug!( + content_len = response.content.len(), + tool_calls_count = response.tool_calls.len(), + has_reasoning = response.reasoning_content.is_some(), + "Streaming response built" + ); + Ok(response) + } + fn build_request_body(&self, request: &ChatCompletionRequest) -> Value { let mut body = json!({ "model": self.model_id, @@ -281,6 +575,26 @@ impl LLMProvider for OpenAIProvider { &self, request: ChatCompletionRequest, ) -> Result> { + // 检查是否启用流式输出 + if self.is_streaming_enabled() { + // 优先尝试流式输出 + match self.chat_streaming(&request).await { + Ok(response) => return Ok(response), + Err(e) => { + tracing::debug!( + provider = %self.name, + model = %self.model_id, + error = %e, + "Streaming failed, falling back to non-streaming" + ); + // 流式失败,回退到非流式 + } + } + } else { + tracing::debug!(provider = %self.name, model = %self.model_id, "Streaming disabled, using non-streaming"); + } + + // 非流式回退实现 let url = format!("{}/chat/completions", self.base_url); let body = self.build_request_body(&request);