- 新增 StreamDelta 和 StreamEnd 类型,支持流式数据增量传输 - 扩展 LLMProvider trait,添加带回调的 chat_with_streaming 接口 - 修改 OpenAI Provider 实现,支持流式聊天回调传输增量数据 - Agent 流处理改为异步消费增量消息并传递给前端 - 保证流式增量和最终消息使用相同消息 ID 以便前端替换 - 修改消息总线和协议层,支持携带和识别流式消息的消息 ID - 客户端 CLI 通过增量输出实现交互式流式响应显示 - Web 前端接收流式增量消息,追加到对应消息,实现实时显示 - 各通道(飞书、微信)支持转发流式增量和结束消息 - 任务工具运行时添加消息 ID 支持,保持消息一致性 - 统一消息构造函数新增流式增量和结束信号的构建方法
177 lines
5.0 KiB
Rust
177 lines
5.0 KiB
Rust
use crate::domain::messages::{ContentBlock, ToolCall};
|
||
use crate::domain::tools::Tool;
|
||
use crate::config::LLMProviderConfig;
|
||
use async_trait::async_trait;
|
||
use serde::{Deserialize, Serialize};
|
||
use std::collections::HashMap;
|
||
use std::sync::Arc;
|
||
|
||
#[derive(Debug, Clone)]
|
||
pub struct ProviderRuntimeConfig {
|
||
pub provider_type: String,
|
||
pub name: String,
|
||
pub base_url: String,
|
||
pub api_key: String,
|
||
pub extra_headers: HashMap<String, String>,
|
||
pub llm_timeout_secs: u64,
|
||
pub model_id: String,
|
||
pub temperature: Option<f32>,
|
||
pub max_tokens: Option<u32>,
|
||
pub model_extra: HashMap<String, serde_json::Value>,
|
||
}
|
||
|
||
impl From<LLMProviderConfig> for ProviderRuntimeConfig {
|
||
fn from(config: LLMProviderConfig) -> Self {
|
||
Self {
|
||
provider_type: config.provider_type,
|
||
name: config.name,
|
||
base_url: config.base_url,
|
||
api_key: config.api_key,
|
||
extra_headers: config.extra_headers,
|
||
llm_timeout_secs: config.llm_timeout_secs,
|
||
model_id: config.model_id,
|
||
temperature: config.temperature,
|
||
max_tokens: config.max_tokens,
|
||
model_extra: config.model_extra,
|
||
}
|
||
}
|
||
}
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct Message {
|
||
pub role: String,
|
||
pub content: Vec<ContentBlock>,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
pub reasoning_content: Option<String>,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
pub tool_call_id: Option<String>,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
pub name: Option<String>,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
pub tool_calls: Option<Vec<ToolCall>>,
|
||
}
|
||
|
||
impl Message {
|
||
pub fn user(content: impl Into<String>) -> Self {
|
||
Self {
|
||
role: "user".to_string(),
|
||
content: vec![ContentBlock::text(content)],
|
||
reasoning_content: None,
|
||
tool_call_id: None,
|
||
name: None,
|
||
tool_calls: None,
|
||
}
|
||
}
|
||
|
||
pub fn user_with_blocks(content: Vec<ContentBlock>) -> Self {
|
||
Self {
|
||
role: "user".to_string(),
|
||
content,
|
||
reasoning_content: None,
|
||
tool_call_id: None,
|
||
name: None,
|
||
tool_calls: None,
|
||
}
|
||
}
|
||
|
||
pub fn assistant(content: impl Into<String>) -> Self {
|
||
Self {
|
||
role: "assistant".to_string(),
|
||
content: vec![ContentBlock::text(content)],
|
||
reasoning_content: None,
|
||
tool_call_id: None,
|
||
name: None,
|
||
tool_calls: None,
|
||
}
|
||
}
|
||
|
||
pub fn system(content: impl Into<String>) -> Self {
|
||
Self {
|
||
role: "system".to_string(),
|
||
content: vec![ContentBlock::text(content)],
|
||
reasoning_content: None,
|
||
tool_call_id: None,
|
||
name: None,
|
||
tool_calls: None,
|
||
}
|
||
}
|
||
|
||
pub fn tool(
|
||
tool_call_id: impl Into<String>,
|
||
tool_name: impl Into<String>,
|
||
content: impl Into<String>,
|
||
) -> Self {
|
||
Self {
|
||
role: "tool".to_string(),
|
||
content: vec![ContentBlock::text(content)],
|
||
reasoning_content: None,
|
||
tool_call_id: Some(tool_call_id.into()),
|
||
name: Some(tool_name.into()),
|
||
tool_calls: None,
|
||
}
|
||
}
|
||
}
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct ChatCompletionRequest {
|
||
pub messages: Vec<Message>,
|
||
pub temperature: Option<f32>,
|
||
pub max_tokens: Option<u32>,
|
||
pub tools: Option<Vec<Tool>>,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct ChatCompletionResponse {
|
||
pub id: String,
|
||
pub model: String,
|
||
pub content: String,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
pub reasoning_content: Option<String>,
|
||
pub tool_calls: Vec<ToolCall>,
|
||
pub usage: Usage,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct Usage {
|
||
pub prompt_tokens: u32,
|
||
pub completion_tokens: u32,
|
||
pub total_tokens: u32,
|
||
}
|
||
|
||
/// 流式响应中的增量事件
|
||
#[derive(Debug, Clone)]
|
||
pub struct StreamDelta {
|
||
/// 文本内容增量
|
||
pub content: String,
|
||
/// 推理内容增量
|
||
pub reasoning_content: Option<String>,
|
||
}
|
||
|
||
/// 流式回调类型:每收到一个 delta 就调用一次
|
||
pub type StreamCallback = Arc<dyn Fn(StreamDelta) + Send + Sync>;
|
||
|
||
#[async_trait]
|
||
pub trait LLMProvider: Send + Sync {
|
||
async fn chat(
|
||
&self,
|
||
request: ChatCompletionRequest,
|
||
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>>;
|
||
|
||
/// 带流式回调的 chat:每收到一个 SSE delta 就调用 callback。
|
||
/// 返回值与 `chat()` 相同(完整的 ChatCompletionResponse)。
|
||
/// 默认实现忽略 callback,直接调用 chat()。
|
||
async fn chat_with_streaming(
|
||
&self,
|
||
request: ChatCompletionRequest,
|
||
_callback: StreamCallback,
|
||
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||
self.chat(request).await
|
||
}
|
||
|
||
fn ptype(&self) -> &str;
|
||
|
||
fn name(&self) -> &str;
|
||
|
||
fn model_id(&self) -> &str;
|
||
}
|