diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index f4f8877..a6da683 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -267,6 +267,7 @@ fn chat_message_to_llm_message(m: &ChatMessage) -> Message { Message { role: m.role.clone(), content, + reasoning_content: m.reasoning_content.clone(), tool_call_id: m.tool_call_id.clone(), name: m.tool_name.clone(), tool_calls: m.tool_calls.clone(), @@ -452,7 +453,11 @@ impl AgentLoop { // If no tool calls, this is the final response if response.tool_calls.is_empty() { - let assistant_message = ChatMessage::assistant(response.content); + let assistant_message = if let Some(reasoning_content) = response.reasoning_content { + ChatMessage::assistant_with_reasoning(response.content, reasoning_content) + } else { + ChatMessage::assistant(response.content) + }; emitted_messages.push(assistant_message.clone()); return Ok(AgentProcessResult { final_response: assistant_message, @@ -464,10 +469,18 @@ impl AgentLoop { tracing::info!(iteration, count = response.tool_calls.len(), "Tool calls detected, executing tools"); // Add assistant message with tool calls - let assistant_message = ChatMessage::assistant_with_tool_calls( - response.content.clone(), - response.tool_calls.clone(), - ); + let assistant_message = if let Some(reasoning_content) = response.reasoning_content.clone() { + ChatMessage::assistant_with_tool_calls_and_reasoning( + response.content.clone(), + response.tool_calls.clone(), + reasoning_content, + ) + } else { + ChatMessage::assistant_with_tool_calls( + response.content.clone(), + response.tool_calls.clone(), + ) + }; messages.push(assistant_message.clone()); emitted_messages.push(assistant_message); self.emit_live_tool_call_message(emitted_messages.last().expect("assistant message just pushed").clone()).await; @@ -584,7 +597,11 @@ impl AgentLoop { match (*self.provider).chat(request).await { Ok(response) => { - let assistant_message = ChatMessage::assistant(response.content); + let assistant_message = if let Some(reasoning_content) = response.reasoning_content { + ChatMessage::assistant_with_reasoning(response.content, reasoning_content) + } else { + ChatMessage::assistant(response.content) + }; emitted_messages.push(assistant_message.clone()); Ok(AgentProcessResult { final_response: assistant_message, @@ -879,6 +896,19 @@ mod tests { assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].name, "calculator"); } + #[test] + fn test_chat_message_to_llm_message_preserves_reasoning_content() { + let chat_message = ChatMessage::assistant_with_reasoning( + "final answer", + "hidden chain of thought", + ); + + let provider_message = chat_message_to_llm_message(&chat_message); + + assert_eq!(provider_message.role, "assistant"); + assert_eq!(provider_message.reasoning_content.as_deref(), Some("hidden chain of thought")); + } + #[test] fn test_memory_prompt_requires_proactive_memory_search() { assert!(MEMORY_TOOL_USAGE_SYSTEM_PROMPT.contains("在绝大多数请求开始时")); diff --git a/src/bus/message.rs b/src/bus/message.rs index 2c8a373..a4210f4 100644 --- a/src/bus/message.rs +++ b/src/bus/message.rs @@ -75,6 +75,8 @@ pub struct ChatMessage { pub media_refs: Vec, // Paths to media files for context pub timestamp: i64, #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_content: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub tool_call_id: Option, #[serde(skip_serializing_if = "Option::is_none")] pub tool_name: Option, @@ -92,6 +94,7 @@ impl ChatMessage { content: content.into(), media_refs: Vec::new(), timestamp: current_timestamp(), + reasoning_content: None, tool_call_id: None, tool_name: None, tool_state: None, @@ -106,6 +109,7 @@ impl ChatMessage { content: content.into(), media_refs, timestamp: current_timestamp(), + reasoning_content: None, tool_call_id: None, tool_name: None, tool_state: None, @@ -120,6 +124,7 @@ impl ChatMessage { content: content.into(), media_refs: Vec::new(), timestamp: current_timestamp(), + reasoning_content: None, tool_call_id: None, tool_name: None, tool_state: None, @@ -127,6 +132,15 @@ impl ChatMessage { } } + pub fn assistant_with_reasoning( + content: impl Into, + reasoning_content: impl Into, + ) -> Self { + let mut message = Self::assistant(content); + message.reasoning_content = Some(reasoning_content.into()); + message + } + pub fn assistant_with_tool_calls(content: impl Into, tool_calls: Vec) -> Self { Self { id: uuid::Uuid::new_v4().to_string(), @@ -134,6 +148,7 @@ impl ChatMessage { content: content.into(), media_refs: Vec::new(), timestamp: current_timestamp(), + reasoning_content: None, tool_call_id: None, tool_name: None, tool_state: None, @@ -141,6 +156,16 @@ impl ChatMessage { } } + pub fn assistant_with_tool_calls_and_reasoning( + content: impl Into, + tool_calls: Vec, + reasoning_content: impl Into, + ) -> Self { + let mut message = Self::assistant_with_tool_calls(content, tool_calls); + message.reasoning_content = Some(reasoning_content.into()); + message + } + pub fn system(content: impl Into) -> Self { Self { id: uuid::Uuid::new_v4().to_string(), @@ -148,6 +173,7 @@ impl ChatMessage { content: content.into(), media_refs: Vec::new(), timestamp: current_timestamp(), + reasoning_content: None, tool_call_id: None, tool_name: None, tool_state: None, @@ -171,6 +197,7 @@ impl ChatMessage { content: content.into(), media_refs: Vec::new(), timestamp: current_timestamp(), + reasoning_content: None, tool_call_id: Some(tool_call_id.into()), tool_name: Some(tool_name.into()), tool_state: Some(tool_state), diff --git a/src/providers/anthropic.rs b/src/providers/anthropic.rs index bcbd50a..dafee77 100644 --- a/src/providers/anthropic.rs +++ b/src/providers/anthropic.rs @@ -243,6 +243,7 @@ impl LLMProvider for AnthropicProvider { id: anthropic_resp.id, model: anthropic_resp.model, content, + reasoning_content: None, tool_calls, usage: Usage { prompt_tokens: anthropic_resp.usage.input_tokens, diff --git a/src/providers/openai.rs b/src/providers/openai.rs index e13bd22..19ce4b3 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -79,7 +79,7 @@ impl OpenAIProvider { "name": m.name, }) } else if m.role == "assistant" && m.tool_calls.is_some() { - json!({ + let mut message = json!({ "role": m.role, "content": convert_content_blocks(&m.content), "tool_calls": m.tool_calls.as_ref().map(|calls| { @@ -92,12 +92,26 @@ impl OpenAIProvider { } })).collect::>() }) - }) + }); + + if let Some(reasoning_content) = &m.reasoning_content { + message["reasoning_content"] = Value::String(reasoning_content.clone()); + } + + message } else { - json!({ + let mut message = json!({ "role": m.role, "content": convert_content_blocks(&m.content) - }) + }); + + if m.role == "assistant" { + if let Some(reasoning_content) = &m.reasoning_content { + message["reasoning_content"] = Value::String(reasoning_content.clone()); + } + } + + message } }).collect::>(), "temperature": request.temperature.or(self.temperature).unwrap_or(0.7), @@ -135,6 +149,8 @@ struct OpenAIMessage { #[serde(default)] content: Option, #[serde(default)] + reasoning_content: Option, + #[serde(default)] name: Option, #[serde(default)] tool_calls: Vec, @@ -250,6 +266,7 @@ impl LLMProvider for OpenAIProvider { id: openai_resp.id, model: openai_resp.model, content, + reasoning_content: openai_resp.choices[0].message.reasoning_content.clone(), tool_calls, usage: Usage { prompt_tokens: openai_resp.usage.prompt_tokens, @@ -295,6 +312,7 @@ mod tests { messages: vec![Message { role: "assistant".to_string(), content: vec![ContentBlock::text("calling tool")], + reasoning_content: None, tool_call_id: None, name: None, tool_calls: Some(vec![ToolCall { @@ -318,4 +336,61 @@ mod tests { assert_eq!(tool_calls[0]["function"]["name"], "calculator"); assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}"); } + + #[test] + fn test_build_request_body_includes_assistant_reasoning_content() { + let provider = OpenAIProvider::new( + "test".to_string(), + "key".to_string(), + "https://example.com/v1".to_string(), + HashMap::new(), + 120, + "gpt-test".to_string(), + None, + None, + HashMap::new(), + ); + + let request = ChatCompletionRequest { + messages: vec![Message { + role: "assistant".to_string(), + content: vec![ContentBlock::text("final answer")], + reasoning_content: Some("step by step".to_string()), + tool_call_id: None, + name: None, + tool_calls: None, + }], + temperature: None, + max_tokens: None, + tools: None, + }; + + let body = provider.build_request_body(&request); + let messages = body["messages"].as_array().unwrap(); + + assert_eq!(messages[0]["reasoning_content"], "step by step"); + } + + #[test] + fn test_openai_response_parses_reasoning_content() { + let response: OpenAIResponse = serde_json::from_value(json!({ + "id": "resp_1", + "model": "gpt-test", + "choices": [{ + "message": { + "content": "final answer", + "reasoning_content": "hidden reasoning", + "tool_calls": [] + } + }], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15 + } + })) + .unwrap(); + + assert_eq!(response.choices[0].message.reasoning_content.as_deref(), Some("hidden reasoning")); + } } diff --git a/src/providers/traits.rs b/src/providers/traits.rs index 63a46ef..280911c 100644 --- a/src/providers/traits.rs +++ b/src/providers/traits.rs @@ -7,6 +7,8 @@ pub struct Message { pub role: String, pub content: Vec, #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_content: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub tool_call_id: Option, #[serde(skip_serializing_if = "Option::is_none")] pub name: Option, @@ -19,6 +21,7 @@ impl Message { Self { role: "user".to_string(), content: vec![ContentBlock::text(content)], + reasoning_content: None, tool_call_id: None, name: None, tool_calls: None, @@ -29,6 +32,7 @@ impl Message { Self { role: "user".to_string(), content, + reasoning_content: None, tool_call_id: None, name: None, tool_calls: None, @@ -39,6 +43,7 @@ impl Message { Self { role: "assistant".to_string(), content: vec![ContentBlock::text(content)], + reasoning_content: None, tool_call_id: None, name: None, tool_calls: None, @@ -49,6 +54,7 @@ impl Message { Self { role: "system".to_string(), content: vec![ContentBlock::text(content)], + reasoning_content: None, tool_call_id: None, name: None, tool_calls: None, @@ -59,6 +65,7 @@ impl Message { Self { role: "tool".to_string(), content: vec![ContentBlock::text(content)], + reasoning_content: None, tool_call_id: Some(tool_call_id.into()), name: Some(tool_name.into()), tool_calls: None, @@ -100,6 +107,8 @@ pub struct ChatCompletionResponse { pub id: String, pub model: String, pub content: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_content: Option, pub tool_calls: Vec, pub usage: Usage, } diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 56776a0..0853980 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -238,6 +238,7 @@ impl SessionStore { seq INTEGER NOT NULL, role TEXT NOT NULL, content TEXT NOT NULL, + reasoning_content TEXT, media_refs_json TEXT NOT NULL, tool_call_id TEXT, tool_name TEXT, @@ -345,6 +346,7 @@ impl SessionStore { )?; ensure_sessions_schema(&conn)?; + ensure_messages_schema(&conn)?; ensure_scheduler_schema(&conn)?; Ok(Self { @@ -553,8 +555,8 @@ impl SessionStore { " INSERT INTO messages ( id, session_id, seq, role, content, - media_refs_json, tool_call_id, tool_name, tool_calls_json, created_at - ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10) + reasoning_content, media_refs_json, tool_call_id, tool_name, tool_calls_json, created_at + ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11) ", params![ message.id, @@ -562,6 +564,7 @@ impl SessionStore { seq, message.role, message.content, + message.reasoning_content, media_refs_json, message.tool_call_id, message.tool_name, @@ -1349,6 +1352,17 @@ fn ensure_sessions_schema(conn: &Connection) -> Result<(), StorageError> { Ok(()) } +fn ensure_messages_schema(conn: &Connection) -> Result<(), StorageError> { + if !has_column(conn, "messages", "reasoning_content")? { + conn.execute( + "ALTER TABLE messages ADD COLUMN reasoning_content TEXT", + [], + )?; + } + + Ok(()) +} + fn ensure_scheduler_schema(conn: &Connection) -> Result<(), StorageError> { if !has_column(conn, "scheduler_jobs", "schedule_json")? { conn.execute( @@ -1443,7 +1457,7 @@ fn load_messages_after( ) -> Result, StorageError> { let mut stmt = conn.prepare( " - SELECT id, role, content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json + SELECT id, role, content, reasoning_content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json FROM messages WHERE session_id = ?1 AND seq > ?2 ORDER BY seq ASC @@ -1451,7 +1465,7 @@ fn load_messages_after( )?; let rows = stmt.query_map(params![session_id, cutoff_seq], |row| { - let media_refs_json: String = row.get(3)?; + let media_refs_json: String = row.get(4)?; let media_refs: Vec = serde_json::from_str(&media_refs_json).map_err(|err| { rusqlite::Error::FromSqlConversionFailure( media_refs_json.len(), @@ -1460,14 +1474,14 @@ fn load_messages_after( ) })?; - let tool_calls_json: Option = row.get(7)?; + let tool_calls_json: Option = row.get(8)?; let tool_calls = tool_calls_json .as_deref() .map(serde_json::from_str) .transpose() .map_err(|err| { rusqlite::Error::FromSqlConversionFailure( - 7, + 8, rusqlite::types::Type::Text, Box::new(err), ) @@ -1477,10 +1491,11 @@ fn load_messages_after( id: row.get(0)?, role: row.get(1)?, content: row.get(2)?, + reasoning_content: row.get(3)?, media_refs, - timestamp: row.get(4)?, - tool_call_id: row.get(5)?, - tool_name: row.get(6)?, + timestamp: row.get(5)?, + tool_call_id: row.get(6)?, + tool_name: row.get(7)?, tool_state: None, tool_calls, }) @@ -1619,6 +1634,24 @@ mod tests { assert_eq!(messages[0].tool_calls.as_ref().unwrap()[0].name, "calculator"); } + #[test] + fn test_assistant_reasoning_content_roundtrip() { + let store = SessionStore::in_memory().unwrap(); + let session = store.create_cli_session(Some("reasoning")).unwrap(); + + let assistant = ChatMessage::assistant_with_reasoning( + "final answer", + "hidden reasoning", + ); + + store.append_message(&session.id, &assistant).unwrap(); + + let messages = store.load_messages(&session.id).unwrap(); + assert_eq!(messages.len(), 1); + assert_eq!(messages[0].content, "final answer"); + assert_eq!(messages[0].reasoning_content.as_deref(), Some("hidden reasoning")); + } + #[test] fn test_reset_session_preserves_full_history_and_hides_active_history() { let store = SessionStore::in_memory().unwrap(); @@ -1694,6 +1727,49 @@ mod tests { assert_eq!(session.agent_prompt_reinjection_count, 0); } + #[test] + fn test_schema_migration_adds_reasoning_content_column_to_messages() { + let conn = Connection::open_in_memory().unwrap(); + conn.execute_batch( + " + CREATE TABLE sessions ( + id TEXT PRIMARY KEY, + title TEXT NOT NULL, + channel_name TEXT NOT NULL, + chat_id TEXT NOT NULL, + summary TEXT, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL, + last_active_at INTEGER NOT NULL, + archived_at INTEGER, + deleted_at INTEGER, + message_count INTEGER NOT NULL DEFAULT 0 + ); + + CREATE TABLE messages ( + id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + seq INTEGER NOT NULL, + role TEXT NOT NULL, + content TEXT NOT NULL, + media_refs_json TEXT NOT NULL, + tool_call_id TEXT, + tool_name TEXT, + tool_calls_json TEXT, + created_at INTEGER NOT NULL, + FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE, + UNIQUE(session_id, seq) + ); + ", + ) + .unwrap(); + + let _store = SessionStore::from_connection(conn).unwrap(); + let conn = _store.conn.lock().unwrap(); + + assert!(has_column(&conn, "messages", "reasoning_content").unwrap()); + } + #[test] fn test_count_active_user_messages_respects_reset_cutoff_seq() { let store = SessionStore::in_memory().unwrap();