From cb48ef09b22ef72798a15ff3485c8e21f77e59f4 Mon Sep 17 00:00:00 2001 From: oudecheng <13802883547@139.com> Date: Wed, 20 May 2026 09:10:47 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E6=B5=81=E5=BC=8F?= =?UTF-8?q?=E8=81=8A=E5=A4=A9=E6=94=AF=E6=8C=81=EF=BC=8C=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=E7=9B=B8=E5=85=B3=E4=BE=9D=E8=B5=96=E5=92=8C=E6=8E=A5=E5=8F=A3?= =?UTF-8?q?=E4=BB=A5=E5=AE=9E=E7=8E=B0=E6=B5=81=E5=BC=8F=E5=93=8D=E5=BA=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.toml | 2 +- src/agent/agent_loop.rs | 101 ++++++++++++++++++++- src/providers/anthropic.rs | 14 ++- src/providers/mod.rs | 2 +- src/providers/openai.rs | 176 ++++++++++++++++++++++++++++++++++++- src/providers/traits.rs | 71 +++++++++++++++ 6 files changed, 359 insertions(+), 7 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d9e97c3..c82700e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2024" [dependencies] -reqwest = { version = "0.13.2", default-features = false, features = ["json", "rustls", "multipart"] } +reqwest = { version = "0.13.2", default-features = false, features = ["json", "rustls", "multipart", "stream"] } dotenv = "0.15" serde = { version = "1.0", features = ["derive"] } regex = "1.0" diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index ab82f87..2b7ab18 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -6,7 +6,7 @@ use crate::domain::messages::{ContentBlock, ToolCall}; use crate::observability::{ Observer, ObserverEvent, ToolExecutionOutcome, ToolExecutionState, truncate_args, }; -use crate::providers::{ChatCompletionRequest, LLMProvider, Message, create_provider}; +use crate::providers::{ChatCompletionRequest, LLMProvider, Message, StreamingResponse, create_provider}; use crate::text::{char_count, take_prefix_chars, take_suffix_chars}; use crate::tools::{ToolContext, ToolRegistry}; use async_trait::async_trait; @@ -360,6 +360,91 @@ fn recoverable_llm_message(error: &str) -> String { } } +/// 使用流式 API 调用 LLM,在超时/错误时返回已收到的内容 +async fn call_llm_streaming( + provider: &dyn LLMProvider, + request: ChatCompletionRequest, + provider_name: &str, + model_id: &str, +) -> Result> { + let mut streaming_response = StreamingResponse::new(model_id.to_string(), String::new()); + let tool_calls: Vec = Vec::new(); + + match provider.chat_streaming(request.clone()).await { + Ok(mut stream) => { + use futures_util::StreamExt; + while let Some(chunk_result) = stream.next().await { + match chunk_result { + Ok(chunk) => { + if chunk.is_done { + break; + } + streaming_response.add_chunk(chunk); + } + Err(e) => { + // 流式传输错误,记录并返回已累积的内容 + let error_msg = e.to_string(); + tracing::error!( + provider = %provider_name, + model = %model_id, + error = %error_msg, + accumulated_len = %streaming_response.accumulated_content.len(), + "Stream error during LLM request, returning partial content" + ); + + // 如果有累积的内容,返回部分结果 + if !streaming_response.accumulated_content.is_empty() { + let partial_content = streaming_response.get_partial_content(); + return Ok(crate::providers::ChatCompletionResponse { + id: format!("partial-{}", uuid::Uuid::new_v4()), + model: model_id.to_string(), + content: partial_content, + reasoning_content: streaming_response.accumulated_reasoning_content.clone(), + tool_calls: Vec::new(), // 超时时不返回 tool_calls + usage: crate::providers::Usage { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0, + }, + }); + } + + // 没有累积内容,返回原始错误 + return Err(e); + } + } + } + + // 流完成,构建完整响应 + Ok(crate::providers::ChatCompletionResponse { + id: streaming_response.id.clone().is_empty() + .then(|| format!("stream-{}", uuid::Uuid::new_v4())) + .unwrap_or_else(|| streaming_response.id.clone()), + model: streaming_response.model.clone(), + content: streaming_response.accumulated_content.clone(), + reasoning_content: streaming_response.accumulated_reasoning_content.clone(), + tool_calls, // 注意:流式输出通常不包含 tool_calls + usage: crate::providers::Usage { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0, + }, + }) + } + Err(e) => { + // 流式请求初始化失败,可能是 provider 不支持流式 + tracing::debug!( + provider = %provider_name, + model = %model_id, + error = %e, + "Streaming not supported or failed, falling back to non-streaming" + ); + // 回退到非流式调用 + provider.chat(request).await + } + } +} + /// Loop detection result. #[derive(Debug, Clone, PartialEq, Eq)] enum LoopDetectionResult { @@ -738,7 +823,12 @@ impl AgentLoop { tools, }; - let response = match (*self.provider).chat(request).await { + let response = match call_llm_streaming( + &*self.provider, + request, + self.provider.name(), + self.provider.model_id(), + ).await { Ok(response) => response, Err(e) => { tracing::error!( @@ -947,7 +1037,12 @@ impl AgentLoop { tools: None, // No tools in final summary call }; - match (*self.provider).chat(request).await { + match call_llm_streaming( + &*self.provider, + request, + self.provider.name(), + self.provider.model_id(), + ).await { Ok(response) => { let assistant_message = if let Some(reasoning_content) = response.reasoning_content { diff --git a/src/providers/anthropic.rs b/src/providers/anthropic.rs index a2632fe..f7772c2 100644 --- a/src/providers/anthropic.rs +++ b/src/providers/anthropic.rs @@ -1,11 +1,12 @@ use async_trait::async_trait; +use futures_util::Stream; use reqwest::Client; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::time::Duration; use super::traits::Usage; -use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall}; +use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, StreamChunk, Tool, ToolCall}; use crate::domain::messages::ContentBlock; fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String { @@ -294,6 +295,17 @@ impl LLMProvider for AnthropicProvider { }) } + async fn chat_streaming( + &self, + _request: ChatCompletionRequest, + ) -> Result< + std::pin::Pin>> + Send>>, + Box, + > { + // Anthropic 流式暂未实现,返回错误 + Err("Anthropic streaming not yet implemented".into()) + } + fn ptype(&self) -> &str { "anthropic" } diff --git a/src/providers/mod.rs b/src/providers/mod.rs index 8436e7e..9bca9ff 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -9,7 +9,7 @@ pub use crate::domain::messages::ToolCall; pub use crate::domain::tools::{Tool, ToolFunction}; pub use traits::{ ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, ProviderRuntimeConfig, - Usage, + StreamChunk, StreamingResponse, Usage, }; pub fn create_provider( diff --git a/src/providers/openai.rs b/src/providers/openai.rs index 5f256bb..7d3e684 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -1,4 +1,5 @@ use async_trait::async_trait; +use futures_util::{Stream, StreamExt}; use reqwest::Client; use serde::Deserialize; use serde_json::{Value, json}; @@ -6,7 +7,7 @@ use std::collections::HashMap; use std::time::Duration; use super::traits::Usage; -use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall}; +use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, StreamChunk, ToolCall}; use crate::domain::messages::ContentBlock; const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content"]; @@ -416,6 +417,179 @@ impl LLMProvider for OpenAIProvider { }) } + async fn chat_streaming( + &self, + request: ChatCompletionRequest, + ) -> Result< + std::pin::Pin>> + Send>>, + Box, + > { + let url = format!("{}/chat/completions", self.base_url); + + let mut body = self.build_request_body(&request); + // 启用流式输出 + body["stream"] = Value::Bool(true); + + let mut req_builder = self + .client + .post(&url) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .header("Accept", "text/event-stream"); + + for (key, value) in &self.extra_headers { + req_builder = req_builder.header(key.as_str(), value.as_str()); + } + + let resp = req_builder.json(&body).send().await.map_err(|err| { + let error_context = format_transport_error_context( + &self.name, + &self.model_id, + &url, + self.llm_timeout_secs, + &err, + ); + tracing::error!( + provider = %self.name, + model = %self.model_id, + url = %url, + base_url = %self.base_url, + timeout_secs = self.llm_timeout_secs, + error = %error_context, + "OpenAI-compatible API streaming transport request failed" + ); + error_context + })?; + + let status = resp.status(); + if !status.is_success() { + let text = resp.text().await.unwrap_or_default(); + tracing::error!( + provider = %self.name, + model = %self.model_id, + url = %url, + status = %status, + response_len = text.len(), + response_body = %text, + "OpenAI-compatible API streaming request failed" + ); + return Err(format!("API error {}: {}", status, text).into()); + } + + let model_id = self.model_id.clone(); + let provider_name = self.name.clone(); + + // 检查是否是 SSE 响应 + let is_sse = resp + .headers() + .get("content-type") + .and_then(|ct| ct.to_str().ok()) + .map(|ct| ct.contains("text/event-stream")) + .unwrap_or(false); + + // 创建 SSE 流 + let stream = resp.bytes_stream().flat_map(move |chunk_result| { + match chunk_result { + Ok(chunk) => { + let text = String::from_utf8_lossy(&chunk); + let mut chunks = Vec::new(); + + // 如果不是 SSE,尝试解析为普通 JSON 响应 + if !is_sse && !text.contains("data: ") { + // 尝试解析为非流式 JSON 响应 + if let Ok(json) = serde_json::from_str::(&text) { + if let Some(choice) = json.choices.first() { + if let Some(content) = choice.message.content.as_ref() { + if !content.is_empty() { + chunks.push(Ok(StreamChunk { + content_delta: content.clone(), + reasoning_content_delta: choice.message.reasoning_content.clone(), + is_done: false, + })); + } + } + } + // 标记完成 + chunks.push(Ok(StreamChunk { + content_delta: String::new(), + reasoning_content_delta: None, + is_done: true, + })); + return futures_util::stream::iter(chunks); + } + } + + for line in text.lines() { + let line = line.trim(); + if line.is_empty() || line.starts_with(":") { + continue; + } + + // SSE 格式: data: {...} + if let Some(data) = line.strip_prefix("data: ") { + if data == "[DONE]" { + chunks.push(Ok(StreamChunk { + content_delta: String::new(), + reasoning_content_delta: None, + is_done: true, + })); + continue; + } + + // 解析 JSON + match serde_json::from_str::(data) { + Ok(json) => { + if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) { + if let Some(choice) = choices.first() { + // 获取 content delta + let content_delta = choice + .get("delta") + .and_then(|d| d.get("content")) + .and_then(|c| c.as_str()) + .unwrap_or("") + .to_string(); + + // 获取 reasoning_content delta(某些模型支持) + let reasoning_delta = choice + .get("delta") + .and_then(|d| d.get("reasoning_content")) + .and_then(|r| r.as_str()) + .map(|s| s.to_string()); + + if !content_delta.is_empty() || reasoning_delta.is_some() { + chunks.push(Ok(StreamChunk { + content_delta, + reasoning_content_delta: reasoning_delta, + is_done: false, + })); + } + } + } + } + Err(e) => { + tracing::debug!( + provider = %provider_name, + model = %model_id, + error = %e, + data = %data, + "Failed to parse SSE data" + ); + } + } + } + } + + futures_util::stream::iter(chunks) + } + Err(e) => { + futures_util::stream::iter(vec![Err(Box::new(e) as Box)]) + } + } + }); + + Ok(Box::pin(stream)) + } + fn ptype(&self) -> &str { "openai" } diff --git a/src/providers/traits.rs b/src/providers/traits.rs index 28b757c..3a307b2 100644 --- a/src/providers/traits.rs +++ b/src/providers/traits.rs @@ -1,9 +1,70 @@ use crate::domain::messages::{ContentBlock, ToolCall}; use crate::domain::tools::Tool; use async_trait::async_trait; +use futures_util::Stream; use serde::{Deserialize, Serialize}; use std::collections::HashMap; +/// 流式输出 chunk +#[derive(Debug, Clone)] +pub struct StreamChunk { + /// 内容增量(delta) + pub content_delta: String, + /// 推理内容增量 + pub reasoning_content_delta: Option, + /// 是否为完成信号 + pub is_done: bool, +} + +/// 流式响应收集器 +pub struct StreamingResponse { + /// 已累积的完整内容 + pub accumulated_content: String, + /// 已累积的推理内容 + pub accumulated_reasoning_content: Option, + /// 模型ID + pub model: String, + /// 响应ID + pub id: String, +} + +impl StreamingResponse { + pub fn new(model: String, id: String) -> Self { + Self { + accumulated_content: String::new(), + accumulated_reasoning_content: None, + model, + id, + } + } + + /// 添加 chunk 并累积内容 + pub fn add_chunk(&mut self, chunk: StreamChunk) { + self.accumulated_content.push_str(&chunk.content_delta); + if let Some(reasoning_delta) = chunk.reasoning_content_delta { + if self.accumulated_reasoning_content.is_none() { + self.accumulated_reasoning_content = Some(String::new()); + } + self.accumulated_reasoning_content + .as_mut() + .unwrap() + .push_str(&reasoning_delta); + } + } + + /// 获取当前累积的内容(用于错误时保存) + pub fn get_partial_content(&self) -> String { + if self.accumulated_content.is_empty() { + "[模型响应超时,无部分内容可用]".to_string() + } else { + format!( + "{}\n\n[警告:模型响应超时,以上内容可能不完整]", + self.accumulated_content + ) + } + } +} + #[derive(Debug, Clone)] pub struct ProviderRuntimeConfig { pub provider_type: String, @@ -126,6 +187,16 @@ pub trait LLMProvider: Send + Sync { request: ChatCompletionRequest, ) -> Result>; + /// 流式聊天完成,返回一个 pinned chunk 流 + /// 当连接断开或超时时,调用者可以使用累积的部分内容 + async fn chat_streaming( + &self, + request: ChatCompletionRequest, + ) -> Result< + std::pin::Pin>> + Send>>, + Box, + >; + fn ptype(&self) -> &str; fn name(&self) -> &str;