feat: 添加工具调用支持,优化消息处理和持久化

This commit is contained in:
ooodc 2026-04-18 14:17:23 +08:00
parent 8bb32fa066
commit ef601107ac
8 changed files with 278 additions and 49 deletions

View File

@ -211,6 +211,7 @@ fn chat_message_to_llm_message(m: &ChatMessage) -> Message {
content, content,
tool_call_id: m.tool_call_id.clone(), tool_call_id: m.tool_call_id.clone(),
name: m.tool_name.clone(), name: m.tool_name.clone(),
tool_calls: m.tool_calls.clone(),
} }
} }
@ -223,6 +224,12 @@ pub struct AgentLoop {
max_iterations: usize, max_iterations: usize,
} }
#[derive(Debug, Clone)]
pub struct AgentProcessResult {
pub final_response: ChatMessage,
pub emitted_messages: Vec<ChatMessage>,
}
impl AgentLoop { impl AgentLoop {
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> { pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
let max_iterations = provider_config.max_tool_iterations; let max_iterations = provider_config.max_tool_iterations;
@ -267,12 +274,13 @@ impl AgentLoop {
/// it loops back to the LLM with the tool results until either: /// it loops back to the LLM with the tool results until either:
/// - The LLM returns no more tool calls (final response) /// - The LLM returns no more tool calls (final response)
/// - Maximum iterations are reached /// - Maximum iterations are reached
pub async fn process(&self, mut messages: Vec<ChatMessage>) -> Result<ChatMessage, AgentError> { pub async fn process(&self, mut messages: Vec<ChatMessage>) -> Result<AgentProcessResult, AgentError> {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
tracing::debug!(history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process"); tracing::debug!(history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process");
// Track tool calls for loop detection // Track tool calls for loop detection
let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default()); let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default());
let mut emitted_messages = Vec::new();
for iteration in 0..self.max_iterations { for iteration in 0..self.max_iterations {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
@ -316,15 +324,23 @@ impl AgentLoop {
// If no tool calls, this is the final response // If no tool calls, this is the final response
if response.tool_calls.is_empty() { if response.tool_calls.is_empty() {
let assistant_message = ChatMessage::assistant(response.content); let assistant_message = ChatMessage::assistant(response.content);
return Ok(assistant_message); emitted_messages.push(assistant_message.clone());
return Ok(AgentProcessResult {
final_response: assistant_message,
emitted_messages,
});
} }
// Execute tool calls // Execute tool calls
tracing::info!(iteration, count = response.tool_calls.len(), "Tool calls detected, executing tools"); tracing::info!(iteration, count = response.tool_calls.len(), "Tool calls detected, executing tools");
// Add assistant message with tool calls // Add assistant message with tool calls
let assistant_message = ChatMessage::assistant(response.content.clone()); let assistant_message = ChatMessage::assistant_with_tool_calls(
response.content.clone(),
response.tool_calls.clone(),
);
messages.push(assistant_message.clone()); messages.push(assistant_message.clone());
emitted_messages.push(assistant_message);
// Execute tools and add results to messages // Execute tools and add results to messages
let tool_results = self.execute_tools(&response.tool_calls).await; let tool_results = self.execute_tools(&response.tool_calls).await;
@ -356,7 +372,8 @@ impl AgentLoop {
tool_call.name.clone(), tool_call.name.clone(),
format!("{}\n\n[上一条结果]\n{}", msg, truncated_output), format!("{}\n\n[上一条结果]\n{}", msg, truncated_output),
); );
messages.push(tool_message); messages.push(tool_message.clone());
emitted_messages.push(tool_message);
} }
LoopDetectionResult::Ok => { LoopDetectionResult::Ok => {
let tool_message = ChatMessage::tool( let tool_message = ChatMessage::tool(
@ -364,7 +381,8 @@ impl AgentLoop {
tool_call.name.clone(), tool_call.name.clone(),
truncated_output, truncated_output,
); );
messages.push(tool_message); messages.push(tool_message.clone());
emitted_messages.push(tool_message);
} }
} }
} }
@ -400,7 +418,11 @@ impl AgentLoop {
match (*self.provider).chat(request).await { match (*self.provider).chat(request).await {
Ok(response) => { Ok(response) => {
let assistant_message = ChatMessage::assistant(response.content); let assistant_message = ChatMessage::assistant(response.content);
Ok(assistant_message) emitted_messages.push(assistant_message.clone());
Ok(AgentProcessResult {
final_response: assistant_message,
emitted_messages,
})
} }
Err(e) => { Err(e) => {
// Fallback if summary call fails // Fallback if summary call fails
@ -408,7 +430,11 @@ impl AgentLoop {
let final_message = ChatMessage::assistant( let final_message = ChatMessage::assistant(
format!("I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.", self.max_iterations) format!("I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.", self.max_iterations)
); );
Ok(final_message) emitted_messages.push(final_message.clone());
Ok(AgentProcessResult {
final_response: final_message,
emitted_messages,
})
} }
} }
} }
@ -593,6 +619,25 @@ mod tests {
// If there's only 1 tool, should return false regardless // If there's only 1 tool, should return false regardless
assert_eq!(calls.len() <= 1, true); assert_eq!(calls.len() <= 1, true);
} }
#[test]
fn test_chat_message_to_llm_message_preserves_assistant_tool_calls() {
let chat_message = ChatMessage::assistant_with_tool_calls(
"calling tool",
vec![ToolCall {
id: "call_1".to_string(),
name: "calculator".to_string(),
arguments: serde_json::json!({ "expression": "2+2" }),
}],
);
let provider_message = chat_message_to_llm_message(&chat_message);
assert_eq!(provider_message.role, "assistant");
assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1);
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].id, "call_1");
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].name, "calculator");
}
} }
#[derive(Debug)] #[derive(Debug)]

View File

@ -1,5 +1,5 @@
pub mod agent_loop; pub mod agent_loop;
pub mod context_compressor; pub mod context_compressor;
pub use agent_loop::{AgentLoop, AgentError}; pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult};
pub use context_compressor::ContextCompressor; pub use context_compressor::ContextCompressor;

View File

@ -1,6 +1,8 @@
use std::collections::HashMap; use std::collections::HashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::providers::ToolCall;
// ============================================================================ // ============================================================================
// ContentBlock - Multimodal content representation (OpenAI-style) // ContentBlock - Multimodal content representation (OpenAI-style)
// ============================================================================ // ============================================================================
@ -69,6 +71,8 @@ pub struct ChatMessage {
pub tool_call_id: Option<String>, pub tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub tool_name: Option<String>, pub tool_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
} }
impl ChatMessage { impl ChatMessage {
@ -81,6 +85,7 @@ impl ChatMessage {
timestamp: current_timestamp(), timestamp: current_timestamp(),
tool_call_id: None, tool_call_id: None,
tool_name: None, tool_name: None,
tool_calls: None,
} }
} }
@ -93,6 +98,7 @@ impl ChatMessage {
timestamp: current_timestamp(), timestamp: current_timestamp(),
tool_call_id: None, tool_call_id: None,
tool_name: None, tool_name: None,
tool_calls: None,
} }
} }
@ -105,6 +111,20 @@ impl ChatMessage {
timestamp: current_timestamp(), timestamp: current_timestamp(),
tool_call_id: None, tool_call_id: None,
tool_name: None, tool_name: None,
tool_calls: None,
}
}
pub fn assistant_with_tool_calls(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
role: "assistant".to_string(),
content: content.into(),
media_refs: Vec::new(),
timestamp: current_timestamp(),
tool_call_id: None,
tool_name: None,
tool_calls: Some(tool_calls),
} }
} }
@ -117,6 +137,7 @@ impl ChatMessage {
timestamp: current_timestamp(), timestamp: current_timestamp(),
tool_call_id: None, tool_call_id: None,
tool_name: None, tool_name: None,
tool_calls: None,
} }
} }
@ -129,6 +150,7 @@ impl ChatMessage {
timestamp: current_timestamp(), timestamp: current_timestamp(),
tool_call_id: Some(tool_call_id.into()), tool_call_id: Some(tool_call_id.into()),
tool_name: Some(tool_name.into()), tool_name: Some(tool_name.into()),
tool_calls: None,
} }
} }
} }

View File

@ -115,6 +115,16 @@ impl Session {
Ok(()) Ok(())
} }
pub fn append_persisted_messages<I>(&mut self, chat_id: &str, messages: I) -> Result<(), AgentError>
where
I: IntoIterator<Item = ChatMessage>,
{
for message in messages {
self.append_persisted_message(chat_id, message)?;
}
Ok(())
}
pub fn create_user_message(&self, content: &str, media_refs: Vec<String>) -> ChatMessage { pub fn create_user_message(&self, content: &str, media_refs: Vec<String>) -> ChatMessage {
if media_refs.is_empty() { if media_refs.is_empty() {
ChatMessage::user(content) ChatMessage::user(content)
@ -381,12 +391,12 @@ impl SessionManager {
// 创建 agent 并处理 // 创建 agent 并处理
let agent = session_guard.create_agent()?; let agent = session_guard.create_agent()?;
let response = agent.process(history).await?; let result = agent.process(history).await?;
// 添加助手响应到历史 // 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复
session_guard.append_persisted_message(chat_id, response.clone())?; session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?;
response result.final_response
}; };
#[cfg(debug_assertions)] #[cfg(debug_assertions)]

View File

@ -171,13 +171,13 @@ async fn handle_inbound(
let agent = session_guard.create_agent()?; let agent = session_guard.create_agent()?;
match agent.process(history).await { match agent.process(history).await {
Ok(response) => { Ok(result) => {
session_guard.append_persisted_message(&chat_id, response.clone())?; session_guard.append_persisted_messages(&chat_id, result.emitted_messages.clone())?;
let _ = session_guard let _ = session_guard
.send(WsOutbound::AssistantResponse { .send(WsOutbound::AssistantResponse {
id: response.id, id: result.final_response.id,
content: response.content, content: result.final_response.content,
role: response.role, role: result.final_response.role,
}) })
.await; .await;
} }

View File

@ -57,6 +57,54 @@ impl OpenAIProvider {
model_extra, model_extra,
} }
} }
fn build_request_body(&self, request: &ChatCompletionRequest) -> Value {
let mut body = json!({
"model": self.model_id,
"messages": request.messages.iter().map(|m| {
if m.role == "tool" {
json!({
"role": m.role,
"content": convert_content_blocks(&m.content),
"tool_call_id": m.tool_call_id,
"name": m.name,
})
} else if m.role == "assistant" && m.tool_calls.is_some() {
json!({
"role": m.role,
"content": convert_content_blocks(&m.content),
"tool_calls": m.tool_calls.as_ref().map(|calls| {
calls.iter().map(|call| json!({
"id": call.id,
"type": "function",
"function": {
"name": call.name,
"arguments": serde_json::to_string(&call.arguments).unwrap_or_else(|_| "null".to_string())
}
})).collect::<Vec<_>>()
})
})
} else {
json!({
"role": m.role,
"content": convert_content_blocks(&m.content)
})
}
}).collect::<Vec<_>>(),
"temperature": request.temperature.or(self.temperature).unwrap_or(0.7),
"max_tokens": request.max_tokens.or(self.max_tokens),
});
for (key, value) in &self.model_extra {
body[key] = value.clone();
}
if let Some(tools) = &request.tools {
body["tools"] = json!(tools);
}
body
}
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -116,35 +164,7 @@ impl LLMProvider for OpenAIProvider {
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
let url = format!("{}/chat/completions", self.base_url); let url = format!("{}/chat/completions", self.base_url);
let mut body = json!({ let body = self.build_request_body(&request);
"model": self.model_id,
"messages": request.messages.iter().map(|m| {
if m.role == "tool" {
json!({
"role": m.role,
"content": convert_content_blocks(&m.content),
"tool_call_id": m.tool_call_id,
"name": m.name,
})
} else {
json!({
"role": m.role,
"content": convert_content_blocks(&m.content)
})
}
}).collect::<Vec<_>>(),
"temperature": request.temperature.or(self.temperature).unwrap_or(0.7),
"max_tokens": request.max_tokens.or(self.max_tokens),
});
// Add model extra fields
for (key, value) in &self.model_extra {
body[key] = value.clone();
}
if let Some(tools) = &request.tools {
body["tools"] = json!(tools);
}
// Debug: Log LLM request summary (only in debug builds) // Debug: Log LLM request summary (only in debug builds)
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
@ -242,3 +262,50 @@ impl LLMProvider for OpenAIProvider {
&self.model_id &self.model_id
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use crate::providers::Message;
#[test]
fn test_build_request_body_includes_assistant_tool_calls() {
let provider = OpenAIProvider::new(
"test".to_string(),
"key".to_string(),
"https://example.com/v1".to_string(),
HashMap::new(),
"gpt-test".to_string(),
None,
None,
HashMap::new(),
);
let request = ChatCompletionRequest {
messages: vec![Message {
role: "assistant".to_string(),
content: vec![ContentBlock::text("calling tool")],
tool_call_id: None,
name: None,
tool_calls: Some(vec![ToolCall {
id: "call_1".to_string(),
name: "calculator".to_string(),
arguments: json!({"expression": "1+1"}),
}]),
}],
temperature: None,
max_tokens: None,
tools: None,
};
let body = provider.build_request_body(&request);
let messages = body["messages"].as_array().unwrap();
let tool_calls = messages[0]["tool_calls"].as_array().unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0]["id"], "call_1");
assert_eq!(tool_calls[0]["type"], "function");
assert_eq!(tool_calls[0]["function"]["name"], "calculator");
assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}");
}
}

View File

@ -10,6 +10,8 @@ pub struct Message {
pub tool_call_id: Option<String>, pub tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>, pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
} }
impl Message { impl Message {
@ -19,6 +21,7 @@ impl Message {
content: vec![ContentBlock::text(content)], content: vec![ContentBlock::text(content)],
tool_call_id: None, tool_call_id: None,
name: None, name: None,
tool_calls: None,
} }
} }
@ -28,6 +31,7 @@ impl Message {
content, content,
tool_call_id: None, tool_call_id: None,
name: None, name: None,
tool_calls: None,
} }
} }
@ -37,6 +41,7 @@ impl Message {
content: vec![ContentBlock::text(content)], content: vec![ContentBlock::text(content)],
tool_call_id: None, tool_call_id: None,
name: None, name: None,
tool_calls: None,
} }
} }
@ -46,6 +51,7 @@ impl Message {
content: vec![ContentBlock::text(content)], content: vec![ContentBlock::text(content)],
tool_call_id: None, tool_call_id: None,
name: None, name: None,
tool_calls: None,
} }
} }
@ -55,6 +61,7 @@ impl Message {
content: vec![ContentBlock::text(content)], content: vec![ContentBlock::text(content)],
tool_call_id: Some(tool_call_id.into()), tool_call_id: Some(tool_call_id.into()),
name: Some(tool_name.into()), name: Some(tool_name.into()),
tool_calls: None,
} }
} }
} }

View File

@ -85,6 +85,7 @@ impl SessionStore {
media_refs_json TEXT NOT NULL, media_refs_json TEXT NOT NULL,
tool_call_id TEXT, tool_call_id TEXT,
tool_name TEXT, tool_name TEXT,
tool_calls_json TEXT,
created_at INTEGER NOT NULL, created_at INTEGER NOT NULL,
FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE, FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE,
UNIQUE(session_id, seq) UNIQUE(session_id, seq)
@ -97,6 +98,10 @@ impl SessionStore {
", ",
)?; )?;
if !table_has_column(&conn, "messages", "tool_calls_json")? {
conn.execute("ALTER TABLE messages ADD COLUMN tool_calls_json TEXT", [])?;
}
Ok(Self { Ok(Self {
conn: Arc::new(Mutex::new(conn)), conn: Arc::new(Mutex::new(conn)),
}) })
@ -260,12 +265,13 @@ impl SessionStore {
)?; )?;
let media_refs_json = serde_json::to_string(&message.media_refs)?; let media_refs_json = serde_json::to_string(&message.media_refs)?;
let tool_calls_json = message.tool_calls.as_ref().map(serde_json::to_string).transpose()?;
tx.execute( tx.execute(
" "
INSERT INTO messages ( INSERT INTO messages (
id, session_id, seq, role, content, id, session_id, seq, role, content,
media_refs_json, tool_call_id, tool_name, created_at media_refs_json, tool_call_id, tool_name, tool_calls_json, created_at
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9) ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
", ",
params![ params![
message.id, message.id,
@ -276,6 +282,7 @@ impl SessionStore {
media_refs_json, media_refs_json,
message.tool_call_id, message.tool_call_id,
message.tool_name, message.tool_name,
tool_calls_json,
message.timestamp, message.timestamp,
], ],
)?; )?;
@ -301,7 +308,7 @@ impl SessionStore {
let conn = self.conn.lock().expect("session db mutex poisoned"); let conn = self.conn.lock().expect("session db mutex poisoned");
let mut stmt = conn.prepare( let mut stmt = conn.prepare(
" "
SELECT id, role, content, media_refs_json, created_at, tool_call_id, tool_name SELECT id, role, content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json
FROM messages FROM messages
WHERE session_id = ?1 WHERE session_id = ?1
ORDER BY seq ASC ORDER BY seq ASC
@ -318,6 +325,19 @@ impl SessionStore {
) )
})?; })?;
let tool_calls_json: Option<String> = row.get(7)?;
let tool_calls = tool_calls_json
.as_deref()
.map(serde_json::from_str)
.transpose()
.map_err(|err| {
rusqlite::Error::FromSqlConversionFailure(
7,
rusqlite::types::Type::Text,
Box::new(err),
)
})?;
Ok(ChatMessage { Ok(ChatMessage {
id: row.get(0)?, id: row.get(0)?,
role: row.get(1)?, role: row.get(1)?,
@ -326,6 +346,7 @@ impl SessionStore {
timestamp: row.get(4)?, timestamp: row.get(4)?,
tool_call_id: row.get(5)?, tool_call_id: row.get(5)?,
tool_name: row.get(6)?, tool_name: row.get(6)?,
tool_calls,
}) })
})?; })?;
@ -376,6 +397,7 @@ fn current_timestamp() -> i64 {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::providers::ToolCall;
#[test] #[test]
fn test_persistent_session_id_for_cli_and_channel() { fn test_persistent_session_id_for_cli_and_channel() {
@ -444,4 +466,60 @@ mod tests {
assert_eq!(first.chat_id, "chat-1"); assert_eq!(first.chat_id, "chat-1");
assert_eq!(second.channel_name, "feishu"); assert_eq!(second.channel_name, "feishu");
} }
#[test]
fn test_assistant_tool_calls_roundtrip() {
let store = SessionStore::in_memory().unwrap();
let session = store.create_cli_session(Some("tools")).unwrap();
let assistant = ChatMessage::assistant_with_tool_calls(
"calling tool",
vec![ToolCall {
id: "call_1".to_string(),
name: "calculator".to_string(),
arguments: serde_json::json!({ "expression": "3*7" }),
}],
);
store.append_message(&session.id, &assistant).unwrap();
let messages = store.load_messages(&session.id).unwrap();
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].role, "assistant");
assert_eq!(messages[0].tool_calls.as_ref().unwrap().len(), 1);
assert_eq!(messages[0].tool_calls.as_ref().unwrap()[0].id, "call_1");
assert_eq!(messages[0].tool_calls.as_ref().unwrap()[0].name, "calculator");
}
#[test]
fn test_tool_result_roundtrip() {
let store = SessionStore::in_memory().unwrap();
let session = store.create_cli_session(Some("tool-result")).unwrap();
let tool_message = ChatMessage::tool("call_9", "file_write", "saved to /tmp/output.txt");
store.append_message(&session.id, &tool_message).unwrap();
let messages = store.load_messages(&session.id).unwrap();
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].role, "tool");
assert_eq!(messages[0].content, "saved to /tmp/output.txt");
assert_eq!(messages[0].tool_call_id.as_deref(), Some("call_9"));
assert_eq!(messages[0].tool_name.as_deref(), Some("file_write"));
assert!(messages[0].tool_calls.is_none());
}
}
fn table_has_column(conn: &Connection, table: &str, column: &str) -> Result<bool, rusqlite::Error> {
let pragma = format!("PRAGMA table_info({})", table);
let mut stmt = conn.prepare(&pragma)?;
let mut rows = stmt.query([])?;
while let Some(row) = rows.next()? {
let name: String = row.get(1)?;
if name == column {
return Ok(true);
}
}
Ok(false)
} }