diff --git a/Cargo.toml b/Cargo.toml index c82700e..d9e97c3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2024" [dependencies] -reqwest = { version = "0.13.2", default-features = false, features = ["json", "rustls", "multipart", "stream"] } +reqwest = { version = "0.13.2", default-features = false, features = ["json", "rustls", "multipart"] } dotenv = "0.15" serde = { version = "1.0", features = ["derive"] } regex = "1.0" diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index 2b7ab18..ab82f87 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -6,7 +6,7 @@ use crate::domain::messages::{ContentBlock, ToolCall}; use crate::observability::{ Observer, ObserverEvent, ToolExecutionOutcome, ToolExecutionState, truncate_args, }; -use crate::providers::{ChatCompletionRequest, LLMProvider, Message, StreamingResponse, create_provider}; +use crate::providers::{ChatCompletionRequest, LLMProvider, Message, create_provider}; use crate::text::{char_count, take_prefix_chars, take_suffix_chars}; use crate::tools::{ToolContext, ToolRegistry}; use async_trait::async_trait; @@ -360,91 +360,6 @@ fn recoverable_llm_message(error: &str) -> String { } } -/// 使用流式 API 调用 LLM,在超时/错误时返回已收到的内容 -async fn call_llm_streaming( - provider: &dyn LLMProvider, - request: ChatCompletionRequest, - provider_name: &str, - model_id: &str, -) -> Result> { - let mut streaming_response = StreamingResponse::new(model_id.to_string(), String::new()); - let tool_calls: Vec = Vec::new(); - - match provider.chat_streaming(request.clone()).await { - Ok(mut stream) => { - use futures_util::StreamExt; - while let Some(chunk_result) = stream.next().await { - match chunk_result { - Ok(chunk) => { - if chunk.is_done { - break; - } - streaming_response.add_chunk(chunk); - } - Err(e) => { - // 流式传输错误,记录并返回已累积的内容 - let error_msg = e.to_string(); - tracing::error!( - provider = %provider_name, - model = %model_id, - error = %error_msg, - accumulated_len = %streaming_response.accumulated_content.len(), - "Stream error during LLM request, returning partial content" - ); - - // 如果有累积的内容,返回部分结果 - if !streaming_response.accumulated_content.is_empty() { - let partial_content = streaming_response.get_partial_content(); - return Ok(crate::providers::ChatCompletionResponse { - id: format!("partial-{}", uuid::Uuid::new_v4()), - model: model_id.to_string(), - content: partial_content, - reasoning_content: streaming_response.accumulated_reasoning_content.clone(), - tool_calls: Vec::new(), // 超时时不返回 tool_calls - usage: crate::providers::Usage { - prompt_tokens: 0, - completion_tokens: 0, - total_tokens: 0, - }, - }); - } - - // 没有累积内容,返回原始错误 - return Err(e); - } - } - } - - // 流完成,构建完整响应 - Ok(crate::providers::ChatCompletionResponse { - id: streaming_response.id.clone().is_empty() - .then(|| format!("stream-{}", uuid::Uuid::new_v4())) - .unwrap_or_else(|| streaming_response.id.clone()), - model: streaming_response.model.clone(), - content: streaming_response.accumulated_content.clone(), - reasoning_content: streaming_response.accumulated_reasoning_content.clone(), - tool_calls, // 注意:流式输出通常不包含 tool_calls - usage: crate::providers::Usage { - prompt_tokens: 0, - completion_tokens: 0, - total_tokens: 0, - }, - }) - } - Err(e) => { - // 流式请求初始化失败,可能是 provider 不支持流式 - tracing::debug!( - provider = %provider_name, - model = %model_id, - error = %e, - "Streaming not supported or failed, falling back to non-streaming" - ); - // 回退到非流式调用 - provider.chat(request).await - } - } -} - /// Loop detection result. #[derive(Debug, Clone, PartialEq, Eq)] enum LoopDetectionResult { @@ -823,12 +738,7 @@ impl AgentLoop { tools, }; - let response = match call_llm_streaming( - &*self.provider, - request, - self.provider.name(), - self.provider.model_id(), - ).await { + let response = match (*self.provider).chat(request).await { Ok(response) => response, Err(e) => { tracing::error!( @@ -1037,12 +947,7 @@ impl AgentLoop { tools: None, // No tools in final summary call }; - match call_llm_streaming( - &*self.provider, - request, - self.provider.name(), - self.provider.model_id(), - ).await { + match (*self.provider).chat(request).await { Ok(response) => { let assistant_message = if let Some(reasoning_content) = response.reasoning_content { diff --git a/src/providers/anthropic.rs b/src/providers/anthropic.rs index f7772c2..a2632fe 100644 --- a/src/providers/anthropic.rs +++ b/src/providers/anthropic.rs @@ -1,12 +1,11 @@ use async_trait::async_trait; -use futures_util::Stream; use reqwest::Client; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::time::Duration; use super::traits::Usage; -use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, StreamChunk, Tool, ToolCall}; +use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall}; use crate::domain::messages::ContentBlock; fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String { @@ -295,17 +294,6 @@ impl LLMProvider for AnthropicProvider { }) } - async fn chat_streaming( - &self, - _request: ChatCompletionRequest, - ) -> Result< - std::pin::Pin>> + Send>>, - Box, - > { - // Anthropic 流式暂未实现,返回错误 - Err("Anthropic streaming not yet implemented".into()) - } - fn ptype(&self) -> &str { "anthropic" } diff --git a/src/providers/mod.rs b/src/providers/mod.rs index 9bca9ff..8436e7e 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -9,7 +9,7 @@ pub use crate::domain::messages::ToolCall; pub use crate::domain::tools::{Tool, ToolFunction}; pub use traits::{ ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, ProviderRuntimeConfig, - StreamChunk, StreamingResponse, Usage, + Usage, }; pub fn create_provider( diff --git a/src/providers/openai.rs b/src/providers/openai.rs index 7d3e684..5f256bb 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -1,5 +1,4 @@ use async_trait::async_trait; -use futures_util::{Stream, StreamExt}; use reqwest::Client; use serde::Deserialize; use serde_json::{Value, json}; @@ -7,7 +6,7 @@ use std::collections::HashMap; use std::time::Duration; use super::traits::Usage; -use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, StreamChunk, ToolCall}; +use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall}; use crate::domain::messages::ContentBlock; const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content"]; @@ -417,179 +416,6 @@ impl LLMProvider for OpenAIProvider { }) } - async fn chat_streaming( - &self, - request: ChatCompletionRequest, - ) -> Result< - std::pin::Pin>> + Send>>, - Box, - > { - let url = format!("{}/chat/completions", self.base_url); - - let mut body = self.build_request_body(&request); - // 启用流式输出 - body["stream"] = Value::Bool(true); - - let mut req_builder = self - .client - .post(&url) - .header("Authorization", format!("Bearer {}", self.api_key)) - .header("Content-Type", "application/json") - .header("Accept", "text/event-stream"); - - for (key, value) in &self.extra_headers { - req_builder = req_builder.header(key.as_str(), value.as_str()); - } - - let resp = req_builder.json(&body).send().await.map_err(|err| { - let error_context = format_transport_error_context( - &self.name, - &self.model_id, - &url, - self.llm_timeout_secs, - &err, - ); - tracing::error!( - provider = %self.name, - model = %self.model_id, - url = %url, - base_url = %self.base_url, - timeout_secs = self.llm_timeout_secs, - error = %error_context, - "OpenAI-compatible API streaming transport request failed" - ); - error_context - })?; - - let status = resp.status(); - if !status.is_success() { - let text = resp.text().await.unwrap_or_default(); - tracing::error!( - provider = %self.name, - model = %self.model_id, - url = %url, - status = %status, - response_len = text.len(), - response_body = %text, - "OpenAI-compatible API streaming request failed" - ); - return Err(format!("API error {}: {}", status, text).into()); - } - - let model_id = self.model_id.clone(); - let provider_name = self.name.clone(); - - // 检查是否是 SSE 响应 - let is_sse = resp - .headers() - .get("content-type") - .and_then(|ct| ct.to_str().ok()) - .map(|ct| ct.contains("text/event-stream")) - .unwrap_or(false); - - // 创建 SSE 流 - let stream = resp.bytes_stream().flat_map(move |chunk_result| { - match chunk_result { - Ok(chunk) => { - let text = String::from_utf8_lossy(&chunk); - let mut chunks = Vec::new(); - - // 如果不是 SSE,尝试解析为普通 JSON 响应 - if !is_sse && !text.contains("data: ") { - // 尝试解析为非流式 JSON 响应 - if let Ok(json) = serde_json::from_str::(&text) { - if let Some(choice) = json.choices.first() { - if let Some(content) = choice.message.content.as_ref() { - if !content.is_empty() { - chunks.push(Ok(StreamChunk { - content_delta: content.clone(), - reasoning_content_delta: choice.message.reasoning_content.clone(), - is_done: false, - })); - } - } - } - // 标记完成 - chunks.push(Ok(StreamChunk { - content_delta: String::new(), - reasoning_content_delta: None, - is_done: true, - })); - return futures_util::stream::iter(chunks); - } - } - - for line in text.lines() { - let line = line.trim(); - if line.is_empty() || line.starts_with(":") { - continue; - } - - // SSE 格式: data: {...} - if let Some(data) = line.strip_prefix("data: ") { - if data == "[DONE]" { - chunks.push(Ok(StreamChunk { - content_delta: String::new(), - reasoning_content_delta: None, - is_done: true, - })); - continue; - } - - // 解析 JSON - match serde_json::from_str::(data) { - Ok(json) => { - if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) { - if let Some(choice) = choices.first() { - // 获取 content delta - let content_delta = choice - .get("delta") - .and_then(|d| d.get("content")) - .and_then(|c| c.as_str()) - .unwrap_or("") - .to_string(); - - // 获取 reasoning_content delta(某些模型支持) - let reasoning_delta = choice - .get("delta") - .and_then(|d| d.get("reasoning_content")) - .and_then(|r| r.as_str()) - .map(|s| s.to_string()); - - if !content_delta.is_empty() || reasoning_delta.is_some() { - chunks.push(Ok(StreamChunk { - content_delta, - reasoning_content_delta: reasoning_delta, - is_done: false, - })); - } - } - } - } - Err(e) => { - tracing::debug!( - provider = %provider_name, - model = %model_id, - error = %e, - data = %data, - "Failed to parse SSE data" - ); - } - } - } - } - - futures_util::stream::iter(chunks) - } - Err(e) => { - futures_util::stream::iter(vec![Err(Box::new(e) as Box)]) - } - } - }); - - Ok(Box::pin(stream)) - } - fn ptype(&self) -> &str { "openai" } diff --git a/src/providers/traits.rs b/src/providers/traits.rs index 3a307b2..28b757c 100644 --- a/src/providers/traits.rs +++ b/src/providers/traits.rs @@ -1,70 +1,9 @@ use crate::domain::messages::{ContentBlock, ToolCall}; use crate::domain::tools::Tool; use async_trait::async_trait; -use futures_util::Stream; use serde::{Deserialize, Serialize}; use std::collections::HashMap; -/// 流式输出 chunk -#[derive(Debug, Clone)] -pub struct StreamChunk { - /// 内容增量(delta) - pub content_delta: String, - /// 推理内容增量 - pub reasoning_content_delta: Option, - /// 是否为完成信号 - pub is_done: bool, -} - -/// 流式响应收集器 -pub struct StreamingResponse { - /// 已累积的完整内容 - pub accumulated_content: String, - /// 已累积的推理内容 - pub accumulated_reasoning_content: Option, - /// 模型ID - pub model: String, - /// 响应ID - pub id: String, -} - -impl StreamingResponse { - pub fn new(model: String, id: String) -> Self { - Self { - accumulated_content: String::new(), - accumulated_reasoning_content: None, - model, - id, - } - } - - /// 添加 chunk 并累积内容 - pub fn add_chunk(&mut self, chunk: StreamChunk) { - self.accumulated_content.push_str(&chunk.content_delta); - if let Some(reasoning_delta) = chunk.reasoning_content_delta { - if self.accumulated_reasoning_content.is_none() { - self.accumulated_reasoning_content = Some(String::new()); - } - self.accumulated_reasoning_content - .as_mut() - .unwrap() - .push_str(&reasoning_delta); - } - } - - /// 获取当前累积的内容(用于错误时保存) - pub fn get_partial_content(&self) -> String { - if self.accumulated_content.is_empty() { - "[模型响应超时,无部分内容可用]".to_string() - } else { - format!( - "{}\n\n[警告:模型响应超时,以上内容可能不完整]", - self.accumulated_content - ) - } - } -} - #[derive(Debug, Clone)] pub struct ProviderRuntimeConfig { pub provider_type: String, @@ -187,16 +126,6 @@ pub trait LLMProvider: Send + Sync { request: ChatCompletionRequest, ) -> Result>; - /// 流式聊天完成,返回一个 pinned chunk 流 - /// 当连接断开或超时时,调用者可以使用累积的部分内容 - async fn chat_streaming( - &self, - request: ChatCompletionRequest, - ) -> Result< - std::pin::Pin>> + Send>>, - Box, - >; - fn ptype(&self) -> &str; fn name(&self) -> &str;