feat: 添加推理内容支持到聊天消息,增强消息处理能力

This commit is contained in:
ooodc 2026-04-24 17:42:19 +08:00
parent 4b74fabb98
commit 95c53fa830
6 changed files with 237 additions and 19 deletions

View File

@ -267,6 +267,7 @@ fn chat_message_to_llm_message(m: &ChatMessage) -> Message {
Message { Message {
role: m.role.clone(), role: m.role.clone(),
content, content,
reasoning_content: m.reasoning_content.clone(),
tool_call_id: m.tool_call_id.clone(), tool_call_id: m.tool_call_id.clone(),
name: m.tool_name.clone(), name: m.tool_name.clone(),
tool_calls: m.tool_calls.clone(), tool_calls: m.tool_calls.clone(),
@ -452,7 +453,11 @@ impl AgentLoop {
// If no tool calls, this is the final response // If no tool calls, this is the final response
if response.tool_calls.is_empty() { if response.tool_calls.is_empty() {
let assistant_message = ChatMessage::assistant(response.content); let assistant_message = if let Some(reasoning_content) = response.reasoning_content {
ChatMessage::assistant_with_reasoning(response.content, reasoning_content)
} else {
ChatMessage::assistant(response.content)
};
emitted_messages.push(assistant_message.clone()); emitted_messages.push(assistant_message.clone());
return Ok(AgentProcessResult { return Ok(AgentProcessResult {
final_response: assistant_message, final_response: assistant_message,
@ -464,10 +469,18 @@ impl AgentLoop {
tracing::info!(iteration, count = response.tool_calls.len(), "Tool calls detected, executing tools"); tracing::info!(iteration, count = response.tool_calls.len(), "Tool calls detected, executing tools");
// Add assistant message with tool calls // Add assistant message with tool calls
let assistant_message = ChatMessage::assistant_with_tool_calls( let assistant_message = if let Some(reasoning_content) = response.reasoning_content.clone() {
response.content.clone(), ChatMessage::assistant_with_tool_calls_and_reasoning(
response.tool_calls.clone(), response.content.clone(),
); response.tool_calls.clone(),
reasoning_content,
)
} else {
ChatMessage::assistant_with_tool_calls(
response.content.clone(),
response.tool_calls.clone(),
)
};
messages.push(assistant_message.clone()); messages.push(assistant_message.clone());
emitted_messages.push(assistant_message); emitted_messages.push(assistant_message);
self.emit_live_tool_call_message(emitted_messages.last().expect("assistant message just pushed").clone()).await; self.emit_live_tool_call_message(emitted_messages.last().expect("assistant message just pushed").clone()).await;
@ -584,7 +597,11 @@ impl AgentLoop {
match (*self.provider).chat(request).await { match (*self.provider).chat(request).await {
Ok(response) => { Ok(response) => {
let assistant_message = ChatMessage::assistant(response.content); let assistant_message = if let Some(reasoning_content) = response.reasoning_content {
ChatMessage::assistant_with_reasoning(response.content, reasoning_content)
} else {
ChatMessage::assistant(response.content)
};
emitted_messages.push(assistant_message.clone()); emitted_messages.push(assistant_message.clone());
Ok(AgentProcessResult { Ok(AgentProcessResult {
final_response: assistant_message, final_response: assistant_message,
@ -879,6 +896,19 @@ mod tests {
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].name, "calculator"); assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].name, "calculator");
} }
#[test]
fn test_chat_message_to_llm_message_preserves_reasoning_content() {
let chat_message = ChatMessage::assistant_with_reasoning(
"final answer",
"hidden chain of thought",
);
let provider_message = chat_message_to_llm_message(&chat_message);
assert_eq!(provider_message.role, "assistant");
assert_eq!(provider_message.reasoning_content.as_deref(), Some("hidden chain of thought"));
}
#[test] #[test]
fn test_memory_prompt_requires_proactive_memory_search() { fn test_memory_prompt_requires_proactive_memory_search() {
assert!(MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("在绝大多数请求开始时")); assert!(MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("在绝大多数请求开始时"));

View File

@ -75,6 +75,8 @@ pub struct ChatMessage {
pub media_refs: Vec<String>, // Paths to media files for context pub media_refs: Vec<String>, // Paths to media files for context
pub timestamp: i64, pub timestamp: i64,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>, pub tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub tool_name: Option<String>, pub tool_name: Option<String>,
@ -92,6 +94,7 @@ impl ChatMessage {
content: content.into(), content: content.into(),
media_refs: Vec::new(), media_refs: Vec::new(),
timestamp: current_timestamp(), timestamp: current_timestamp(),
reasoning_content: None,
tool_call_id: None, tool_call_id: None,
tool_name: None, tool_name: None,
tool_state: None, tool_state: None,
@ -106,6 +109,7 @@ impl ChatMessage {
content: content.into(), content: content.into(),
media_refs, media_refs,
timestamp: current_timestamp(), timestamp: current_timestamp(),
reasoning_content: None,
tool_call_id: None, tool_call_id: None,
tool_name: None, tool_name: None,
tool_state: None, tool_state: None,
@ -120,6 +124,7 @@ impl ChatMessage {
content: content.into(), content: content.into(),
media_refs: Vec::new(), media_refs: Vec::new(),
timestamp: current_timestamp(), timestamp: current_timestamp(),
reasoning_content: None,
tool_call_id: None, tool_call_id: None,
tool_name: None, tool_name: None,
tool_state: None, tool_state: None,
@ -127,6 +132,15 @@ impl ChatMessage {
} }
} }
pub fn assistant_with_reasoning(
content: impl Into<String>,
reasoning_content: impl Into<String>,
) -> Self {
let mut message = Self::assistant(content);
message.reasoning_content = Some(reasoning_content.into());
message
}
pub fn assistant_with_tool_calls(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self { pub fn assistant_with_tool_calls(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
Self { Self {
id: uuid::Uuid::new_v4().to_string(), id: uuid::Uuid::new_v4().to_string(),
@ -134,6 +148,7 @@ impl ChatMessage {
content: content.into(), content: content.into(),
media_refs: Vec::new(), media_refs: Vec::new(),
timestamp: current_timestamp(), timestamp: current_timestamp(),
reasoning_content: None,
tool_call_id: None, tool_call_id: None,
tool_name: None, tool_name: None,
tool_state: None, tool_state: None,
@ -141,6 +156,16 @@ impl ChatMessage {
} }
} }
pub fn assistant_with_tool_calls_and_reasoning(
content: impl Into<String>,
tool_calls: Vec<ToolCall>,
reasoning_content: impl Into<String>,
) -> Self {
let mut message = Self::assistant_with_tool_calls(content, tool_calls);
message.reasoning_content = Some(reasoning_content.into());
message
}
pub fn system(content: impl Into<String>) -> Self { pub fn system(content: impl Into<String>) -> Self {
Self { Self {
id: uuid::Uuid::new_v4().to_string(), id: uuid::Uuid::new_v4().to_string(),
@ -148,6 +173,7 @@ impl ChatMessage {
content: content.into(), content: content.into(),
media_refs: Vec::new(), media_refs: Vec::new(),
timestamp: current_timestamp(), timestamp: current_timestamp(),
reasoning_content: None,
tool_call_id: None, tool_call_id: None,
tool_name: None, tool_name: None,
tool_state: None, tool_state: None,
@ -171,6 +197,7 @@ impl ChatMessage {
content: content.into(), content: content.into(),
media_refs: Vec::new(), media_refs: Vec::new(),
timestamp: current_timestamp(), timestamp: current_timestamp(),
reasoning_content: None,
tool_call_id: Some(tool_call_id.into()), tool_call_id: Some(tool_call_id.into()),
tool_name: Some(tool_name.into()), tool_name: Some(tool_name.into()),
tool_state: Some(tool_state), tool_state: Some(tool_state),

View File

@ -243,6 +243,7 @@ impl LLMProvider for AnthropicProvider {
id: anthropic_resp.id, id: anthropic_resp.id,
model: anthropic_resp.model, model: anthropic_resp.model,
content, content,
reasoning_content: None,
tool_calls, tool_calls,
usage: Usage { usage: Usage {
prompt_tokens: anthropic_resp.usage.input_tokens, prompt_tokens: anthropic_resp.usage.input_tokens,

View File

@ -79,7 +79,7 @@ impl OpenAIProvider {
"name": m.name, "name": m.name,
}) })
} else if m.role == "assistant" && m.tool_calls.is_some() { } else if m.role == "assistant" && m.tool_calls.is_some() {
json!({ let mut message = json!({
"role": m.role, "role": m.role,
"content": convert_content_blocks(&m.content), "content": convert_content_blocks(&m.content),
"tool_calls": m.tool_calls.as_ref().map(|calls| { "tool_calls": m.tool_calls.as_ref().map(|calls| {
@ -92,12 +92,26 @@ impl OpenAIProvider {
} }
})).collect::<Vec<_>>() })).collect::<Vec<_>>()
}) })
}) });
if let Some(reasoning_content) = &m.reasoning_content {
message["reasoning_content"] = Value::String(reasoning_content.clone());
}
message
} else { } else {
json!({ let mut message = json!({
"role": m.role, "role": m.role,
"content": convert_content_blocks(&m.content) "content": convert_content_blocks(&m.content)
}) });
if m.role == "assistant" {
if let Some(reasoning_content) = &m.reasoning_content {
message["reasoning_content"] = Value::String(reasoning_content.clone());
}
}
message
} }
}).collect::<Vec<_>>(), }).collect::<Vec<_>>(),
"temperature": request.temperature.or(self.temperature).unwrap_or(0.7), "temperature": request.temperature.or(self.temperature).unwrap_or(0.7),
@ -135,6 +149,8 @@ struct OpenAIMessage {
#[serde(default)] #[serde(default)]
content: Option<String>, content: Option<String>,
#[serde(default)] #[serde(default)]
reasoning_content: Option<String>,
#[serde(default)]
name: Option<String>, name: Option<String>,
#[serde(default)] #[serde(default)]
tool_calls: Vec<OpenAIToolCall>, tool_calls: Vec<OpenAIToolCall>,
@ -250,6 +266,7 @@ impl LLMProvider for OpenAIProvider {
id: openai_resp.id, id: openai_resp.id,
model: openai_resp.model, model: openai_resp.model,
content, content,
reasoning_content: openai_resp.choices[0].message.reasoning_content.clone(),
tool_calls, tool_calls,
usage: Usage { usage: Usage {
prompt_tokens: openai_resp.usage.prompt_tokens, prompt_tokens: openai_resp.usage.prompt_tokens,
@ -295,6 +312,7 @@ mod tests {
messages: vec![Message { messages: vec![Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: vec![ContentBlock::text("calling tool")], content: vec![ContentBlock::text("calling tool")],
reasoning_content: None,
tool_call_id: None, tool_call_id: None,
name: None, name: None,
tool_calls: Some(vec![ToolCall { tool_calls: Some(vec![ToolCall {
@ -318,4 +336,61 @@ mod tests {
assert_eq!(tool_calls[0]["function"]["name"], "calculator"); assert_eq!(tool_calls[0]["function"]["name"], "calculator");
assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}"); assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}");
} }
#[test]
fn test_build_request_body_includes_assistant_reasoning_content() {
let provider = OpenAIProvider::new(
"test".to_string(),
"key".to_string(),
"https://example.com/v1".to_string(),
HashMap::new(),
120,
"gpt-test".to_string(),
None,
None,
HashMap::new(),
);
let request = ChatCompletionRequest {
messages: vec![Message {
role: "assistant".to_string(),
content: vec![ContentBlock::text("final answer")],
reasoning_content: Some("step by step".to_string()),
tool_call_id: None,
name: None,
tool_calls: None,
}],
temperature: None,
max_tokens: None,
tools: None,
};
let body = provider.build_request_body(&request);
let messages = body["messages"].as_array().unwrap();
assert_eq!(messages[0]["reasoning_content"], "step by step");
}
#[test]
fn test_openai_response_parses_reasoning_content() {
let response: OpenAIResponse = serde_json::from_value(json!({
"id": "resp_1",
"model": "gpt-test",
"choices": [{
"message": {
"content": "final answer",
"reasoning_content": "hidden reasoning",
"tool_calls": []
}
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15
}
}))
.unwrap();
assert_eq!(response.choices[0].message.reasoning_content.as_deref(), Some("hidden reasoning"));
}
} }

View File

@ -7,6 +7,8 @@ pub struct Message {
pub role: String, pub role: String,
pub content: Vec<ContentBlock>, pub content: Vec<ContentBlock>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>, pub tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>, pub name: Option<String>,
@ -19,6 +21,7 @@ impl Message {
Self { Self {
role: "user".to_string(), role: "user".to_string(),
content: vec![ContentBlock::text(content)], content: vec![ContentBlock::text(content)],
reasoning_content: None,
tool_call_id: None, tool_call_id: None,
name: None, name: None,
tool_calls: None, tool_calls: None,
@ -29,6 +32,7 @@ impl Message {
Self { Self {
role: "user".to_string(), role: "user".to_string(),
content, content,
reasoning_content: None,
tool_call_id: None, tool_call_id: None,
name: None, name: None,
tool_calls: None, tool_calls: None,
@ -39,6 +43,7 @@ impl Message {
Self { Self {
role: "assistant".to_string(), role: "assistant".to_string(),
content: vec![ContentBlock::text(content)], content: vec![ContentBlock::text(content)],
reasoning_content: None,
tool_call_id: None, tool_call_id: None,
name: None, name: None,
tool_calls: None, tool_calls: None,
@ -49,6 +54,7 @@ impl Message {
Self { Self {
role: "system".to_string(), role: "system".to_string(),
content: vec![ContentBlock::text(content)], content: vec![ContentBlock::text(content)],
reasoning_content: None,
tool_call_id: None, tool_call_id: None,
name: None, name: None,
tool_calls: None, tool_calls: None,
@ -59,6 +65,7 @@ impl Message {
Self { Self {
role: "tool".to_string(), role: "tool".to_string(),
content: vec![ContentBlock::text(content)], content: vec![ContentBlock::text(content)],
reasoning_content: None,
tool_call_id: Some(tool_call_id.into()), tool_call_id: Some(tool_call_id.into()),
name: Some(tool_name.into()), name: Some(tool_name.into()),
tool_calls: None, tool_calls: None,
@ -100,6 +107,8 @@ pub struct ChatCompletionResponse {
pub id: String, pub id: String,
pub model: String, pub model: String,
pub content: String, pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>,
pub tool_calls: Vec<ToolCall>, pub tool_calls: Vec<ToolCall>,
pub usage: Usage, pub usage: Usage,
} }

View File

@ -238,6 +238,7 @@ impl SessionStore {
seq INTEGER NOT NULL, seq INTEGER NOT NULL,
role TEXT NOT NULL, role TEXT NOT NULL,
content TEXT NOT NULL, content TEXT NOT NULL,
reasoning_content TEXT,
media_refs_json TEXT NOT NULL, media_refs_json TEXT NOT NULL,
tool_call_id TEXT, tool_call_id TEXT,
tool_name TEXT, tool_name TEXT,
@ -345,6 +346,7 @@ impl SessionStore {
)?; )?;
ensure_sessions_schema(&conn)?; ensure_sessions_schema(&conn)?;
ensure_messages_schema(&conn)?;
ensure_scheduler_schema(&conn)?; ensure_scheduler_schema(&conn)?;
Ok(Self { Ok(Self {
@ -553,8 +555,8 @@ impl SessionStore {
" "
INSERT INTO messages ( INSERT INTO messages (
id, session_id, seq, role, content, id, session_id, seq, role, content,
media_refs_json, tool_call_id, tool_name, tool_calls_json, created_at reasoning_content, media_refs_json, tool_call_id, tool_name, tool_calls_json, created_at
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10) ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)
", ",
params![ params![
message.id, message.id,
@ -562,6 +564,7 @@ impl SessionStore {
seq, seq,
message.role, message.role,
message.content, message.content,
message.reasoning_content,
media_refs_json, media_refs_json,
message.tool_call_id, message.tool_call_id,
message.tool_name, message.tool_name,
@ -1349,6 +1352,17 @@ fn ensure_sessions_schema(conn: &Connection) -> Result<(), StorageError> {
Ok(()) Ok(())
} }
fn ensure_messages_schema(conn: &Connection) -> Result<(), StorageError> {
if !has_column(conn, "messages", "reasoning_content")? {
conn.execute(
"ALTER TABLE messages ADD COLUMN reasoning_content TEXT",
[],
)?;
}
Ok(())
}
fn ensure_scheduler_schema(conn: &Connection) -> Result<(), StorageError> { fn ensure_scheduler_schema(conn: &Connection) -> Result<(), StorageError> {
if !has_column(conn, "scheduler_jobs", "schedule_json")? { if !has_column(conn, "scheduler_jobs", "schedule_json")? {
conn.execute( conn.execute(
@ -1443,7 +1457,7 @@ fn load_messages_after(
) -> Result<Vec<ChatMessage>, StorageError> { ) -> Result<Vec<ChatMessage>, StorageError> {
let mut stmt = conn.prepare( let mut stmt = conn.prepare(
" "
SELECT id, role, content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json SELECT id, role, content, reasoning_content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json
FROM messages FROM messages
WHERE session_id = ?1 AND seq > ?2 WHERE session_id = ?1 AND seq > ?2
ORDER BY seq ASC ORDER BY seq ASC
@ -1451,7 +1465,7 @@ fn load_messages_after(
)?; )?;
let rows = stmt.query_map(params![session_id, cutoff_seq], |row| { let rows = stmt.query_map(params![session_id, cutoff_seq], |row| {
let media_refs_json: String = row.get(3)?; let media_refs_json: String = row.get(4)?;
let media_refs: Vec<String> = serde_json::from_str(&media_refs_json).map_err(|err| { let media_refs: Vec<String> = serde_json::from_str(&media_refs_json).map_err(|err| {
rusqlite::Error::FromSqlConversionFailure( rusqlite::Error::FromSqlConversionFailure(
media_refs_json.len(), media_refs_json.len(),
@ -1460,14 +1474,14 @@ fn load_messages_after(
) )
})?; })?;
let tool_calls_json: Option<String> = row.get(7)?; let tool_calls_json: Option<String> = row.get(8)?;
let tool_calls = tool_calls_json let tool_calls = tool_calls_json
.as_deref() .as_deref()
.map(serde_json::from_str) .map(serde_json::from_str)
.transpose() .transpose()
.map_err(|err| { .map_err(|err| {
rusqlite::Error::FromSqlConversionFailure( rusqlite::Error::FromSqlConversionFailure(
7, 8,
rusqlite::types::Type::Text, rusqlite::types::Type::Text,
Box::new(err), Box::new(err),
) )
@ -1477,10 +1491,11 @@ fn load_messages_after(
id: row.get(0)?, id: row.get(0)?,
role: row.get(1)?, role: row.get(1)?,
content: row.get(2)?, content: row.get(2)?,
reasoning_content: row.get(3)?,
media_refs, media_refs,
timestamp: row.get(4)?, timestamp: row.get(5)?,
tool_call_id: row.get(5)?, tool_call_id: row.get(6)?,
tool_name: row.get(6)?, tool_name: row.get(7)?,
tool_state: None, tool_state: None,
tool_calls, tool_calls,
}) })
@ -1619,6 +1634,24 @@ mod tests {
assert_eq!(messages[0].tool_calls.as_ref().unwrap()[0].name, "calculator"); assert_eq!(messages[0].tool_calls.as_ref().unwrap()[0].name, "calculator");
} }
#[test]
fn test_assistant_reasoning_content_roundtrip() {
let store = SessionStore::in_memory().unwrap();
let session = store.create_cli_session(Some("reasoning")).unwrap();
let assistant = ChatMessage::assistant_with_reasoning(
"final answer",
"hidden reasoning",
);
store.append_message(&session.id, &assistant).unwrap();
let messages = store.load_messages(&session.id).unwrap();
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].content, "final answer");
assert_eq!(messages[0].reasoning_content.as_deref(), Some("hidden reasoning"));
}
#[test] #[test]
fn test_reset_session_preserves_full_history_and_hides_active_history() { fn test_reset_session_preserves_full_history_and_hides_active_history() {
let store = SessionStore::in_memory().unwrap(); let store = SessionStore::in_memory().unwrap();
@ -1694,6 +1727,49 @@ mod tests {
assert_eq!(session.agent_prompt_reinjection_count, 0); assert_eq!(session.agent_prompt_reinjection_count, 0);
} }
#[test]
fn test_schema_migration_adds_reasoning_content_column_to_messages() {
let conn = Connection::open_in_memory().unwrap();
conn.execute_batch(
"
CREATE TABLE sessions (
id TEXT PRIMARY KEY,
title TEXT NOT NULL,
channel_name TEXT NOT NULL,
chat_id TEXT NOT NULL,
summary TEXT,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL,
last_active_at INTEGER NOT NULL,
archived_at INTEGER,
deleted_at INTEGER,
message_count INTEGER NOT NULL DEFAULT 0
);
CREATE TABLE messages (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
seq INTEGER NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
media_refs_json TEXT NOT NULL,
tool_call_id TEXT,
tool_name TEXT,
tool_calls_json TEXT,
created_at INTEGER NOT NULL,
FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE,
UNIQUE(session_id, seq)
);
",
)
.unwrap();
let _store = SessionStore::from_connection(conn).unwrap();
let conn = _store.conn.lock().unwrap();
assert!(has_column(&conn, "messages", "reasoning_content").unwrap());
}
#[test] #[test]
fn test_count_active_user_messages_respects_reset_cutoff_seq() { fn test_count_active_user_messages_respects_reset_cutoff_seq() {
let store = SessionStore::in_memory().unwrap(); let store = SessionStore::in_memory().unwrap();