增加deepseek的支持。
This commit is contained in:
parent
f0879f8d13
commit
a77c026826
@ -383,6 +383,7 @@ impl AgentLoop {
|
||||
Message {
|
||||
role: m.role.clone(),
|
||||
content,
|
||||
reasoning_content: m.reasoning_content.clone(),
|
||||
tool_call_id: m.tool_call_id.clone(),
|
||||
name: m.tool_name.clone(),
|
||||
tool_calls: m.tool_calls.clone(),
|
||||
@ -475,7 +476,8 @@ impl AgentLoop {
|
||||
|
||||
// If no tool calls, this is the final response
|
||||
if response.tool_calls.is_empty() {
|
||||
let assistant_message = ChatMessage::assistant(response.content);
|
||||
let mut assistant_message = ChatMessage::assistant(response.content);
|
||||
assistant_message.reasoning_content = response.reasoning_content;
|
||||
emitted_messages.push(assistant_message.clone());
|
||||
return Ok(AgentProcessResult {
|
||||
final_response: assistant_message,
|
||||
@ -499,10 +501,11 @@ impl AgentLoop {
|
||||
}
|
||||
|
||||
// Add assistant message with tool calls
|
||||
let assistant_message = ChatMessage::assistant_with_tool_calls(
|
||||
let mut assistant_message = ChatMessage::assistant_with_tool_calls(
|
||||
response.content.clone(),
|
||||
response.tool_calls.clone(),
|
||||
);
|
||||
assistant_message.reasoning_content = response.reasoning_content;
|
||||
messages.push(assistant_message.clone());
|
||||
emitted_messages.push(assistant_message);
|
||||
|
||||
@ -581,7 +584,8 @@ impl AgentLoop {
|
||||
|
||||
match (*self.provider).chat(request).await {
|
||||
Ok(response) => {
|
||||
let assistant_message = ChatMessage::assistant(response.content);
|
||||
let mut assistant_message = ChatMessage::assistant(response.content);
|
||||
assistant_message.reasoning_content = response.reasoning_content;
|
||||
emitted_messages.push(assistant_message.clone());
|
||||
Ok(AgentProcessResult {
|
||||
final_response: assistant_message,
|
||||
@ -801,6 +805,7 @@ mod tests {
|
||||
let provider_message = Message {
|
||||
role: chat_message.role.clone(),
|
||||
content,
|
||||
reasoning_content: None,
|
||||
tool_call_id: chat_message.tool_call_id.clone(),
|
||||
name: chat_message.tool_name.clone(),
|
||||
tool_calls: chat_message.tool_calls.clone(),
|
||||
|
||||
@ -536,6 +536,7 @@ mod tests {
|
||||
id: "mock".into(),
|
||||
model: "mock".into(),
|
||||
content: "[summarized]".into(),
|
||||
reasoning_content: None,
|
||||
tool_calls: vec![],
|
||||
usage: Usage { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 },
|
||||
})
|
||||
|
||||
@ -82,6 +82,7 @@ pub struct ChatMessage {
|
||||
pub id: String,
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
pub reasoning_content: Option<String>,
|
||||
pub media_refs: Vec<MediaRef>,
|
||||
pub timestamp: i64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
@ -120,6 +121,7 @@ impl ChatMessage {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
role: "user".to_string(),
|
||||
content: content.into(),
|
||||
reasoning_content: None,
|
||||
media_refs: Vec::new(),
|
||||
timestamp: current_timestamp(),
|
||||
tool_call_id: None,
|
||||
@ -134,6 +136,7 @@ impl ChatMessage {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
role: "user".to_string(),
|
||||
content: content.into(),
|
||||
reasoning_content: None,
|
||||
media_refs,
|
||||
timestamp: current_timestamp(),
|
||||
tool_call_id: None,
|
||||
@ -148,6 +151,7 @@ impl ChatMessage {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
role: "assistant".to_string(),
|
||||
content: content.into(),
|
||||
reasoning_content: None,
|
||||
media_refs: Vec::new(),
|
||||
timestamp: current_timestamp(),
|
||||
tool_call_id: None,
|
||||
@ -162,6 +166,7 @@ impl ChatMessage {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
role: "assistant".to_string(),
|
||||
content: content.into(),
|
||||
reasoning_content: None,
|
||||
media_refs: Vec::new(),
|
||||
timestamp: current_timestamp(),
|
||||
tool_call_id: None,
|
||||
@ -176,6 +181,7 @@ impl ChatMessage {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
role: "assistant".to_string(),
|
||||
content: content.into(),
|
||||
reasoning_content: None,
|
||||
media_refs: Vec::new(),
|
||||
timestamp: current_timestamp(),
|
||||
tool_call_id: None,
|
||||
@ -190,6 +196,7 @@ impl ChatMessage {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
role: "system".to_string(),
|
||||
content: content.into(),
|
||||
reasoning_content: None,
|
||||
media_refs: Vec::new(),
|
||||
timestamp: current_timestamp(),
|
||||
tool_call_id: None,
|
||||
@ -204,6 +211,7 @@ impl ChatMessage {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
role: "tool".to_string(),
|
||||
content: content.into(),
|
||||
reasoning_content: None,
|
||||
media_refs: Vec::new(),
|
||||
timestamp: current_timestamp(),
|
||||
tool_call_id: Some(tool_call_id.into()),
|
||||
@ -218,6 +226,7 @@ impl ChatMessage {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
role: "user".to_string(),
|
||||
content: content.into(),
|
||||
reasoning_content: None,
|
||||
media_refs: Vec::new(),
|
||||
timestamp: current_timestamp(),
|
||||
tool_call_id: None,
|
||||
|
||||
@ -309,6 +309,7 @@ impl LLMProvider for AnthropicProvider {
|
||||
})?;
|
||||
|
||||
let mut content = String::new();
|
||||
let mut reasoning = None;
|
||||
let mut tool_calls = Vec::new();
|
||||
|
||||
for c in &anthropic_resp.content {
|
||||
@ -321,7 +322,9 @@ impl LLMProvider for AnthropicProvider {
|
||||
content.push_str(text);
|
||||
}
|
||||
}
|
||||
AnthropicContent::Thinking { .. } => {}
|
||||
AnthropicContent::Thinking { thinking } => {
|
||||
reasoning = Some(thinking.clone());
|
||||
}
|
||||
AnthropicContent::Unknown => {}
|
||||
AnthropicContent::ToolUse { id, name, input } => {
|
||||
tool_calls.push(ToolCall {
|
||||
@ -337,6 +340,7 @@ impl LLMProvider for AnthropicProvider {
|
||||
id: anthropic_resp.id.unwrap_or_default(),
|
||||
model: anthropic_resp.model.unwrap_or_default(),
|
||||
content,
|
||||
reasoning_content: reasoning,
|
||||
tool_calls,
|
||||
usage: Usage {
|
||||
prompt_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens).unwrap_or(0),
|
||||
|
||||
@ -83,7 +83,7 @@ impl OpenAIProvider {
|
||||
"name": m.name,
|
||||
})
|
||||
} else if m.role == "assistant" && m.tool_calls.as_ref().is_some_and(|c| !c.is_empty()) {
|
||||
json!({
|
||||
let mut msg = json!({
|
||||
"role": m.role,
|
||||
"content": convert_content_blocks(&m.content),
|
||||
"tool_calls": m.tool_calls.as_ref().map(|calls| {
|
||||
@ -96,12 +96,22 @@ impl OpenAIProvider {
|
||||
}
|
||||
})).collect::<Vec<_>>()
|
||||
})
|
||||
})
|
||||
});
|
||||
if let Some(ref rc) = m.reasoning_content {
|
||||
msg["reasoning_content"] = json!(rc);
|
||||
}
|
||||
msg
|
||||
} else {
|
||||
json!({
|
||||
let mut msg = json!({
|
||||
"role": m.role,
|
||||
"content": convert_content_blocks(&m.content)
|
||||
})
|
||||
});
|
||||
if m.role == "assistant" {
|
||||
if let Some(ref rc) = m.reasoning_content {
|
||||
msg["reasoning_content"] = json!(rc);
|
||||
}
|
||||
}
|
||||
msg
|
||||
}
|
||||
}).collect::<Vec<_>>(),
|
||||
"temperature": request.temperature.or(self.temperature).unwrap_or(0.7),
|
||||
@ -139,6 +149,8 @@ struct OpenAIMessage {
|
||||
#[serde(default)]
|
||||
content: Option<String>,
|
||||
#[serde(default)]
|
||||
reasoning_content: Option<String>,
|
||||
#[serde(default)]
|
||||
tool_calls: Vec<OpenAIToolCall>,
|
||||
}
|
||||
|
||||
@ -296,6 +308,7 @@ impl LLMProvider for OpenAIProvider {
|
||||
id: openai_resp.id,
|
||||
model: openai_resp.model,
|
||||
content,
|
||||
reasoning_content: first_choice.message.reasoning_content,
|
||||
tool_calls,
|
||||
usage: Usage {
|
||||
prompt_tokens: openai_resp.usage.prompt_tokens,
|
||||
@ -351,6 +364,7 @@ mod tests {
|
||||
messages: vec![Message {
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentBlock::text("calling tool")],
|
||||
reasoning_content: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
|
||||
@ -7,6 +7,8 @@ pub struct Message {
|
||||
pub role: String,
|
||||
pub content: Vec<ContentBlock>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning_content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
@ -19,6 +21,7 @@ impl Message {
|
||||
Self {
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentBlock::text(content)],
|
||||
reasoning_content: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
@ -29,6 +32,7 @@ impl Message {
|
||||
Self {
|
||||
role: "user".to_string(),
|
||||
content,
|
||||
reasoning_content: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
@ -39,6 +43,7 @@ impl Message {
|
||||
Self {
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentBlock::text(content)],
|
||||
reasoning_content: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
@ -49,6 +54,7 @@ impl Message {
|
||||
Self {
|
||||
role: "system".to_string(),
|
||||
content: vec![ContentBlock::text(content)],
|
||||
reasoning_content: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
@ -59,6 +65,7 @@ impl Message {
|
||||
Self {
|
||||
role: "tool".to_string(),
|
||||
content: vec![ContentBlock::text(content)],
|
||||
reasoning_content: None,
|
||||
tool_call_id: Some(tool_call_id.into()),
|
||||
name: Some(tool_name.into()),
|
||||
tool_calls: None,
|
||||
@ -100,6 +107,7 @@ pub struct ChatCompletionResponse {
|
||||
pub id: String,
|
||||
pub model: String,
|
||||
pub content: String,
|
||||
pub reasoning_content: Option<String>,
|
||||
pub tool_calls: Vec<ToolCall>,
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
@ -208,6 +208,7 @@ impl Session {
|
||||
id: m.id,
|
||||
role: m.role,
|
||||
content: m.content,
|
||||
reasoning_content: m.reasoning_content,
|
||||
media_refs: m.media_refs.map(|refs| serde_json::from_str(&refs).unwrap_or_default()).unwrap_or_default(),
|
||||
timestamp: m.created_at,
|
||||
tool_call_id: m.tool_call_id,
|
||||
@ -231,6 +232,7 @@ impl Session {
|
||||
id: m.id,
|
||||
role: m.role,
|
||||
content: m.content,
|
||||
reasoning_content: m.reasoning_content,
|
||||
media_refs: m.media_refs.map(|refs| serde_json::from_str(&refs).unwrap_or_default()).unwrap_or_default(),
|
||||
timestamp: m.created_at,
|
||||
tool_call_id: m.tool_call_id,
|
||||
@ -301,6 +303,7 @@ impl Session {
|
||||
seq,
|
||||
role: message.role.clone(),
|
||||
content: message.content.clone(),
|
||||
reasoning_content: message.reasoning_content.clone(),
|
||||
media_refs: if message.media_refs.is_empty() {
|
||||
None
|
||||
} else {
|
||||
|
||||
@ -7,6 +7,7 @@ pub struct MessageMeta {
|
||||
pub seq: i64,
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
pub reasoning_content: Option<String>,
|
||||
pub media_refs: Option<String>,
|
||||
pub tool_call_id: Option<String>,
|
||||
pub tool_name: Option<String>,
|
||||
|
||||
@ -97,6 +97,14 @@ impl Storage {
|
||||
.await
|
||||
.ok();
|
||||
|
||||
// Migration: add reasoning_content column if upgrading from older schema
|
||||
sqlx::query(
|
||||
r#"ALTER TABLE messages ADD COLUMN reasoning_content TEXT"#,
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.ok();
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS memories (
|
||||
@ -498,8 +506,8 @@ impl Storage {
|
||||
pub async fn append_message(&self, session_id: &str, msg: &crate::storage::message::MessageMeta) -> Result<i64, StorageError> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO messages (id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO messages (id, session_id, seq, role, content, reasoning_content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"#,
|
||||
)
|
||||
.bind(&msg.id)
|
||||
@ -507,6 +515,7 @@ impl Storage {
|
||||
.bind(msg.seq)
|
||||
.bind(&msg.role)
|
||||
.bind(&msg.content)
|
||||
.bind(&msg.reasoning_content)
|
||||
.bind(&msg.media_refs)
|
||||
.bind(&msg.tool_call_id)
|
||||
.bind(&msg.tool_name)
|
||||
@ -539,7 +548,7 @@ impl Storage {
|
||||
) -> Result<Vec<crate::storage::message::MessageMeta>, StorageError> {
|
||||
let rows = sqlx::query(
|
||||
r#"
|
||||
SELECT id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at
|
||||
SELECT id, session_id, seq, role, content, reasoning_content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at
|
||||
FROM messages
|
||||
WHERE session_id = ? AND seq >= ?
|
||||
ORDER BY seq ASC
|
||||
@ -558,6 +567,7 @@ impl Storage {
|
||||
seq: row.get("seq"),
|
||||
role: row.get("role"),
|
||||
content: row.get("content"),
|
||||
reasoning_content: row.get("reasoning_content"),
|
||||
media_refs: row.get("media_refs"),
|
||||
tool_call_id: row.get("tool_call_id"),
|
||||
tool_name: row.get("tool_name"),
|
||||
@ -585,7 +595,7 @@ impl Storage {
|
||||
) -> Result<Vec<crate::storage::message::MessageMeta>, StorageError> {
|
||||
let rows = sqlx::query(
|
||||
r#"
|
||||
SELECT id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at
|
||||
SELECT id, session_id, seq, role, content, reasoning_content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at
|
||||
FROM messages
|
||||
WHERE session_id = ? AND created_at > ?
|
||||
ORDER BY seq ASC
|
||||
@ -604,6 +614,7 @@ impl Storage {
|
||||
seq: row.get("seq"),
|
||||
role: row.get("role"),
|
||||
content: row.get("content"),
|
||||
reasoning_content: row.get("reasoning_content"),
|
||||
media_refs: row.get("media_refs"),
|
||||
tool_call_id: row.get("tool_call_id"),
|
||||
tool_name: row.get("tool_name"),
|
||||
@ -668,7 +679,7 @@ impl Storage {
|
||||
) -> Result<Vec<crate::storage::message::MessageMeta>, StorageError> {
|
||||
let rows = sqlx::query(
|
||||
r#"
|
||||
SELECT id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at
|
||||
SELECT id, session_id, seq, role, content, reasoning_content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at
|
||||
FROM messages
|
||||
WHERE session_id = ?
|
||||
ORDER BY seq DESC
|
||||
@ -688,6 +699,7 @@ impl Storage {
|
||||
seq: row.get("seq"),
|
||||
role: row.get("role"),
|
||||
content: row.get("content"),
|
||||
reasoning_content: row.get("reasoning_content"),
|
||||
media_refs: row.get("media_refs"),
|
||||
tool_call_id: row.get("tool_call_id"),
|
||||
tool_name: row.get("tool_name"),
|
||||
@ -719,7 +731,7 @@ impl Storage {
|
||||
let count_sql = format!("SELECT COUNT(*) as total FROM messages WHERE session_id = ?{}", where_extra);
|
||||
let select_sql = format!(
|
||||
r#"
|
||||
SELECT id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at
|
||||
SELECT id, session_id, seq, role, content, reasoning_content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at
|
||||
FROM messages
|
||||
WHERE session_id = ?{}
|
||||
ORDER BY seq ASC
|
||||
@ -759,6 +771,7 @@ impl Storage {
|
||||
seq: row.get("seq"),
|
||||
role: row.get("role"),
|
||||
content: row.get("content"),
|
||||
reasoning_content: row.get("reasoning_content"),
|
||||
media_refs: row.get("media_refs"),
|
||||
tool_call_id: row.get("tool_call_id"),
|
||||
tool_name: row.get("tool_name"),
|
||||
@ -933,6 +946,7 @@ mod tests {
|
||||
seq: 1,
|
||||
role: "user".to_string(),
|
||||
content: "你好".to_string(),
|
||||
reasoning_content: None,
|
||||
media_refs: None,
|
||||
tool_call_id: None,
|
||||
tool_name: None,
|
||||
|
||||
@ -348,6 +348,7 @@ mod tests {
|
||||
seq: i as i64 + 1,
|
||||
role: if i == 0 { "user".to_string() } else { "assistant".to_string() },
|
||||
content: format!("消息内容 {}", i),
|
||||
reasoning_content: None,
|
||||
media_refs: None,
|
||||
tool_call_id: None,
|
||||
tool_name: None,
|
||||
@ -404,6 +405,7 @@ mod tests {
|
||||
seq: i as i64 + 1,
|
||||
role: if i % 2 == 0 { "user".to_string() } else { "assistant".to_string() },
|
||||
content: format!("消息内容 {}", i),
|
||||
reasoning_content: None,
|
||||
media_refs: None,
|
||||
tool_call_id: None,
|
||||
tool_name: None,
|
||||
@ -458,6 +460,7 @@ mod tests {
|
||||
seq: i as i64 + 1,
|
||||
role: "user".to_string(),
|
||||
content: format!("消息内容 {}", i),
|
||||
reasoning_content: None,
|
||||
media_refs: None,
|
||||
tool_call_id: None,
|
||||
tool_name: None,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user