PicoBot/src/providers/traits.rs
ooodc fc7df67474 feat(streaming): 支持流式文本增量与结束信号功能
- 新增 StreamDelta 和 StreamEnd 类型,支持流式数据增量传输
- 扩展 LLMProvider trait,添加带回调的 chat_with_streaming 接口
- 修改 OpenAI Provider 实现,支持流式聊天回调传输增量数据
- Agent 流处理改为异步消费增量消息并传递给前端
- 保证流式增量和最终消息使用相同消息 ID 以便前端替换
- 修改消息总线和协议层,支持携带和识别流式消息的消息 ID
- 客户端 CLI 通过增量输出实现交互式流式响应显示
- Web 前端接收流式增量消息,追加到对应消息,实现实时显示
- 各通道(飞书、微信)支持转发流式增量和结束消息
- 任务工具运行时添加消息 ID 支持,保持消息一致性
- 统一消息构造函数新增流式增量和结束信号的构建方法
2026-06-14 10:24:52 +08:00

177 lines
5.0 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use crate::domain::messages::{ContentBlock, ToolCall};
use crate::domain::tools::Tool;
use crate::config::LLMProviderConfig;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct ProviderRuntimeConfig {
pub provider_type: String,
pub name: String,
pub base_url: String,
pub api_key: String,
pub extra_headers: HashMap<String, String>,
pub llm_timeout_secs: u64,
pub model_id: String,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub model_extra: HashMap<String, serde_json::Value>,
}
impl From<LLMProviderConfig> for ProviderRuntimeConfig {
fn from(config: LLMProviderConfig) -> Self {
Self {
provider_type: config.provider_type,
name: config.name,
base_url: config.base_url,
api_key: config.api_key,
extra_headers: config.extra_headers,
llm_timeout_secs: config.llm_timeout_secs,
model_id: config.model_id,
temperature: config.temperature,
max_tokens: config.max_tokens,
model_extra: config.model_extra,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: Vec<ContentBlock>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
}
impl Message {
pub fn user(content: impl Into<String>) -> Self {
Self {
role: "user".to_string(),
content: vec![ContentBlock::text(content)],
reasoning_content: None,
tool_call_id: None,
name: None,
tool_calls: None,
}
}
pub fn user_with_blocks(content: Vec<ContentBlock>) -> Self {
Self {
role: "user".to_string(),
content,
reasoning_content: None,
tool_call_id: None,
name: None,
tool_calls: None,
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: "assistant".to_string(),
content: vec![ContentBlock::text(content)],
reasoning_content: None,
tool_call_id: None,
name: None,
tool_calls: None,
}
}
pub fn system(content: impl Into<String>) -> Self {
Self {
role: "system".to_string(),
content: vec![ContentBlock::text(content)],
reasoning_content: None,
tool_call_id: None,
name: None,
tool_calls: None,
}
}
pub fn tool(
tool_call_id: impl Into<String>,
tool_name: impl Into<String>,
content: impl Into<String>,
) -> Self {
Self {
role: "tool".to_string(),
content: vec![ContentBlock::text(content)],
reasoning_content: None,
tool_call_id: Some(tool_call_id.into()),
name: Some(tool_name.into()),
tool_calls: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionRequest {
pub messages: Vec<Message>,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub tools: Option<Vec<Tool>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub model: String,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>,
pub tool_calls: Vec<ToolCall>,
pub usage: Usage,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
/// 流式响应中的增量事件
#[derive(Debug, Clone)]
pub struct StreamDelta {
/// 文本内容增量
pub content: String,
/// 推理内容增量
pub reasoning_content: Option<String>,
}
/// 流式回调类型:每收到一个 delta 就调用一次
pub type StreamCallback = Arc<dyn Fn(StreamDelta) + Send + Sync>;
#[async_trait]
pub trait LLMProvider: Send + Sync {
async fn chat(
&self,
request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>>;
/// 带流式回调的 chat每收到一个 SSE delta 就调用 callback。
/// 返回值与 `chat()` 相同(完整的 ChatCompletionResponse
/// 默认实现忽略 callback直接调用 chat()。
async fn chat_with_streaming(
&self,
request: ChatCompletionRequest,
_callback: StreamCallback,
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
self.chat(request).await
}
fn ptype(&self) -> &str;
fn name(&self) -> &str;
fn model_id(&self) -> &str;
}