feat: 添加推理内容支持到聊天消息,增强消息处理能力
This commit is contained in:
parent
4b74fabb98
commit
95c53fa830
@ -267,6 +267,7 @@ fn chat_message_to_llm_message(m: &ChatMessage) -> Message {
|
||||
Message {
|
||||
role: m.role.clone(),
|
||||
content,
|
||||
reasoning_content: m.reasoning_content.clone(),
|
||||
tool_call_id: m.tool_call_id.clone(),
|
||||
name: m.tool_name.clone(),
|
||||
tool_calls: m.tool_calls.clone(),
|
||||
@ -452,7 +453,11 @@ impl AgentLoop {
|
||||
|
||||
// If no tool calls, this is the final response
|
||||
if response.tool_calls.is_empty() {
|
||||
let assistant_message = ChatMessage::assistant(response.content);
|
||||
let assistant_message = if let Some(reasoning_content) = response.reasoning_content {
|
||||
ChatMessage::assistant_with_reasoning(response.content, reasoning_content)
|
||||
} else {
|
||||
ChatMessage::assistant(response.content)
|
||||
};
|
||||
emitted_messages.push(assistant_message.clone());
|
||||
return Ok(AgentProcessResult {
|
||||
final_response: assistant_message,
|
||||
@ -464,10 +469,18 @@ impl AgentLoop {
|
||||
tracing::info!(iteration, count = response.tool_calls.len(), "Tool calls detected, executing tools");
|
||||
|
||||
// Add assistant message with tool calls
|
||||
let assistant_message = ChatMessage::assistant_with_tool_calls(
|
||||
let assistant_message = if let Some(reasoning_content) = response.reasoning_content.clone() {
|
||||
ChatMessage::assistant_with_tool_calls_and_reasoning(
|
||||
response.content.clone(),
|
||||
response.tool_calls.clone(),
|
||||
);
|
||||
reasoning_content,
|
||||
)
|
||||
} else {
|
||||
ChatMessage::assistant_with_tool_calls(
|
||||
response.content.clone(),
|
||||
response.tool_calls.clone(),
|
||||
)
|
||||
};
|
||||
messages.push(assistant_message.clone());
|
||||
emitted_messages.push(assistant_message);
|
||||
self.emit_live_tool_call_message(emitted_messages.last().expect("assistant message just pushed").clone()).await;
|
||||
@ -584,7 +597,11 @@ impl AgentLoop {
|
||||
|
||||
match (*self.provider).chat(request).await {
|
||||
Ok(response) => {
|
||||
let assistant_message = ChatMessage::assistant(response.content);
|
||||
let assistant_message = if let Some(reasoning_content) = response.reasoning_content {
|
||||
ChatMessage::assistant_with_reasoning(response.content, reasoning_content)
|
||||
} else {
|
||||
ChatMessage::assistant(response.content)
|
||||
};
|
||||
emitted_messages.push(assistant_message.clone());
|
||||
Ok(AgentProcessResult {
|
||||
final_response: assistant_message,
|
||||
@ -879,6 +896,19 @@ mod tests {
|
||||
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].name, "calculator");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_message_to_llm_message_preserves_reasoning_content() {
|
||||
let chat_message = ChatMessage::assistant_with_reasoning(
|
||||
"final answer",
|
||||
"hidden chain of thought",
|
||||
);
|
||||
|
||||
let provider_message = chat_message_to_llm_message(&chat_message);
|
||||
|
||||
assert_eq!(provider_message.role, "assistant");
|
||||
assert_eq!(provider_message.reasoning_content.as_deref(), Some("hidden chain of thought"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_prompt_requires_proactive_memory_search() {
|
||||
assert!(MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("在绝大多数请求开始时"));
|
||||
|
||||
@ -75,6 +75,8 @@ pub struct ChatMessage {
|
||||
pub media_refs: Vec<String>, // Paths to media files for context
|
||||
pub timestamp: i64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning_content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_name: Option<String>,
|
||||
@ -92,6 +94,7 @@ impl ChatMessage {
|
||||
content: content.into(),
|
||||
media_refs: Vec::new(),
|
||||
timestamp: current_timestamp(),
|
||||
reasoning_content: None,
|
||||
tool_call_id: None,
|
||||
tool_name: None,
|
||||
tool_state: None,
|
||||
@ -106,6 +109,7 @@ impl ChatMessage {
|
||||
content: content.into(),
|
||||
media_refs,
|
||||
timestamp: current_timestamp(),
|
||||
reasoning_content: None,
|
||||
tool_call_id: None,
|
||||
tool_name: None,
|
||||
tool_state: None,
|
||||
@ -120,6 +124,7 @@ impl ChatMessage {
|
||||
content: content.into(),
|
||||
media_refs: Vec::new(),
|
||||
timestamp: current_timestamp(),
|
||||
reasoning_content: None,
|
||||
tool_call_id: None,
|
||||
tool_name: None,
|
||||
tool_state: None,
|
||||
@ -127,6 +132,15 @@ impl ChatMessage {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn assistant_with_reasoning(
|
||||
content: impl Into<String>,
|
||||
reasoning_content: impl Into<String>,
|
||||
) -> Self {
|
||||
let mut message = Self::assistant(content);
|
||||
message.reasoning_content = Some(reasoning_content.into());
|
||||
message
|
||||
}
|
||||
|
||||
pub fn assistant_with_tool_calls(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
|
||||
Self {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
@ -134,6 +148,7 @@ impl ChatMessage {
|
||||
content: content.into(),
|
||||
media_refs: Vec::new(),
|
||||
timestamp: current_timestamp(),
|
||||
reasoning_content: None,
|
||||
tool_call_id: None,
|
||||
tool_name: None,
|
||||
tool_state: None,
|
||||
@ -141,6 +156,16 @@ impl ChatMessage {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn assistant_with_tool_calls_and_reasoning(
|
||||
content: impl Into<String>,
|
||||
tool_calls: Vec<ToolCall>,
|
||||
reasoning_content: impl Into<String>,
|
||||
) -> Self {
|
||||
let mut message = Self::assistant_with_tool_calls(content, tool_calls);
|
||||
message.reasoning_content = Some(reasoning_content.into());
|
||||
message
|
||||
}
|
||||
|
||||
pub fn system(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
@ -148,6 +173,7 @@ impl ChatMessage {
|
||||
content: content.into(),
|
||||
media_refs: Vec::new(),
|
||||
timestamp: current_timestamp(),
|
||||
reasoning_content: None,
|
||||
tool_call_id: None,
|
||||
tool_name: None,
|
||||
tool_state: None,
|
||||
@ -171,6 +197,7 @@ impl ChatMessage {
|
||||
content: content.into(),
|
||||
media_refs: Vec::new(),
|
||||
timestamp: current_timestamp(),
|
||||
reasoning_content: None,
|
||||
tool_call_id: Some(tool_call_id.into()),
|
||||
tool_name: Some(tool_name.into()),
|
||||
tool_state: Some(tool_state),
|
||||
|
||||
@ -243,6 +243,7 @@ impl LLMProvider for AnthropicProvider {
|
||||
id: anthropic_resp.id,
|
||||
model: anthropic_resp.model,
|
||||
content,
|
||||
reasoning_content: None,
|
||||
tool_calls,
|
||||
usage: Usage {
|
||||
prompt_tokens: anthropic_resp.usage.input_tokens,
|
||||
|
||||
@ -79,7 +79,7 @@ impl OpenAIProvider {
|
||||
"name": m.name,
|
||||
})
|
||||
} else if m.role == "assistant" && m.tool_calls.is_some() {
|
||||
json!({
|
||||
let mut message = json!({
|
||||
"role": m.role,
|
||||
"content": convert_content_blocks(&m.content),
|
||||
"tool_calls": m.tool_calls.as_ref().map(|calls| {
|
||||
@ -92,12 +92,26 @@ impl OpenAIProvider {
|
||||
}
|
||||
})).collect::<Vec<_>>()
|
||||
})
|
||||
})
|
||||
});
|
||||
|
||||
if let Some(reasoning_content) = &m.reasoning_content {
|
||||
message["reasoning_content"] = Value::String(reasoning_content.clone());
|
||||
}
|
||||
|
||||
message
|
||||
} else {
|
||||
json!({
|
||||
let mut message = json!({
|
||||
"role": m.role,
|
||||
"content": convert_content_blocks(&m.content)
|
||||
})
|
||||
});
|
||||
|
||||
if m.role == "assistant" {
|
||||
if let Some(reasoning_content) = &m.reasoning_content {
|
||||
message["reasoning_content"] = Value::String(reasoning_content.clone());
|
||||
}
|
||||
}
|
||||
|
||||
message
|
||||
}
|
||||
}).collect::<Vec<_>>(),
|
||||
"temperature": request.temperature.or(self.temperature).unwrap_or(0.7),
|
||||
@ -135,6 +149,8 @@ struct OpenAIMessage {
|
||||
#[serde(default)]
|
||||
content: Option<String>,
|
||||
#[serde(default)]
|
||||
reasoning_content: Option<String>,
|
||||
#[serde(default)]
|
||||
name: Option<String>,
|
||||
#[serde(default)]
|
||||
tool_calls: Vec<OpenAIToolCall>,
|
||||
@ -250,6 +266,7 @@ impl LLMProvider for OpenAIProvider {
|
||||
id: openai_resp.id,
|
||||
model: openai_resp.model,
|
||||
content,
|
||||
reasoning_content: openai_resp.choices[0].message.reasoning_content.clone(),
|
||||
tool_calls,
|
||||
usage: Usage {
|
||||
prompt_tokens: openai_resp.usage.prompt_tokens,
|
||||
@ -295,6 +312,7 @@ mod tests {
|
||||
messages: vec![Message {
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentBlock::text("calling tool")],
|
||||
reasoning_content: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
@ -318,4 +336,61 @@ mod tests {
|
||||
assert_eq!(tool_calls[0]["function"]["name"], "calculator");
|
||||
assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_request_body_includes_assistant_reasoning_content() {
|
||||
let provider = OpenAIProvider::new(
|
||||
"test".to_string(),
|
||||
"key".to_string(),
|
||||
"https://example.com/v1".to_string(),
|
||||
HashMap::new(),
|
||||
120,
|
||||
"gpt-test".to_string(),
|
||||
None,
|
||||
None,
|
||||
HashMap::new(),
|
||||
);
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
messages: vec![Message {
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentBlock::text("final answer")],
|
||||
reasoning_content: Some("step by step".to_string()),
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
}],
|
||||
temperature: None,
|
||||
max_tokens: None,
|
||||
tools: None,
|
||||
};
|
||||
|
||||
let body = provider.build_request_body(&request);
|
||||
let messages = body["messages"].as_array().unwrap();
|
||||
|
||||
assert_eq!(messages[0]["reasoning_content"], "step by step");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_response_parses_reasoning_content() {
|
||||
let response: OpenAIResponse = serde_json::from_value(json!({
|
||||
"id": "resp_1",
|
||||
"model": "gpt-test",
|
||||
"choices": [{
|
||||
"message": {
|
||||
"content": "final answer",
|
||||
"reasoning_content": "hidden reasoning",
|
||||
"tool_calls": []
|
||||
}
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 15
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.choices[0].message.reasoning_content.as_deref(), Some("hidden reasoning"));
|
||||
}
|
||||
}
|
||||
|
||||
@ -7,6 +7,8 @@ pub struct Message {
|
||||
pub role: String,
|
||||
pub content: Vec<ContentBlock>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning_content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
@ -19,6 +21,7 @@ impl Message {
|
||||
Self {
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentBlock::text(content)],
|
||||
reasoning_content: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
@ -29,6 +32,7 @@ impl Message {
|
||||
Self {
|
||||
role: "user".to_string(),
|
||||
content,
|
||||
reasoning_content: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
@ -39,6 +43,7 @@ impl Message {
|
||||
Self {
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentBlock::text(content)],
|
||||
reasoning_content: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
@ -49,6 +54,7 @@ impl Message {
|
||||
Self {
|
||||
role: "system".to_string(),
|
||||
content: vec![ContentBlock::text(content)],
|
||||
reasoning_content: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
@ -59,6 +65,7 @@ impl Message {
|
||||
Self {
|
||||
role: "tool".to_string(),
|
||||
content: vec![ContentBlock::text(content)],
|
||||
reasoning_content: None,
|
||||
tool_call_id: Some(tool_call_id.into()),
|
||||
name: Some(tool_name.into()),
|
||||
tool_calls: None,
|
||||
@ -100,6 +107,8 @@ pub struct ChatCompletionResponse {
|
||||
pub id: String,
|
||||
pub model: String,
|
||||
pub content: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning_content: Option<String>,
|
||||
pub tool_calls: Vec<ToolCall>,
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
@ -238,6 +238,7 @@ impl SessionStore {
|
||||
seq INTEGER NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
reasoning_content TEXT,
|
||||
media_refs_json TEXT NOT NULL,
|
||||
tool_call_id TEXT,
|
||||
tool_name TEXT,
|
||||
@ -345,6 +346,7 @@ impl SessionStore {
|
||||
)?;
|
||||
|
||||
ensure_sessions_schema(&conn)?;
|
||||
ensure_messages_schema(&conn)?;
|
||||
ensure_scheduler_schema(&conn)?;
|
||||
|
||||
Ok(Self {
|
||||
@ -553,8 +555,8 @@ impl SessionStore {
|
||||
"
|
||||
INSERT INTO messages (
|
||||
id, session_id, seq, role, content,
|
||||
media_refs_json, tool_call_id, tool_name, tool_calls_json, created_at
|
||||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
|
||||
reasoning_content, media_refs_json, tool_call_id, tool_name, tool_calls_json, created_at
|
||||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)
|
||||
",
|
||||
params![
|
||||
message.id,
|
||||
@ -562,6 +564,7 @@ impl SessionStore {
|
||||
seq,
|
||||
message.role,
|
||||
message.content,
|
||||
message.reasoning_content,
|
||||
media_refs_json,
|
||||
message.tool_call_id,
|
||||
message.tool_name,
|
||||
@ -1349,6 +1352,17 @@ fn ensure_sessions_schema(conn: &Connection) -> Result<(), StorageError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ensure_messages_schema(conn: &Connection) -> Result<(), StorageError> {
|
||||
if !has_column(conn, "messages", "reasoning_content")? {
|
||||
conn.execute(
|
||||
"ALTER TABLE messages ADD COLUMN reasoning_content TEXT",
|
||||
[],
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ensure_scheduler_schema(conn: &Connection) -> Result<(), StorageError> {
|
||||
if !has_column(conn, "scheduler_jobs", "schedule_json")? {
|
||||
conn.execute(
|
||||
@ -1443,7 +1457,7 @@ fn load_messages_after(
|
||||
) -> Result<Vec<ChatMessage>, StorageError> {
|
||||
let mut stmt = conn.prepare(
|
||||
"
|
||||
SELECT id, role, content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json
|
||||
SELECT id, role, content, reasoning_content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json
|
||||
FROM messages
|
||||
WHERE session_id = ?1 AND seq > ?2
|
||||
ORDER BY seq ASC
|
||||
@ -1451,7 +1465,7 @@ fn load_messages_after(
|
||||
)?;
|
||||
|
||||
let rows = stmt.query_map(params![session_id, cutoff_seq], |row| {
|
||||
let media_refs_json: String = row.get(3)?;
|
||||
let media_refs_json: String = row.get(4)?;
|
||||
let media_refs: Vec<String> = serde_json::from_str(&media_refs_json).map_err(|err| {
|
||||
rusqlite::Error::FromSqlConversionFailure(
|
||||
media_refs_json.len(),
|
||||
@ -1460,14 +1474,14 @@ fn load_messages_after(
|
||||
)
|
||||
})?;
|
||||
|
||||
let tool_calls_json: Option<String> = row.get(7)?;
|
||||
let tool_calls_json: Option<String> = row.get(8)?;
|
||||
let tool_calls = tool_calls_json
|
||||
.as_deref()
|
||||
.map(serde_json::from_str)
|
||||
.transpose()
|
||||
.map_err(|err| {
|
||||
rusqlite::Error::FromSqlConversionFailure(
|
||||
7,
|
||||
8,
|
||||
rusqlite::types::Type::Text,
|
||||
Box::new(err),
|
||||
)
|
||||
@ -1477,10 +1491,11 @@ fn load_messages_after(
|
||||
id: row.get(0)?,
|
||||
role: row.get(1)?,
|
||||
content: row.get(2)?,
|
||||
reasoning_content: row.get(3)?,
|
||||
media_refs,
|
||||
timestamp: row.get(4)?,
|
||||
tool_call_id: row.get(5)?,
|
||||
tool_name: row.get(6)?,
|
||||
timestamp: row.get(5)?,
|
||||
tool_call_id: row.get(6)?,
|
||||
tool_name: row.get(7)?,
|
||||
tool_state: None,
|
||||
tool_calls,
|
||||
})
|
||||
@ -1619,6 +1634,24 @@ mod tests {
|
||||
assert_eq!(messages[0].tool_calls.as_ref().unwrap()[0].name, "calculator");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_assistant_reasoning_content_roundtrip() {
|
||||
let store = SessionStore::in_memory().unwrap();
|
||||
let session = store.create_cli_session(Some("reasoning")).unwrap();
|
||||
|
||||
let assistant = ChatMessage::assistant_with_reasoning(
|
||||
"final answer",
|
||||
"hidden reasoning",
|
||||
);
|
||||
|
||||
store.append_message(&session.id, &assistant).unwrap();
|
||||
|
||||
let messages = store.load_messages(&session.id).unwrap();
|
||||
assert_eq!(messages.len(), 1);
|
||||
assert_eq!(messages[0].content, "final answer");
|
||||
assert_eq!(messages[0].reasoning_content.as_deref(), Some("hidden reasoning"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reset_session_preserves_full_history_and_hides_active_history() {
|
||||
let store = SessionStore::in_memory().unwrap();
|
||||
@ -1694,6 +1727,49 @@ mod tests {
|
||||
assert_eq!(session.agent_prompt_reinjection_count, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_schema_migration_adds_reasoning_content_column_to_messages() {
|
||||
let conn = Connection::open_in_memory().unwrap();
|
||||
conn.execute_batch(
|
||||
"
|
||||
CREATE TABLE sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
title TEXT NOT NULL,
|
||||
channel_name TEXT NOT NULL,
|
||||
chat_id TEXT NOT NULL,
|
||||
summary TEXT,
|
||||
created_at INTEGER NOT NULL,
|
||||
updated_at INTEGER NOT NULL,
|
||||
last_active_at INTEGER NOT NULL,
|
||||
archived_at INTEGER,
|
||||
deleted_at INTEGER,
|
||||
message_count INTEGER NOT NULL DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE TABLE messages (
|
||||
id TEXT PRIMARY KEY,
|
||||
session_id TEXT NOT NULL,
|
||||
seq INTEGER NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
media_refs_json TEXT NOT NULL,
|
||||
tool_call_id TEXT,
|
||||
tool_name TEXT,
|
||||
tool_calls_json TEXT,
|
||||
created_at INTEGER NOT NULL,
|
||||
FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE,
|
||||
UNIQUE(session_id, seq)
|
||||
);
|
||||
",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let _store = SessionStore::from_connection(conn).unwrap();
|
||||
let conn = _store.conn.lock().unwrap();
|
||||
|
||||
assert!(has_column(&conn, "messages", "reasoning_content").unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_count_active_user_messages_respects_reset_cutoff_seq() {
|
||||
let store = SessionStore::in_memory().unwrap();
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user