From ef601107ac53658bea8c23b77a0f1540ca535ac6 Mon Sep 17 00:00:00 2001 From: ooodc <549496103@qq.com> Date: Sat, 18 Apr 2026 14:17:23 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E5=B7=A5=E5=85=B7?= =?UTF-8?q?=E8=B0=83=E7=94=A8=E6=94=AF=E6=8C=81=EF=BC=8C=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E6=B6=88=E6=81=AF=E5=A4=84=E7=90=86=E5=92=8C=E6=8C=81=E4=B9=85?= =?UTF-8?q?=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/agent/agent_loop.rs | 59 ++++++++++++++++--- src/agent/mod.rs | 2 +- src/bus/message.rs | 22 +++++++ src/gateway/session.rs | 18 ++++-- src/gateway/ws.rs | 10 ++-- src/providers/openai.rs | 125 ++++++++++++++++++++++++++++++---------- src/providers/traits.rs | 7 +++ src/storage/mod.rs | 84 ++++++++++++++++++++++++++- 8 files changed, 278 insertions(+), 49 deletions(-) diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index 8207a71..9db10f9 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -211,6 +211,7 @@ fn chat_message_to_llm_message(m: &ChatMessage) -> Message { content, tool_call_id: m.tool_call_id.clone(), name: m.tool_name.clone(), + tool_calls: m.tool_calls.clone(), } } @@ -223,6 +224,12 @@ pub struct AgentLoop { max_iterations: usize, } +#[derive(Debug, Clone)] +pub struct AgentProcessResult { + pub final_response: ChatMessage, + pub emitted_messages: Vec, +} + impl AgentLoop { pub fn new(provider_config: LLMProviderConfig) -> Result { let max_iterations = provider_config.max_tool_iterations; @@ -267,12 +274,13 @@ impl AgentLoop { /// it loops back to the LLM with the tool results until either: /// - The LLM returns no more tool calls (final response) /// - Maximum iterations are reached - pub async fn process(&self, mut messages: Vec) -> Result { + pub async fn process(&self, mut messages: Vec) -> Result { #[cfg(debug_assertions)] tracing::debug!(history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process"); // Track tool calls for loop detection let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default()); + let mut emitted_messages = Vec::new(); for iteration in 0..self.max_iterations { #[cfg(debug_assertions)] @@ -316,15 +324,23 @@ impl AgentLoop { // If no tool calls, this is the final response if response.tool_calls.is_empty() { let assistant_message = ChatMessage::assistant(response.content); - return Ok(assistant_message); + emitted_messages.push(assistant_message.clone()); + return Ok(AgentProcessResult { + final_response: assistant_message, + emitted_messages, + }); } // Execute tool calls tracing::info!(iteration, count = response.tool_calls.len(), "Tool calls detected, executing tools"); // Add assistant message with tool calls - let assistant_message = ChatMessage::assistant(response.content.clone()); + let assistant_message = ChatMessage::assistant_with_tool_calls( + response.content.clone(), + response.tool_calls.clone(), + ); messages.push(assistant_message.clone()); + emitted_messages.push(assistant_message); // Execute tools and add results to messages let tool_results = self.execute_tools(&response.tool_calls).await; @@ -356,7 +372,8 @@ impl AgentLoop { tool_call.name.clone(), format!("{}\n\n[上一条结果]\n{}", msg, truncated_output), ); - messages.push(tool_message); + messages.push(tool_message.clone()); + emitted_messages.push(tool_message); } LoopDetectionResult::Ok => { let tool_message = ChatMessage::tool( @@ -364,7 +381,8 @@ impl AgentLoop { tool_call.name.clone(), truncated_output, ); - messages.push(tool_message); + messages.push(tool_message.clone()); + emitted_messages.push(tool_message); } } } @@ -400,7 +418,11 @@ impl AgentLoop { match (*self.provider).chat(request).await { Ok(response) => { let assistant_message = ChatMessage::assistant(response.content); - Ok(assistant_message) + emitted_messages.push(assistant_message.clone()); + Ok(AgentProcessResult { + final_response: assistant_message, + emitted_messages, + }) } Err(e) => { // Fallback if summary call fails @@ -408,7 +430,11 @@ impl AgentLoop { let final_message = ChatMessage::assistant( format!("I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.", self.max_iterations) ); - Ok(final_message) + emitted_messages.push(final_message.clone()); + Ok(AgentProcessResult { + final_response: final_message, + emitted_messages, + }) } } } @@ -593,6 +619,25 @@ mod tests { // If there's only 1 tool, should return false regardless assert_eq!(calls.len() <= 1, true); } + + #[test] + fn test_chat_message_to_llm_message_preserves_assistant_tool_calls() { + let chat_message = ChatMessage::assistant_with_tool_calls( + "calling tool", + vec![ToolCall { + id: "call_1".to_string(), + name: "calculator".to_string(), + arguments: serde_json::json!({ "expression": "2+2" }), + }], + ); + + let provider_message = chat_message_to_llm_message(&chat_message); + + assert_eq!(provider_message.role, "assistant"); + assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1); + assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].id, "call_1"); + assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].name, "calculator"); + } } #[derive(Debug)] diff --git a/src/agent/mod.rs b/src/agent/mod.rs index 3b2c508..4dd5762 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -1,5 +1,5 @@ pub mod agent_loop; pub mod context_compressor; -pub use agent_loop::{AgentLoop, AgentError}; +pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult}; pub use context_compressor::ContextCompressor; diff --git a/src/bus/message.rs b/src/bus/message.rs index 3a55053..1b2386c 100644 --- a/src/bus/message.rs +++ b/src/bus/message.rs @@ -1,6 +1,8 @@ use std::collections::HashMap; use serde::{Deserialize, Serialize}; +use crate::providers::ToolCall; + // ============================================================================ // ContentBlock - Multimodal content representation (OpenAI-style) // ============================================================================ @@ -69,6 +71,8 @@ pub struct ChatMessage { pub tool_call_id: Option, #[serde(skip_serializing_if = "Option::is_none")] pub tool_name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, } impl ChatMessage { @@ -81,6 +85,7 @@ impl ChatMessage { timestamp: current_timestamp(), tool_call_id: None, tool_name: None, + tool_calls: None, } } @@ -93,6 +98,7 @@ impl ChatMessage { timestamp: current_timestamp(), tool_call_id: None, tool_name: None, + tool_calls: None, } } @@ -105,6 +111,20 @@ impl ChatMessage { timestamp: current_timestamp(), tool_call_id: None, tool_name: None, + tool_calls: None, + } + } + + pub fn assistant_with_tool_calls(content: impl Into, tool_calls: Vec) -> Self { + Self { + id: uuid::Uuid::new_v4().to_string(), + role: "assistant".to_string(), + content: content.into(), + media_refs: Vec::new(), + timestamp: current_timestamp(), + tool_call_id: None, + tool_name: None, + tool_calls: Some(tool_calls), } } @@ -117,6 +137,7 @@ impl ChatMessage { timestamp: current_timestamp(), tool_call_id: None, tool_name: None, + tool_calls: None, } } @@ -129,6 +150,7 @@ impl ChatMessage { timestamp: current_timestamp(), tool_call_id: Some(tool_call_id.into()), tool_name: Some(tool_name.into()), + tool_calls: None, } } } diff --git a/src/gateway/session.rs b/src/gateway/session.rs index 6b9876a..c1b2a26 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -115,6 +115,16 @@ impl Session { Ok(()) } + pub fn append_persisted_messages(&mut self, chat_id: &str, messages: I) -> Result<(), AgentError> + where + I: IntoIterator, + { + for message in messages { + self.append_persisted_message(chat_id, message)?; + } + Ok(()) + } + pub fn create_user_message(&self, content: &str, media_refs: Vec) -> ChatMessage { if media_refs.is_empty() { ChatMessage::user(content) @@ -381,12 +391,12 @@ impl SessionManager { // 创建 agent 并处理 let agent = session_guard.create_agent()?; - let response = agent.process(history).await?; + let result = agent.process(history).await?; - // 添加助手响应到历史 - session_guard.append_persisted_message(chat_id, response.clone())?; + // 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复 + session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?; - response + result.final_response }; #[cfg(debug_assertions)] diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index 18ba69f..18c979d 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -171,13 +171,13 @@ async fn handle_inbound( let agent = session_guard.create_agent()?; match agent.process(history).await { - Ok(response) => { - session_guard.append_persisted_message(&chat_id, response.clone())?; + Ok(result) => { + session_guard.append_persisted_messages(&chat_id, result.emitted_messages.clone())?; let _ = session_guard .send(WsOutbound::AssistantResponse { - id: response.id, - content: response.content, - role: response.role, + id: result.final_response.id, + content: result.final_response.content, + role: result.final_response.role, }) .await; } diff --git a/src/providers/openai.rs b/src/providers/openai.rs index 073bd35..b55ed8e 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -57,6 +57,54 @@ impl OpenAIProvider { model_extra, } } + + fn build_request_body(&self, request: &ChatCompletionRequest) -> Value { + let mut body = json!({ + "model": self.model_id, + "messages": request.messages.iter().map(|m| { + if m.role == "tool" { + json!({ + "role": m.role, + "content": convert_content_blocks(&m.content), + "tool_call_id": m.tool_call_id, + "name": m.name, + }) + } else if m.role == "assistant" && m.tool_calls.is_some() { + json!({ + "role": m.role, + "content": convert_content_blocks(&m.content), + "tool_calls": m.tool_calls.as_ref().map(|calls| { + calls.iter().map(|call| json!({ + "id": call.id, + "type": "function", + "function": { + "name": call.name, + "arguments": serde_json::to_string(&call.arguments).unwrap_or_else(|_| "null".to_string()) + } + })).collect::>() + }) + }) + } else { + json!({ + "role": m.role, + "content": convert_content_blocks(&m.content) + }) + } + }).collect::>(), + "temperature": request.temperature.or(self.temperature).unwrap_or(0.7), + "max_tokens": request.max_tokens.or(self.max_tokens), + }); + + for (key, value) in &self.model_extra { + body[key] = value.clone(); + } + + if let Some(tools) = &request.tools { + body["tools"] = json!(tools); + } + + body + } } #[derive(Deserialize)] @@ -116,35 +164,7 @@ impl LLMProvider for OpenAIProvider { ) -> Result> { let url = format!("{}/chat/completions", self.base_url); - let mut body = json!({ - "model": self.model_id, - "messages": request.messages.iter().map(|m| { - if m.role == "tool" { - json!({ - "role": m.role, - "content": convert_content_blocks(&m.content), - "tool_call_id": m.tool_call_id, - "name": m.name, - }) - } else { - json!({ - "role": m.role, - "content": convert_content_blocks(&m.content) - }) - } - }).collect::>(), - "temperature": request.temperature.or(self.temperature).unwrap_or(0.7), - "max_tokens": request.max_tokens.or(self.max_tokens), - }); - - // Add model extra fields - for (key, value) in &self.model_extra { - body[key] = value.clone(); - } - - if let Some(tools) = &request.tools { - body["tools"] = json!(tools); - } + let body = self.build_request_body(&request); // Debug: Log LLM request summary (only in debug builds) #[cfg(debug_assertions)] @@ -242,3 +262,50 @@ impl LLMProvider for OpenAIProvider { &self.model_id } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::providers::Message; + + #[test] + fn test_build_request_body_includes_assistant_tool_calls() { + let provider = OpenAIProvider::new( + "test".to_string(), + "key".to_string(), + "https://example.com/v1".to_string(), + HashMap::new(), + "gpt-test".to_string(), + None, + None, + HashMap::new(), + ); + + let request = ChatCompletionRequest { + messages: vec![Message { + role: "assistant".to_string(), + content: vec![ContentBlock::text("calling tool")], + tool_call_id: None, + name: None, + tool_calls: Some(vec![ToolCall { + id: "call_1".to_string(), + name: "calculator".to_string(), + arguments: json!({"expression": "1+1"}), + }]), + }], + temperature: None, + max_tokens: None, + tools: None, + }; + + let body = provider.build_request_body(&request); + let messages = body["messages"].as_array().unwrap(); + let tool_calls = messages[0]["tool_calls"].as_array().unwrap(); + + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0]["id"], "call_1"); + assert_eq!(tool_calls[0]["type"], "function"); + assert_eq!(tool_calls[0]["function"]["name"], "calculator"); + assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}"); + } +} diff --git a/src/providers/traits.rs b/src/providers/traits.rs index 8e8520d..63a46ef 100644 --- a/src/providers/traits.rs +++ b/src/providers/traits.rs @@ -10,6 +10,8 @@ pub struct Message { pub tool_call_id: Option, #[serde(skip_serializing_if = "Option::is_none")] pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, } impl Message { @@ -19,6 +21,7 @@ impl Message { content: vec![ContentBlock::text(content)], tool_call_id: None, name: None, + tool_calls: None, } } @@ -28,6 +31,7 @@ impl Message { content, tool_call_id: None, name: None, + tool_calls: None, } } @@ -37,6 +41,7 @@ impl Message { content: vec![ContentBlock::text(content)], tool_call_id: None, name: None, + tool_calls: None, } } @@ -46,6 +51,7 @@ impl Message { content: vec![ContentBlock::text(content)], tool_call_id: None, name: None, + tool_calls: None, } } @@ -55,6 +61,7 @@ impl Message { content: vec![ContentBlock::text(content)], tool_call_id: Some(tool_call_id.into()), name: Some(tool_name.into()), + tool_calls: None, } } } diff --git a/src/storage/mod.rs b/src/storage/mod.rs index f341f80..f1d53c0 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -85,6 +85,7 @@ impl SessionStore { media_refs_json TEXT NOT NULL, tool_call_id TEXT, tool_name TEXT, + tool_calls_json TEXT, created_at INTEGER NOT NULL, FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE, UNIQUE(session_id, seq) @@ -97,6 +98,10 @@ impl SessionStore { ", )?; + if !table_has_column(&conn, "messages", "tool_calls_json")? { + conn.execute("ALTER TABLE messages ADD COLUMN tool_calls_json TEXT", [])?; + } + Ok(Self { conn: Arc::new(Mutex::new(conn)), }) @@ -260,12 +265,13 @@ impl SessionStore { )?; let media_refs_json = serde_json::to_string(&message.media_refs)?; + let tool_calls_json = message.tool_calls.as_ref().map(serde_json::to_string).transpose()?; tx.execute( " INSERT INTO messages ( id, session_id, seq, role, content, - media_refs_json, tool_call_id, tool_name, created_at - ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9) + media_refs_json, tool_call_id, tool_name, tool_calls_json, created_at + ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10) ", params![ message.id, @@ -276,6 +282,7 @@ impl SessionStore { media_refs_json, message.tool_call_id, message.tool_name, + tool_calls_json, message.timestamp, ], )?; @@ -301,7 +308,7 @@ impl SessionStore { let conn = self.conn.lock().expect("session db mutex poisoned"); let mut stmt = conn.prepare( " - SELECT id, role, content, media_refs_json, created_at, tool_call_id, tool_name + SELECT id, role, content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json FROM messages WHERE session_id = ?1 ORDER BY seq ASC @@ -318,6 +325,19 @@ impl SessionStore { ) })?; + let tool_calls_json: Option = row.get(7)?; + let tool_calls = tool_calls_json + .as_deref() + .map(serde_json::from_str) + .transpose() + .map_err(|err| { + rusqlite::Error::FromSqlConversionFailure( + 7, + rusqlite::types::Type::Text, + Box::new(err), + ) + })?; + Ok(ChatMessage { id: row.get(0)?, role: row.get(1)?, @@ -326,6 +346,7 @@ impl SessionStore { timestamp: row.get(4)?, tool_call_id: row.get(5)?, tool_name: row.get(6)?, + tool_calls, }) })?; @@ -376,6 +397,7 @@ fn current_timestamp() -> i64 { #[cfg(test)] mod tests { use super::*; + use crate::providers::ToolCall; #[test] fn test_persistent_session_id_for_cli_and_channel() { @@ -444,4 +466,60 @@ mod tests { assert_eq!(first.chat_id, "chat-1"); assert_eq!(second.channel_name, "feishu"); } + + #[test] + fn test_assistant_tool_calls_roundtrip() { + let store = SessionStore::in_memory().unwrap(); + let session = store.create_cli_session(Some("tools")).unwrap(); + + let assistant = ChatMessage::assistant_with_tool_calls( + "calling tool", + vec![ToolCall { + id: "call_1".to_string(), + name: "calculator".to_string(), + arguments: serde_json::json!({ "expression": "3*7" }), + }], + ); + + store.append_message(&session.id, &assistant).unwrap(); + + let messages = store.load_messages(&session.id).unwrap(); + assert_eq!(messages.len(), 1); + assert_eq!(messages[0].role, "assistant"); + assert_eq!(messages[0].tool_calls.as_ref().unwrap().len(), 1); + assert_eq!(messages[0].tool_calls.as_ref().unwrap()[0].id, "call_1"); + assert_eq!(messages[0].tool_calls.as_ref().unwrap()[0].name, "calculator"); + } + + #[test] + fn test_tool_result_roundtrip() { + let store = SessionStore::in_memory().unwrap(); + let session = store.create_cli_session(Some("tool-result")).unwrap(); + + let tool_message = ChatMessage::tool("call_9", "file_write", "saved to /tmp/output.txt"); + store.append_message(&session.id, &tool_message).unwrap(); + + let messages = store.load_messages(&session.id).unwrap(); + assert_eq!(messages.len(), 1); + assert_eq!(messages[0].role, "tool"); + assert_eq!(messages[0].content, "saved to /tmp/output.txt"); + assert_eq!(messages[0].tool_call_id.as_deref(), Some("call_9")); + assert_eq!(messages[0].tool_name.as_deref(), Some("file_write")); + assert!(messages[0].tool_calls.is_none()); + } +} + +fn table_has_column(conn: &Connection, table: &str, column: &str) -> Result { + let pragma = format!("PRAGMA table_info({})", table); + let mut stmt = conn.prepare(&pragma)?; + let mut rows = stmt.query([])?; + + while let Some(row) = rows.next()? { + let name: String = row.get(1)?; + if name == column { + return Ok(true); + } + } + + Ok(false) } \ No newline at end of file